From 77f831beedb3f2b90fb2ed36ae4554e88ce7738c Mon Sep 17 00:00:00 2001 From: shmsong Date: Fri, 9 Sep 2022 11:01:51 -0700 Subject: [PATCH] minor clean up --- torch/csrc/jit/codegen/cuda/ir_nodes.cpp | 4 +- .../jit/codegen/cuda/lower_double_buffer.cpp | 8 +- torch/csrc/jit/codegen/cuda/lower_utils.cpp | 2 +- .../jit/codegen/cuda/scheduler/matmul.cpp | 3 +- .../codegen/cuda/test/test_gpu_tensorcore.cpp | 383 +++++------------- 5 files changed, 116 insertions(+), 284 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/ir_nodes.cpp b/torch/csrc/jit/codegen/cuda/ir_nodes.cpp index be2cd10372cdb..19ce379fdd832 100644 --- a/torch/csrc/jit/codegen/cuda/ir_nodes.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_nodes.cpp @@ -1405,8 +1405,8 @@ void IterDomain::parallelize(ParallelType t) { // TORCH_CHECK( // t == ParallelType::Vectorize || t == ParallelType::TIDx || // t == ParallelType::Serial || t == ParallelType::Mma, - // "Parallel type other than serial, tidx, vectorize not allowed for mma swizzled ids", - // t); + // "Parallel type other than serial, tidx, vectorize not allowed for mma + // swizzled ids", t); } parallel_type_ = t; diff --git a/torch/csrc/jit/codegen/cuda/lower_double_buffer.cpp b/torch/csrc/jit/codegen/cuda/lower_double_buffer.cpp index 98345636c7d3d..86fa5d11291ac 100644 --- a/torch/csrc/jit/codegen/cuda/lower_double_buffer.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_double_buffer.cpp @@ -94,7 +94,8 @@ void validateDoubleBufferedTensor(const TensorView* tv) { const auto c_mem_type = tv->getMemoryType(); // TORCH_INTERNAL_ASSERT( // (p_mem_type == MemoryType::Global && - // (c_mem_type == MemoryType::Shared || c_mem_type == MemoryType::Local)) || + // (c_mem_type == MemoryType::Shared || c_mem_type == MemoryType::Local)) + // || // (c_mem_type == MemoryType::Local), // "Invalid tensor to double-buffer: ", // tv->toString(), @@ -146,9 +147,8 @@ class DoubleBufferFusionInspector : private IterVisitor { bool requireEpilogue(const std::vector& exprs) { return std::any_of(exprs.begin(), exprs.end(), [](const Expr* expr) { return expr->input(0)->as()->getMemoryType() == - MemoryType::Shared || - expr->input(0)->as()->getMemoryType() == - MemoryType::Local; + MemoryType::Shared || + expr->input(0)->as()->getMemoryType() == MemoryType::Local; }); } diff --git a/torch/csrc/jit/codegen/cuda/lower_utils.cpp b/torch/csrc/jit/codegen/cuda/lower_utils.cpp index 910b0910c1450..337b4547b1e5a 100644 --- a/torch/csrc/jit/codegen/cuda/lower_utils.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_utils.cpp @@ -501,7 +501,7 @@ BasicAllocInfo getAllocInformation( outer_alloc_found = true; } - if(tv->getMemoryType()==MemoryType::Shared && !fl_id->isThread()){ + if (tv->getMemoryType() == MemoryType::Shared && !fl_id->isThread()) { outer_alloc_found = true; } diff --git a/torch/csrc/jit/codegen/cuda/scheduler/matmul.cpp b/torch/csrc/jit/codegen/cuda/scheduler/matmul.cpp index 25aebb3c6d243..8e7ec987cbca0 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/matmul.cpp +++ b/torch/csrc/jit/codegen/cuda/scheduler/matmul.cpp @@ -552,7 +552,7 @@ void scheduleMatmul( .propagateToBoundary()); c_smem->computeAt(c, 3); - c->reorder({{-1,-2}, {-2,-1}}); + c->reorder({{-1, -2}, {-2, -1}}); // 16 x 128, with half of the warps: // Output vectorize by 4: @@ -566,7 +566,6 @@ void scheduleMatmul( c_smem->axis(-1)->parallelize(ParallelType::Vectorize); c_smem->doubleBuffer(); - if (params.index_lift_options.lift_gmem_read_address) { a->liftReadAddress(); b->liftReadAddress(); diff --git a/torch/csrc/jit/codegen/cuda/test/test_gpu_tensorcore.cpp b/torch/csrc/jit/codegen/cuda/test/test_gpu_tensorcore.cpp index 6b8003faab190..cc435f627d084 100644 --- a/torch/csrc/jit/codegen/cuda/test/test_gpu_tensorcore.cpp +++ b/torch/csrc/jit/codegen/cuda/test/test_gpu_tensorcore.cpp @@ -2801,7 +2801,6 @@ TEST_F(NVFuserTest, FusionAmpereMatmulLargeLoad_CUDA) { params.double_buffer_options.smem_double_buffer_stage = 3; scheduleMatmul(tv2, tv0, tv1, params); - at::manual_seed(0); auto inputs = fp16MatmulAtInput(M, N, K, layout); @@ -2823,6 +2822,114 @@ TEST_F(NVFuserTest, FusionAmpereMatmulLargeLoad_CUDA) { } } +TEST_F(NVFuserTest, FusionAmpereMatmulLargeLoadQuickSanity1_CUDA) { + // Keep multiples of 8 to keep vectorizable. + int M = 2048, N = 3456, K = 1024; + for (auto layout : kAllSupportedLayout) { + Fusion fusion; + FusionGuard fg(&fusion); + auto tv0 = makeContigTensor(2, DataType::Half); + auto tv1 = makeContigTensor(2, DataType::Half); + + fusion.addInput(tv0); + fusion.addInput(tv1); + + auto tv2 = matmul(tv0, tv1, layout); + + fusion.addOutput(tv2); + + MatMulTileOptions gemm_tile; + gemm_tile.cta_tile = GemmTile(128, 128, 64); + gemm_tile.warp_tile = GemmTile(64, 64, 64); + gemm_tile.instruction_tile = GemmTile(16, 16, 16); + + auto mma_builder = + MmaBuilder(MmaOptions::MacroType::Ampere_16_16_16, gemm_tile) + .layout(layout); + + MatmulParam params(mma_builder); + params.tile_sizes = gemm_tile; + params.async_gmem_load_operands = true; + params.double_buffer_options.double_buffer_smem_write = true; + params.double_buffer_options.double_buffer_smem_read = true; + params.double_buffer_options.smem_double_buffer_stage = 3; + scheduleMatmul(tv2, tv0, tv1, params); + + at::manual_seed(0); + auto inputs = fp16MatmulAtInput(M, N, K, layout); + + CompileOptions co; + co.index_mode = KernelIndexMode::INT32; + + FusionExecutor fe; + NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( + 8, + 0, + fe.compileFusion( + &fusion, {inputs.first, inputs.second}, LaunchParams(), co)); + + // return; + auto cg_outputs = fe.runFusion({inputs.first, inputs.second}); + auto tref = atMatmul( + inputs.first.to(at::kFloat), inputs.second.to(at::kFloat), layout); + TORCH_CHECK(cg_outputs[0].allclose(tref, 0.001, 0.001)); + } +} + +TEST_F(NVFuserTest, FusionAmpereMatmulLargeLoadQuickSanity2_CUDA) { + // Keep multiples of 8 to keep vectorizable. + int M = 2048, N = 3456, K = 1024; + for (auto layout : kAllSupportedLayout) { + Fusion fusion; + FusionGuard fg(&fusion); + auto tv0 = makeContigTensor(2, DataType::Half); + auto tv1 = makeContigTensor(2, DataType::Half); + + fusion.addInput(tv0); + fusion.addInput(tv1); + + auto tv2 = matmul(tv0, tv1, layout); + + fusion.addOutput(tv2); + + MatMulTileOptions gemm_tile; + gemm_tile.cta_tile = GemmTile(256, 128, 64); + gemm_tile.warp_tile = GemmTile(64, 64, 64); + gemm_tile.instruction_tile = GemmTile(16, 16, 16); + + auto mma_builder = + MmaBuilder(MmaOptions::MacroType::Ampere_16_16_16, gemm_tile) + .layout(layout); + + MatmulParam params(mma_builder); + params.tile_sizes = gemm_tile; + params.async_gmem_load_operands = true; + params.double_buffer_options.double_buffer_smem_write = true; + params.double_buffer_options.double_buffer_smem_read = true; + params.double_buffer_options.smem_double_buffer_stage = 2; + scheduleMatmul(tv2, tv0, tv1, params); + + at::manual_seed(0); + auto inputs = fp16MatmulAtInput(M, N, K, layout); + + CompileOptions co; + co.index_mode = KernelIndexMode::INT32; + + FusionExecutor fe; + NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( + 8, + 0, + fe.compileFusion( + &fusion, {inputs.first, inputs.second}, LaunchParams(), co)); + + // return; + auto cg_outputs = fe.runFusion({inputs.first, inputs.second}); + auto tref = atMatmul( + inputs.first.to(at::kFloat), inputs.second.to(at::kFloat), layout); + TORCH_CHECK(cg_outputs[0].allclose(tref, 0.001, 0.001)); + } +} + // Tile layout check for symmetric 4-warp recipes TEST_F(NVFuserTest, FusionAmpereMatmulTileCheck4warp_CUDA) { REQUIRE_DEVICE_SMEM_SIZE(98384, 0); @@ -3085,280 +3192,6 @@ TEST_F(NVFuserTest, FusionSimpleSkewDoubleBuffer_CUDA) { testValidate(&fusion, cg_outputs, {t0}, {t0}, __LINE__, __FILE__); } -TEST_F(NVFuserTest, FusionTestCompileRtc_CUDA) { - FusionExecutor fe; - std::string kernel = R"( -__global__ void kernel1(Tensor<__half, 2> T0, Tensor<__half, 2> T1, Tensor T4) { - alignas(16) extern __shared__ char array[]; - unsigned offset = 0; - NVFUSER_DEFINE_MAGIC_ZERO - offset = alignBufferSize(offset, 16); - __half* T7 = reinterpret_cast<__half*>(array + offset); - offset += (((((((ceilDiv(((((ceilDiv(128, 8)) * (ceilDiv(8, 1))) * 1) * (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)), 1024)) * (ceilDiv((ceilDiv((ceilDiv(1024, 8)), 32)), 2))) * 2) * 32) * 8) * 3) * sizeof(__half)); - offset = alignBufferSize(offset, 16); - __half* T6 = reinterpret_cast<__half*>(array + offset); - offset += (((((((ceilDiv(((((ceilDiv(128, 8)) * (ceilDiv(8, 1))) * 1) * (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)), 1024)) * (ceilDiv((ceilDiv((ceilDiv(1024, 8)), 32)), 2))) * 2) * 32) * 8) * 3) * sizeof(__half)); - Array T5; - #pragma unroll - for(nvfuser_index_t i165 = 0; i165 < (ceilDiv(64, 16)); ++i165) { - #pragma unroll - for(nvfuser_index_t i166 = 0; i166 < (ceilDiv(64, 16)); ++i166) { - Ampere::initM16N16K16TN<16>(reinterpret_cast*>(&T5[((i165 * (ceilDiv(16, 8))) * ((ceilDiv(64, 16)) * ((ceilDiv(16, 8)) * 2))) + (i166 * ((ceilDiv(16, 8)) * 2))])); - } - } - NVFUSER_UPDATE_MAGIC_ZERO - DataPointer T12[1]; - unsigned T12s[1]; - - //Base Address::: - T12[0] = (DataPointer) &T7[(((((((((((((nvfuser_index_t)threadIdx.z) * 2) + ((nvfuser_index_t)threadIdx.y)) * 32) + ((nvfuser_index_t)threadIdx.x)) * 8) / (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) / 1) / (ceilDiv(8, 1))) * 8) + ((Xor({((((((((((nvfuser_index_t)threadIdx.z) * 2) + ((nvfuser_index_t)threadIdx.y)) * 32) + ((nvfuser_index_t)threadIdx.x)) * 8) / (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) / 1) % (ceilDiv(8, 1))),((((((((((nvfuser_index_t)threadIdx.z) * 2) + ((nvfuser_index_t)threadIdx.y)) * 32) + ((nvfuser_index_t)threadIdx.x)) * 8) % (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) / 8) % (ceilDiv(64, 8)))} , {(ceilDiv(8, 1)),(ceilDiv(64, 8))}).x) + (((((((((nvfuser_index_t)threadIdx.z) * 2) + ((nvfuser_index_t)threadIdx.y)) * 32) + ((nvfuser_index_t)threadIdx.x)) * 8) / (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) % 1))) * 64) + ((((((((((((nvfuser_index_t)threadIdx.z) * 2) + ((nvfuser_index_t)threadIdx.y)) * 32) + ((nvfuser_index_t)threadIdx.x)) * 8) % (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) / 8) / (ceilDiv(64, 8))) * 64) + (((Xor({((((((((((nvfuser_index_t)threadIdx.z) * 2) + ((nvfuser_index_t)threadIdx.y)) * 32) + ((nvfuser_index_t)threadIdx.x)) * 8) / (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) / 1) % (ceilDiv(8, 1))),((((((((((nvfuser_index_t)threadIdx.z) * 2) + ((nvfuser_index_t)threadIdx.y)) * 32) + ((nvfuser_index_t)threadIdx.x)) * 8) % (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) / 8) % (ceilDiv(64, 8)))} , {(ceilDiv(8, 1)),(ceilDiv(64, 8))}).y) * 8) + (((((((((nvfuser_index_t)threadIdx.z) * 2) + ((nvfuser_index_t)threadIdx.y)) * 32) + ((nvfuser_index_t)threadIdx.x)) * 8) % (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) % 8)))]; - T12s[0] = Turing::util::toSmem(T12[0]); - - nvfuser_index_t T13[1]; - //Predicate Compute Index::: - T13[0] = ((((nvfuser_index_t)blockIdx.y) * 128) + ((((((((((((nvfuser_index_t)threadIdx.z) * 2) + ((nvfuser_index_t)threadIdx.y)) * 32) + ((nvfuser_index_t)threadIdx.x)) * 8) / (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) / 1) / (ceilDiv(8, 1))) * 8) + (((((((((((nvfuser_index_t)threadIdx.z) * 2) + ((nvfuser_index_t)threadIdx.y)) * 32) + ((nvfuser_index_t)threadIdx.x)) * 8) / (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) / 1) % (ceilDiv(8, 1))) + (((((((((nvfuser_index_t)threadIdx.z) * 2) + ((nvfuser_index_t)threadIdx.y)) * 32) + ((nvfuser_index_t)threadIdx.x)) * 8) / (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) % 1)))); - DataPointer T14[(ceilDiv(((((ceilDiv(128, 8)) * (ceilDiv(8, 1))) * 1) * (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)), 1024))]; - #pragma unroll - for(nvfuser_index_t i154 = 0; i154 < (ceilDiv(((((ceilDiv(128, 8)) * (ceilDiv(8, 1))) * 1) * (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)), 1024)); ++i154) { - //Base Address::: - T14[i154] = (DataPointer) &T1[(((((nvfuser_index_t)blockIdx.y) * 128) + ((((((((i154 + nvfuser_zero) * 1024) + (((((((nvfuser_index_t)threadIdx.z) * 2) + ((nvfuser_index_t)threadIdx.y)) * 32) + ((nvfuser_index_t)threadIdx.x)) * 8)) / (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) / 1) / (ceilDiv(8, 1))) * 8) + (((((((i154 + nvfuser_zero) * 1024) + (((((((nvfuser_index_t)threadIdx.z) * 2) + ((nvfuser_index_t)threadIdx.y)) * 32) + ((nvfuser_index_t)threadIdx.x)) * 8)) / (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) / 1) % (ceilDiv(8, 1))) + (((((i154 + nvfuser_zero) * 1024) + (((((((nvfuser_index_t)threadIdx.z) * 2) + ((nvfuser_index_t)threadIdx.y)) * 32) + ((nvfuser_index_t)threadIdx.x)) * 8)) / (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) % 1)))) * T0.size[1]) + ((((((((i154 + nvfuser_zero) * 1024) + (((((((nvfuser_index_t)threadIdx.z) * 2) + ((nvfuser_index_t)threadIdx.y)) * 32) + ((nvfuser_index_t)threadIdx.x)) * 8)) % (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) / 8) / (ceilDiv(64, 8))) * 64) + ((((((((i154 + nvfuser_zero) * 1024) + (((((((nvfuser_index_t)threadIdx.z) * 2) + ((nvfuser_index_t)threadIdx.y)) * 32) + ((nvfuser_index_t)threadIdx.x)) * 8)) % (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) / 8) % (ceilDiv(64, 8))) * 8) + (((((i154 + nvfuser_zero) * 1024) + (((((((nvfuser_index_t)threadIdx.z) * 2) + ((nvfuser_index_t)threadIdx.y)) * 32) + ((nvfuser_index_t)threadIdx.x)) * 8)) % (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) % 8)))]; - } - NVFUSER_UPDATE_MAGIC_ZERO - DataPointer T15[1]; - unsigned T15s[1]; - - //Base Address::: - T15[0] = (DataPointer) &T6[(((((((((((((nvfuser_index_t)threadIdx.z) * 2) + ((nvfuser_index_t)threadIdx.y)) * 32) + ((nvfuser_index_t)threadIdx.x)) * 8) / (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) / 1) / (ceilDiv(8, 1))) * 8) + ((Xor({((((((((((nvfuser_index_t)threadIdx.z) * 2) + ((nvfuser_index_t)threadIdx.y)) * 32) + ((nvfuser_index_t)threadIdx.x)) * 8) / (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) / 1) % (ceilDiv(8, 1))),((((((((((nvfuser_index_t)threadIdx.z) * 2) + ((nvfuser_index_t)threadIdx.y)) * 32) + ((nvfuser_index_t)threadIdx.x)) * 8) % (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) / 8) % (ceilDiv(64, 8)))} , {(ceilDiv(8, 1)),(ceilDiv(64, 8))}).x) + (((((((((nvfuser_index_t)threadIdx.z) * 2) + ((nvfuser_index_t)threadIdx.y)) * 32) + ((nvfuser_index_t)threadIdx.x)) * 8) / (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) % 1))) * 64) + ((((((((((((nvfuser_index_t)threadIdx.z) * 2) + ((nvfuser_index_t)threadIdx.y)) * 32) + ((nvfuser_index_t)threadIdx.x)) * 8) % (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) / 8) / (ceilDiv(64, 8))) * 64) + (((Xor({((((((((((nvfuser_index_t)threadIdx.z) * 2) + ((nvfuser_index_t)threadIdx.y)) * 32) + ((nvfuser_index_t)threadIdx.x)) * 8) / (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) / 1) % (ceilDiv(8, 1))),((((((((((nvfuser_index_t)threadIdx.z) * 2) + ((nvfuser_index_t)threadIdx.y)) * 32) + ((nvfuser_index_t)threadIdx.x)) * 8) % (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) / 8) % (ceilDiv(64, 8)))} , {(ceilDiv(8, 1)),(ceilDiv(64, 8))}).y) * 8) + (((((((((nvfuser_index_t)threadIdx.z) * 2) + ((nvfuser_index_t)threadIdx.y)) * 32) + ((nvfuser_index_t)threadIdx.x)) * 8) % (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) % 8)))]; - T15s[1] = Turing::util::toSmem(T15[0]); - - nvfuser_index_t T16[1]; - //Predicate Compute Index::: - T16[0] = ((((nvfuser_index_t)blockIdx.x) * 128) + ((((((((((((nvfuser_index_t)threadIdx.z) * 2) + ((nvfuser_index_t)threadIdx.y)) * 32) + ((nvfuser_index_t)threadIdx.x)) * 8) / (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) / 1) / (ceilDiv(8, 1))) * 8) + (((((((((((nvfuser_index_t)threadIdx.z) * 2) + ((nvfuser_index_t)threadIdx.y)) * 32) + ((nvfuser_index_t)threadIdx.x)) * 8) / (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) / 1) % (ceilDiv(8, 1))) + (((((((((nvfuser_index_t)threadIdx.z) * 2) + ((nvfuser_index_t)threadIdx.y)) * 32) + ((nvfuser_index_t)threadIdx.x)) * 8) / (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) % 1)))); - DataPointer T17[(ceilDiv(((((ceilDiv(128, 8)) * (ceilDiv(8, 1))) * 1) * (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)), 1024))]; - #pragma unroll - for(nvfuser_index_t i143 = 0; i143 < (ceilDiv(((((ceilDiv(128, 8)) * (ceilDiv(8, 1))) * 1) * (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)), 1024)); ++i143) { - //Base Address::: - T17[i143] = (DataPointer) &T0[(((((nvfuser_index_t)blockIdx.x) * 128) + ((((((((i143 + nvfuser_zero) * 1024) + (((((((nvfuser_index_t)threadIdx.z) * 2) + ((nvfuser_index_t)threadIdx.y)) * 32) + ((nvfuser_index_t)threadIdx.x)) * 8)) / (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) / 1) / (ceilDiv(8, 1))) * 8) + (((((((i143 + nvfuser_zero) * 1024) + (((((((nvfuser_index_t)threadIdx.z) * 2) + ((nvfuser_index_t)threadIdx.y)) * 32) + ((nvfuser_index_t)threadIdx.x)) * 8)) / (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) / 1) % (ceilDiv(8, 1))) + (((((i143 + nvfuser_zero) * 1024) + (((((((nvfuser_index_t)threadIdx.z) * 2) + ((nvfuser_index_t)threadIdx.y)) * 32) + ((nvfuser_index_t)threadIdx.x)) * 8)) / (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) % 1)))) * T0.size[1]) + ((((((((i143 + nvfuser_zero) * 1024) + (((((((nvfuser_index_t)threadIdx.z) * 2) + ((nvfuser_index_t)threadIdx.y)) * 32) + ((nvfuser_index_t)threadIdx.x)) * 8)) % (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) / 8) / (ceilDiv(64, 8))) * 64) + ((((((((i143 + nvfuser_zero) * 1024) + (((((((nvfuser_index_t)threadIdx.z) * 2) + ((nvfuser_index_t)threadIdx.y)) * 32) + ((nvfuser_index_t)threadIdx.x)) * 8)) % (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) / 8) % (ceilDiv(64, 8))) * 8) + (((((i143 + nvfuser_zero) * 1024) + (((((((nvfuser_index_t)threadIdx.z) * 2) + ((nvfuser_index_t)threadIdx.y)) * 32) + ((nvfuser_index_t)threadIdx.x)) * 8)) % (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) % 8)))]; - } - NVFUSER_UPDATE_MAGIC_ZERO - DataPointer T11[(ceilDiv(64, 16))]; - #pragma unroll - for(nvfuser_index_t i164 = 0; i164 < (ceilDiv(64, 16)); ++i164) { - //Base Address::: - T11[i164] = (DataPointer) &T6[((((((((nvfuser_index_t)threadIdx.z) * 64) + ((((((nvfuser_index_t)threadIdx.x) / 8) % (ceilDiv(16, 8))) * 8) + (((nvfuser_index_t)threadIdx.x) % 8))) / 8) * 8) + ((Xor({((((((nvfuser_index_t)threadIdx.z) * 64) + ((((((nvfuser_index_t)threadIdx.x) / 8) % (ceilDiv(16, 8))) * 8) + (((nvfuser_index_t)threadIdx.x) % 8))) % 8) / 1),((((i164 * 16) + (((((nvfuser_index_t)threadIdx.x) / 8) / (ceilDiv(16, 8))) * 8)) % 64) / 8)} , {(ceilDiv(8, 1)),(ceilDiv(64, 8))}).x) + ((((((nvfuser_index_t)threadIdx.z) * 64) + ((((((nvfuser_index_t)threadIdx.x) / 8) % (ceilDiv(16, 8))) * 8) + (((nvfuser_index_t)threadIdx.x) % 8))) % 8) % 1))) * 64) + (((((i164 * 16) + (((((nvfuser_index_t)threadIdx.x) / 8) / (ceilDiv(16, 8))) * 8)) / 64) * 64) + (((Xor({((((((nvfuser_index_t)threadIdx.z) * 64) + ((((((nvfuser_index_t)threadIdx.x) / 8) % (ceilDiv(16, 8))) * 8) + (((nvfuser_index_t)threadIdx.x) % 8))) % 8) / 1),((((i164 * 16) + (((((nvfuser_index_t)threadIdx.x) / 8) / (ceilDiv(16, 8))) * 8)) % 64) / 8)} , {(ceilDiv(8, 1)),(ceilDiv(64, 8))}).y) * 8) + ((((i164 * 16) + (((((nvfuser_index_t)threadIdx.x) / 8) / (ceilDiv(16, 8))) * 8)) % 64) % 8)))]; - } - NVFUSER_UPDATE_MAGIC_ZERO - DataPointer T10[(ceilDiv(64, 16))]; - #pragma unroll - for(nvfuser_index_t i164 = 0; i164 < (ceilDiv(64, 16)); ++i164) { - //Base Address::: - T10[i164] = (DataPointer) &T7[((((((((nvfuser_index_t)threadIdx.y) * 64) + ((((((nvfuser_index_t)threadIdx.x) / 8) / (ceilDiv(16, 8))) * 8) + (((nvfuser_index_t)threadIdx.x) % 8))) / 8) * 8) + ((Xor({((((((nvfuser_index_t)threadIdx.y) * 64) + ((((((nvfuser_index_t)threadIdx.x) / 8) / (ceilDiv(16, 8))) * 8) + (((nvfuser_index_t)threadIdx.x) % 8))) % 8) / 1),((((i164 * 16) + (((((nvfuser_index_t)threadIdx.x) / 8) % (ceilDiv(16, 8))) * 8)) % 64) / 8)} , {(ceilDiv(8, 1)),(ceilDiv(64, 8))}).x) + ((((((nvfuser_index_t)threadIdx.y) * 64) + ((((((nvfuser_index_t)threadIdx.x) / 8) / (ceilDiv(16, 8))) * 8) + (((nvfuser_index_t)threadIdx.x) % 8))) % 8) % 1))) * 64) + (((((i164 * 16) + (((((nvfuser_index_t)threadIdx.x) / 8) % (ceilDiv(16, 8))) * 8)) / 64) * 64) + (((Xor({((((((nvfuser_index_t)threadIdx.y) * 64) + ((((((nvfuser_index_t)threadIdx.x) / 8) / (ceilDiv(16, 8))) * 8) + (((nvfuser_index_t)threadIdx.x) % 8))) % 8) / 1),((((i164 * 16) + (((((nvfuser_index_t)threadIdx.x) / 8) % (ceilDiv(16, 8))) * 8)) % 64) / 8)} , {(ceilDiv(8, 1)),(ceilDiv(64, 8))}).y) * 8) + ((((i164 * 16) + (((((nvfuser_index_t)threadIdx.x) / 8) % (ceilDiv(16, 8))) * 8)) % 64) % 8)))]; - } - NVFUSER_UPDATE_MAGIC_ZERO - int i211 = 0; - int i212 = 0; - #pragma unroll - for(nvfuser_index_t i154 = 0; i154 < (ceilDiv(((((ceilDiv(128, 8)) * (ceilDiv(8, 1))) * 1) * (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)), 1024)); ++i154) { - if ((!((((((nvfuser_index_t)blockIdx.y) * 128) + ((((((((i154 + nvfuser_zero) * 1024) + ((((((((nvfuser_index_t)threadIdx.z) * 2) + ((nvfuser_index_t)threadIdx.y)) * 32) + ((nvfuser_index_t)threadIdx.x)) * 8) + 7)) / (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) / 1) / (ceilDiv(8, 1))) * 8) + (((((((i154 + nvfuser_zero) * 1024) + ((((((((nvfuser_index_t)threadIdx.z) * 2) + ((nvfuser_index_t)threadIdx.y)) * 32) + ((nvfuser_index_t)threadIdx.x)) * 8) + 7)) / (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) / 1) % (ceilDiv(8, 1))) + (((((i154 + nvfuser_zero) * 1024) + ((((((((nvfuser_index_t)threadIdx.z) * 2) + ((nvfuser_index_t)threadIdx.y)) * 32) + ((nvfuser_index_t)threadIdx.x)) * 8) + 7)) / (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) % 1)))) < T1.size[0]) && (((((((((i154 + nvfuser_zero) * 1024) + ((((((((nvfuser_index_t)threadIdx.z) * 2) + ((nvfuser_index_t)threadIdx.y)) * 32) + ((nvfuser_index_t)threadIdx.x)) * 8) + 7)) % (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) / 8) / (ceilDiv(64, 8))) * 64) + ((((((((i154 + nvfuser_zero) * 1024) + ((((((((nvfuser_index_t)threadIdx.z) * 2) + ((nvfuser_index_t)threadIdx.y)) * 32) + ((nvfuser_index_t)threadIdx.x)) * 8) + 7)) % (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) / 8) % (ceilDiv(64, 8))) * 8) + (((((i154 + nvfuser_zero) * 1024) + ((((((((nvfuser_index_t)threadIdx.z) * 2) + ((nvfuser_index_t)threadIdx.y)) * 32) + ((nvfuser_index_t)threadIdx.x)) * 8) + 7)) % (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) % 8))) < ((T0.size[1] + 64) + (-((ceilDiv(T0.size[1], 64)) * 64))))))) { - arraySet<__half, 8>(&reinterpret_cast<__half*>(T12[0])[(((((((i154 * 1024) / (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) / 1) / (ceilDiv(8, 1))) * 8) + (((((i154 * 1024) / (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) / 1) % (ceilDiv(8, 1))) + (((i154 * 1024) / (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) % 1))) * 64) + ((((((i154 * 1024) % (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) / 8) / (ceilDiv(64, 8))) * 64) + ((((((i154 * 1024) % (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) / 8) % (ceilDiv(64, 8))) * 8) + (((i154 * 1024) % (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) % 8)))], (__half)0); - } - } - NVFUSER_UPDATE_MAGIC_ZERO - #pragma unroll - for(nvfuser_index_t i154 = 0; i154 < (ceilDiv(((((ceilDiv(128, 8)) * (ceilDiv(8, 1))) * 1) * (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)), 1024)); ++i154) { - Ampere::cpAsync<__half,8>(((((((((i154 * 1024) / (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) / 1) / (ceilDiv(8, 1))) * 8) + (((((i154 * 1024) / (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) / 1) % (ceilDiv(8, 1))) + (((i154 * 1024) / (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) % 1))) * 64))*2 + (((((((i154 * 1024) % (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) / 8) / (ceilDiv(64, 8))) * 64) + ((((((i154 * 1024) % (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) / 8) % (ceilDiv(64, 8))) * 8) + (((i154 * 1024) % (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) % 8))))*2,T12[0],0,T14[i154],((((((nvfuser_index_t)blockIdx.y) * 128) + ((((((((i154 + nvfuser_zero) * 1024) + ((((((((nvfuser_index_t)threadIdx.z) * 2) + ((nvfuser_index_t)threadIdx.y)) * 32) + ((nvfuser_index_t)threadIdx.x)) * 8) + 7)) / (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) / 1) / (ceilDiv(8, 1))) * 8) + (((((((i154 + nvfuser_zero) * 1024) + ((((((((nvfuser_index_t)threadIdx.z) * 2) + ((nvfuser_index_t)threadIdx.y)) * 32) + ((nvfuser_index_t)threadIdx.x)) * 8) + 7)) / (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) / 1) % (ceilDiv(8, 1))) + (((((i154 + nvfuser_zero) * 1024) + ((((((((nvfuser_index_t)threadIdx.z) * 2) + ((nvfuser_index_t)threadIdx.y)) * 32) + ((nvfuser_index_t)threadIdx.x)) * 8) + 7)) / (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) % 1)))) < T1.size[0]) && (((((((((i154 + nvfuser_zero) * 1024) + ((((((((nvfuser_index_t)threadIdx.z) * 2) + ((nvfuser_index_t)threadIdx.y)) * 32) + ((nvfuser_index_t)threadIdx.x)) * 8) + 7)) % (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) / 8) / (ceilDiv(64, 8))) * 64) + ((((((((i154 + nvfuser_zero) * 1024) + ((((((((nvfuser_index_t)threadIdx.z) * 2) + ((nvfuser_index_t)threadIdx.y)) * 32) + ((nvfuser_index_t)threadIdx.x)) * 8) + 7)) % (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) / 8) % (ceilDiv(64, 8))) * 8) + (((((i154 + nvfuser_zero) * 1024) + ((((((((nvfuser_index_t)threadIdx.z) * 2) + ((nvfuser_index_t)threadIdx.y)) * 32) + ((nvfuser_index_t)threadIdx.x)) * 8) + 7)) % (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) % 8))) < ((T0.size[1] + 64) + (-((ceilDiv(T0.size[1], 64)) * 64)))))); - } - NVFUSER_UPDATE_MAGIC_ZERO - #pragma unroll - for(nvfuser_index_t i143 = 0; i143 < (ceilDiv(((((ceilDiv(128, 8)) * (ceilDiv(8, 1))) * 1) * (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)), 1024)); ++i143) { - if ((!((((((nvfuser_index_t)blockIdx.x) * 128) + ((((((((i143 + nvfuser_zero) * 1024) + ((((((((nvfuser_index_t)threadIdx.z) * 2) + ((nvfuser_index_t)threadIdx.y)) * 32) + ((nvfuser_index_t)threadIdx.x)) * 8) + 7)) / (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) / 1) / (ceilDiv(8, 1))) * 8) + (((((((i143 + nvfuser_zero) * 1024) + ((((((((nvfuser_index_t)threadIdx.z) * 2) + ((nvfuser_index_t)threadIdx.y)) * 32) + ((nvfuser_index_t)threadIdx.x)) * 8) + 7)) / (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) / 1) % (ceilDiv(8, 1))) + (((((i143 + nvfuser_zero) * 1024) + ((((((((nvfuser_index_t)threadIdx.z) * 2) + ((nvfuser_index_t)threadIdx.y)) * 32) + ((nvfuser_index_t)threadIdx.x)) * 8) + 7)) / (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) % 1)))) < T0.size[0]) && (((((((((i143 + nvfuser_zero) * 1024) + ((((((((nvfuser_index_t)threadIdx.z) * 2) + ((nvfuser_index_t)threadIdx.y)) * 32) + ((nvfuser_index_t)threadIdx.x)) * 8) + 7)) % (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) / 8) / (ceilDiv(64, 8))) * 64) + ((((((((i143 + nvfuser_zero) * 1024) + ((((((((nvfuser_index_t)threadIdx.z) * 2) + ((nvfuser_index_t)threadIdx.y)) * 32) + ((nvfuser_index_t)threadIdx.x)) * 8) + 7)) % (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) / 8) % (ceilDiv(64, 8))) * 8) + (((((i143 + nvfuser_zero) * 1024) + ((((((((nvfuser_index_t)threadIdx.z) * 2) + ((nvfuser_index_t)threadIdx.y)) * 32) + ((nvfuser_index_t)threadIdx.x)) * 8) + 7)) % (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) % 8))) < ((T0.size[1] + 64) + (-((ceilDiv(T0.size[1], 64)) * 64))))))) { - arraySet<__half, 8>(&reinterpret_cast<__half*>(T15[0])[(((((((i143 * 1024) / (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) / 1) / (ceilDiv(8, 1))) * 8) + (((((i143 * 1024) / (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) / 1) % (ceilDiv(8, 1))) + (((i143 * 1024) / (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) % 1))) * 64) + ((((((i143 * 1024) % (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) / 8) / (ceilDiv(64, 8))) * 64) + ((((((i143 * 1024) % (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) / 8) % (ceilDiv(64, 8))) * 8) + (((i143 * 1024) % (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) % 8)))], (__half)0); - } - } - NVFUSER_UPDATE_MAGIC_ZERO - #pragma unroll - for(nvfuser_index_t i143 = 0; i143 < (ceilDiv(((((ceilDiv(128, 8)) * (ceilDiv(8, 1))) * 1) * (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)), 1024)); ++i143) { - Ampere::cpAsync<__half,8>(((((((((i143 * 1024) / (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) / 1) / (ceilDiv(8, 1))) * 8) + (((((i143 * 1024) / (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) / 1) % (ceilDiv(8, 1))) + (((i143 * 1024) / (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) % 1))) * 64))*2 + (((((((i143 * 1024) % (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) / 8) / (ceilDiv(64, 8))) * 64) + ((((((i143 * 1024) % (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) / 8) % (ceilDiv(64, 8))) * 8) + (((i143 * 1024) % (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) % 8))))*2,T15[0],0,T17[i143],((((((nvfuser_index_t)blockIdx.x) * 128) + ((((((((i143 + nvfuser_zero) * 1024) + ((((((((nvfuser_index_t)threadIdx.z) * 2) + ((nvfuser_index_t)threadIdx.y)) * 32) + ((nvfuser_index_t)threadIdx.x)) * 8) + 7)) / (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) / 1) / (ceilDiv(8, 1))) * 8) + (((((((i143 + nvfuser_zero) * 1024) + ((((((((nvfuser_index_t)threadIdx.z) * 2) + ((nvfuser_index_t)threadIdx.y)) * 32) + ((nvfuser_index_t)threadIdx.x)) * 8) + 7)) / (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) / 1) % (ceilDiv(8, 1))) + (((((i143 + nvfuser_zero) * 1024) + ((((((((nvfuser_index_t)threadIdx.z) * 2) + ((nvfuser_index_t)threadIdx.y)) * 32) + ((nvfuser_index_t)threadIdx.x)) * 8) + 7)) / (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) % 1)))) < T0.size[0]) && (((((((((i143 + nvfuser_zero) * 1024) + ((((((((nvfuser_index_t)threadIdx.z) * 2) + ((nvfuser_index_t)threadIdx.y)) * 32) + ((nvfuser_index_t)threadIdx.x)) * 8) + 7)) % (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) / 8) / (ceilDiv(64, 8))) * 64) + ((((((((i143 + nvfuser_zero) * 1024) + ((((((((nvfuser_index_t)threadIdx.z) * 2) + ((nvfuser_index_t)threadIdx.y)) * 32) + ((nvfuser_index_t)threadIdx.x)) * 8) + 7)) % (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) / 8) % (ceilDiv(64, 8))) * 8) + (((((i143 + nvfuser_zero) * 1024) + ((((((((nvfuser_index_t)threadIdx.z) * 2) + ((nvfuser_index_t)threadIdx.y)) * 32) + ((nvfuser_index_t)threadIdx.x)) * 8) + 7)) % (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) % 8))) < ((T0.size[1] + 64) + (-((ceilDiv(T0.size[1], 64)) * 64)))))); - } - NVFUSER_UPDATE_MAGIC_ZERO - doubleBufferUpdate<3,0>(T12[0],0,((((((ceilDiv(((((ceilDiv(128, 8)) * (ceilDiv(8, 1))) * 1) * (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)), 1024)) * (ceilDiv((ceilDiv((ceilDiv(1024, 8)), 32)), 2))) * 2) * 32) * 8) * 2)); - doubleBufferUpdate<3,0>(T15[0],0,((((((ceilDiv(((((ceilDiv(128, 8)) * (ceilDiv(8, 1))) * 1) * (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)), 1024)) * (ceilDiv((ceilDiv((ceilDiv(1024, 8)), 32)), 2))) * 2) * 32) * 8) * 2)); - doubleBufferUpdate<3,0>(T12s[0],0,((((((ceilDiv(((((ceilDiv(128, 8)) * (ceilDiv(8, 1))) * 1) * (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)), 1024)) * (ceilDiv((ceilDiv((ceilDiv(1024, 8)), 32)), 2))) * 2) * 32) * 8) * 2)); - doubleBufferUpdate<3,0>(T15s[0],0,((((((ceilDiv(((((ceilDiv(128, 8)) * (ceilDiv(8, 1))) * 1) * (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)), 1024)) * (ceilDiv((ceilDiv((ceilDiv(1024, 8)), 32)), 2))) * 2) * 32) * 8) * 2)); - - Ampere::cpAsyncCommit(); - #pragma unroll - for(nvfuser_index_t i160 = 1; i160 < 2; ++i160) { - #pragma unroll - for(nvfuser_index_t i154 = 0; i154 < (ceilDiv(((((ceilDiv(128, 8)) * (ceilDiv(8, 1))) * 1) * (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)), 1024)); ++i154) { - if ((!(T13[0] < (T1.size[0] + (-((((((((i154 + nvfuser_zero) * 1024) + 7) / (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) / 1) / (ceilDiv(8, 1))) * 8) + (((((((i154 + nvfuser_zero) * 1024) + 7) / (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) / 1) % (ceilDiv(8, 1))) + (((((i154 + nvfuser_zero) * 1024) + 7) / (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) % 1)))))))) { - arraySet<__half, 8>(&reinterpret_cast<__half*>(T12[0])[(((((((i154 * 1024) / (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) / 1) / (ceilDiv(8, 1))) * 8) + (((((i154 * 1024) / (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) / 1) % (ceilDiv(8, 1))) + (((i154 * 1024) / (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) % 1))) * 64) + ((((((i154 * 1024) % (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) / 8) / (ceilDiv(64, 8))) * 64) + ((((((i154 * 1024) % (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) / 8) % (ceilDiv(64, 8))) * 8) + (((i154 * 1024) % (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) % 8)))], (__half)0); - } - } - #pragma unroll - for(nvfuser_index_t i154 = 0; i154 < (ceilDiv(((((ceilDiv(128, 8)) * (ceilDiv(8, 1))) * 1) * (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)), 1024)); ++i154) { - Ampere::cpAsync<__half,8>(((((((((i154 * 1024) / (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) / 1) / (ceilDiv(8, 1))) * 8) + (((((i154 * 1024) / (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) / 1) % (ceilDiv(8, 1))) + (((i154 * 1024) / (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) % 1))) * 64))*2 + (((((((i154 * 1024) % (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) / 8) / (ceilDiv(64, 8))) * 64) + ((((((i154 * 1024) % (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) / 8) % (ceilDiv(64, 8))) * 8) + (((i154 * 1024) % (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) % 8))))*2,T12[0],((((i160 + nvfuser_zero) * 64) + (-((-((T0.size[1] + 64) + (-((ceilDiv(T0.size[1], 64)) * 64)))) + 64))))*2,T14[i154],(T13[0] < (T1.size[0] + (-((((((((i154 + nvfuser_zero) * 1024) + 7) / (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) / 1) / (ceilDiv(8, 1))) * 8) + (((((((i154 + nvfuser_zero) * 1024) + 7) / (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) / 1) % (ceilDiv(8, 1))) + (((((i154 + nvfuser_zero) * 1024) + 7) / (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) % 1))))))); - } - #pragma unroll - for(nvfuser_index_t i143 = 0; i143 < (ceilDiv(((((ceilDiv(128, 8)) * (ceilDiv(8, 1))) * 1) * (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)), 1024)); ++i143) { - if ((!(T16[0] < (T0.size[0] + (-((((((((i143 + nvfuser_zero) * 1024) + 7) / (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) / 1) / (ceilDiv(8, 1))) * 8) + (((((((i143 + nvfuser_zero) * 1024) + 7) / (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) / 1) % (ceilDiv(8, 1))) + (((((i143 + nvfuser_zero) * 1024) + 7) / (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) % 1)))))))) { - arraySet<__half, 8>(&reinterpret_cast<__half*>(T15[0])[(((((((i143 * 1024) / (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) / 1) / (ceilDiv(8, 1))) * 8) + (((((i143 * 1024) / (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) / 1) % (ceilDiv(8, 1))) + (((i143 * 1024) / (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) % 1))) * 64) + ((((((i143 * 1024) % (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) / 8) / (ceilDiv(64, 8))) * 64) + ((((((i143 * 1024) % (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) / 8) % (ceilDiv(64, 8))) * 8) + (((i143 * 1024) % (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) % 8)))], (__half)0); - } - } - #pragma unroll - for(nvfuser_index_t i143 = 0; i143 < (ceilDiv(((((ceilDiv(128, 8)) * (ceilDiv(8, 1))) * 1) * (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)), 1024)); ++i143) { - Ampere::cpAsync<__half,8>(((((((((i143 * 1024) / (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) / 1) / (ceilDiv(8, 1))) * 8) + (((((i143 * 1024) / (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) / 1) % (ceilDiv(8, 1))) + (((i143 * 1024) / (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) % 1))) * 64))*2 + (((((((i143 * 1024) % (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) / 8) / (ceilDiv(64, 8))) * 64) + ((((((i143 * 1024) % (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) / 8) % (ceilDiv(64, 8))) * 8) + (((i143 * 1024) % (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) % 8))))*2,T15[0],((((i160 + nvfuser_zero) * 64) + (-((-((T0.size[1] + 64) + (-((ceilDiv(T0.size[1], 64)) * 64)))) + 64))))*2,T17[i143],(T16[0] < (T0.size[0] + (-((((((((i143 + nvfuser_zero) * 1024) + 7) / (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) / 1) / (ceilDiv(8, 1))) * 8) + (((((((i143 + nvfuser_zero) * 1024) + 7) / (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) / 1) % (ceilDiv(8, 1))) + (((((i143 + nvfuser_zero) * 1024) + 7) / (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) % 1))))))); - } - doubleBufferUpdate<3,0>(T12[0],i160,((((((ceilDiv(((((ceilDiv(128, 8)) * (ceilDiv(8, 1))) * 1) * (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)), 1024)) * (ceilDiv((ceilDiv((ceilDiv(1024, 8)), 32)), 2))) * 2) * 32) * 8) * 2)); - doubleBufferUpdate<3,0>(T15[0],i160,((((((ceilDiv(((((ceilDiv(128, 8)) * (ceilDiv(8, 1))) * 1) * (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)), 1024)) * (ceilDiv((ceilDiv((ceilDiv(1024, 8)), 32)), 2))) * 2) * 32) * 8) * 2)); - doubleBufferUpdate<3,0>(T12s[0],i160,((((((ceilDiv(((((ceilDiv(128, 8)) * (ceilDiv(8, 1))) * 1) * (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)), 1024)) * (ceilDiv((ceilDiv((ceilDiv(1024, 8)), 32)), 2))) * 2) * 32) * 8) * 2)); - doubleBufferUpdate<3,0>(T15s[0],i160,((((((ceilDiv(((((ceilDiv(128, 8)) * (ceilDiv(8, 1))) * 1) * (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)), 1024)) * (ceilDiv((ceilDiv((ceilDiv(1024, 8)), 32)), 2))) * 2) * 32) * 8) * 2)); - - Ampere::cpAsyncCommit(); - } - NVFUSER_UPDATE_MAGIC_ZERO - #pragma unroll - for(nvfuser_index_t i161 = 2; i161 < 3; ++i161) { - #pragma unroll - for(nvfuser_index_t i154 = 0; i154 < (ceilDiv(((((ceilDiv(128, 8)) * (ceilDiv(8, 1))) * 1) * (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)), 1024)); ++i154) { - if ((!(T13[0] < (T1.size[0] + (-((((((((i154 + nvfuser_zero) * 1024) + 7) / (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) / 1) / (ceilDiv(8, 1))) * 8) + (((((((i154 + nvfuser_zero) * 1024) + 7) / (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) / 1) % (ceilDiv(8, 1))) + (((((i154 + nvfuser_zero) * 1024) + 7) / (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) % 1)))))))) { - arraySet<__half, 8>(&reinterpret_cast<__half*>(T12[0])[(((((((i154 * 1024) / (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) / 1) / (ceilDiv(8, 1))) * 8) + (((((i154 * 1024) / (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) / 1) % (ceilDiv(8, 1))) + (((i154 * 1024) / (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) % 1))) * 64) + ((((((i154 * 1024) % (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) / 8) / (ceilDiv(64, 8))) * 64) + ((((((i154 * 1024) % (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) / 8) % (ceilDiv(64, 8))) * 8) + (((i154 * 1024) % (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) % 8)))], (__half)0); - } - } - #pragma unroll - for(nvfuser_index_t i143 = 0; i143 < (ceilDiv(((((ceilDiv(128, 8)) * (ceilDiv(8, 1))) * 1) * (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)), 1024)); ++i143) { - if ((!(T16[0] < (T0.size[0] + (-((((((((i143 + nvfuser_zero) * 1024) + 7) / (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) / 1) / (ceilDiv(8, 1))) * 8) + (((((((i143 + nvfuser_zero) * 1024) + 7) / (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) / 1) % (ceilDiv(8, 1))) + (((((i143 + nvfuser_zero) * 1024) + 7) / (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) % 1)))))))) { - arraySet<__half, 8>(&reinterpret_cast<__half*>(T15[0])[(((((((i143 * 1024) / (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) / 1) / (ceilDiv(8, 1))) * 8) + (((((i143 * 1024) / (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) / 1) % (ceilDiv(8, 1))) + (((i143 * 1024) / (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) % 1))) * 64) + ((((((i143 * 1024) % (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) / 8) / (ceilDiv(64, 8))) * 64) + ((((((i143 * 1024) % (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) / 8) % (ceilDiv(64, 8))) * 8) + (((i143 * 1024) % (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) % 8)))], (__half)0); - } - } - Ampere::cpAsyncCommit(); - } - NVFUSER_UPDATE_MAGIC_ZERO - Ampere::cpAsyncPartialBarrier<1>(); - __barrier_sync(0); - #pragma unroll 1 - for(nvfuser_index_t i162 = 0; i162 < (ceilDiv(T0.size[1], 64)); ++i162) { - #pragma unroll - for(nvfuser_index_t i154 = 0; i154 < (ceilDiv(((((ceilDiv(128, 8)) * (ceilDiv(8, 1))) * 1) * (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)), 1024)); ++i154) { - Ampere::cpAsync<__half,8>(((((((((i154 * 1024) / (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) / 1) / (ceilDiv(8, 1))) * 8) + (((((i154 * 1024) / (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) / 1) % (ceilDiv(8, 1))) + (((i154 * 1024) / (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) % 1))) * 64))*2 + (((((((i154 * 1024) % (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) / 8) / (ceilDiv(64, 8))) * 64) + ((((((i154 * 1024) % (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) / 8) % (ceilDiv(64, 8))) * 8) + (((i154 * 1024) % (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) % 8))))*2,T12s[0],((((i162 + 2) * 64) + (-((-((T0.size[1] + 64) + (-((ceilDiv(T0.size[1], 64)) * 64)))) + 64))))*2,T14[i154],((T13[0] < (T1.size[0] + (-((((((((i154 + nvfuser_zero) * 1024) + 7) / (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) / 1) / (ceilDiv(8, 1))) * 8) + (((((((i154 + nvfuser_zero) * 1024) + 7) / (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) / 1) % (ceilDiv(8, 1))) + (((((i154 + nvfuser_zero) * 1024) + 7) / (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) % 1)))))) && ((i162 + 2) < (ceilDiv(T0.size[1], 64))))); - } - NVFUSER_UPDATE_MAGIC_ZERO - #pragma unroll - for(nvfuser_index_t i143 = 0; i143 < (ceilDiv(((((ceilDiv(128, 8)) * (ceilDiv(8, 1))) * 1) * (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)), 1024)); ++i143) { - Ampere::cpAsync<__half,8>(((((((((i143 * 1024) / (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) / 1) / (ceilDiv(8, 1))) * 8) + (((((i143 * 1024) / (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) / 1) % (ceilDiv(8, 1))) + (((i143 * 1024) / (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) % 1))) * 64))*2 + (((((((i143 * 1024) % (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) / 8) / (ceilDiv(64, 8))) * 64) + ((((((i143 * 1024) % (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) / 8) % (ceilDiv(64, 8))) * 8) + (((i143 * 1024) % (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) % 8))))*2,T15s[0],((((i162 + 2) * 64) + (-((-((T0.size[1] + 64) + (-((ceilDiv(T0.size[1], 64)) * 64)))) + 64))))*2,T17[i143],((T16[0] < (T0.size[0] + (-((((((((i143 + nvfuser_zero) * 1024) + 7) / (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) / 1) / (ceilDiv(8, 1))) * 8) + (((((((i143 + nvfuser_zero) * 1024) + 7) / (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) / 1) % (ceilDiv(8, 1))) + (((((i143 + nvfuser_zero) * 1024) + 7) / (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)) % 1)))))) && ((i162 + 2) < (ceilDiv(T0.size[1], 64))))); - } - NVFUSER_UPDATE_MAGIC_ZERO - #pragma unroll - for(nvfuser_index_t i164 = 0; i164 < (ceilDiv(64, 16)); ++i164) { - Array<__half, ((ceilDiv(64, 16)) * 8), 8> T8; - #pragma unroll - for(nvfuser_index_t i145 = 0; i145 < (ceilDiv(64, 16)); ++i145) { - Turing::ldMatrix (*reinterpret_cast*>(&T8[(i145 * 8)]),(((i145 * 16) * 64))*2 + i212,T11[i164]); - } - __half T2[((ceilDiv(64, 16)) * 8)]; - #pragma unroll - for(nvfuser_index_t i147 = 0; i147 < (ceilDiv(64, 16)); ++i147) { - #pragma unroll - for(nvfuser_index_t i149 = 0; i149 < 8; ++i149) { - T2[(i147 * 8) + i149] = 0.00000000000000000e+00; - } - } - #pragma unroll - for(nvfuser_index_t i147 = 0; i147 < (ceilDiv(64, 16)); ++i147) { - #pragma unroll - for(nvfuser_index_t i149 = 0; i149 < 8; ++i149) { - T2[(i147 * 8) + i149] - = T8[(i147 * 8) + i149]; - } - } - // Alias Allocation - register - auto& T9 = T8; - #pragma unroll - for(nvfuser_index_t i156 = 0; i156 < (ceilDiv(64, 16)); ++i156) { - Turing::ldMatrix (*reinterpret_cast*>(&T9[(i156 * 8)]),(((i156 * 16) * 64))*2 + i211,T10[i164]); - } - __half T3[((ceilDiv(64, 16)) * 8)]; - #pragma unroll - for(nvfuser_index_t i158 = 0; i158 < (ceilDiv(64, 16)); ++i158) { - #pragma unroll - for(nvfuser_index_t i159 = 0; i159 < 8; ++i159) { - T3[(i158 * 8) + i159] = 0.00000000000000000e+00; - } - } - #pragma unroll - for(nvfuser_index_t i158 = 0; i158 < (ceilDiv(64, 16)); ++i158) { - #pragma unroll - for(nvfuser_index_t i159 = 0; i159 < 8; ++i159) { - T3[(i158 * 8) + i159] - = T9[(i158 * 8) + i159]; - } - } - #pragma unroll - for(nvfuser_index_t i165 = 0; i165 < (ceilDiv(64, 16)); ++i165) { - #pragma unroll - for(nvfuser_index_t i166 = 0; i166 < (ceilDiv(64, 16)); ++i166) { - Ampere::M16N16K16TN<16>( - reinterpret_cast*>(&T5[((i165 * (ceilDiv(16, 8))) * ((ceilDiv(64, 16)) * ((ceilDiv(16, 8)) * 2))) + (i166 * ((ceilDiv(16, 8)) * 2))]), - &(reinterpret_cast*>(&T2)[i165]), - &(reinterpret_cast*>(&T3)[i166])); - } - } - } - NVFUSER_UPDATE_MAGIC_ZERO - Ampere::cpAsyncPartialBarrier<1>(); - __barrier_sync(0); - doubleBufferUpdate<3,2>(T12[0],i162,((((((ceilDiv(((((ceilDiv(128, 8)) * (ceilDiv(8, 1))) * 1) * (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)), 1024)) * (ceilDiv((ceilDiv((ceilDiv(1024, 8)), 32)), 2))) * 2) * 32) * 8) * 2)); - doubleBufferUpdate<3,2>(T15[0],i162,((((((ceilDiv(((((ceilDiv(128, 8)) * (ceilDiv(8, 1))) * 1) * (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)), 1024)) * (ceilDiv((ceilDiv((ceilDiv(1024, 8)), 32)), 2))) * 2) * 32) * 8) * 2)); - doubleBufferUpdate<3,2>(T12s[0],i162,((((((ceilDiv(((((ceilDiv(128, 8)) * (ceilDiv(8, 1))) * 1) * (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)), 1024)) * (ceilDiv((ceilDiv((ceilDiv(1024, 8)), 32)), 2))) * 2) * 32) * 8) * 2)); - doubleBufferUpdate<3,2>(T15s[0],i162,((((((ceilDiv(((((ceilDiv(128, 8)) * (ceilDiv(8, 1))) * 1) * (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)), 1024)) * (ceilDiv((ceilDiv((ceilDiv(1024, 8)), 32)), 2))) * 2) * 32) * 8) * 2)); - - doubleBufferSwitch<3,0>(i211,i162,((((((ceilDiv(((((ceilDiv(128, 8)) * (ceilDiv(8, 1))) * 1) * (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)), 1024)) * (ceilDiv((ceilDiv((ceilDiv(1024, 8)), 32)), 2))) * 2) * 32) * 8) * 2)); - doubleBufferSwitch<3,0>(i212,i162,((((((ceilDiv(((((ceilDiv(128, 8)) * (ceilDiv(8, 1))) * 1) * (((ceilDiv(64, 64)) * (ceilDiv(64, 8))) * 8)), 1024)) * (ceilDiv((ceilDiv((ceilDiv(1024, 8)), 32)), 2))) * 2) * 32) * 8) * 2)); - Ampere::cpAsyncCommit(); - } - #pragma unroll - for(nvfuser_index_t i171 = 0; i171 < (ceilDiv(64, 16)); ++i171) { - #pragma unroll - for(nvfuser_index_t i172 = 0; i172 < (ceilDiv(64, 16)); ++i172) { - #pragma unroll - for(nvfuser_index_t i173 = 0; i173 < (ceilDiv(16, 8)); ++i173) { - #pragma unroll - for(nvfuser_index_t i174 = 0; i174 < (ceilDiv(16, 8)); ++i174) { - int i1506; - i1506 = (((nvfuser_index_t)blockIdx.x) * 128) + ((((nvfuser_index_t)threadIdx.z) * 64) + ((i171 * 16) + (((i174 + nvfuser_zero) * 8) + (((nvfuser_index_t)threadIdx.x) / (ceilDiv(8, 2)))))); - if (((i1506 < T0.size[0]) && (((((nvfuser_index_t)blockIdx.y) * 128) + ((((nvfuser_index_t)threadIdx.y) * 64) + ((i172 * 16) + (((i173 + nvfuser_zero) * 8) + (((((nvfuser_index_t)threadIdx.x) % (ceilDiv(8, 2))) * 2) + 1))))) < T1.size[0]))) { - loadLocalToGlobal( &T4[(i1506 * T1.size[0]) + ((((nvfuser_index_t)blockIdx.y) * 128) + ((((nvfuser_index_t)threadIdx.y) * 64) + ((i172 * 16) + ((i173 * 8) + ((((nvfuser_index_t)threadIdx.x) % (ceilDiv(8, 2))) * 2)))))], &T5[(((i171 * (ceilDiv(16, 8))) + i174) * ((ceilDiv(64, 16)) * ((ceilDiv(16, 8)) * 2))) + ((i172 * ((ceilDiv(16, 8)) * 2)) + (i173 * 2))]); - } - } - } - } - } - NVFUSER_UPDATE_MAGIC_ZERO -} - )"; - fe.compileRtc(kernel, "CudaCodeGen::kernel1"); - LaunchParams lp( - 256, // gdimx - 1, // gdimy - 1, // gdimz - 1, // bdimx - 1, // bdimy - 1 // bdimz - ); - - return; - lp.setSmem(0); - const auto options = - at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - const std::vector tensor_dims = {8}; - auto in0 = at::randn(tensor_dims, options); - auto out0 = at::empty_like(in0); - fe.runRtc(lp, {in0, out0}); - - auto out_ref = in0 * 2; - TORCH_CHECK(out_ref.allclose(out0)); -} - #undef NVFUSER_TEST_CUDA_ARCH_GUARD } // namespace jit