From 3340b8f182e82ce8aecada86d4a875dc453804ea Mon Sep 17 00:00:00 2001 From: shmsong Date: Mon, 22 Aug 2022 16:35:53 -0700 Subject: [PATCH 01/17] add base address field in tensor index --- torch/csrc/jit/codegen/cuda/kernel_ir.cpp | 6 ++++-- torch/csrc/jit/codegen/cuda/kernel_ir.h | 12 +++++++++++- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/kernel_ir.cpp b/torch/csrc/jit/codegen/cuda/kernel_ir.cpp index d020144cc7c063..2ea900a4ccd055 100644 --- a/torch/csrc/jit/codegen/cuda/kernel_ir.cpp +++ b/torch/csrc/jit/codegen/cuda/kernel_ir.cpp @@ -66,10 +66,12 @@ Predicate::Predicate(IrBuilderPasskey passkey, const Predicate* other) TensorIndex::TensorIndex( IrBuilderPasskey passkey, const TensorView* view, - std::vector indices) + std::vector indices, + Val* base_address) : Val(passkey, ValType::TensorIndex, view->getDataType().value()), view_(view), - indices_(indices) { + indices_(indices), + base_address_(base_address) { TORCH_INTERNAL_ASSERT( passkey.ir_container_->isA(), "IR type only valid for Kernel container."); diff --git a/torch/csrc/jit/codegen/cuda/kernel_ir.h b/torch/csrc/jit/codegen/cuda/kernel_ir.h index c1db32e0261933..e7f9460e28082f 100644 --- a/torch/csrc/jit/codegen/cuda/kernel_ir.h +++ b/torch/csrc/jit/codegen/cuda/kernel_ir.h @@ -152,7 +152,8 @@ class TORCH_CUDA_CU_API TensorIndex final : public Val { TensorIndex( IrBuilderPasskey, const TensorView* view, - std::vector indices); + std::vector indices, + Val* base_address = nullptr); std::vector::size_type nDims() const { return indices_.size(); @@ -169,9 +170,18 @@ class TORCH_CUDA_CU_API TensorIndex final : public Val { return const_cast(view_); // NOLINT } + bool hasBaseAddress() const { + return base_address_ != nullptr; + } + + Val* baseAddress() const { + return base_address_; + } + private: const TensorView* view_ = nullptr; std::vector indices_; + Val* base_address_ = nullptr; }; //! Allocate is a lower level Node that describes a buffer of memory that From 00a53b4e04ceffb0155be0e23cc9ff6ef7106fa0 Mon Sep 17 00:00:00 2001 From: shmsong Date: Mon, 22 Aug 2022 16:41:11 -0700 Subject: [PATCH 02/17] add pointer data type --- torch/csrc/jit/codegen/cuda/executor.cpp | 1 + torch/csrc/jit/codegen/cuda/type.cpp | 9 +++++++++ torch/csrc/jit/codegen/cuda/type.h | 1 + 3 files changed, 11 insertions(+) diff --git a/torch/csrc/jit/codegen/cuda/executor.cpp b/torch/csrc/jit/codegen/cuda/executor.cpp index d2299a0ce54974..6da07ba5fd3488 100644 --- a/torch/csrc/jit/codegen/cuda/executor.cpp +++ b/torch/csrc/jit/codegen/cuda/executor.cpp @@ -54,6 +54,7 @@ typedef int int32_t; typedef unsigned int uint32_t; typedef long long int int64_t; typedef unsigned long long int uint64_t; +typedef char* Pointer; )"; } diff --git a/torch/csrc/jit/codegen/cuda/type.cpp b/torch/csrc/jit/codegen/cuda/type.cpp index 243842dcb8e360..721809161c9245 100644 --- a/torch/csrc/jit/codegen/cuda/type.cpp +++ b/torch/csrc/jit/codegen/cuda/type.cpp @@ -28,6 +28,7 @@ bool isFloatingPointType(DataType dtype) { return true; case DataType::Bool: case DataType::Index: + case DataType::Pointer: case DataType::Int: case DataType::Int32: case DataType::ComplexFloat: @@ -73,6 +74,7 @@ bool isIntegralType(DataType dtype) { case DataType::ComplexDouble: return false; case DataType::Index: + case DataType::Pointer: case DataType::Int: case DataType::Int32: return true; @@ -95,6 +97,7 @@ bool isComplexType(DataType dtype) { case DataType::BFloat16: case DataType::Int: case DataType::Index: + case DataType::Pointer: case DataType::Int32: return false; case DataType::Null: @@ -223,6 +226,8 @@ static const char* data_type2string(DataType t) { return "int64_t"; case DataType::Index: return "nvfuser_index_t"; + case DataType::Pointer: + return "DataPointer"; case DataType::Int32: return "int"; case DataType::ComplexFloat: @@ -939,6 +944,7 @@ at::ScalarType data_type_to_aten(const DataType& data_type) { case DataType::Int: return at::ScalarType::Long; case DataType::Index: + case DataType::Pointer: TORCH_INTERNAL_ASSERT( false, "Index is determined at compile time,", @@ -1104,6 +1110,8 @@ std::string typePrefix(const DataType data_type) { case DataType::Int: case DataType::Int32: return "i"; + case DataType::Pointer: + return "p"; case DataType::ComplexFloat: case DataType::ComplexDouble: return "c"; @@ -1155,6 +1163,7 @@ size_t dataTypeSize(DataType type) { case DataType::BFloat16: return sizeof(at::BFloat16); case DataType::Index: + case DataType::Pointer: TORCH_INTERNAL_ASSERT( false, "The actual type of Index is only known at compile time."); case DataType::Int: diff --git a/torch/csrc/jit/codegen/cuda/type.h b/torch/csrc/jit/codegen/cuda/type.h index 1b2cf4f10c0491..63d9a46ffd898e 100644 --- a/torch/csrc/jit/codegen/cuda/type.h +++ b/torch/csrc/jit/codegen/cuda/type.h @@ -66,6 +66,7 @@ enum class DataType { Half, Int, Index, + Pointer, Int32, Bool, BFloat16, From b34958614f1cf7702c3665860ff09d3255660ce5 Mon Sep 17 00:00:00 2001 From: shmsong Date: Mon, 22 Aug 2022 20:05:41 -0700 Subject: [PATCH 03/17] codegen for base address option --- torch/csrc/jit/codegen/cuda/codegen.cpp | 50 +++++-- torch/csrc/jit/codegen/cuda/runtime/memory.cu | 123 +++++++++++++++++- 2 files changed, 159 insertions(+), 14 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/codegen.cpp b/torch/csrc/jit/codegen/cuda/codegen.cpp index 36836748f26661..53972d657023b1 100644 --- a/torch/csrc/jit/codegen/cuda/codegen.cpp +++ b/torch/csrc/jit/codegen/cuda/codegen.cpp @@ -504,6 +504,28 @@ class CudaKernelGenerator : private OptOutConstDispatch { return index.str(); } + std::string genTensorAddressIndex( + const kir::TensorIndex* ti, + DataType dtype) { + bool first = true; + std::stringstream index; + for (auto* ind : ti->indices()) { + if (!ind->isZeroInt()) { + if (!first) { + index << " + "; + } + index << "(" << genInline(ind) << ")*" << dataTypeSize(dtype); + first = false; + } + } + + if (first) { + index << "0"; + } + + return index.str(); + } + void handle(const kir::TensorIndex* ti) final { bool is_volatile = ti->view()->getMemoryType() == MemoryType::Global && kernel_->summary().sync_map.needsRawSync(ti->view()).hasBID(); @@ -545,20 +567,33 @@ class CudaKernelGenerator : private OptOutConstDispatch { return ss.str(); } + std::string genMaybeHoistedPointer(Val* val) { + auto ti = dynamic_cast(val); + TORCH_INTERNAL_ASSERT(ti != nullptr, "only support tensor index input"); + std::stringstream ss; + + if (ti->hasBaseAddress()) { + ss << genTensorAddressIndex(ti, ti->view()->dtype()) << "," + << gen(ti->baseAddress()); + } else { + ss << "&" << gen(ti) << "\n"; + } + } + // Utility function to emit a cp.async intrinsic void genCpAsync(const LoadStoreOp* ldst, int vec_size) { auto dtype = ldst->in()->getDataType().value(); if (ldst->predicate() == nullptr) { // Out of line predicate variant - indent() << "Ampere::cpAsync(" - << genVectorPointer(ldst->out(), dtype, vec_size) << "," - << genVectorPointer(ldst->in(), dtype, vec_size) << ");\n"; + indent() << "Ampere::cpAsync<" << dtype << "," << vec_size << ">(" + << genMaybeHoistedPointer(ldst->out()) << "," + << genMaybeHoistedPointer(ldst->in()) << ");\n"; } else { // Inline predicate variant - indent() << "Ampere::cpAsync(" - << genVectorPointer(ldst->out(), dtype, vec_size) << "," - << genVectorPointer(ldst->in(), dtype, vec_size) << "," + indent() << "Ampere::cpAsync<" << dtype << "," << vec_size << ">(" + << genMaybeHoistedPointer(ldst->out()) << "," + << genMaybeHoistedPointer(ldst->in()) << "," << genInline(ldst->predicate()) << ");\n"; } } @@ -579,8 +614,7 @@ class CudaKernelGenerator : private OptOutConstDispatch { } code_ << " ("; code_ << "*" << genVectorPointer(ldst->out(), dtype, vector_word_size) - << "," - << "&" << gen(ldst->in()) << ");\n"; + << "," << genMaybeHoistedPointer(ldst->in()) << ");\n"; genBankConflictCheck(ldst->in(), 16); } diff --git a/torch/csrc/jit/codegen/cuda/runtime/memory.cu b/torch/csrc/jit/codegen/cuda/runtime/memory.cu index e064a43090fd7e..e0def0eaa9af34 100644 --- a/torch/csrc/jit/codegen/cuda/runtime/memory.cu +++ b/torch/csrc/jit/codegen/cuda/runtime/memory.cu @@ -105,6 +105,66 @@ DEVICE_INLINE void ldMatrixT(Array<__half, 8, 8>& out, void const* ptr) { : "r"(addr)); } +// Load Matrix (per warp instruction) is to take data from SMEM to Local Memory. +// Automatically handles vectorized loads/stores in the MMA operation. +// Loads 8x8 matrix into a warp. Thread 0-7 provide the ptr that is the start +// of each row. All other threads can simply point to something valid +// (including 0). +// The x2 modifier on the instruction will actually load 2x8 rows to make a +// 16x8, +// then thread 0-15 will specify the start of each row. +// Finally is an x4 modifier producing a 32x8 using addrs from 0-31 in each +// warp. +DEVICE_INLINE void ldMatrix( + Array<__half, 4, 4>& out, + nvfuser_index_t index, + Pointer base_ptr) { + uint2& val = reinterpret_cast(out); + unsigned addr = util::toSmem(base_ptr); + util::adjustPartialLdMatrixAddrInTuring(addr); + asm volatile("ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0,%1}, [%2];" + : "=r"(val.x), "=r"(val.y) + : "r"(addr + index)); +} + +// Same as previous, 8x8 matrix is vectorized loaded, then scattered (to perform +// transpose) so threads will hold 2 values down a column (instead of the +// previous instruction that's across a row). +DEVICE_INLINE void ldMatrixT( + Array<__half, 4, 4>& out, + nvfuser_index_t index, + Pointer base_ptr) { + uint2& val = reinterpret_cast(out); + unsigned addr = util::toSmem(base_ptr); + util::adjustPartialLdMatrixAddrInTuring(addr); + asm volatile("ldmatrix.sync.aligned.x2.trans.m8n8.shared.b16 {%0,%1}, [%2];" + : "=r"(val.x), "=r"(val.y) + : "r"(addr + index)); +} + +DEVICE_INLINE void ldMatrix( + Array<__half, 8, 8>& out, + nvfuser_index_t index, + Pointer base_ptr) { + uint4& val = reinterpret_cast(out); + unsigned addr = util::toSmem(base_ptr); + asm volatile("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0,%1,%2,%3}, [%4];" + : "=r"(val.x), "=r"(val.y), "=r"(val.z), "=r"(val.w) + : "r"(addr + index)); +} + +DEVICE_INLINE void ldMatrixT( + Array<__half, 8, 8>& out, + nvfuser_index_t index, + Pointer base_ptr) { + uint4& val = reinterpret_cast(out); + unsigned addr = util::toSmem(ptr); + asm volatile( + "ldmatrix.sync.aligned.x4.trans.m8n8.shared.b16 {%0,%1,%2,%3}, [%4];" + : "=r"(val.x), "=r"(val.y), "=r"(val.z), "=r"(val.w) + : "r"(addr + index)); +} + } // namespace Turing #endif // Arch 75 @@ -136,10 +196,8 @@ DEVICE_INLINE unsigned toSmem(void* ptr) { // Global to SMEM load that is asynchronous, // not guaranteed to be completed until cpAsyncBarrier() is called. template -DEVICE_INLINE void cpAsync( - Array* smem_ptr, - void const* gmem_ptr) { - unsigned smem_addr = util::toSmem(&(smem_ptr->array[0])); +DEVICE_INLINE void cpAsync(void* smem_ptr, void const* gmem_ptr) { + unsigned smem_addr = util::toSmem(smem_ptr); constexpr int byte_size = sizeof(dtype) * len; static_assert( @@ -156,10 +214,10 @@ DEVICE_INLINE void cpAsync( // not guaranteed to be completed until cpAsyncBarrier() is called. template DEVICE_INLINE void cpAsync( - Array* smem_ptr, + void* smem_ptr, void const* gmem_ptr, bool predicate) { - unsigned smem_addr = util::toSmem(&(smem_ptr->array[0])); + unsigned smem_addr = util::toSmem(smem_ptr); constexpr int byte_size = sizeof(dtype) * len; static_assert( @@ -177,6 +235,59 @@ DEVICE_INLINE void cpAsync( "r"((int)predicate)); } +// Global to SMEM load that is asynchronous, +// not guaranteed to be completed until cpAsyncBarrier() is called. +template +DEVICE_INLINE void cpAsync( + nvfuser_index_t smem_index, + Pointer smem_base_ptr, + nvfuser_index_t gmem_index, + Pointer& gmem_ptr) { + unsigned smem_addr = util::toSmem(smem_base_ptr); + constexpr int byte_size = sizeof(dtype) * len; + + static_assert( + byte_size == 4 || byte_size == 8 || byte_size == 16, + "cp_async : unsupported byte size"); + + gmem_ptr += gmem_index; + asm volatile( + "cp.async.ca.shared.global [%0], [%1], %2;\n" ::"r"( + smem_addr + smem_index), + "+l"(gmem_ptr), + "n"(byte_size)); + gmem_ptr -= gmem_index; +} + +// Global to SMEM load that is asynchronous, +// not guaranteed to be completed until cpAsyncBarrier() is called. +template +DEVICE_INLINE void cpAsync( + nvfuser_index_t smem_index, + Pointer smem_base_ptr, + nvfuser_index_t gmem_index, + Pointer& gmem_ptr, + bool predicate) { + unsigned smem_addr = util::toSmem(smem_base_ptr); + constexpr int byte_size = sizeof(dtype) * len; + + static_assert( + byte_size == 4 || byte_size == 8 || byte_size == 16, + "cp_async : unsupported byte size"); + + gmem_ptr += gmem_index; + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %3, 0;\n" + "@p cp.async.ca.shared.global [%0], [%1], %2;\n" + "}\n" ::"r"(smem_addr + smem_index), + "+l"(gmem_ptr), + "n"(byte_size), + "r"((int)predicate)); + gmem_ptr -= gmem_index; +} + // TODO: Might have a different category of sync if we want to build out this: DEVICE_INLINE void cpAsyncBarrier() { asm volatile("cp.async.wait_all;"); From f995fd12d9cf1a7022fcf7bf335e5225c026ff1e Mon Sep 17 00:00:00 2001 From: shmsong Date: Mon, 22 Aug 2022 23:02:49 -0700 Subject: [PATCH 04/17] pointer mod take 1 --- torch/csrc/jit/codegen/cuda/codegen.cpp | 22 ++++--- torch/csrc/jit/codegen/cuda/executor.cpp | 2 +- torch/csrc/jit/codegen/cuda/index_compute.cpp | 65 +++++++------------ .../csrc/jit/codegen/cuda/lower_mem_index.cpp | 2 +- torch/csrc/jit/codegen/cuda/runtime/memory.cu | 34 +++++----- .../codegen/cuda/test/test_gpu_tensorcore.cpp | 13 +++- 6 files changed, 68 insertions(+), 70 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/codegen.cpp b/torch/csrc/jit/codegen/cuda/codegen.cpp index 53972d657023b1..88943debce6923 100644 --- a/torch/csrc/jit/codegen/cuda/codegen.cpp +++ b/torch/csrc/jit/codegen/cuda/codegen.cpp @@ -532,6 +532,14 @@ class CudaKernelGenerator : private OptOutConstDispatch { if (is_volatile) { code_ << "*(volatile " << ti->getDataType().value() << "*)&"; } + + if (ti->hasBaseAddress()) { + // WAR path to generate a tensor index with pointer content. + code_ << "reinterpret_cast<" << ti->view()->dtype() << "*>(" + << gen(ti->baseAddress()) << ")" + << "[" << genTensorIndex(ti) << "]"; + return; + } code_ << varName(ti->view()) << "[" << genTensorIndex(ti) << "]"; } @@ -567,8 +575,8 @@ class CudaKernelGenerator : private OptOutConstDispatch { return ss.str(); } - std::string genMaybeHoistedPointer(Val* val) { - auto ti = dynamic_cast(val); + std::string genMaybeHoistedPointer(const Val* val) { + auto ti = dynamic_cast(val); TORCH_INTERNAL_ASSERT(ti != nullptr, "only support tensor index input"); std::stringstream ss; @@ -578,6 +586,8 @@ class CudaKernelGenerator : private OptOutConstDispatch { } else { ss << "&" << gen(ti) << "\n"; } + + return ss.str(); } // Utility function to emit a cp.async intrinsic @@ -2462,12 +2472,8 @@ class CudaKernelGenerator : private OptOutConstDispatch { } void handle(const kir::AddressCompute* address_compute) final { - indent() << "// Address tensor for indexing " - << varName(address_compute->dataTv()) << "\n"; - indent() << gen(address_compute->addressTv()) << " = " - << genTensorIndex( - address_compute->dataTv()->as()) - << ";\n"; + indent() << gen(address_compute->addressTv()) << " = (DataPointer) &" + << gen(address_compute->dataTv()->as()) << ";\n"; } void handle(const kir::GridSync* sync) final { diff --git a/torch/csrc/jit/codegen/cuda/executor.cpp b/torch/csrc/jit/codegen/cuda/executor.cpp index 6da07ba5fd3488..3716e0926c9c62 100644 --- a/torch/csrc/jit/codegen/cuda/executor.cpp +++ b/torch/csrc/jit/codegen/cuda/executor.cpp @@ -54,7 +54,7 @@ typedef int int32_t; typedef unsigned int uint32_t; typedef long long int int64_t; typedef unsigned long long int uint64_t; -typedef char* Pointer; +typedef char* DataPointer; )"; } diff --git a/torch/csrc/jit/codegen/cuda/index_compute.cpp b/torch/csrc/jit/codegen/cuda/index_compute.cpp index 09fd40883de6b8..8588ac3e2239cf 100644 --- a/torch/csrc/jit/codegen/cuda/index_compute.cpp +++ b/torch/csrc/jit/codegen/cuda/index_compute.cpp @@ -1625,16 +1625,6 @@ std::vector Index::getGlobalProducerStridedIndices( } } - if (shouldUseLiftedAddress(producer_tv, consumer_tv, loops)) { - auto maybe_address_record = - GpuLower::current()->addressComputeInfo().getMaybeLiftedAddress( - producer_tv, consumer_tv); - - auto address_index = generateAddressTensorIndex( - loops, maybe_address_record.value()->addressTensor()); - strided_inds.push_back(address_index); - } - return strided_inds; } @@ -1943,16 +1933,6 @@ std::vector Index::getNonGlobalProducerStridedIndices( } } - auto maybe_address_record = - GpuLower::current()->addressComputeInfo().getMaybeLiftedAddress( - producer_tv, consumer_tv); - - if (should_use_lifted_address) { - auto address_index = generateAddressTensorIndex( - loops, maybe_address_record.value()->addressTensor()); - strided_inds.push_back(address_index); - } - return strided_inds; } @@ -2126,16 +2106,6 @@ std::vector Index::getGlobalConsumerStridedIndices( TORCH_INTERNAL_ASSERT( strided_inds.size() == consumer_tv->getMaybeRFactorDomain().size()); - if (shouldUseLiftedAddress(consumer_tv, consumer_tv, loops)) { - auto maybe_address_record = - GpuLower::current()->addressComputeInfo().getMaybeLiftedAddress( - consumer_tv); - - auto address_index = generateAddressTensorIndex( - loops, maybe_address_record.value()->addressTensor()); - strided_inds.push_back(address_index); - } - return strided_inds; } @@ -2305,17 +2275,6 @@ std::vector Index::getNonGlobalConsumerStridedIndices( } } - // Add pre computed index path: - - if (shouldUseLiftedAddress(consumer_tv, consumer_tv, loops)) { - auto maybe_address_record = - GpuLower::current()->addressComputeInfo().getMaybeLiftedAddress( - consumer_tv); - auto address_index = generateAddressTensorIndex( - loops, maybe_address_record.value()->addressTensor()); - strided_inds.push_back(address_index); - } - return strided_inds; } @@ -2361,6 +2320,18 @@ kir::TensorIndex* Index::getProducerIndex( const TensorView* consumer, const std::vector& loops) { auto strided_indices = getProducerStridedIndices(producer, consumer, loops); + + if (shouldUseLiftedAddress(producer, consumer, loops)) { + auto maybe_address_record = + GpuLower::current()->addressComputeInfo().getMaybeLiftedAddress( + producer, consumer); + + auto address_index = generateAddressTensorIndex( + loops, maybe_address_record.value()->addressTensor()); + return SimplifyingIrBuilder::create( + producer, strided_indices, address_index); + } + return SimplifyingIrBuilder::create( producer, strided_indices); } @@ -2390,6 +2361,18 @@ kir::TensorIndex* Index::getConsumerIndex( const TensorView* consumer, const std::vector& loops) { auto strided_indices = getConsumerStridedIndices(consumer, loops); + + if (shouldUseLiftedAddress(consumer, consumer, loops)) { + auto maybe_address_record = + GpuLower::current()->addressComputeInfo().getMaybeLiftedAddress( + consumer); + + auto address_index = generateAddressTensorIndex( + loops, maybe_address_record.value()->addressTensor()); + return SimplifyingIrBuilder::create( + consumer, strided_indices, address_index); + } + return SimplifyingIrBuilder::create( consumer, strided_indices); } diff --git a/torch/csrc/jit/codegen/cuda/lower_mem_index.cpp b/torch/csrc/jit/codegen/cuda/lower_mem_index.cpp index 9bc2dcc491897a..38c359aa66fac6 100644 --- a/torch/csrc/jit/codegen/cuda/lower_mem_index.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_mem_index.cpp @@ -677,7 +677,7 @@ c10::optional AddressComputeInfo::getMaybeLiftedAddress( TensorView* AddressComputeInfo::makeAddressTv( std::vector address_domains, bool is_global_address) { - DataType dtype = is_global_address ? DataType::Index : DataType::Int32; + DataType dtype = DataType::Pointer; return IrBuilder::create( IrBuilder::create( address_domains, std::vector(address_domains.size(), true)), diff --git a/torch/csrc/jit/codegen/cuda/runtime/memory.cu b/torch/csrc/jit/codegen/cuda/runtime/memory.cu index e0def0eaa9af34..9ac522dd4600da 100644 --- a/torch/csrc/jit/codegen/cuda/runtime/memory.cu +++ b/torch/csrc/jit/codegen/cuda/runtime/memory.cu @@ -118,13 +118,13 @@ DEVICE_INLINE void ldMatrixT(Array<__half, 8, 8>& out, void const* ptr) { DEVICE_INLINE void ldMatrix( Array<__half, 4, 4>& out, nvfuser_index_t index, - Pointer base_ptr) { + DataPointer base_ptr) { uint2& val = reinterpret_cast(out); unsigned addr = util::toSmem(base_ptr); util::adjustPartialLdMatrixAddrInTuring(addr); asm volatile("ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0,%1}, [%2];" : "=r"(val.x), "=r"(val.y) - : "r"(addr + index)); + : "r"(addr + (unsigned)index)); } // Same as previous, 8x8 matrix is vectorized loaded, then scattered (to perform @@ -133,36 +133,36 @@ DEVICE_INLINE void ldMatrix( DEVICE_INLINE void ldMatrixT( Array<__half, 4, 4>& out, nvfuser_index_t index, - Pointer base_ptr) { + DataPointer base_ptr) { uint2& val = reinterpret_cast(out); unsigned addr = util::toSmem(base_ptr); util::adjustPartialLdMatrixAddrInTuring(addr); asm volatile("ldmatrix.sync.aligned.x2.trans.m8n8.shared.b16 {%0,%1}, [%2];" : "=r"(val.x), "=r"(val.y) - : "r"(addr + index)); + : "r"(addr + (unsigned)index)); } DEVICE_INLINE void ldMatrix( Array<__half, 8, 8>& out, nvfuser_index_t index, - Pointer base_ptr) { + DataPointer base_ptr) { uint4& val = reinterpret_cast(out); unsigned addr = util::toSmem(base_ptr); asm volatile("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0,%1,%2,%3}, [%4];" : "=r"(val.x), "=r"(val.y), "=r"(val.z), "=r"(val.w) - : "r"(addr + index)); + : "r"(addr + (unsigned)index)); } DEVICE_INLINE void ldMatrixT( Array<__half, 8, 8>& out, nvfuser_index_t index, - Pointer base_ptr) { + DataPointer base_ptr) { uint4& val = reinterpret_cast(out); - unsigned addr = util::toSmem(ptr); + unsigned addr = util::toSmem(base_ptr); asm volatile( "ldmatrix.sync.aligned.x4.trans.m8n8.shared.b16 {%0,%1,%2,%3}, [%4];" : "=r"(val.x), "=r"(val.y), "=r"(val.z), "=r"(val.w) - : "r"(addr + index)); + : "r"(addr + (unsigned)index)); } } // namespace Turing @@ -240,9 +240,9 @@ DEVICE_INLINE void cpAsync( template DEVICE_INLINE void cpAsync( nvfuser_index_t smem_index, - Pointer smem_base_ptr, + DataPointer smem_base_ptr, nvfuser_index_t gmem_index, - Pointer& gmem_ptr) { + DataPointer& gmem_ptr) { unsigned smem_addr = util::toSmem(smem_base_ptr); constexpr int byte_size = sizeof(dtype) * len; @@ -253,8 +253,8 @@ DEVICE_INLINE void cpAsync( gmem_ptr += gmem_index; asm volatile( "cp.async.ca.shared.global [%0], [%1], %2;\n" ::"r"( - smem_addr + smem_index), - "+l"(gmem_ptr), + smem_addr + (unsigned)smem_index), + "l"(gmem_ptr), "n"(byte_size)); gmem_ptr -= gmem_index; } @@ -264,9 +264,9 @@ DEVICE_INLINE void cpAsync( template DEVICE_INLINE void cpAsync( nvfuser_index_t smem_index, - Pointer smem_base_ptr, + DataPointer smem_base_ptr, nvfuser_index_t gmem_index, - Pointer& gmem_ptr, + DataPointer& gmem_ptr, bool predicate) { unsigned smem_addr = util::toSmem(smem_base_ptr); constexpr int byte_size = sizeof(dtype) * len; @@ -281,8 +281,8 @@ DEVICE_INLINE void cpAsync( " .reg .pred p;\n" " setp.ne.b32 p, %3, 0;\n" "@p cp.async.ca.shared.global [%0], [%1], %2;\n" - "}\n" ::"r"(smem_addr + smem_index), - "+l"(gmem_ptr), + "}\n" ::"r"(smem_addr + (unsigned)smem_index), + "l"(gmem_ptr), "n"(byte_size), "r"((int)predicate)); gmem_ptr -= gmem_index; diff --git a/torch/csrc/jit/codegen/cuda/test/test_gpu_tensorcore.cpp b/torch/csrc/jit/codegen/cuda/test/test_gpu_tensorcore.cpp index 54659812fb0de1..984a129970fc06 100644 --- a/torch/csrc/jit/codegen/cuda/test/test_gpu_tensorcore.cpp +++ b/torch/csrc/jit/codegen/cuda/test/test_gpu_tensorcore.cpp @@ -2771,7 +2771,7 @@ TEST_F(NVFuserTest, FusionAmpereMatmulTNSwizzled_CUDA) { TEST_F(NVFuserTest, FusionAmpereMatmulLargeLoad_CUDA) { // Keep multiples of 8 to keep vectorizable. int M = 504, N = 136, K = 248; - for (auto layout : kAllSupportedLayout) { + for (auto layout : {kAllSupportedLayout[2]}) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeContigTensor(2, DataType::Half); @@ -2803,9 +2803,18 @@ TEST_F(NVFuserTest, FusionAmpereMatmulLargeLoad_CUDA) { at::manual_seed(0); auto inputs = fp16MatmulAtInput(M, N, K, layout); + fusion.printKernel(); + // return; + CompileOptions co; + co.index_mode = KernelIndexMode::INT32; + FusionExecutor fe; NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( - 8, 0, fe.compileFusion(&fusion, {inputs.first, inputs.second})); + 8, + 0, + fe.compileFusion( + &fusion, {inputs.first, inputs.second}, LaunchParams(), co)); + // return; auto cg_outputs = fe.runFusion({inputs.first, inputs.second}); auto tref = atMatmul( inputs.first.to(at::kFloat), inputs.second.to(at::kFloat), layout); From 84d8f046e66e34d66e14b2f1a87533ec2df331de Mon Sep 17 00:00:00 2001 From: shmsong Date: Tue, 23 Aug 2022 12:45:42 -0700 Subject: [PATCH 05/17] minor update --- torch/csrc/jit/codegen/cuda/ir_nodes.cpp | 5 +++-- torch/csrc/jit/codegen/cuda/scheduler/matmul.cpp | 2 ++ 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/ir_nodes.cpp b/torch/csrc/jit/codegen/cuda/ir_nodes.cpp index da926a63cac1f9..55412e1816d972 100644 --- a/torch/csrc/jit/codegen/cuda/ir_nodes.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_nodes.cpp @@ -1404,8 +1404,9 @@ void IterDomain::parallelize(ParallelType t) { // they are swizzled. TORCH_CHECK( t == ParallelType::Vectorize || t == ParallelType::TIDx || - t == ParallelType::Serial, - "Parallel type other than serial, tidx, vectorize not allowed for mma swizzled ids"); + t == ParallelType::Serial || t == ParallelType::Mma, + "Parallel type other than serial, tidx, vectorize not allowed for mma swizzled ids", + t); } parallel_type_ = t; diff --git a/torch/csrc/jit/codegen/cuda/scheduler/matmul.cpp b/torch/csrc/jit/codegen/cuda/scheduler/matmul.cpp index 31b30abc24b9d1..7d117e520d0d85 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/matmul.cpp +++ b/torch/csrc/jit/codegen/cuda/scheduler/matmul.cpp @@ -530,6 +530,8 @@ void scheduleMatmul( .propagateParallelType() .propagateToBoundary()); + c->axis(-1)->parallelize(ParallelType::Vectorize); + if (params.index_lift_options.lift_gmem_read_address) { a->liftReadAddress(); b->liftReadAddress(); From 250e46bb050f49253beca4dfed841ddefc4e6577 Mon Sep 17 00:00:00 2001 From: shmsong Date: Fri, 26 Aug 2022 09:27:58 -0700 Subject: [PATCH 06/17] (wip) increment mode --- torch/csrc/jit/codegen/cuda/kernel_ir.h | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/torch/csrc/jit/codegen/cuda/kernel_ir.h b/torch/csrc/jit/codegen/cuda/kernel_ir.h index e7f9460e28082f..da3eacb481ba1c 100644 --- a/torch/csrc/jit/codegen/cuda/kernel_ir.h +++ b/torch/csrc/jit/codegen/cuda/kernel_ir.h @@ -297,7 +297,7 @@ class TORCH_CUDA_CU_API CpAsyncCommit final : public Expr { class TORCH_CUDA_CU_API AddressCompute final : public Expr { public: - enum class AddressComputeOpType { BASE_ADDRESS }; + enum class AddressComputeOpType { BASE_ADDRESS, INCREMENT }; explicit AddressCompute( IrBuilderPasskey passkey, @@ -329,6 +329,9 @@ class TORCH_CUDA_CU_API AddressCompute final : public Expr { // Tensor that stores pre-computed address for the // data tensor. Val* address_tensor_ = nullptr; + + // Tensor that holds the value to increment (INCREMENT MODE only). + Val* inc_value_ = nullptr; }; // Synchronize all blocks in device, implies cooperative group launch is From ce3f1e1f6969b6de64e52fd8ec89c7eab756c1fa Mon Sep 17 00:00:00 2001 From: shmsong Date: Thu, 1 Sep 2022 22:32:15 -0700 Subject: [PATCH 07/17] lift read db index --- torch/csrc/jit/codegen/cuda/codegen.cpp | 28 +++++++- torch/csrc/jit/codegen/cuda/index_compute.cpp | 68 +++++++++++-------- torch/csrc/jit/codegen/cuda/ir_base_nodes.h | 6 ++ torch/csrc/jit/codegen/cuda/kernel_ir.cpp | 27 +++++++- torch/csrc/jit/codegen/cuda/kernel_ir.h | 53 ++++++++++++++- .../jit/codegen/cuda/lower_double_buffer.cpp | 55 +++++++++++++++ .../jit/codegen/cuda/lower_double_buffer.h | 17 +++++ torch/csrc/jit/codegen/cuda/lower_index.cpp | 21 ++++++ torch/csrc/jit/codegen/cuda/runtime/memory.cu | 45 ++++++++++++ .../codegen/cuda/test/test_gpu_tensorcore.cpp | 5 +- 10 files changed, 288 insertions(+), 37 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/codegen.cpp b/torch/csrc/jit/codegen/cuda/codegen.cpp index 88943debce6923..8b43afe0c08cac 100644 --- a/torch/csrc/jit/codegen/cuda/codegen.cpp +++ b/torch/csrc/jit/codegen/cuda/codegen.cpp @@ -519,6 +519,13 @@ class CudaKernelGenerator : private OptOutConstDispatch { } } + if (ti->uniformAddress() != nullptr) { + if (!first) { + index << " + "; + } + index << genInline(ti->uniformAddress()); + } + if (first) { index << "0"; } @@ -2395,7 +2402,11 @@ class CudaKernelGenerator : private OptOutConstDispatch { alloc_map_.emplace(alloc->buffer(), alloc); if (!alloc->buffer()->isA()) { - indent() << buffer_dtype << " " << gen(alloc->buffer()) << ";\n"; + indent() << buffer_dtype << " " << gen(alloc->buffer()); + if (alloc->zeroInit()) { + code_ << " = 0"; + } + code_ << ";\n"; return; } @@ -2472,8 +2483,19 @@ class CudaKernelGenerator : private OptOutConstDispatch { } void handle(const kir::AddressCompute* address_compute) final { - indent() << gen(address_compute->addressTv()) << " = (DataPointer) &" - << gen(address_compute->dataTv()->as()) << ";\n"; + if (address_compute->opType() == + kir::AddressCompute::AddressComputeOpType::DOUBLE_BUFFER_SWITCH) { + indent() << "doubleBufferSwitch<" << address_compute->stageNumber() << "," + << address_compute->loopOffset() << ">(" + << gen(address_compute->doubleBufferSwitchIndex()) << "," + << gen(address_compute->loopIndex()) << "," + << gen(address_compute->doubleBufferByteSize()) << ");\n"; + } else { + indent() << "//Base Address:::\n"; + indent() << gen(address_compute->addressTv()) << " = (DataPointer) &" + << gen(address_compute->dataTv()->as()) + << ";\n"; + } } void handle(const kir::GridSync* sync) final { diff --git a/torch/csrc/jit/codegen/cuda/index_compute.cpp b/torch/csrc/jit/codegen/cuda/index_compute.cpp index 229745f5675cb7..f5f686871b6b6a 100644 --- a/torch/csrc/jit/codegen/cuda/index_compute.cpp +++ b/torch/csrc/jit/codegen/cuda/index_compute.cpp @@ -1918,37 +1918,43 @@ std::vector Index::getNonGlobalProducerStridedIndices( // No need to compute double buffer index in the address compute loop // as they have been handled with addtional offsets. if (!db_loop->isBaseIndexLoop()) { - auto loop_index = - db_loop->isTrivial() ? db_loop->start() : db_loop->index(); - - // Need to add the producer outer main loop index by 1 - // in the case of lower prolog, see the example in - // [Skew Double Buffer Loop Transformation] - auto consumer_db_loop = - gpu_lower->doubleBufferInfo().getDoubleBufferLoop( - consumer_tv, loops); - - if (consumer_db_loop != nullptr) { - if (gpu_lower->doubleBufferInfo().isLowerPrologWithin( - consumer_db_loop->iter_domain(), db_loop->iter_domain())) { - if (consumer_db_loop->doubleBufferLoopStage() == - DoubleBufferLoopStage::LowerProlog) { - loop_index = SimplifyingIrBuilder::addExpr( - loop_index, gpu_lower->kernel()->oneVal()); + auto maybe_read_offset = + GpuLower::current()->doubleBufferInfo().getReadSwitchIndex( + producer_tv); + + if (!maybe_read_offset.has_value()) { + auto loop_index = + db_loop->isTrivial() ? db_loop->start() : db_loop->index(); + + // Need to add the producer outer main loop index by 1 + // in the case of lower prolog, see the example in + // [Skew Double Buffer Loop Transformation] + auto consumer_db_loop = + gpu_lower->doubleBufferInfo().getDoubleBufferLoop( + consumer_tv, loops); + + if (consumer_db_loop != nullptr) { + if (gpu_lower->doubleBufferInfo().isLowerPrologWithin( + consumer_db_loop->iter_domain(), db_loop->iter_domain())) { + if (consumer_db_loop->doubleBufferLoopStage() == + DoubleBufferLoopStage::LowerProlog) { + loop_index = SimplifyingIrBuilder::addExpr( + loop_index, gpu_lower->kernel()->oneVal()); + } } } - } - auto stage_depth = gpu_lower->doubleBufferInfo().getStageDepthFor( - db_loop->iter_domain()); - auto db_switch_index = SimplifyingIrBuilder::modExpr( - loop_index, SimplifyingIrBuilder::create(stage_depth)); + auto stage_depth = gpu_lower->doubleBufferInfo().getStageDepthFor( + db_loop->iter_domain()); + auto db_switch_index = SimplifyingIrBuilder::modExpr( + loop_index, SimplifyingIrBuilder::create(stage_depth)); - auto original_alloc_size = - gpu_lower->doubleBufferInfo().getOriginalAllocSize(producer_tv); - auto db_strided_index = - SimplifyingIrBuilder::mulExpr(db_switch_index, original_alloc_size); - strided_inds.push_back(db_strided_index); + auto original_alloc_size = + gpu_lower->doubleBufferInfo().getOriginalAllocSize(producer_tv); + auto db_strided_index = SimplifyingIrBuilder::mulExpr( + db_switch_index, original_alloc_size); + strided_inds.push_back(db_strided_index); + } } } } @@ -2350,10 +2356,16 @@ kir::TensorIndex* Index::getProducerIndex( GpuLower::current()->addressComputeInfo().getMaybeLiftedAddress( producer, consumer); + auto maybe_read_offset = + GpuLower::current()->doubleBufferInfo().getReadSwitchIndex(producer); + Val* uniform_address = nullptr; + if (maybe_read_offset.has_value()) { + uniform_address = maybe_read_offset.value(); + } auto address_index = generateAddressTensorIndex( loops, maybe_address_record.value()->addressTensor()); return SimplifyingIrBuilder::create( - producer, strided_indices, address_index); + producer, strided_indices, address_index, uniform_address); } return SimplifyingIrBuilder::create( diff --git a/torch/csrc/jit/codegen/cuda/ir_base_nodes.h b/torch/csrc/jit/codegen/cuda/ir_base_nodes.h index f59d7d7deaa0ef..53969924dd8398 100644 --- a/torch/csrc/jit/codegen/cuda/ir_base_nodes.h +++ b/torch/csrc/jit/codegen/cuda/ir_base_nodes.h @@ -341,6 +341,12 @@ class TORCH_CUDA_CU_API Val : public Statement { void resolveIndexDtype(); + // Provide a way to instantiate a 32b integer scalar + void to32b() { + TORCH_INTERNAL_ASSERT(vtype_ == ValType::Scalar && dtype_ == DataType::Int); + dtype_ = DataType::Int32; + } + protected: friend Fusion; diff --git a/torch/csrc/jit/codegen/cuda/kernel_ir.cpp b/torch/csrc/jit/codegen/cuda/kernel_ir.cpp index 2ea900a4ccd055..c103f81620b016 100644 --- a/torch/csrc/jit/codegen/cuda/kernel_ir.cpp +++ b/torch/csrc/jit/codegen/cuda/kernel_ir.cpp @@ -67,11 +67,13 @@ TensorIndex::TensorIndex( IrBuilderPasskey passkey, const TensorView* view, std::vector indices, - Val* base_address) + Val* base_address, + Val* uniform_address) : Val(passkey, ValType::TensorIndex, view->getDataType().value()), view_(view), indices_(indices), - base_address_(base_address) { + base_address_(base_address), + uniform_address_(uniform_address) { TORCH_INTERNAL_ASSERT( passkey.ir_container_->isA(), "IR type only valid for Kernel container."); @@ -136,6 +138,27 @@ AddressCompute::AddressCompute( "IR type only valid for Kernel container."); } +AddressCompute::AddressCompute( + IrBuilderPasskey passkey, + TensorView* data_tv, + Val* double_buffer_switch_index, + Val* buffer_size_in_byte, + int loop_offset, + int stage_number, + Val* loop_index) + : Expr(passkey, ExprType::AddressCompute), + op_type_(AddressCompute::AddressComputeOpType::DOUBLE_BUFFER_SWITCH), + data_tensor_(data_tv), + double_buffer_switch_index_(double_buffer_switch_index), + buffer_size_in_byte_(buffer_size_in_byte), + loop_offset_(loop_offset), + stage_number_(stage_number), + loop_index_(loop_index) { + TORCH_INTERNAL_ASSERT( + passkey.ir_container_->isA(), + "IR type only valid for Kernel container."); +} + InitMagicZero::InitMagicZero(IrBuilderPasskey passkey) : Expr(passkey, ExprType::InitMagicZero) { TORCH_INTERNAL_ASSERT( diff --git a/torch/csrc/jit/codegen/cuda/kernel_ir.h b/torch/csrc/jit/codegen/cuda/kernel_ir.h index 230273040a3a61..cb2777b279b92d 100644 --- a/torch/csrc/jit/codegen/cuda/kernel_ir.h +++ b/torch/csrc/jit/codegen/cuda/kernel_ir.h @@ -153,7 +153,8 @@ class TORCH_CUDA_CU_API TensorIndex final : public Val { IrBuilderPasskey, const TensorView* view, std::vector indices, - Val* base_address = nullptr); + Val* base_address = nullptr, + Val* uniform_address = nullptr); std::vector::size_type nDims() const { return indices_.size(); @@ -178,10 +179,15 @@ class TORCH_CUDA_CU_API TensorIndex final : public Val { return base_address_; } + auto uniformAddress() const { + return uniform_address_; + } + private: const TensorView* view_ = nullptr; std::vector indices_; Val* base_address_ = nullptr; + Val* uniform_address_ = nullptr; }; //! Allocate is a lower level Node that describes a buffer of memory that @@ -299,7 +305,12 @@ class TORCH_CUDA_CU_API CpAsyncCommit final : public Expr { //! that are not inlined. class TORCH_CUDA_CU_API AddressCompute final : public Expr { public: - enum class AddressComputeOpType { BASE_ADDRESS, INCREMENT }; + enum class AddressComputeOpType { + BASE_ADDRESS, + INCREMENT, + DOUBLE_BUFFER_SWITCH, + DOUBLE_BUFFER_UPDATE + }; explicit AddressCompute( IrBuilderPasskey passkey, @@ -307,6 +318,17 @@ class TORCH_CUDA_CU_API AddressCompute final : public Expr { Val* address_tensor, Val* data_tensor); + // Interface for double buffer offset + // calculation: + explicit AddressCompute( + IrBuilderPasskey passkey, + TensorView* data_tv, + Val* double_buffer_switch_index, + Val* buffer_size_in_byte, + int loop_offset, + int stage_number, + Val* loop_index = nullptr); + auto dataTv() const { return data_tensor_; } @@ -319,6 +341,26 @@ class TORCH_CUDA_CU_API AddressCompute final : public Expr { return op_type_; } + auto doubleBufferSwitchIndex() const { + return double_buffer_switch_index_; + } + + auto doubleBufferByteSize() const { + return buffer_size_in_byte_; + } + + auto loopOffset() const { + return loop_offset_; + } + + auto stageNumber() const { + return stage_number_; + } + + auto loopIndex() const { + return loop_index_; + } + private: // The type of computation this op computes, // currently only do compute address. @@ -334,6 +376,13 @@ class TORCH_CUDA_CU_API AddressCompute final : public Expr { // Tensor that holds the value to increment (INCREMENT MODE only). Val* inc_value_ = nullptr; + + // Double buffer switch and update parameters: + Val* double_buffer_switch_index_ = nullptr; + Val* buffer_size_in_byte_ = nullptr; + int loop_offset_ = 0; + int stage_number_ = 0; + Val* loop_index_ = nullptr; }; // Synchronize all blocks in device, implies cooperative group launch is diff --git a/torch/csrc/jit/codegen/cuda/lower_double_buffer.cpp b/torch/csrc/jit/codegen/cuda/lower_double_buffer.cpp index dc0e4b9a505a81..fca0e05b23cada 100644 --- a/torch/csrc/jit/codegen/cuda/lower_double_buffer.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_double_buffer.cpp @@ -230,6 +230,37 @@ class DoubleBufferLoopCloner : public kir::IrVisitor { cloned_top_level_loop_->body().push_back( IrBuilder::create()); } + + // insert double buffer switching for the read offset: + if (loop_type_ == DoubleBufferLoopStage::Main) { + auto& db_info = GpuLower::current()->doubleBufferInfo(); + + for (auto load : double_buffer_load_exprs_) { + if (auto tv_out = ir_utils::getTvOutput(load)) { + auto maybe_read_index = db_info.getReadSwitchIndex(tv_out); + if (maybe_read_index.has_value()) { + // insert db switch: + auto switch_size = db_info.getOriginalAllocSize(tv_out); + auto switch_size_in_byte = SimplifyingIrBuilder::mulExpr( + switch_size, + SimplifyingIrBuilder::create( + dataTypeSize(tv_out->dtype()))); + + auto address_compute = + SimplifyingIrBuilder::create( + tv_out, + maybe_read_index.value(), + switch_size_in_byte, + 0, // assume this path only supports read + // so offset is 0 + db_info.getStageDepthFor( + double_buffer_loop_->iter_domain())); + + cloned_top_level_loop_->body().push_back(address_compute); + } + } + } + } } void handle(kir::ForLoop* fl) final { @@ -494,6 +525,30 @@ class DoubleBufferInserter : private kir::ExprMutator { void insert( kir::ForLoop* double_buffer_loop, const std::vector& loads) { + // Insert double buffer read switch index + for (auto load : loads) { + if (auto load_output = dynamic_cast(load->output(0))) { + if (load_output->getMemoryType() == MemoryType::Shared && + (load_output->isDoubleBuffered() || + load_output->isCircularBuffered()) && + load_output->shouldLiftReadAddress()) { + auto switch_val = IrBuilder::create(); + switch_val->to32b(); + + // TODO: maybe want to do this in id graph instead + GpuLower::current()->doubleBufferInfo().setReadSwitchIndex( + load_output, switch_val); + + auto index_alloc = IrBuilder::create( + switch_val, + MemoryType::Local, + GpuLower::current()->kernel()->oneVal(), + true); + registerInsertBefore(double_buffer_loop, index_alloc); + } + } + } + auto prologue_loop = DoubleBufferLoopCloner::clone( double_buffer_loop, loads, DoubleBufferLoopStage::Prolog); registerInsertBefore(double_buffer_loop, prologue_loop); diff --git a/torch/csrc/jit/codegen/cuda/lower_double_buffer.h b/torch/csrc/jit/codegen/cuda/lower_double_buffer.h index 29d27778bccc13..36538993506ef5 100644 --- a/torch/csrc/jit/codegen/cuda/lower_double_buffer.h +++ b/torch/csrc/jit/codegen/cuda/lower_double_buffer.h @@ -226,6 +226,20 @@ class TORCH_CUDA_CU_API DoubleBufferInfo { //! skew double buffer transform within the given outer loop. bool isLowerPrologWithin(IterDomain* db_loop, IterDomain* outer_loop); + void setReadSwitchIndex(TensorView* db_tv, Val* switch_index) { + TORCH_INTERNAL_ASSERT( + read_switch_index_map_.insert(std::make_pair(db_tv, switch_index)) + .second); + } + + c10::optional getReadSwitchIndex(TensorView* db_tv) { + auto val_it = read_switch_index_map_.find(db_tv); + if (val_it == read_switch_index_map_.end()) { + return c10::nullopt; + } + return val_it->second; + } + private: TvInfo& getTvInfo(const TensorView* tv); void buildSkewInfo(const TensorView* tv, const TvInfo& tv_info); @@ -258,6 +272,9 @@ class TORCH_CUDA_CU_API DoubleBufferInfo { //! mapping from inner loop to outer loop. std::unordered_map concrete_skewed_double_buffer_loop_map_; + + //! Keep track of read switch index + std::unordered_map read_switch_index_map_; }; } // namespace cuda diff --git a/torch/csrc/jit/codegen/cuda/lower_index.cpp b/torch/csrc/jit/codegen/cuda/lower_index.cpp index 885e9bafd703cb..16763153293d6c 100644 --- a/torch/csrc/jit/codegen/cuda/lower_index.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_index.cpp @@ -870,6 +870,27 @@ void IndexLowering::handle(const kir::CpAsyncCommit* commit) { } void IndexLowering::handle(const kir::AddressCompute* address_compute) { + // Logic for double buffer switching: + if (address_compute->opType() == + kir::AddressCompute::AddressComputeOpType::DOUBLE_BUFFER_SWITCH) { + // no indexing is needed, just forward through. + auto db_loop = GpuLower::current()->doubleBufferInfo().getDoubleBufferLoop( + address_compute->dataTv()->as(), for_loops_, false); + TORCH_INTERNAL_ASSERT(db_loop != nullptr); + auto db_index = db_loop->isTrivial() ? db_loop->start() : db_loop->index(); + + pushBack(IrBuilder::create( + address_compute->dataTv()->as(), + address_compute->doubleBufferSwitchIndex(), + address_compute->doubleBufferByteSize(), + address_compute->loopOffset(), + address_compute->stageNumber(), + db_index)); + return; + } + + // Logic for double buffer updating: + // Logic for base address computation auto address_tv = address_compute->addressTv(); diff --git a/torch/csrc/jit/codegen/cuda/runtime/memory.cu b/torch/csrc/jit/codegen/cuda/runtime/memory.cu index 9ac522dd4600da..f77f6f5549fa01 100644 --- a/torch/csrc/jit/codegen/cuda/runtime/memory.cu +++ b/torch/csrc/jit/codegen/cuda/runtime/memory.cu @@ -306,4 +306,49 @@ DEVICE_INLINE void cpAsyncPartialBarrier() { #endif // Arch 80 +// Double buffer calculation utilities: + +// In place update of double buffer index that has been accumulated to the data +// buffer. +template +DEVICE_INLINE void doubleBufferUpdate( + DataPointer& data_buffer, + const nvfuser_index_t& loop_index, + nvfuser_index_t buffer_size) { + // static_assert( + // loop_offset < number_of_stage && loop_offset > -number_of_stage); + + // convert offset to [0, number_of_stage) + constexpr nvfuser_index_t offset = + loop_offset < 0 ? (loop_offset + number_of_stage) : loop_offset; + + // Rewind back at number_of_stage-1, otherwise increment by 1. + nvfuser_index_t increment = + (loop_index % number_of_stage) == (number_of_stage - 1 - offset) + ? buffer_size * (-number_of_stage + 1) + : buffer_size; + data_buffer += increment; +} + +// Update double buffer offset value for smem double buffered tensors. +template +DEVICE_INLINE void doubleBufferSwitch( + nvfuser_index_t& buffer_offset, + const nvfuser_index_t& loop_index, + nvfuser_index_t buffer_size) { + // static_assert( + // loop_offset < number_of_stage && loop_offset > -number_of_stage); + + // convert offset to [0, number_of_stage) + constexpr nvfuser_index_t offset = + loop_offset < 0 ? (loop_offset + number_of_stage) : loop_offset; + + // Rewind back at number_of_stage-1, otherwise increment by 1. + nvfuser_index_t increment = + (loop_index % number_of_stage) == (number_of_stage - 1 - offset) + ? buffer_size * (-number_of_stage + 1) + : buffer_size; + buffer_offset += increment; +} + #undef DEVICE_INLINE diff --git a/torch/csrc/jit/codegen/cuda/test/test_gpu_tensorcore.cpp b/torch/csrc/jit/codegen/cuda/test/test_gpu_tensorcore.cpp index 3c2405e575e5b4..05df3189b1487c 100644 --- a/torch/csrc/jit/codegen/cuda/test/test_gpu_tensorcore.cpp +++ b/torch/csrc/jit/codegen/cuda/test/test_gpu_tensorcore.cpp @@ -2771,7 +2771,7 @@ TEST_F(NVFuserTest, FusionAmpereMatmulTNSwizzled_CUDA) { TEST_F(NVFuserTest, FusionAmpereMatmulLargeLoad_CUDA) { // Keep multiples of 8 to keep vectorizable. int M = 504, N = 136, K = 248; - for (auto layout : {kAllSupportedLayout[2]}) { + for (auto layout : kAllSupportedLayout) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeContigTensor(2, DataType::Half); @@ -2797,7 +2797,8 @@ TEST_F(NVFuserTest, FusionAmpereMatmulLargeLoad_CUDA) { params.tile_sizes = gemm_tile; params.async_gmem_load_operands = true; params.double_buffer_options.double_buffer_smem_write = true; - params.double_buffer_options.smem_double_buffer_stage = 4; + params.double_buffer_options.double_buffer_smem_read = true; + params.double_buffer_options.smem_double_buffer_stage = 3; scheduleMatmul(tv2, tv0, tv1, params); at::manual_seed(0); From 16e9c4a86e792a876a0c589c6ddf776ed3d7eaf4 Mon Sep 17 00:00:00 2001 From: shmsong Date: Fri, 2 Sep 2022 11:07:37 -0700 Subject: [PATCH 08/17] inplace write double buffer update --- torch/csrc/jit/codegen/cuda/codegen.cpp | 8 +++++ torch/csrc/jit/codegen/cuda/index_compute.cpp | 5 ++- torch/csrc/jit/codegen/cuda/kernel_ir.cpp | 21 ++++++++++++ torch/csrc/jit/codegen/cuda/kernel_ir.h | 11 +++++++ .../jit/codegen/cuda/lower_double_buffer.cpp | 20 +++++++++--- torch/csrc/jit/codegen/cuda/lower_index.cpp | 23 +++++++++++++ .../csrc/jit/codegen/cuda/lower_mem_index.cpp | 32 +++++++++++++++++++ 7 files changed, 114 insertions(+), 6 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/codegen.cpp b/torch/csrc/jit/codegen/cuda/codegen.cpp index 8b43afe0c08cac..67246b877e4109 100644 --- a/torch/csrc/jit/codegen/cuda/codegen.cpp +++ b/torch/csrc/jit/codegen/cuda/codegen.cpp @@ -2490,6 +2490,14 @@ class CudaKernelGenerator : private OptOutConstDispatch { << gen(address_compute->doubleBufferSwitchIndex()) << "," << gen(address_compute->loopIndex()) << "," << gen(address_compute->doubleBufferByteSize()) << ");\n"; + } else if ( + address_compute->opType() == + kir::AddressCompute::AddressComputeOpType::DOUBLE_BUFFER_UPDATE) { + indent() << "doubleBufferUpdate<" << address_compute->stageNumber() << "," + << address_compute->loopOffset() << ">(" + << gen(address_compute->addressTv()) << "," + << gen(address_compute->loopIndex()) << "," + << gen(address_compute->doubleBufferByteSize()) << ");\n"; } else { indent() << "//Base Address:::\n"; indent() << gen(address_compute->addressTv()) << " = (DataPointer) &" diff --git a/torch/csrc/jit/codegen/cuda/index_compute.cpp b/torch/csrc/jit/codegen/cuda/index_compute.cpp index f5f686871b6b6a..6cdcc5be36a89e 100644 --- a/torch/csrc/jit/codegen/cuda/index_compute.cpp +++ b/torch/csrc/jit/codegen/cuda/index_compute.cpp @@ -2261,7 +2261,10 @@ std::vector Index::getNonGlobalConsumerStridedIndices( TORCH_INTERNAL_ASSERT( strided_inds.size() == consumer_tv->getMaybeRFactorDomain().size()); - if (consumer_tv->isDoubleBuffered() || consumer_tv->isCircularBuffered()) { + if ((consumer_tv->isDoubleBuffered() || consumer_tv->isCircularBuffered()) + // Lifted address case the double buffer offset is + // computed inplace into the write address buffer. + && !consumer_tv->shouldLiftWriteAddress()) { auto db_loop = gpu_lower->doubleBufferInfo().getDoubleBufferLoop(consumer_tv, loops); TORCH_INTERNAL_ASSERT( diff --git a/torch/csrc/jit/codegen/cuda/kernel_ir.cpp b/torch/csrc/jit/codegen/cuda/kernel_ir.cpp index c103f81620b016..0ee3cd339173de 100644 --- a/torch/csrc/jit/codegen/cuda/kernel_ir.cpp +++ b/torch/csrc/jit/codegen/cuda/kernel_ir.cpp @@ -159,6 +159,27 @@ AddressCompute::AddressCompute( "IR type only valid for Kernel container."); } +AddressCompute::AddressCompute( + IrBuilderPasskey passkey, + Val* address_tensor, + Val* buffer_size_in_byte, + int stage_number, + int loop_offset, + TensorView* data_tensor, + Val* loop_index) + : Expr(passkey, ExprType::AddressCompute), + op_type_(AddressCompute::AddressComputeOpType::DOUBLE_BUFFER_UPDATE), + data_tensor_(data_tensor), + address_tensor_(address_tensor), + buffer_size_in_byte_(buffer_size_in_byte), + loop_offset_(loop_offset), + stage_number_(stage_number), + loop_index_(loop_index) { + TORCH_INTERNAL_ASSERT( + passkey.ir_container_->isA(), + "IR type only valid for Kernel container."); +} + InitMagicZero::InitMagicZero(IrBuilderPasskey passkey) : Expr(passkey, ExprType::InitMagicZero) { TORCH_INTERNAL_ASSERT( diff --git a/torch/csrc/jit/codegen/cuda/kernel_ir.h b/torch/csrc/jit/codegen/cuda/kernel_ir.h index cb2777b279b92d..feed43c1298799 100644 --- a/torch/csrc/jit/codegen/cuda/kernel_ir.h +++ b/torch/csrc/jit/codegen/cuda/kernel_ir.h @@ -329,6 +329,17 @@ class TORCH_CUDA_CU_API AddressCompute final : public Expr { int stage_number, Val* loop_index = nullptr); + // Interface for double buffer offset + // inplace update: + explicit AddressCompute( + IrBuilderPasskey passkey, + Val* address_tensor, + Val* buffer_size_in_byte, + int stage_number, + int loop_offset, + TensorView* data_tensor, + Val* loop_index = nullptr); + auto dataTv() const { return data_tensor_; } diff --git a/torch/csrc/jit/codegen/cuda/lower_double_buffer.cpp b/torch/csrc/jit/codegen/cuda/lower_double_buffer.cpp index fca0e05b23cada..4cdee398e5d897 100644 --- a/torch/csrc/jit/codegen/cuda/lower_double_buffer.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_double_buffer.cpp @@ -226,11 +226,6 @@ class DoubleBufferLoopCloner : public kir::IrVisitor { handle(double_buffer_loop_); - if (stage_depth > 2) { - cloned_top_level_loop_->body().push_back( - IrBuilder::create()); - } - // insert double buffer switching for the read offset: if (loop_type_ == DoubleBufferLoopStage::Main) { auto& db_info = GpuLower::current()->doubleBufferInfo(); @@ -261,6 +256,11 @@ class DoubleBufferLoopCloner : public kir::IrVisitor { } } } + + if (stage_depth > 2) { + cloned_top_level_loop_->body().push_back( + IrBuilder::create()); + } } void handle(kir::ForLoop* fl) final { @@ -333,6 +333,16 @@ class DoubleBufferLoopCloner : public kir::IrVisitor { cloned_scopes_.back()->push_back(expr); } } + + // Need the double buffer update expr in prologs too. + if (loop_type_ == DoubleBufferLoopStage::Prolog) { + if (auto address_compute = dynamic_cast(expr)) { + if (address_compute->opType() == + kir::AddressCompute::AddressComputeOpType::DOUBLE_BUFFER_UPDATE) { + cloned_scopes_.back()->push_back(expr); + } + } + } } //! Returns true if the expression is an initialization expr that diff --git a/torch/csrc/jit/codegen/cuda/lower_index.cpp b/torch/csrc/jit/codegen/cuda/lower_index.cpp index 16763153293d6c..6d31b586f19fe4 100644 --- a/torch/csrc/jit/codegen/cuda/lower_index.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_index.cpp @@ -887,6 +887,29 @@ void IndexLowering::handle(const kir::AddressCompute* address_compute) { address_compute->stageNumber(), db_index)); return; + } else if ( + address_compute->opType() == + kir::AddressCompute::AddressComputeOpType::DOUBLE_BUFFER_UPDATE) { + auto db_loop = GpuLower::current()->doubleBufferInfo().getDoubleBufferLoop( + address_compute->dataTv()->as(), for_loops_, false); + TORCH_INTERNAL_ASSERT(db_loop != nullptr); + auto db_index = db_loop->isTrivial() ? db_loop->start() : db_loop->index(); + auto loop_offset = + db_loop->doubleBufferLoopStage() == DoubleBufferLoopStage::Main + ? address_compute->stageNumber() - 1 + : 0; + + auto indexed_address_tv = Index::generateAddressTensorIndex( + for_loops_, address_compute->addressTv()->as()); + + pushBack(IrBuilder::create( + indexed_address_tv, + address_compute->doubleBufferByteSize(), + address_compute->stageNumber(), + loop_offset, + address_compute->dataTv()->as(), + db_index)); + return; } // Logic for double buffer updating: diff --git a/torch/csrc/jit/codegen/cuda/lower_mem_index.cpp b/torch/csrc/jit/codegen/cuda/lower_mem_index.cpp index c56a62df3d2e41..11bd02c841ec36 100644 --- a/torch/csrc/jit/codegen/cuda/lower_mem_index.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_mem_index.cpp @@ -1059,6 +1059,38 @@ class MemoryAddressComputeInserter : public kir::ExprMutator { // put the new loopnest before the hoisted loop registerInsertBefore(loop, outermost_innermost.first); + + auto data_tensor = insertion_info.address_compute_record->dataTensor(); + if ((data_tensor->isDoubleBuffered() || + data_tensor->isCircularBuffered()) && + insertion_info.address_compute_record->isWrite()) { + // Insert double buffer index update if it is a double buffered write: + // The insertion info loop nest starts with the serial loop, + // in the double buffer update we need to insert into the original + // serial loop itself, so remove the outermost level. + auto db_loop_nest = std::vector( + std::next(insertion_info.loop_nest.begin()), + insertion_info.loop_nest.end()); + + auto db_outer_inner = scope_utils::makeLoopNest(db_loop_nest); + + auto& db_info = GpuLower::current()->doubleBufferInfo(); + + auto db_size_in_byte = SimplifyingIrBuilder::mulExpr( + db_info.getOriginalAllocSize(data_tensor), + SimplifyingIrBuilder::create( + dataTypeSize(data_tensor->dtype()))); + + auto update_expr = SimplifyingIrBuilder::create( + insertion_info.address_compute_record->addressTensor(), + db_size_in_byte, + data_tensor->circularBufferDepth(), + 0, + data_tensor); + + db_outer_inner.second->body().push_back(update_expr); + loop->body().push_back(db_outer_inner.first); + } } } From 5bdeea28064a95ee3ec3b03f0434f8055509f70a Mon Sep 17 00:00:00 2001 From: shmsong Date: Fri, 2 Sep 2022 12:46:18 -0700 Subject: [PATCH 09/17] lift cvta out of main loop --- torch/csrc/jit/codegen/cuda/codegen.cpp | 25 +++++- torch/csrc/jit/codegen/cuda/executor.cpp | 4 +- torch/csrc/jit/codegen/cuda/kernel_ir.h | 10 +++ torch/csrc/jit/codegen/cuda/lower_index.cpp | 8 ++ .../csrc/jit/codegen/cuda/lower_mem_index.cpp | 26 ++++-- torch/csrc/jit/codegen/cuda/lower_mem_index.h | 4 +- torch/csrc/jit/codegen/cuda/runtime/memory.cu | 86 +++++++++++++++++++ torch/csrc/jit/codegen/cuda/type.cpp | 2 + torch/csrc/jit/codegen/cuda/type.h | 1 + 9 files changed, 156 insertions(+), 10 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/codegen.cpp b/torch/csrc/jit/codegen/cuda/codegen.cpp index 67246b877e4109..a3750898b83d25 100644 --- a/torch/csrc/jit/codegen/cuda/codegen.cpp +++ b/torch/csrc/jit/codegen/cuda/codegen.cpp @@ -708,6 +708,16 @@ class CudaKernelGenerator : private OptOutConstDispatch { !(out_tv->isDoubleBuffered() || out_tv->isCircularBuffered())) { // Vectorized initialization indent() << varName(out_tv) << ".set(" << gen(uop->in()) << ");\n"; + } else if ( + uop->out()->isA() && + uop->out()->as()->useSmemAddress()) { + auto ti = uop->out()->as(); + // Special case branch for smem reset + // FIXME: only support filling zero at the moment: + indent() << "smemReset<" << ti->view()->dtype() << "," + << vector_word_size << ">(" << gen(ti->baseAddress()) + << "+" << genTensorAddressIndex(ti, ti->view()->dtype()) + << ");\n"; } else { // Note: currently arraySet option is not vectorized, so it will // rely on auto vectorization pass of cuda compiler. @@ -2500,9 +2510,18 @@ class CudaKernelGenerator : private OptOutConstDispatch { << gen(address_compute->doubleBufferByteSize()) << ");\n"; } else { indent() << "//Base Address:::\n"; - indent() << gen(address_compute->addressTv()) << " = (DataPointer) &" - << gen(address_compute->dataTv()->as()) - << ";\n"; + indent() << gen(address_compute->addressTv()); + + if (address_compute->addressTv()->dtype() == DataType::Pointer) { + code_ << " = (DataPointer) &" + << gen(address_compute->dataTv()->as()) + << ";\n"; + } else if ( + address_compute->addressTv()->dtype() == DataType::SmemAddress) { + code_ << " = Turing::util::toSmem(&" + << gen(address_compute->dataTv()->as()) + << ");\n"; + } } } diff --git a/torch/csrc/jit/codegen/cuda/executor.cpp b/torch/csrc/jit/codegen/cuda/executor.cpp index 3716e0926c9c62..eb47038a440828 100644 --- a/torch/csrc/jit/codegen/cuda/executor.cpp +++ b/torch/csrc/jit/codegen/cuda/executor.cpp @@ -55,6 +55,7 @@ typedef unsigned int uint32_t; typedef long long int int64_t; typedef unsigned long long int uint64_t; typedef char* DataPointer; +typedef unsigned SmemAddress; )"; } @@ -1100,6 +1101,8 @@ void FusionExecutor::compileRtc( const std::string& code, const std::string& name, bool structured) { + options_ = CompileOptions(); + options_.index_mode = KernelIndexMode::INT32; FUSER_PERF_SCOPE("ExecutorRunFusion::compileRtc"); std::string scode; if (!structured) { @@ -1108,7 +1111,6 @@ void FusionExecutor::compileRtc( scode = code; } fusion_id_ = 1; - options_ = CompileOptions(); std::tie(compiled_kernel_, last_compiler_log_) = executor_utils::nvrtcCompile(scode, name, fusion_id_); diff --git a/torch/csrc/jit/codegen/cuda/kernel_ir.h b/torch/csrc/jit/codegen/cuda/kernel_ir.h index feed43c1298799..4abc14902793e4 100644 --- a/torch/csrc/jit/codegen/cuda/kernel_ir.h +++ b/torch/csrc/jit/codegen/cuda/kernel_ir.h @@ -183,11 +183,21 @@ class TORCH_CUDA_CU_API TensorIndex final : public Val { return uniform_address_; } + bool useSmemAddress() const { + return use_smem_address_; + } + + TensorIndex* toSmemAddress() { + use_smem_address_ = true; + return this; + } + private: const TensorView* view_ = nullptr; std::vector indices_; Val* base_address_ = nullptr; Val* uniform_address_ = nullptr; + bool use_smem_address_ = false; }; //! Allocate is a lower level Node that describes a buffer of memory that diff --git a/torch/csrc/jit/codegen/cuda/lower_index.cpp b/torch/csrc/jit/codegen/cuda/lower_index.cpp index 6d31b586f19fe4..8b7c538b8c4b03 100644 --- a/torch/csrc/jit/codegen/cuda/lower_index.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_index.cpp @@ -121,6 +121,14 @@ void IndexLowering::handle(const UnaryOp* uop) { } const auto in = lowerSrcIndex(uop->in(), uop->out()); const auto out = lowerDstIndex(uop->out()); + + if (ir_utils::isCpAsyncInit(uop)) { + auto out_tv = ir_utils::getTvOutput(uop); + if (out_tv->shouldLiftWriteAddress()) { + out->as()->toSmemAddress(); + } + } + pushBack(IrBuilder::create( uop->getUnaryOpType(), out, in, uop->getRNGOffset())); GpuLower::current()->propagateExprInfo(uop, back()); diff --git a/torch/csrc/jit/codegen/cuda/lower_mem_index.cpp b/torch/csrc/jit/codegen/cuda/lower_mem_index.cpp index 11bd02c841ec36..670d6507d24b1d 100644 --- a/torch/csrc/jit/codegen/cuda/lower_mem_index.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_mem_index.cpp @@ -900,9 +900,6 @@ void AddressComputeInfo::makeAddressRecord( isSeparable(reference_tv, serial_id, contig_merged_ids), "The serial id is required to be separable for the index lifting to work."); - // Create address record: - auto address_tv = makeAddressTv(alloc_ids_vec, !is_shared_mem_access); - // Assuming we are only having two scenarios, // either accessing a consumer in the consumer's loop, // or accessing the producer in producer's loop. @@ -910,6 +907,20 @@ void AddressComputeInfo::makeAddressRecord( ? AddressRecord::ReadWrite::WRITE : AddressRecord::ReadWrite::READ; + bool is_cp_async_write = + access_direction == AddressRecord::ReadWrite::WRITE && + ir_utils::isCpAsyncOp(data_tv->definition()); + + // Place holder for predicate lifting PR. + bool is_predicate_record = false; + + // Create address record: + auto address_tv = makeAddressTv( + alloc_ids_vec, + !is_shared_mem_access, + is_predicate_record, + is_cp_async_write); + TORCH_INTERNAL_ASSERT( serial_id != nullptr, "no support yet for global scope hoisting"); @@ -953,8 +964,13 @@ c10::optional AddressComputeInfo::getMaybeLiftedAddress( TensorView* AddressComputeInfo::makeAddressTv( std::vector address_domains, - bool is_global_address) { - DataType dtype = DataType::Pointer; + bool is_global_address, + bool is_predicate_index, + bool is_cpasync_write) { + DataType dtype = is_predicate_index ? DataType::Index : DataType::Pointer; + if (is_cpasync_write) { + dtype = DataType::SmemAddress; + } return IrBuilder::create( IrBuilder::create( address_domains, std::vector(address_domains.size(), true)), diff --git a/torch/csrc/jit/codegen/cuda/lower_mem_index.h b/torch/csrc/jit/codegen/cuda/lower_mem_index.h index 4c5e5449f12af5..d17e28043a119a 100644 --- a/torch/csrc/jit/codegen/cuda/lower_mem_index.h +++ b/torch/csrc/jit/codegen/cuda/lower_mem_index.h @@ -166,7 +166,9 @@ class AddressComputeInfo { // Utility to help allocate space for saving pre-computed address. TensorView* makeAddressTv( std::vector address_domains, - bool is_global_address); + bool is_global_address, + bool is_predicate_index, + bool is_cpasync_write = false); void makeAddressRecord(TensorView* data_tv, TensorView* reference_tv); diff --git a/torch/csrc/jit/codegen/cuda/runtime/memory.cu b/torch/csrc/jit/codegen/cuda/runtime/memory.cu index f77f6f5549fa01..395adab9868fed 100644 --- a/torch/csrc/jit/codegen/cuda/runtime/memory.cu +++ b/torch/csrc/jit/codegen/cuda/runtime/memory.cu @@ -288,6 +288,34 @@ DEVICE_INLINE void cpAsync( gmem_ptr -= gmem_index; } +// Global to SMEM load that is asynchronous, +// not guaranteed to be completed until cpAsyncBarrier() is called. +template +DEVICE_INLINE void cpAsync( + nvfuser_index_t smem_index, + unsigned smem_addr, + nvfuser_index_t gmem_index, + DataPointer& gmem_ptr, + bool predicate) { + constexpr int byte_size = sizeof(dtype) * len; + + static_assert( + byte_size == 4 || byte_size == 8 || byte_size == 16, + "cp_async : unsupported byte size"); + + gmem_ptr += gmem_index; + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %3, 0;\n" + "@p cp.async.ca.shared.global [%0], [%1], %2;\n" + "}\n" ::"r"(smem_addr + (unsigned)smem_index), + "l"(gmem_ptr), + "n"(byte_size), + "r"((int)predicate)); + gmem_ptr -= gmem_index; +} + // TODO: Might have a different category of sync if we want to build out this: DEVICE_INLINE void cpAsyncBarrier() { asm volatile("cp.async.wait_all;"); @@ -330,6 +358,26 @@ DEVICE_INLINE void doubleBufferUpdate( data_buffer += increment; } +template +DEVICE_INLINE void doubleBufferUpdate( + unsigned& data_buffer, + const nvfuser_index_t& loop_index, + nvfuser_index_t buffer_size) { + // static_assert( + // loop_offset < number_of_stage && loop_offset > -number_of_stage); + + // convert offset to [0, number_of_stage) + constexpr nvfuser_index_t offset = + loop_offset < 0 ? (loop_offset + number_of_stage) : loop_offset; + + // Rewind back at number_of_stage-1, otherwise increment by 1. + nvfuser_index_t increment = + (loop_index % number_of_stage) == (number_of_stage - 1 - offset) + ? buffer_size * (-number_of_stage + 1) + : buffer_size; + data_buffer += increment; +} + // Update double buffer offset value for smem double buffered tensors. template DEVICE_INLINE void doubleBufferSwitch( @@ -351,4 +399,42 @@ DEVICE_INLINE void doubleBufferSwitch( buffer_offset += increment; } +// Reset smem space to zero +// TODO: try cp.async.ignore-source ? +template +DEVICE_INLINE void smemReset(SmemAddress smem_addr) { + constexpr int byte_size = sizeof(dtype) * len; + + static_assert( + byte_size == 4 || byte_size == 8 || byte_size == 16, + "cp_async : unsupported byte size"); + + switch (byte_size) { + case 4: + asm volatile( + "{\n" + "st.shared.u32 [%0], {%1};\n" + "}\n" + : + : "r"(smem_addr), "r"(0)); + break; + case 8: + asm volatile( + "{\n" + "st.shared.v2.u32 [%0], {%1, %2};\n" + "}\n" + : + : "r"(smem_addr), "r"(0), "r"(0)); + break; + case 16: + asm volatile( + "{\n" + "st.shared.v4.u32 [%0], {%1, %2, %3, %4};\n" + "}\n" + : + : "r"(smem_addr), "r"(0), "r"(0), "r"(0), "r"(0)); + break; + } +} + #undef DEVICE_INLINE diff --git a/torch/csrc/jit/codegen/cuda/type.cpp b/torch/csrc/jit/codegen/cuda/type.cpp index 351244e66ff313..339b9825adaa29 100644 --- a/torch/csrc/jit/codegen/cuda/type.cpp +++ b/torch/csrc/jit/codegen/cuda/type.cpp @@ -228,6 +228,8 @@ static const char* data_type2string(DataType t) { return "nvfuser_index_t"; case DataType::Pointer: return "DataPointer"; + case DataType::SmemAddress: + return "SmemAddress"; case DataType::Int32: return "int"; case DataType::ComplexFloat: diff --git a/torch/csrc/jit/codegen/cuda/type.h b/torch/csrc/jit/codegen/cuda/type.h index d79c85f4b0113d..29d689559a7698 100644 --- a/torch/csrc/jit/codegen/cuda/type.h +++ b/torch/csrc/jit/codegen/cuda/type.h @@ -67,6 +67,7 @@ enum class DataType { Int, Index, Pointer, + SmemAddress, Int32, Bool, BFloat16, From 578dcfe96de989e02ce47f564c96f95c812d92f6 Mon Sep 17 00:00:00 2001 From: shmsong Date: Fri, 2 Sep 2022 20:31:02 -0700 Subject: [PATCH 10/17] increment gmem load --- torch/csrc/jit/codegen/cuda/codegen.cpp | 8 ++++ torch/csrc/jit/codegen/cuda/index_compute.cpp | 15 ++++++- torch/csrc/jit/codegen/cuda/kernel_ir.cpp | 15 +++++++ torch/csrc/jit/codegen/cuda/kernel_ir.h | 20 +++++++-- .../jit/codegen/cuda/lower_double_buffer.cpp | 9 ++++ torch/csrc/jit/codegen/cuda/lower_index.cpp | 13 ++++++ .../jit/codegen/cuda/lower_index_compute.cpp | 45 ++++++++++++++++++- .../csrc/jit/codegen/cuda/lower_mem_index.cpp | 41 +++++++++++++++++ torch/csrc/jit/codegen/cuda/runtime/memory.cu | 4 +- 9 files changed, 162 insertions(+), 8 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/codegen.cpp b/torch/csrc/jit/codegen/cuda/codegen.cpp index a3750898b83d25..019b41f9c08290 100644 --- a/torch/csrc/jit/codegen/cuda/codegen.cpp +++ b/torch/csrc/jit/codegen/cuda/codegen.cpp @@ -2508,6 +2508,14 @@ class CudaKernelGenerator : private OptOutConstDispatch { << gen(address_compute->addressTv()) << "," << gen(address_compute->loopIndex()) << "," << gen(address_compute->doubleBufferByteSize()) << ");\n"; + } else if ( + address_compute->opType() == + kir::AddressCompute::AddressComputeOpType::GMEM_INCREMENT) { + indent() << gen(address_compute->addressTv()) << "+=" + << genTensorAddressIndex( + address_compute->incrementValue(), + address_compute->dataTv()->dtype()) + << ";\n"; } else { indent() << "//Base Address:::\n"; indent() << gen(address_compute->addressTv()); diff --git a/torch/csrc/jit/codegen/cuda/index_compute.cpp b/torch/csrc/jit/codegen/cuda/index_compute.cpp index 6cdcc5be36a89e..e333d7aabefc41 100644 --- a/torch/csrc/jit/codegen/cuda/index_compute.cpp +++ b/torch/csrc/jit/codegen/cuda/index_compute.cpp @@ -1623,7 +1623,20 @@ std::vector Index::getGlobalProducerStridedIndices( loops, root_dom[i])) { // Add the "predicate peeling offset", see [Predicate Peeling] // to the tensor index if this root domain is predicate peeled. - if (tile_entry.value().peel_stage != PredicatePeelStage::Prolog && + + // Incremental mode should add offset at prolog, + // inline mode should be all instances except prolog. + bool is_increment = + std::any_of(loops.begin(), loops.end(), [](kir::ForLoop* fl) { + return fl->loopTransformInfo().is_increment_loop; + }); + + bool should_add_offset = + (tile_entry.value().peel_stage != PredicatePeelStage::Prolog && + !producer_tv->shouldLiftReadAddress()) || + (tile_entry.value().peel_stage == PredicatePeelStage::Prolog && + producer_tv->shouldLiftReadAddress() && is_increment); + if (should_add_offset && !tile_entry.value() .for_loop->loopTransformInfo() .is_base_index_loop) { diff --git a/torch/csrc/jit/codegen/cuda/kernel_ir.cpp b/torch/csrc/jit/codegen/cuda/kernel_ir.cpp index 0ee3cd339173de..99f4d310f6bb19 100644 --- a/torch/csrc/jit/codegen/cuda/kernel_ir.cpp +++ b/torch/csrc/jit/codegen/cuda/kernel_ir.cpp @@ -138,6 +138,21 @@ AddressCompute::AddressCompute( "IR type only valid for Kernel container."); } +AddressCompute::AddressCompute( + IrBuilderPasskey passkey, + Val* address_tensor, + Val* data_tensor, + TensorIndex* increment_value) + : Expr(passkey, ExprType::AddressCompute), + op_type_(AddressCompute::AddressComputeOpType::GMEM_INCREMENT), + data_tensor_(data_tensor), + address_tensor_(address_tensor), + increment_value_(increment_value) { + TORCH_INTERNAL_ASSERT( + passkey.ir_container_->isA(), + "IR type only valid for Kernel container."); +} + AddressCompute::AddressCompute( IrBuilderPasskey passkey, TensorView* data_tv, diff --git a/torch/csrc/jit/codegen/cuda/kernel_ir.h b/torch/csrc/jit/codegen/cuda/kernel_ir.h index 4abc14902793e4..1966022332e724 100644 --- a/torch/csrc/jit/codegen/cuda/kernel_ir.h +++ b/torch/csrc/jit/codegen/cuda/kernel_ir.h @@ -319,7 +319,8 @@ class TORCH_CUDA_CU_API AddressCompute final : public Expr { BASE_ADDRESS, INCREMENT, DOUBLE_BUFFER_SWITCH, - DOUBLE_BUFFER_UPDATE + DOUBLE_BUFFER_UPDATE, + GMEM_INCREMENT }; explicit AddressCompute( @@ -328,6 +329,13 @@ class TORCH_CUDA_CU_API AddressCompute final : public Expr { Val* address_tensor, Val* data_tensor); + // Interface for gmem increment + explicit AddressCompute( + IrBuilderPasskey passkey, + Val* address_tensor, + Val* data_tensor, + TensorIndex* increment_value = nullptr); + // Interface for double buffer offset // calculation: explicit AddressCompute( @@ -382,6 +390,10 @@ class TORCH_CUDA_CU_API AddressCompute final : public Expr { return loop_index_; } + auto incrementValue() const { + return increment_value_; + } + private: // The type of computation this op computes, // currently only do compute address. @@ -395,15 +407,15 @@ class TORCH_CUDA_CU_API AddressCompute final : public Expr { // data tensor. Val* address_tensor_ = nullptr; - // Tensor that holds the value to increment (INCREMENT MODE only). - Val* inc_value_ = nullptr; - // Double buffer switch and update parameters: Val* double_buffer_switch_index_ = nullptr; Val* buffer_size_in_byte_ = nullptr; int loop_offset_ = 0; int stage_number_ = 0; Val* loop_index_ = nullptr; + + // Gmem increment parameters + kir::TensorIndex* increment_value_ = nullptr; }; // Synchronize all blocks in device, implies cooperative group launch is diff --git a/torch/csrc/jit/codegen/cuda/lower_double_buffer.cpp b/torch/csrc/jit/codegen/cuda/lower_double_buffer.cpp index 4cdee398e5d897..e34ff746d9887f 100644 --- a/torch/csrc/jit/codegen/cuda/lower_double_buffer.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_double_buffer.cpp @@ -343,6 +343,15 @@ class DoubleBufferLoopCloner : public kir::IrVisitor { } } } + + if (loop_type_ != DoubleBufferLoopStage::CircularInitProlog) { + if (auto address_compute = dynamic_cast(expr)) { + if (address_compute->opType() == + kir::AddressCompute::AddressComputeOpType::GMEM_INCREMENT) { + cloned_scopes_.back()->push_back(expr); + } + } + } } //! Returns true if the expression is an initialization expr that diff --git a/torch/csrc/jit/codegen/cuda/lower_index.cpp b/torch/csrc/jit/codegen/cuda/lower_index.cpp index 8b7c538b8c4b03..a8966121f1592b 100644 --- a/torch/csrc/jit/codegen/cuda/lower_index.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_index.cpp @@ -933,6 +933,19 @@ void IndexLowering::handle(const kir::AddressCompute* address_compute) { auto address_record = maybe_address_record.value(); + if (address_compute->opType() == + kir::AddressCompute::AddressComputeOpType::GMEM_INCREMENT) { + pushBack(IrBuilder::create( + Index::generateAddressTensorIndex( + for_loops_, address_compute->addressTv()->as()), + address_compute->dataTv(), + lowerSrcIndex( + address_record->dataTensor(), + address_record->indexReferenceTensor()) + ->as())); + return; + } + Val* lowered_data_index = nullptr; if (address_record->isRead()) { diff --git a/torch/csrc/jit/codegen/cuda/lower_index_compute.cpp b/torch/csrc/jit/codegen/cuda/lower_index_compute.cpp index 5f75c10c2876cf..978dca52c8a3e0 100644 --- a/torch/csrc/jit/codegen/cuda/lower_index_compute.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_index_compute.cpp @@ -123,9 +123,24 @@ std::unordered_set getZeroIdSetsForAddressCompute( // Checks if this loop nest is calculating base address. bool is_address_tv_calculation = serial_loop->isBaseIndexLoop(); + bool is_increment = + std::any_of(loops.begin(), loops.end(), [](kir::ForLoop* fl) { + return fl->loopTransformInfo().is_increment_loop; + }); std::unordered_set zero_ids; + if (is_increment) { + for (auto fl : loops) { + if (fl != serial_loop) { + zero_ids.insert(fl->iter_domain()); + } + } + // Zero everything except the serial loop + // in the case of increment gmem iterator. + return zero_ids; + } + for (auto outer_loop_it = loops.begin(); outer_loop_it != loop_it; outer_loop_it++) { auto outer_loop = *outer_loop_it; @@ -189,6 +204,13 @@ std::unordered_set getZeroIdSetsForAddressCompute( loop_it++; } + if (address_record->isRead() && + address_record->dataTensor()->getMemoryType() == MemoryType::Global) { + // The serial loop id is incremented by the address compute, + // so it could be zeroed here. + zero_ids.insert(address_record->getConcreteSerialLoopId()); + } + return zero_ids; } @@ -230,6 +252,13 @@ IndexingParameters getGlobalIndexParameters( maybe_address_record.value(), loop_indexing.loops()); } + bool is_increment = std::any_of( + loop_indexing.loops().begin(), + loop_indexing.loops().end(), + [](kir::ForLoop* fl) { + return fl->loopTransformInfo().is_increment_loop; + }); + auto& loops = loop_indexing.loops(); auto& loop_domain = loop_indexing.loopDomains(); auto& loop_index_map = index_parameters.initial_concrete_id_index; @@ -253,6 +282,20 @@ IndexingParameters getGlobalIndexParameters( // Default use pre-allocated integers for index loop_index_map[index_domain] = loop->index(); } + + if (is_increment) { + TORCH_INTERNAL_ASSERT(maybe_address_record.has_value()); + if (GpuLower::current()->caMap()->areMapped( + concrete_loop_domain, + maybe_address_record.value()->getConcreteSerialLoopId(), + IdMappingMode::LOOP)) { + // TODO: + // The current restriction on the serial loop makes this ok + // but should eventually use the f(i+1) - f(i) instead + // of a one for the increment calculation. + loop_index_map[index_domain] = GpuLower::current()->kernel()->oneVal(); + } + } } // Derive the halo extents from the loop indexing result. @@ -267,7 +310,7 @@ IndexingParameters getGlobalIndexParameters( // Setup double buffer increment for producer case: // TODO: could unify these double buffer index calculation // in follow ups. - if (index_producer) { + if (index_producer && !maybe_address_record.has_value()) { auto double_buffer_loop = GpuLower::current()->doubleBufferInfo().getDoubleBufferLoop( loop_indexing.consumerTv(), loops, true); diff --git a/torch/csrc/jit/codegen/cuda/lower_mem_index.cpp b/torch/csrc/jit/codegen/cuda/lower_mem_index.cpp index 670d6507d24b1d..722bc908b5e0a5 100644 --- a/torch/csrc/jit/codegen/cuda/lower_mem_index.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_mem_index.cpp @@ -1077,6 +1077,8 @@ class MemoryAddressComputeInserter : public kir::ExprMutator { registerInsertBefore(loop, outermost_innermost.first); auto data_tensor = insertion_info.address_compute_record->dataTensor(); + + // Insert double buffer increment if ((data_tensor->isDoubleBuffered() || data_tensor->isCircularBuffered()) && insertion_info.address_compute_record->isWrite()) { @@ -1107,6 +1109,22 @@ class MemoryAddressComputeInserter : public kir::ExprMutator { db_outer_inner.second->body().push_back(update_expr); loop->body().push_back(db_outer_inner.first); } + + // Insert gmem increment: + if (data_tensor->getMemoryType() == MemoryType::Global && + insertion_info.address_compute_record->isRead()) { + auto increment_loop_vector = + createIncrementLoop(insertion_info.loop_nest); + auto increment_loop_outer_inner = + scope_utils::makeLoopNest(increment_loop_vector); + + auto inc_expr = SimplifyingIrBuilder::create( + insertion_info.address_compute_record->addressTensor(), + insertion_info.address_compute_record->dataTensor()); + + increment_loop_outer_inner.second->body().push_back(inc_expr); + loop->body().push_back(increment_loop_outer_inner.first); + } } } @@ -1140,6 +1158,29 @@ class MemoryAddressComputeInserter : public kir::ExprMutator { original_loop->loopTransformInfo().baseIndexLoop()); } + std::vector createIncrementLoop( + std::vector address_compute_loop_vector) { + std::vector loop_nest_to_clone( + std::next(address_compute_loop_vector.begin()), + address_compute_loop_vector.end()); + + std::vector cloned_loop_nest; + for (auto fl : loop_nest_to_clone) { + cloned_loop_nest.push_back(IrBuilder::create( + fl->iter_domain(), + fl->index(), + fl->start(), + fl->stop(), + fl->step(), + fl->vectorize(), + fl->vectorize_shift(), + fl->isUnrollRequired(), + fl->loopTransformInfo().incrementLoop())); + } + + return cloned_loop_nest; + } + std::vector createAddressComputeLoop( AddressRecord* address_record) { // Find the loop in the current loop nest that maps the concrete serial loop diff --git a/torch/csrc/jit/codegen/cuda/runtime/memory.cu b/torch/csrc/jit/codegen/cuda/runtime/memory.cu index 395adab9868fed..d6f9a6d173c3fb 100644 --- a/torch/csrc/jit/codegen/cuda/runtime/memory.cu +++ b/torch/csrc/jit/codegen/cuda/runtime/memory.cu @@ -303,7 +303,7 @@ DEVICE_INLINE void cpAsync( byte_size == 4 || byte_size == 8 || byte_size == 16, "cp_async : unsupported byte size"); - gmem_ptr += gmem_index; + // gmem_ptr += gmem_index; asm volatile( "{\n" " .reg .pred p;\n" @@ -313,7 +313,7 @@ DEVICE_INLINE void cpAsync( "l"(gmem_ptr), "n"(byte_size), "r"((int)predicate)); - gmem_ptr -= gmem_index; + // gmem_ptr -= gmem_index; } // TODO: Might have a different category of sync if we want to build out this: From 9f731a3c9cb28836c097583f39a109b59f7e7666 Mon Sep 17 00:00:00 2001 From: shmsong Date: Sun, 11 Sep 2022 16:14:44 -0700 Subject: [PATCH 11/17] [hack] decrement index --- torch/csrc/jit/codegen/cuda/codegen.cpp | 8 +++ torch/csrc/jit/codegen/cuda/kernel_ir.cpp | 6 +- torch/csrc/jit/codegen/cuda/kernel_ir.h | 10 +++- .../jit/codegen/cuda/lower_double_buffer.cpp | 55 +++++++++++++++++++ torch/csrc/jit/codegen/cuda/lower_index.cpp | 7 ++- 5 files changed, 81 insertions(+), 5 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/codegen.cpp b/torch/csrc/jit/codegen/cuda/codegen.cpp index 019b41f9c08290..d948165b54068e 100644 --- a/torch/csrc/jit/codegen/cuda/codegen.cpp +++ b/torch/csrc/jit/codegen/cuda/codegen.cpp @@ -2516,6 +2516,14 @@ class CudaKernelGenerator : private OptOutConstDispatch { address_compute->incrementValue(), address_compute->dataTv()->dtype()) << ";\n"; + } else if ( + address_compute->opType() == + kir::AddressCompute::AddressComputeOpType::GMEM_DECREMENT) { + indent() << gen(address_compute->addressTv()) << "-=" + << genTensorAddressIndex( + address_compute->incrementValue(), + address_compute->dataTv()->dtype()) + << ";\n"; } else { indent() << "//Base Address:::\n"; indent() << gen(address_compute->addressTv()); diff --git a/torch/csrc/jit/codegen/cuda/kernel_ir.cpp b/torch/csrc/jit/codegen/cuda/kernel_ir.cpp index 99f4d310f6bb19..87506ab2e1f4e3 100644 --- a/torch/csrc/jit/codegen/cuda/kernel_ir.cpp +++ b/torch/csrc/jit/codegen/cuda/kernel_ir.cpp @@ -142,12 +142,16 @@ AddressCompute::AddressCompute( IrBuilderPasskey passkey, Val* address_tensor, Val* data_tensor, - TensorIndex* increment_value) + TensorIndex* increment_value, + bool is_decrement) : Expr(passkey, ExprType::AddressCompute), op_type_(AddressCompute::AddressComputeOpType::GMEM_INCREMENT), data_tensor_(data_tensor), address_tensor_(address_tensor), increment_value_(increment_value) { + if (is_decrement) { + op_type_ = AddressCompute::AddressComputeOpType::GMEM_DECREMENT; + } TORCH_INTERNAL_ASSERT( passkey.ir_container_->isA(), "IR type only valid for Kernel container."); diff --git a/torch/csrc/jit/codegen/cuda/kernel_ir.h b/torch/csrc/jit/codegen/cuda/kernel_ir.h index 1966022332e724..269c7bfb396529 100644 --- a/torch/csrc/jit/codegen/cuda/kernel_ir.h +++ b/torch/csrc/jit/codegen/cuda/kernel_ir.h @@ -320,7 +320,8 @@ class TORCH_CUDA_CU_API AddressCompute final : public Expr { INCREMENT, DOUBLE_BUFFER_SWITCH, DOUBLE_BUFFER_UPDATE, - GMEM_INCREMENT + GMEM_INCREMENT, + GMEM_DECREMENT }; explicit AddressCompute( @@ -334,7 +335,8 @@ class TORCH_CUDA_CU_API AddressCompute final : public Expr { IrBuilderPasskey passkey, Val* address_tensor, Val* data_tensor, - TensorIndex* increment_value = nullptr); + TensorIndex* increment_value = nullptr, + bool is_decrement = false); // Interface for double buffer offset // calculation: @@ -394,6 +396,10 @@ class TORCH_CUDA_CU_API AddressCompute final : public Expr { return increment_value_; } + bool isDecrement() const { + return op_type_ == AddressComputeOpType::GMEM_DECREMENT; + } + private: // The type of computation this op computes, // currently only do compute address. diff --git a/torch/csrc/jit/codegen/cuda/lower_double_buffer.cpp b/torch/csrc/jit/codegen/cuda/lower_double_buffer.cpp index e34ff746d9887f..64aadd9fd911ff 100644 --- a/torch/csrc/jit/codegen/cuda/lower_double_buffer.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_double_buffer.cpp @@ -150,6 +150,39 @@ bool requireEpilogue(const std::vector& exprs) { }); } +bool isGmemIncrement(Expr* expr) { + if (auto loop = dynamic_cast(expr)) { + if (loop->body().exprs().size() != 1) { + return false; + } + return isGmemIncrement(loop->body().exprs()[0]); + } else if (auto address_compute = dynamic_cast(expr)) { + return address_compute->opType() == + kir::AddressCompute::AddressComputeOpType::GMEM_INCREMENT; + } + return false; +} + +kir::ForLoop* hoistGmemIncrement(kir::ForLoop* fl) { + auto hoisted_loop = IrBuilder::create(fl); + + // insert all gmem increment exprs + for (auto expr : fl->body().exprs()) { + if (isGmemIncrement(expr)) { + hoisted_loop->body().push_back(expr); + } + } + + // insert all non gmem increment exprs + for (auto expr : fl->body().exprs()) { + if (!isGmemIncrement(expr)) { + hoisted_loop->body().push_back(expr); + } + } + + return hoisted_loop; +} + // Replicates double buffer loops for Prologue, Main, and // Epilogue. Prologue only copies the load expressions of double // buffered tensors, whereas Epilogue does any expression other than @@ -261,6 +294,14 @@ class DoubleBufferLoopCloner : public kir::IrVisitor { cloned_top_level_loop_->body().push_back( IrBuilder::create()); } + + if (loop_type_ == DoubleBufferLoopStage::Main && + std::any_of( + double_buffer_loop_->body().exprs().begin(), + double_buffer_loop_->body().exprs().end(), + isGmemIncrement)) { + cloned_top_level_loop_ = hoistGmemIncrement(cloned_top_level_loop_); + } } void handle(kir::ForLoop* fl) final { @@ -334,6 +375,20 @@ class DoubleBufferLoopCloner : public kir::IrVisitor { } } + if (loop_type_ == DoubleBufferLoopStage::CircularInitProlog) { + if (auto address_compute = dynamic_cast(expr)) { + if (address_compute->opType() == + kir::AddressCompute::AddressComputeOpType::GMEM_INCREMENT) { + cloned_scopes_.back()->push_back( + IrBuilder::create( + address_compute->addressTv(), + address_compute->dataTv(), + address_compute->incrementValue(), + true /* is_decrement */)); + } + } + } + // Need the double buffer update expr in prologs too. if (loop_type_ == DoubleBufferLoopStage::Prolog) { if (auto address_compute = dynamic_cast(expr)) { diff --git a/torch/csrc/jit/codegen/cuda/lower_index.cpp b/torch/csrc/jit/codegen/cuda/lower_index.cpp index a8966121f1592b..a247bc1e55b285 100644 --- a/torch/csrc/jit/codegen/cuda/lower_index.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_index.cpp @@ -934,7 +934,9 @@ void IndexLowering::handle(const kir::AddressCompute* address_compute) { auto address_record = maybe_address_record.value(); if (address_compute->opType() == - kir::AddressCompute::AddressComputeOpType::GMEM_INCREMENT) { + kir::AddressCompute::AddressComputeOpType::GMEM_INCREMENT || + address_compute->opType() == + kir::AddressCompute::AddressComputeOpType::GMEM_DECREMENT) { pushBack(IrBuilder::create( Index::generateAddressTensorIndex( for_loops_, address_compute->addressTv()->as()), @@ -942,7 +944,8 @@ void IndexLowering::handle(const kir::AddressCompute* address_compute) { lowerSrcIndex( address_record->dataTensor(), address_record->indexReferenceTensor()) - ->as())); + ->as(), + address_compute->isDecrement())); return; } From 334e81a0138aec4a033c2069bba5ad3f295f2e56 Mon Sep 17 00:00:00 2001 From: shmsong Date: Mon, 19 Sep 2022 14:33:03 -0700 Subject: [PATCH 12/17] rebase fix --- torch/csrc/jit/codegen/cuda/codegen.cpp | 1 + torch/csrc/jit/codegen/cuda/index_compute.cpp | 2 +- torch/csrc/jit/codegen/cuda/kernel_ir.h | 10 ++++++++++ 3 files changed, 12 insertions(+), 1 deletion(-) diff --git a/torch/csrc/jit/codegen/cuda/codegen.cpp b/torch/csrc/jit/codegen/cuda/codegen.cpp index d948165b54068e..4f2bb90a51e4ac 100644 --- a/torch/csrc/jit/codegen/cuda/codegen.cpp +++ b/torch/csrc/jit/codegen/cuda/codegen.cpp @@ -524,6 +524,7 @@ class CudaKernelGenerator : private OptOutConstDispatch { index << " + "; } index << genInline(ti->uniformAddress()); + first = false; } if (first) { diff --git a/torch/csrc/jit/codegen/cuda/index_compute.cpp b/torch/csrc/jit/codegen/cuda/index_compute.cpp index e333d7aabefc41..6f68c45424fb5e 100644 --- a/torch/csrc/jit/codegen/cuda/index_compute.cpp +++ b/torch/csrc/jit/codegen/cuda/index_compute.cpp @@ -1623,7 +1623,7 @@ std::vector Index::getGlobalProducerStridedIndices( loops, root_dom[i])) { // Add the "predicate peeling offset", see [Predicate Peeling] // to the tensor index if this root domain is predicate peeled. - + // Incremental mode should add offset at prolog, // inline mode should be all instances except prolog. bool is_increment = diff --git a/torch/csrc/jit/codegen/cuda/kernel_ir.h b/torch/csrc/jit/codegen/cuda/kernel_ir.h index 269c7bfb396529..d2d5e643a0a8a4 100644 --- a/torch/csrc/jit/codegen/cuda/kernel_ir.h +++ b/torch/csrc/jit/codegen/cuda/kernel_ir.h @@ -542,6 +542,10 @@ struct LoopTransformInfo { //! lifted memory address. bool is_base_index_loop = false; + //! Tracks if this for loop is for calculating inductive variable + //! increments. + bool is_increment_loop = false; + //! Setter API LoopTransformInfo& doubleBufferStage(DoubleBufferLoopStage stage) { double_buffer_loop_stage = stage; @@ -559,6 +563,12 @@ struct LoopTransformInfo { predicate_peel_stage = stage; return *this; } + + // ! Setter API + LoopTransformInfo& incrementLoop() { + is_increment_loop = true; + return *this; + } }; //! ForLoop provides scoping around an int iterator from 0 to range. Exprs From 139368b5abdfea7bca2b89792a890fb68cc3c890 Mon Sep 17 00:00:00 2001 From: shmsong Date: Mon, 19 Sep 2022 16:02:56 -0700 Subject: [PATCH 13/17] clean up --- torch/csrc/jit/codegen/cuda/codegen.cpp | 11 ++++------- torch/csrc/jit/codegen/cuda/executor.cpp | 3 +-- torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp | 7 +++++++ torch/csrc/jit/codegen/cuda/ir_base_nodes.h | 3 +++ 4 files changed, 15 insertions(+), 9 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/codegen.cpp b/torch/csrc/jit/codegen/cuda/codegen.cpp index 4f2bb90a51e4ac..bc028c8483e65b 100644 --- a/torch/csrc/jit/codegen/cuda/codegen.cpp +++ b/torch/csrc/jit/codegen/cuda/codegen.cpp @@ -541,13 +541,6 @@ class CudaKernelGenerator : private OptOutConstDispatch { code_ << "*(volatile " << ti->getDataType().value() << "*)&"; } - if (ti->hasBaseAddress()) { - // WAR path to generate a tensor index with pointer content. - code_ << "reinterpret_cast<" << ti->view()->dtype() << "*>(" - << gen(ti->baseAddress()) << ")" - << "[" << genTensorIndex(ti) << "]"; - return; - } code_ << varName(ti->view()) << "[" << genTensorIndex(ti) << "]"; } @@ -715,6 +708,10 @@ class CudaKernelGenerator : private OptOutConstDispatch { auto ti = uop->out()->as(); // Special case branch for smem reset // FIXME: only support filling zero at the moment: + // could possibly extend. + TORCH_INTERNAL_ASSERT( + uop->in()->isZero(), "only support filling zero in smem reset"); + indent() << "smemReset<" << ti->view()->dtype() << "," << vector_word_size << ">(" << gen(ti->baseAddress()) << "+" << genTensorAddressIndex(ti, ti->view()->dtype()) diff --git a/torch/csrc/jit/codegen/cuda/executor.cpp b/torch/csrc/jit/codegen/cuda/executor.cpp index eb47038a440828..9b2f5bf30bf951 100644 --- a/torch/csrc/jit/codegen/cuda/executor.cpp +++ b/torch/csrc/jit/codegen/cuda/executor.cpp @@ -1101,8 +1101,6 @@ void FusionExecutor::compileRtc( const std::string& code, const std::string& name, bool structured) { - options_ = CompileOptions(); - options_.index_mode = KernelIndexMode::INT32; FUSER_PERF_SCOPE("ExecutorRunFusion::compileRtc"); std::string scode; if (!structured) { @@ -1111,6 +1109,7 @@ void FusionExecutor::compileRtc( scode = code; } fusion_id_ = 1; + options_ = CompileOptions(); std::tie(compiled_kernel_, last_compiler_log_) = executor_utils::nvrtcCompile(scode, name, fusion_id_); diff --git a/torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp b/torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp index 5e4b0a6f2c0f65..be34dedd6b9e98 100644 --- a/torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp @@ -257,6 +257,13 @@ bool Val::isZeroInt() const { return int_val.has_value() && int_val.value() == 0; } +bool Val::isZero() const { + auto int_val = getInt(); + auto double_val = getDouble(); + return (int_val.has_value() && int_val.value() == 0) || + (double_val.has_value() && double_val.value() == 0); +} + bool Val::isOneInt() const { auto int_val = getInt(); return int_val.has_value() && int_val.value() == 1; diff --git a/torch/csrc/jit/codegen/cuda/ir_base_nodes.h b/torch/csrc/jit/codegen/cuda/ir_base_nodes.h index 53969924dd8398..048aa71b0f94aa 100644 --- a/torch/csrc/jit/codegen/cuda/ir_base_nodes.h +++ b/torch/csrc/jit/codegen/cuda/ir_base_nodes.h @@ -284,6 +284,9 @@ class TORCH_CUDA_CU_API Val : public Statement { bool isZeroInt() const; bool isOneInt() const; + // Check zero supporting both int or double. + bool isZero() const; + // Returns the Expr that this value is an output of, returns nullptr if none // was found Expr* definition() const { From 211cc5c508426e51bf8bb428bda2a948d330a1e3 Mon Sep 17 00:00:00 2001 From: shmsong Date: Mon, 19 Sep 2022 19:46:19 -0700 Subject: [PATCH 14/17] comment ; cleanup --- torch/csrc/jit/codegen/cuda/codegen.cpp | 35 ++++++ torch/csrc/jit/codegen/cuda/index_compute.cpp | 23 +++- torch/csrc/jit/codegen/cuda/kernel_ir.h | 36 +++++- .../jit/codegen/cuda/lower_double_buffer.cpp | 105 ++++++++++++++++-- .../jit/codegen/cuda/lower_double_buffer.h | 4 + torch/csrc/jit/codegen/cuda/lower_index.cpp | 18 ++- .../jit/codegen/cuda/lower_index_compute.cpp | 21 +++- .../csrc/jit/codegen/cuda/lower_mem_index.cpp | 89 ++++++++++++++- torch/csrc/jit/codegen/cuda/lower_utils.cpp | 15 +++ torch/csrc/jit/codegen/cuda/lower_utils.h | 4 + torch/csrc/jit/codegen/cuda/runtime/memory.cu | 68 ++---------- 11 files changed, 331 insertions(+), 87 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/codegen.cpp b/torch/csrc/jit/codegen/cuda/codegen.cpp index bc028c8483e65b..e0222313e2534b 100644 --- a/torch/csrc/jit/codegen/cuda/codegen.cpp +++ b/torch/csrc/jit/codegen/cuda/codegen.cpp @@ -504,6 +504,9 @@ class CudaKernelGenerator : private OptOutConstDispatch { return index.str(); } + // Generate the tensor index that are directly added + // to a base address pointer. So all the components + // are computed in units of bytes. std::string genTensorAddressIndex( const kir::TensorIndex* ti, DataType dtype) { @@ -514,11 +517,17 @@ class CudaKernelGenerator : private OptOutConstDispatch { if (!first) { index << " + "; } + + // Multiply all the components here by the size of the data + // type to get byte offset. index << "(" << genInline(ind) << ")*" << dataTypeSize(dtype); first = false; } } + // If there is a uniform component in this tensor index, + // just add them too. + // See also, [Double Buffer Uniform Offset]. if (ti->uniformAddress() != nullptr) { if (!first) { index << " + "; @@ -541,6 +550,14 @@ class CudaKernelGenerator : private OptOutConstDispatch { code_ << "*(volatile " << ti->getDataType().value() << "*)&"; } + if (ti->hasBaseAddress()) { + // WAR path to generate a tensor index with pointer content. + code_ << "reinterpret_cast<" << ti->view()->dtype() << "*>(" + << gen(ti->baseAddress()) << ")" + << "[" << genTensorIndex(ti) << "]"; + return; + } + code_ << varName(ti->view()) << "[" << genTensorIndex(ti) << "]"; } @@ -576,6 +593,12 @@ class CudaKernelGenerator : private OptOutConstDispatch { return ss.str(); } + //! Generates the given value as a pointer address as + //! either: + //! 1. hosted_base_ptr + address_index + //! 2. &Tensor[index] + //! depending on if the given index value carries + //! a hoisted component or not. std::string genMaybeHoistedPointer(const Val* val) { auto ti = dynamic_cast(val); TORCH_INTERNAL_ASSERT(ti != nullptr, "only support tensor index input"); @@ -705,6 +728,11 @@ class CudaKernelGenerator : private OptOutConstDispatch { } else if ( uop->out()->isA() && uop->out()->as()->useSmemAddress()) { + // A special resource string "smemReset" is used if + // the unary op writes to the shared memory using + // the lifted 32b shared mem pointer. + // This mode is reserved for resetting shared memory + // space at the moment currently. auto ti = uop->out()->as(); // Special case branch for smem reset // FIXME: only support filling zero at the moment: @@ -2491,6 +2519,13 @@ class CudaKernelGenerator : private OptOutConstDispatch { } void handle(const kir::AddressCompute* address_compute) final { + // FIXME: + // All the global/shared memory address/offset manipulations + // to reduce register usage are currently lumped into this single + // kernel IR operator. + // + // If there's any need to commit to the current codegen tweaks + // longer, could consider separating them into more IR nodes. if (address_compute->opType() == kir::AddressCompute::AddressComputeOpType::DOUBLE_BUFFER_SWITCH) { indent() << "doubleBufferSwitch<" << address_compute->stageNumber() << "," diff --git a/torch/csrc/jit/codegen/cuda/index_compute.cpp b/torch/csrc/jit/codegen/cuda/index_compute.cpp index 6f68c45424fb5e..4eb73e2005e976 100644 --- a/torch/csrc/jit/codegen/cuda/index_compute.cpp +++ b/torch/csrc/jit/codegen/cuda/index_compute.cpp @@ -1625,7 +1625,7 @@ std::vector Index::getGlobalProducerStridedIndices( // to the tensor index if this root domain is predicate peeled. // Incremental mode should add offset at prolog, - // inline mode should be all instances except prolog. + // See Note [Predicate Peeing interaction with Incremental Offset] bool is_increment = std::any_of(loops.begin(), loops.end(), [](kir::ForLoop* fl) { return fl->loopTransformInfo().is_increment_loop; @@ -1935,6 +1935,16 @@ std::vector Index::getNonGlobalProducerStridedIndices( GpuLower::current()->doubleBufferInfo().getReadSwitchIndex( producer_tv); + // The double buffer switching indices are now applied in two + // different ways, depending on if the index is lifted or not. + // + // When lifted, the double buffer switching index is computed + // separately as a "double buffer offset" and added to the + // uniform section of the tensor index. + // When not lifted, the behavior stays the same as before + // i.e. they are computed inline. + // See also: + // [Double Buffer Uniform Offset]. if (!maybe_read_offset.has_value()) { auto loop_index = db_loop->isTrivial() ? db_loop->start() : db_loop->index(); @@ -2277,7 +2287,8 @@ std::vector Index::getNonGlobalConsumerStridedIndices( if ((consumer_tv->isDoubleBuffered() || consumer_tv->isCircularBuffered()) // Lifted address case the double buffer offset is // computed inplace into the write address buffer. - && !consumer_tv->shouldLiftWriteAddress()) { + // See [Inplace double buffer update] + && !useDirectSmemAddress(consumer_tv)) { auto db_loop = gpu_lower->doubleBufferInfo().getDoubleBufferLoop(consumer_tv, loops); TORCH_INTERNAL_ASSERT( @@ -2367,6 +2378,10 @@ kir::TensorIndex* Index::getProducerIndex( const std::vector& loops) { auto strided_indices = getProducerStridedIndices(producer, consumer, loops); + // Insert base address and uniform components into the tensor + // index object directly to support separating them on the + // code gen interface. + // See also: [Pointer Addressing In Lifted Indices] if (shouldUseLiftedAddress(producer, consumer, loops)) { auto maybe_address_record = GpuLower::current()->addressComputeInfo().getMaybeLiftedAddress( @@ -2414,6 +2429,10 @@ kir::TensorIndex* Index::getConsumerIndex( const std::vector& loops) { auto strided_indices = getConsumerStridedIndices(consumer, loops); + // Insert base address and uniform components into the tensor + // index object directly to support separating them on the + // code gen interface. + // See also: [Pointer Addressing In Lifted Indices] if (shouldUseLiftedAddress(consumer, consumer, loops)) { auto maybe_address_record = GpuLower::current()->addressComputeInfo().getMaybeLiftedAddress( diff --git a/torch/csrc/jit/codegen/cuda/kernel_ir.h b/torch/csrc/jit/codegen/cuda/kernel_ir.h index d2d5e643a0a8a4..fa6470589caff1 100644 --- a/torch/csrc/jit/codegen/cuda/kernel_ir.h +++ b/torch/csrc/jit/codegen/cuda/kernel_ir.h @@ -316,21 +316,31 @@ class TORCH_CUDA_CU_API CpAsyncCommit final : public Expr { class TORCH_CUDA_CU_API AddressCompute final : public Expr { public: enum class AddressComputeOpType { + // Calculate base address for lifted memory index BASE_ADDRESS, - INCREMENT, + // Switch a double buffer index register, + // see [Uniform Double Buffer Offset] DOUBLE_BUFFER_SWITCH, + // Inplace update a double buffered address + // see [Inplace Double Buffer Update] DOUBLE_BUFFER_UPDATE, + // Inplace increment a global address, see + // see [Gmem address increment] GMEM_INCREMENT, + // Inplace increment a global address, see + // see [Gmem Increment Hoisting] GMEM_DECREMENT }; + // Constructor for BASE_ADDRESS mode calculation + // (Default). explicit AddressCompute( IrBuilderPasskey passkey, AddressComputeOpType op_type, Val* address_tensor, Val* data_tensor); - // Interface for gmem increment + // Constructor for gmem increment explicit AddressCompute( IrBuilderPasskey passkey, Val* address_tensor, @@ -338,7 +348,7 @@ class TORCH_CUDA_CU_API AddressCompute final : public Expr { TensorIndex* increment_value = nullptr, bool is_decrement = false); - // Interface for double buffer offset + // Constructor for double buffer offset // calculation: explicit AddressCompute( IrBuilderPasskey passkey, @@ -349,7 +359,7 @@ class TORCH_CUDA_CU_API AddressCompute final : public Expr { int stage_number, Val* loop_index = nullptr); - // Interface for double buffer offset + // Constructor for double buffer offset // inplace update: explicit AddressCompute( IrBuilderPasskey passkey, @@ -413,14 +423,28 @@ class TORCH_CUDA_CU_API AddressCompute final : public Expr { // data tensor. Val* address_tensor_ = nullptr; - // Double buffer switch and update parameters: + // Double buffer switch and update parameters below: + + // The switching index that this op is updating. Val* double_buffer_switch_index_ = nullptr; + + // The original buffer alloc size used for double buffer + // update calculation. Val* buffer_size_in_byte_ = nullptr; + + // The double buffer loop offset that is used for + // computing the double buffer size update. int loop_offset_ = 0; + + // The double buffer loop offset that is used for + // computing the double buffer size update. int stage_number_ = 0; + + // The double buffer loop index. Val* loop_index_ = nullptr; - // Gmem increment parameters + // Gmem increment parameters below: + // The increment value to apply to the pointer. kir::TensorIndex* increment_value_ = nullptr; }; diff --git a/torch/csrc/jit/codegen/cuda/lower_double_buffer.cpp b/torch/csrc/jit/codegen/cuda/lower_double_buffer.cpp index 64aadd9fd911ff..efc695b8b4f3fd 100644 --- a/torch/csrc/jit/codegen/cuda/lower_double_buffer.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_double_buffer.cpp @@ -163,6 +163,36 @@ bool isGmemIncrement(Expr* expr) { return false; } +//! Hoists the gmem increment ops to the beginning of the loop +//! within the scope of the given loop. +//! Note: [Gmem Increment Hoisting] +//! +//! This optimization is very useful when inplace increment +//! is used on the global memory pointers. +//! Before this optimization, the code would look like: +//! +//! for i in ... // main loop +//! load.global ... [ptr] +//! // Here we actually have an anti-dependency (WAR) on +//! // the register holding ptr and could result in +//! // non-ideal performance when we do not have enough +//! // instructions to put between the load and the increment. +//! // depending on how many other instructions we have +//! // within this loop. +//! ptr += increment_value +//! +//! After this transformation, the code looks like: +//! ptr -=increment_value // a naive way to compensate +//! // for the first iter. +//! for i in ... // main loop +//! ptr += increment_value +//! // This is actually ok as integer instructions +//! // are usually much faster than memory. +//! load.global ... [ptr] +//! +//! This function hoists the pointer increments, in the given +//! loop, assuming that the decrements have been inserted +//! on the CircularInitProlog stage. kir::ForLoop* hoistGmemIncrement(kir::ForLoop* fl) { auto hoisted_loop = IrBuilder::create(fl); @@ -265,15 +295,32 @@ class DoubleBufferLoopCloner : public kir::IrVisitor { for (auto load : double_buffer_load_exprs_) { if (auto tv_out = ir_utils::getTvOutput(load)) { + // calculate the switching size + auto switch_size = db_info.getOriginalAllocSize(tv_out); + auto switch_size_in_byte = SimplifyingIrBuilder::mulExpr( + switch_size, + SimplifyingIrBuilder::create(dataTypeSize(tv_out->dtype()))); + + // insert db switch expressions: + // Note:[Uniform Double Buffer Offset] + // This modification is to encourage usage of uniform registers on + // sm75+ when + // accessing shared memory double buffered tensors. + // The code before transformation: + // for i in ... // double buffer loop + // ... = ld.shared [... + (i%5) * double_buffer_size] + // The above code doesn't explictly specify that the double buffer + // switch + // component is uniform. The following transformed code makes it + // explicit: + // for i in ... // double buffer loop + // ... = ld.shared [... + switch_index] + // doubleBufferSwitch(switch_index); + // So that the double buffer indices are all placed in uniform reg. + auto maybe_read_index = db_info.getReadSwitchIndex(tv_out); if (maybe_read_index.has_value()) { - // insert db switch: - auto switch_size = db_info.getOriginalAllocSize(tv_out); - auto switch_size_in_byte = SimplifyingIrBuilder::mulExpr( - switch_size, - SimplifyingIrBuilder::create( - dataTypeSize(tv_out->dtype()))); - + // Instantiate and insert the update operator. auto address_compute = SimplifyingIrBuilder::create( tv_out, @@ -295,11 +342,23 @@ class DoubleBufferLoopCloner : public kir::IrVisitor { IrBuilder::create()); } + // Hoist the address increment in the double buffer main + // loop, see also [Gmem Increment Hoisting] if (loop_type_ == DoubleBufferLoopStage::Main && std::any_of( double_buffer_loop_->body().exprs().begin(), double_buffer_loop_->body().exprs().end(), - isGmemIncrement)) { + isGmemIncrement) && + // FIXME: + // Below is current condition that is required for gmem increment + // hoisting because the gmem decrement is currently placed in + // CircularInitProlog which requires predicate peeling to + // be generated. + // To fix this should probably dedicate another double buffer + // loop stage, maybe GmemPointerDecrement, that is reserved + // for placing the gmem decrement before the main loop stage. + GpuLower::current()->predicatePeelingInfo().shouldPeelLoop( + double_buffer_loop_)) { cloned_top_level_loop_ = hoistGmemIncrement(cloned_top_level_loop_); } } @@ -376,6 +435,8 @@ class DoubleBufferLoopCloner : public kir::IrVisitor { } if (loop_type_ == DoubleBufferLoopStage::CircularInitProlog) { + // Convert the address compute ops to decrement in the circular + // buffer init prolog, see [Gmem Increment Hoisting]. if (auto address_compute = dynamic_cast(expr)) { if (address_compute->opType() == kir::AddressCompute::AddressComputeOpType::GMEM_INCREMENT) { @@ -389,12 +450,21 @@ class DoubleBufferLoopCloner : public kir::IrVisitor { } } - // Need the double buffer update expr in prologs too. + // Include the double buffer update expressions in prologs too as + // prolog does write into the double buffered space. if (loop_type_ == DoubleBufferLoopStage::Prolog) { if (auto address_compute = dynamic_cast(expr)) { if (address_compute->opType() == kir::AddressCompute::AddressComputeOpType::DOUBLE_BUFFER_UPDATE) { - cloned_scopes_.back()->push_back(expr); + if (std::any_of( + double_buffer_load_exprs_.begin(), + double_buffer_load_exprs_.end(), + [address_compute](Expr* expr) { + return ir_utils::getTvOutput(expr)->sameAs( + address_compute->dataTv()); + })) { + cloned_scopes_.back()->push_back(expr); + } } } } @@ -599,20 +669,31 @@ class DoubleBufferInserter : private kir::ExprMutator { void insert( kir::ForLoop* double_buffer_loop, const std::vector& loads) { - // Insert double buffer read switch index + // Allocate read switching index if they need to be updated + // independently. see [Uniform Double Buffer Offset] for (auto load : loads) { if (auto load_output = dynamic_cast(load->output(0))) { + auto uses = load_output->fusion()->unordered_uses(load_output); if (load_output->getMemoryType() == MemoryType::Shared && (load_output->isDoubleBuffered() || load_output->isCircularBuffered()) && - load_output->shouldLiftReadAddress()) { + load_output->shouldLiftReadAddress() && + // TODO: read switch index is only enabled for ldmatrix + // at the moment. + // Would need to extend the ld.shared usage to directly + // take pointers to use this in other cases. + std::all_of(uses.begin(), uses.end(), ir_utils::isLdMatrixOp)) { auto switch_val = IrBuilder::create(); switch_val->to32b(); + // Record the read switch indexing variable so it can be + // used in the indexing pass. // TODO: maybe want to do this in id graph instead GpuLower::current()->doubleBufferInfo().setReadSwitchIndex( load_output, switch_val); + // Place allocation for the switching variable before the + // double buffer loop. auto index_alloc = IrBuilder::create( switch_val, MemoryType::Local, diff --git a/torch/csrc/jit/codegen/cuda/lower_double_buffer.h b/torch/csrc/jit/codegen/cuda/lower_double_buffer.h index 36538993506ef5..5f49b3b75a4726 100644 --- a/torch/csrc/jit/codegen/cuda/lower_double_buffer.h +++ b/torch/csrc/jit/codegen/cuda/lower_double_buffer.h @@ -226,12 +226,16 @@ class TORCH_CUDA_CU_API DoubleBufferInfo { //! skew double buffer transform within the given outer loop. bool isLowerPrologWithin(IterDomain* db_loop, IterDomain* outer_loop); + //! Record the allocated double buffer switching index, + //! see [Uniform Double Buffer Offset] void setReadSwitchIndex(TensorView* db_tv, Val* switch_index) { TORCH_INTERNAL_ASSERT( read_switch_index_map_.insert(std::make_pair(db_tv, switch_index)) .second); } + //! Returns the double buffer switching index if one has been + //! allocated and recorded for the given tv. c10::optional getReadSwitchIndex(TensorView* db_tv) { auto val_it = read_switch_index_map_.find(db_tv); if (val_it == read_switch_index_map_.end()) { diff --git a/torch/csrc/jit/codegen/cuda/lower_index.cpp b/torch/csrc/jit/codegen/cuda/lower_index.cpp index a247bc1e55b285..419dae1f7d394d 100644 --- a/torch/csrc/jit/codegen/cuda/lower_index.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_index.cpp @@ -122,6 +122,12 @@ void IndexLowering::handle(const UnaryOp* uop) { const auto in = lowerSrcIndex(uop->in(), uop->out()); const auto out = lowerDstIndex(uop->out()); + // Convert the output index to direct shared memory + // address usage if this unary op is initialization + // for cp.async op. see [Lifting smem address decoding for cp.async] + // In order to use the same register for indexing the init + // expression as well, the init expr also needs to + // directly use the shared memory address. if (ir_utils::isCpAsyncInit(uop)) { auto out_tv = ir_utils::getTvOutput(uop); if (out_tv->shouldLiftWriteAddress()) { @@ -881,7 +887,8 @@ void IndexLowering::handle(const kir::AddressCompute* address_compute) { // Logic for double buffer switching: if (address_compute->opType() == kir::AddressCompute::AddressComputeOpType::DOUBLE_BUFFER_SWITCH) { - // no indexing is needed, just forward through. + // no indexing is needed, just forward through the expression and + // attach the loop index corresponding to the double buffer loop. auto db_loop = GpuLower::current()->doubleBufferInfo().getDoubleBufferLoop( address_compute->dataTv()->as(), for_loops_, false); TORCH_INTERNAL_ASSERT(db_loop != nullptr); @@ -898,6 +905,7 @@ void IndexLowering::handle(const kir::AddressCompute* address_compute) { } else if ( address_compute->opType() == kir::AddressCompute::AddressComputeOpType::DOUBLE_BUFFER_UPDATE) { + // Unpack the double buffer loop and double buffer allocation component auto db_loop = GpuLower::current()->doubleBufferInfo().getDoubleBufferLoop( address_compute->dataTv()->as(), for_loops_, false); TORCH_INTERNAL_ASSERT(db_loop != nullptr); @@ -907,6 +915,7 @@ void IndexLowering::handle(const kir::AddressCompute* address_compute) { ? address_compute->stageNumber() - 1 : 0; + // Generate index into the address tensor to update. auto indexed_address_tv = Index::generateAddressTensorIndex( for_loops_, address_compute->addressTv()->as()); @@ -920,8 +929,6 @@ void IndexLowering::handle(const kir::AddressCompute* address_compute) { return; } - // Logic for double buffer updating: - // Logic for base address computation auto address_tv = address_compute->addressTv(); @@ -937,6 +944,9 @@ void IndexLowering::handle(const kir::AddressCompute* address_compute) { kir::AddressCompute::AddressComputeOpType::GMEM_INCREMENT || address_compute->opType() == kir::AddressCompute::AddressComputeOpType::GMEM_DECREMENT) { + // GMEM_INCREMENT/DECREMENT is only used on global producer tv + // currently, so only lowering source index for the address tensor + // to compute the amount of increment. pushBack(IrBuilder::create( Index::generateAddressTensorIndex( for_loops_, address_compute->addressTv()->as()), @@ -951,6 +961,8 @@ void IndexLowering::handle(const kir::AddressCompute* address_compute) { Val* lowered_data_index = nullptr; + // This is the base address generation logic, lowering src/dst indexing + // math based on if this record is read or write. if (address_record->isRead()) { lowered_data_index = lowerSrcIndex( address_record->dataTensor(), address_record->indexReferenceTensor()); diff --git a/torch/csrc/jit/codegen/cuda/lower_index_compute.cpp b/torch/csrc/jit/codegen/cuda/lower_index_compute.cpp index 978dca52c8a3e0..d5af147c46c429 100644 --- a/torch/csrc/jit/codegen/cuda/lower_index_compute.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_index_compute.cpp @@ -123,6 +123,9 @@ std::unordered_set getZeroIdSetsForAddressCompute( // Checks if this loop nest is calculating base address. bool is_address_tv_calculation = serial_loop->isBaseIndexLoop(); + + // Check if this loop nest is incrementing a gmem address, + // see [Gmem address increment]; bool is_increment = std::any_of(loops.begin(), loops.end(), [](kir::ForLoop* fl) { return fl->loopTransformInfo().is_increment_loop; @@ -131,6 +134,8 @@ std::unordered_set getZeroIdSetsForAddressCompute( std::unordered_set zero_ids; if (is_increment) { + // In the case of increment calculation, just zero + // every loop except the serial loop from the address record. for (auto fl : loops) { if (fl != serial_loop) { zero_ids.insert(fl->iter_domain()); @@ -206,8 +211,10 @@ std::unordered_set getZeroIdSetsForAddressCompute( if (address_record->isRead() && address_record->dataTensor()->getMemoryType() == MemoryType::Global) { - // The serial loop id is incremented by the address compute, - // so it could be zeroed here. + // The serial loop is converted to increment mode, see [Gmem address + // increment] + // so it can be zeroed always. + // See also [Separability Analysis] on conditions when this is enabled. zero_ids.insert(address_record->getConcreteSerialLoopId()); } @@ -289,10 +296,16 @@ IndexingParameters getGlobalIndexParameters( concrete_loop_domain, maybe_address_record.value()->getConcreteSerialLoopId(), IdMappingMode::LOOP)) { + // For the increment calculation, the current implementation + // inserts a one for the loop index corresponding to the serial + // loop. This is valid if [Separability Analysis] checks ok + // on the serial id. // TODO: - // The current restriction on the serial loop makes this ok + // The current Separability restriction on the serial loop makes this + // ok // but should eventually use the f(i+1) - f(i) instead - // of a one for the increment calculation. + // of a one for the increment calculation to enable more complex + // increment patterns. loop_index_map[index_domain] = GpuLower::current()->kernel()->oneVal(); } } diff --git a/torch/csrc/jit/codegen/cuda/lower_mem_index.cpp b/torch/csrc/jit/codegen/cuda/lower_mem_index.cpp index 722bc908b5e0a5..7dd6121b931ceb 100644 --- a/torch/csrc/jit/codegen/cuda/lower_mem_index.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_mem_index.cpp @@ -131,6 +131,22 @@ void AddressComputeInfo::build(Fusion* fusion) { if (!ir_utils::isTvOp(expr)) { continue; } + + if (ir_utils::isCpAsyncOp(expr)) { + auto in_tv = ir_utils::getTvInput(expr); + auto out_tv = ir_utils::getTvOutput(expr); + + // FIXME: + // It'd take 2 more variants of the resource string for cp.async + // to support lifting one of the producer/consumer indices. As + // the eventual goal of these analysis is to be turned on generically, + // the use case for lifting one of the components is limited so + // not prioritizing. + TORCH_INTERNAL_ASSERT( + in_tv->shouldLiftReadAddress() == out_tv->shouldLiftWriteAddress(), + "For cp.async op only support either lifting both producer and consumer indexing or neither."); + } + for (auto consumer_tv : ir_utils::filterByType(expr->outputs())) { if (consumer_tv->shouldLiftWriteAddress()) { @@ -968,6 +984,17 @@ TensorView* AddressComputeInfo::makeAddressTv( bool is_predicate_index, bool is_cpasync_write) { DataType dtype = is_predicate_index ? DataType::Index : DataType::Pointer; + + // Note: [Lifting smem address decoding for cp.async] + // A trick that saves register usage. + // Before: + // char* smem_ptr; + // for i in ... // main loop + // cp.async [smem_ptr + 123], ... + // After: + // unsigned smem_address = toSmem(smem_ptr); + // for i in ... // main loop + // cp.async [smem_addres+123], ... if (is_cpasync_write) { dtype = DataType::SmemAddress; } @@ -1078,10 +1105,44 @@ class MemoryAddressComputeInserter : public kir::ExprMutator { auto data_tensor = insertion_info.address_compute_record->dataTensor(); - // Insert double buffer increment + // Insert double buffer increment: + // Note: [Inplace Double Buffer Update]: + // + // The trick used in [Uniform Double Buffer Offset] should be the default + // method of handling double buffer switching index when trying to save + // general purpose registers. But there are 2 exceptions: + // 1. On sm70 or below, there are no unifrom regs to use. + // 2. With cp.async, the consumer shared memory buffer currently + // does not provide access to the uniform reg operand so we could not use + // it. (will be actively discussed internally) + // + // To still avoid using too many registers on double buffered access, + // another code gen trick is used here, to enable near term progress: + // The code before transformation: + // for i in ... // double buffer loop + // ... = ld.shared [... + (i%5) * double_buffer_size] + // The code after transformation: + // R0 = ... + // for i in ... // double buffer loop + // ... = ld.shared [R0] + // doubleBufferUpdate(R0); + // This way essentially the double buffer offset is calculated inplace + // into R0 in each double buffer loop iteration. Note that comparing with + // [Uniform Double Buffer Offset] this method uses more instructions as + // all of the pointers will need to be updated, while using uniform regs + // will only need to update the uniform switch index. + + // FIXME: should move this logic into lower_double_buffer.cpp. + // will need to formulate into a separate pass as it needs to + // clone the loop nest. if ((data_tensor->isDoubleBuffered() || data_tensor->isCircularBuffered()) && - insertion_info.address_compute_record->isWrite()) { + insertion_info.address_compute_record->isWrite() && + // Only have support doubleBufferUpdate for + // direct smem access for now. + // FIXME: + // Would need to extend to use this on Volta. + useDirectSmemAddress(data_tensor)) { // Insert double buffer index update if it is a double buffered write: // The insertion info loop nest starts with the serial loop, // in the double buffer update we need to insert into the original @@ -1090,8 +1151,11 @@ class MemoryAddressComputeInserter : public kir::ExprMutator { std::next(insertion_info.loop_nest.begin()), insertion_info.loop_nest.end()); + // Clone the loop nest containing the double buffered write + // expression for the consumer index update. auto db_outer_inner = scope_utils::makeLoopNest(db_loop_nest); + // Calculate the double buffer size. auto& db_info = GpuLower::current()->doubleBufferInfo(); auto db_size_in_byte = SimplifyingIrBuilder::mulExpr( @@ -1099,6 +1163,8 @@ class MemoryAddressComputeInserter : public kir::ExprMutator { SimplifyingIrBuilder::create( dataTypeSize(data_tensor->dtype()))); + // Create the double buffer update expression and insert + // them at the end of the double buffer loop. auto update_expr = SimplifyingIrBuilder::create( insertion_info.address_compute_record->addressTensor(), db_size_in_byte, @@ -1111,13 +1177,30 @@ class MemoryAddressComputeInserter : public kir::ExprMutator { } // Insert gmem increment: + // Note [Gmem address increment]: + // This is a trick that helps lifting some instructions out of main + // loop. + // The code before this transformation: + // R0 = ... + // for i in ... // The serial loop on index lifting record + // x = ld.global [i*123 + R0] + // The code after transformation: + // R0 = ... + // for i in ... // The serial loop on index lifting record + // x = ld.global [R0] + // R0+=123; + // Note that [Separability Analysis] will be checked on the serial + // loop when creating these address records so doing this transformation + // on the serial loop index variable is safe. if (data_tensor->getMemoryType() == MemoryType::Global && insertion_info.address_compute_record->isRead()) { + // Create the loopnest to contain the increment expression. auto increment_loop_vector = createIncrementLoop(insertion_info.loop_nest); auto increment_loop_outer_inner = scope_utils::makeLoopNest(increment_loop_vector); + // Create the increment expression. auto inc_expr = SimplifyingIrBuilder::create( insertion_info.address_compute_record->addressTensor(), insertion_info.address_compute_record->dataTensor()); @@ -1158,6 +1241,8 @@ class MemoryAddressComputeInserter : public kir::ExprMutator { original_loop->loopTransformInfo().baseIndexLoop()); } + // Utility to create the loop nest for gmem increment, + // see [Gmem address increment]. std::vector createIncrementLoop( std::vector address_compute_loop_vector) { std::vector loop_nest_to_clone( diff --git a/torch/csrc/jit/codegen/cuda/lower_utils.cpp b/torch/csrc/jit/codegen/cuda/lower_utils.cpp index 1c7886cd6fc548..deebf5e67ff2ce 100644 --- a/torch/csrc/jit/codegen/cuda/lower_utils.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_utils.cpp @@ -750,6 +750,21 @@ bool supportInlinePredicate(Expr* expr) { return false; } +bool useDirectSmemAddress(const TensorView* tv) { + // Not applicable for any indexing that's not + // lifted. + if (!tv->shouldLiftWriteAddress() || + tv->getMemoryType() != MemoryType::Shared) { + return false; + } + + auto expr = tv->definition(); + // Direct usage of smem address should be avoided at all cost, + // so only allowing this very specific case where this is the + // necessary step to take to get efficient indexing code. + return expr != nullptr && ir_utils::isCpAsyncOp(expr); +} + } // namespace cuda } // namespace fuser } // namespace jit diff --git a/torch/csrc/jit/codegen/cuda/lower_utils.h b/torch/csrc/jit/codegen/cuda/lower_utils.h index ee21df9eb2fe8e..0aeda432913bc3 100644 --- a/torch/csrc/jit/codegen/cuda/lower_utils.h +++ b/torch/csrc/jit/codegen/cuda/lower_utils.h @@ -277,6 +277,10 @@ struct TORCH_CUDA_CU_API IterDomainDependencySorter { //! as an inline argument. bool supportInlinePredicate(Expr* expr); +//! Returns true if the consumer indexing of this tensor directly +//! uses shared mem address. +bool useDirectSmemAddress(const TensorView* tv); + } // namespace cuda } // namespace fuser } // namespace jit diff --git a/torch/csrc/jit/codegen/cuda/runtime/memory.cu b/torch/csrc/jit/codegen/cuda/runtime/memory.cu index d6f9a6d173c3fb..0e418bdc832f89 100644 --- a/torch/csrc/jit/codegen/cuda/runtime/memory.cu +++ b/torch/csrc/jit/codegen/cuda/runtime/memory.cu @@ -105,16 +105,8 @@ DEVICE_INLINE void ldMatrixT(Array<__half, 8, 8>& out, void const* ptr) { : "r"(addr)); } -// Load Matrix (per warp instruction) is to take data from SMEM to Local Memory. -// Automatically handles vectorized loads/stores in the MMA operation. -// Loads 8x8 matrix into a warp. Thread 0-7 provide the ptr that is the start -// of each row. All other threads can simply point to something valid -// (including 0). -// The x2 modifier on the instruction will actually load 2x8 rows to make a -// 16x8, -// then thread 0-15 will specify the start of each row. -// Finally is an x4 modifier producing a 32x8 using addrs from 0-31 in each -// warp. +// Below are the variants of ldmatrix wrapper that supports lifted +// memory indexing. DEVICE_INLINE void ldMatrix( Array<__half, 4, 4>& out, nvfuser_index_t index, @@ -127,9 +119,6 @@ DEVICE_INLINE void ldMatrix( : "r"(addr + (unsigned)index)); } -// Same as previous, 8x8 matrix is vectorized loaded, then scattered (to perform -// transpose) so threads will hold 2 values down a column (instead of the -// previous instruction that's across a row). DEVICE_INLINE void ldMatrixT( Array<__half, 4, 4>& out, nvfuser_index_t index, @@ -235,61 +224,29 @@ DEVICE_INLINE void cpAsync( "r"((int)predicate)); } -// Global to SMEM load that is asynchronous, -// not guaranteed to be completed until cpAsyncBarrier() is called. +// cp.async +// This is the variant that supports lifted indexing template DEVICE_INLINE void cpAsync( nvfuser_index_t smem_index, - DataPointer smem_base_ptr, + unsigned smem_addr, nvfuser_index_t gmem_index, DataPointer& gmem_ptr) { - unsigned smem_addr = util::toSmem(smem_base_ptr); constexpr int byte_size = sizeof(dtype) * len; static_assert( byte_size == 4 || byte_size == 8 || byte_size == 16, "cp_async : unsupported byte size"); - gmem_ptr += gmem_index; asm volatile( "cp.async.ca.shared.global [%0], [%1], %2;\n" ::"r"( smem_addr + (unsigned)smem_index), - "l"(gmem_ptr), + "l"(gmem_ptr + gmem_index), "n"(byte_size)); - gmem_ptr -= gmem_index; } -// Global to SMEM load that is asynchronous, -// not guaranteed to be completed until cpAsyncBarrier() is called. -template -DEVICE_INLINE void cpAsync( - nvfuser_index_t smem_index, - DataPointer smem_base_ptr, - nvfuser_index_t gmem_index, - DataPointer& gmem_ptr, - bool predicate) { - unsigned smem_addr = util::toSmem(smem_base_ptr); - constexpr int byte_size = sizeof(dtype) * len; - - static_assert( - byte_size == 4 || byte_size == 8 || byte_size == 16, - "cp_async : unsupported byte size"); - - gmem_ptr += gmem_index; - asm volatile( - "{\n" - " .reg .pred p;\n" - " setp.ne.b32 p, %3, 0;\n" - "@p cp.async.ca.shared.global [%0], [%1], %2;\n" - "}\n" ::"r"(smem_addr + (unsigned)smem_index), - "l"(gmem_ptr), - "n"(byte_size), - "r"((int)predicate)); - gmem_ptr -= gmem_index; -} - -// Global to SMEM load that is asynchronous, -// not guaranteed to be completed until cpAsyncBarrier() is called. +// cp.async +// This is the variant that supports lifted indexing, with predicate inlined. template DEVICE_INLINE void cpAsync( nvfuser_index_t smem_index, @@ -303,17 +260,15 @@ DEVICE_INLINE void cpAsync( byte_size == 4 || byte_size == 8 || byte_size == 16, "cp_async : unsupported byte size"); - // gmem_ptr += gmem_index; asm volatile( "{\n" " .reg .pred p;\n" " setp.ne.b32 p, %3, 0;\n" "@p cp.async.ca.shared.global [%0], [%1], %2;\n" "}\n" ::"r"(smem_addr + (unsigned)smem_index), - "l"(gmem_ptr), + "l"(gmem_ptr + gmem_index), "n"(byte_size), "r"((int)predicate)); - // gmem_ptr -= gmem_index; } // TODO: Might have a different category of sync if we want to build out this: @@ -379,15 +334,12 @@ DEVICE_INLINE void doubleBufferUpdate( } // Update double buffer offset value for smem double buffered tensors. +// See [Uniform Double Buffer Offset] template DEVICE_INLINE void doubleBufferSwitch( nvfuser_index_t& buffer_offset, const nvfuser_index_t& loop_index, nvfuser_index_t buffer_size) { - // static_assert( - // loop_offset < number_of_stage && loop_offset > -number_of_stage); - - // convert offset to [0, number_of_stage) constexpr nvfuser_index_t offset = loop_offset < 0 ? (loop_offset + number_of_stage) : loop_offset; From b62351468aaa37bbb5e391cccaedd28f41778d74 Mon Sep 17 00:00:00 2001 From: shmsong Date: Mon, 19 Sep 2022 21:34:34 -0700 Subject: [PATCH 15/17] minor fix --- torch/csrc/jit/codegen/cuda/lower_utils.cpp | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/torch/csrc/jit/codegen/cuda/lower_utils.cpp b/torch/csrc/jit/codegen/cuda/lower_utils.cpp index deebf5e67ff2ce..ff90efa59e8339 100644 --- a/torch/csrc/jit/codegen/cuda/lower_utils.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_utils.cpp @@ -148,6 +148,11 @@ bool isCpAsyncOp(const Expr* expr) { } bool isTensorScalarFillOp(const Expr* expr) { + // Check that this expression outputs to tensor + if (getTvOutput(expr) == nullptr) { + return false; + } + // Check that the input is a single scalar. if (expr->inputs().size() == 1 && expr->input(0)->isScalar()) { // All load store op with a single scalar input @@ -333,7 +338,9 @@ std::unordered_map getParallelDomains( } bool isCpAsyncInit(const Expr* expr) { - return isTensorScalarFillOp(expr) && + return + + isTensorScalarFillOp(expr) && // FIXME: // We'd need to add a flag to all the init // exprs so we could robustly detect initialization From add6fec8f51a532a080b65fe6139a8fbd2de7d38 Mon Sep 17 00:00:00 2001 From: shmsong Date: Mon, 19 Sep 2022 21:59:44 -0700 Subject: [PATCH 16/17] minor fix --- torch/csrc/jit/codegen/cuda/ir_interface_nodes.h | 4 ++++ torch/csrc/jit/codegen/cuda/runtime/memory.cu | 6 +++--- torch/csrc/jit/codegen/cuda/test/test_gpu_tensorcore.cpp | 8 +++++++- 3 files changed, 14 insertions(+), 4 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h b/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h index c8bac72d1f3909..a01272689728f0 100644 --- a/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h +++ b/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h @@ -469,6 +469,10 @@ class TORCH_CUDA_CU_API TensorView : public Val { // Returns the depth of circular buffering if applicable. unsigned int circularBufferDepth() const { + if (is_double_buffered_) { + // Double buffering is circular buffering with stage 2. + return 2; + } TORCH_INTERNAL_ASSERT( is_circular_buffered_, toString(), "not circular buffered"); return circular_buffer_stage_; diff --git a/torch/csrc/jit/codegen/cuda/runtime/memory.cu b/torch/csrc/jit/codegen/cuda/runtime/memory.cu index 0e418bdc832f89..38bacb3d6a1fc3 100644 --- a/torch/csrc/jit/codegen/cuda/runtime/memory.cu +++ b/torch/csrc/jit/codegen/cuda/runtime/memory.cu @@ -330,14 +330,14 @@ DEVICE_INLINE void doubleBufferUpdate( (loop_index % number_of_stage) == (number_of_stage - 1 - offset) ? buffer_size * (-number_of_stage + 1) : buffer_size; - data_buffer += increment; + data_buffer += (unsigned)increment; } // Update double buffer offset value for smem double buffered tensors. // See [Uniform Double Buffer Offset] template DEVICE_INLINE void doubleBufferSwitch( - nvfuser_index_t& buffer_offset, + int& buffer_offset, const nvfuser_index_t& loop_index, nvfuser_index_t buffer_size) { constexpr nvfuser_index_t offset = @@ -348,7 +348,7 @@ DEVICE_INLINE void doubleBufferSwitch( (loop_index % number_of_stage) == (number_of_stage - 1 - offset) ? buffer_size * (-number_of_stage + 1) : buffer_size; - buffer_offset += increment; + buffer_offset += (int)increment; } // Reset smem space to zero diff --git a/torch/csrc/jit/codegen/cuda/test/test_gpu_tensorcore.cpp b/torch/csrc/jit/codegen/cuda/test/test_gpu_tensorcore.cpp index 05df3189b1487c..bd5d89d778dc3a 100644 --- a/torch/csrc/jit/codegen/cuda/test/test_gpu_tensorcore.cpp +++ b/torch/csrc/jit/codegen/cuda/test/test_gpu_tensorcore.cpp @@ -842,12 +842,18 @@ TEST_F(NVFuserTest, FusionAmpereMatmulRegDoubleBuffer_CUDA) { params.double_buffer_options.double_buffer_smem_read = true; scheduleMatmul(tv2, tv0, tv1, params); + CompileOptions co; + co.index_mode = KernelIndexMode::INT32; + at::manual_seed(0); auto inputs = fp16MatmulAtInput(M, N, K, layout); FusionExecutor fe; NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( - 8, 0, fe.compileFusion(&fusion, {inputs.first, inputs.second})); + 8, + 0, + fe.compileFusion( + &fusion, {inputs.first, inputs.second}, LaunchParams(), co)); auto cg_outputs = fe.runFusion({inputs.first, inputs.second}); auto tref = atMatmul( inputs.first.to(at::kFloat), inputs.second.to(at::kFloat), layout); From 78a80f744bba316fe2bf95c8af315702fc787067 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Thu, 29 Sep 2022 21:59:19 -0700 Subject: [PATCH 17/17] fix --- torch/csrc/jit/codegen/cuda/index_compute.cpp | 2 +- torch/csrc/jit/codegen/cuda/lower_mem_index.cpp | 2 +- torch/csrc/jit/codegen/cuda/lower_utils.cpp | 2 ++ torch/csrc/jit/codegen/cuda/lower_utils.h | 2 ++ torch/csrc/jit/codegen/cuda/test/test_gpu_tensorcore.cpp | 5 +---- 5 files changed, 7 insertions(+), 6 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/index_compute.cpp b/torch/csrc/jit/codegen/cuda/index_compute.cpp index ebdd595d1e3ba8..2d62a645319547 100644 --- a/torch/csrc/jit/codegen/cuda/index_compute.cpp +++ b/torch/csrc/jit/codegen/cuda/index_compute.cpp @@ -2334,7 +2334,7 @@ std::vector Index::getNonGlobalConsumerStridedIndices( // Lifted address case the double buffer offset is // computed inplace into the write address buffer. // See [Inplace double buffer update] - && !useDirectSmemAddress(consumer_tv)) { + && !lower_utils::useDirectSmemAddress(consumer_tv)) { auto db_loop = gpu_lower->doubleBufferInfo().getDoubleBufferLoop(consumer_tv, loops); TORCH_INTERNAL_ASSERT( diff --git a/torch/csrc/jit/codegen/cuda/lower_mem_index.cpp b/torch/csrc/jit/codegen/cuda/lower_mem_index.cpp index fec7ade7150890..a6ad6dcb8f5e7b 100644 --- a/torch/csrc/jit/codegen/cuda/lower_mem_index.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_mem_index.cpp @@ -1146,7 +1146,7 @@ class MemoryAddressComputeInserter : public kir::ExprMutator { // direct smem access for now. // FIXME: // Would need to extend to use this on Volta. - useDirectSmemAddress(data_tensor)) { + lower_utils::useDirectSmemAddress(data_tensor)) { // Insert double buffer index update if it is a double buffered write: // The insertion info loop nest starts with the serial loop, // in the double buffer update we need to insert into the original diff --git a/torch/csrc/jit/codegen/cuda/lower_utils.cpp b/torch/csrc/jit/codegen/cuda/lower_utils.cpp index d8816187adbb3e..d7c47b60c9d0f0 100644 --- a/torch/csrc/jit/codegen/cuda/lower_utils.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_utils.cpp @@ -798,6 +798,8 @@ bool useDirectSmemAddress(const TensorView* tv) { return expr != nullptr && ir_utils::isCpAsyncOp(expr); } +} // namespace lower_utils + } // namespace cuda } // namespace fuser } // namespace jit diff --git a/torch/csrc/jit/codegen/cuda/lower_utils.h b/torch/csrc/jit/codegen/cuda/lower_utils.h index 6b531af9b022c9..f6ebdddbec04a5 100644 --- a/torch/csrc/jit/codegen/cuda/lower_utils.h +++ b/torch/csrc/jit/codegen/cuda/lower_utils.h @@ -281,6 +281,8 @@ bool supportInlinePredicate(Expr* expr); //! uses shared mem address. bool useDirectSmemAddress(const TensorView* tv); +} // namespace lower_utils + } // namespace cuda } // namespace fuser } // namespace jit diff --git a/torch/csrc/jit/codegen/cuda/test/test_gpu_tensorcore.cpp b/torch/csrc/jit/codegen/cuda/test/test_gpu_tensorcore.cpp index f3895dc4e364e3..d29e0bbec29cce 100644 --- a/torch/csrc/jit/codegen/cuda/test/test_gpu_tensorcore.cpp +++ b/torch/csrc/jit/codegen/cuda/test/test_gpu_tensorcore.cpp @@ -842,9 +842,6 @@ TEST_F(NVFuserTest, FusionAmpereMatmulRegDoubleBuffer_CUDA) { params.double_buffer_options.double_buffer_smem_read = true; scheduleMatmul(tv2, tv0, tv1, params); - CompileOptions co; - co.index_mode = KernelIndexMode::INT32; - at::manual_seed(0); auto inputs = fp16MatmulAtInput(M, N, K, layout); @@ -853,7 +850,7 @@ TEST_F(NVFuserTest, FusionAmpereMatmulRegDoubleBuffer_CUDA) { 8, 0, fe.compileFusion( - &fusion, {inputs.first, inputs.second}, LaunchParams(), co)); + &fusion, {inputs.first, inputs.second}, LaunchParams())); auto cg_outputs = fe.runFusion({inputs.first, inputs.second}); auto tref = atMatmul( inputs.first.to(at::kFloat), inputs.second.to(at::kFloat), layout);