From 736d0db2217bae3585857045d76e23d77894db2b Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sun, 16 Oct 2022 14:54:23 -0700 Subject: [PATCH] [XLA] Eliminated constant folding for operations that have large number of elements in their operands. PiperOrigin-RevId: 481504087 --- .../xla/service/hlo_constant_folding.cc | 5 +++-- .../xla/service/hlo_constant_folding_test.cc | 22 +++++++++++++++++++ 2 files changed, 25 insertions(+), 2 deletions(-) diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding.cc b/tensorflow/compiler/xla/service/hlo_constant_folding.cc index bdf6af017ee..fca6e657948 100644 --- a/tensorflow/compiler/xla/service/hlo_constant_folding.cc +++ b/tensorflow/compiler/xla/service/hlo_constant_folding.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_constant_folding.h" +#include #include #include #include @@ -175,8 +176,8 @@ StatusOr HloConstantFolding::Run( ShapeUtil::ElementsIn(instruction->shape()); static const int64_t kMaximumConstantSizeElements = 45 * 1000 * 1000; - if (elements_in_constant > elements_in_removed_operands && - elements_in_constant > kMaximumConstantSizeElements) { + if (std::max(elements_in_constant, elements_in_removed_operands) > + kMaximumConstantSizeElements) { continue; } } diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc index 6af256a0159..ecc86616289 100644 --- a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc +++ b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc @@ -370,5 +370,27 @@ TEST_F(HloConstantFoldingTest, BigReduceWindow) { EXPECT_TRUE(result); } +TEST_F(HloConstantFoldingTest, TimingConsumingTest) { + constexpr absl::string_view mod_str = R"( + HloModule jit_f, entry_computation_layout={()->f32[]} + region_0.4 { + Arg_0.5 = f32[] parameter(0) + Arg_1.6 = f32[] parameter(1) + ROOT add.7 = f32[] add(Arg_0.5, Arg_1.6) + } + + ENTRY main.9 { + constant.1 = f32[] constant(1) + broadcast.2 = f32[32,999,40,512]{3,2,1,0} broadcast(constant.1), dimensions={} + constant.3 = f32[] constant(0) + ROOT reduce.8 = f32[] reduce(broadcast.2, constant.3), dimensions={0,1,2,3}, to_apply=region_0.4 + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(mod_str)); + HloConstantFolding const_fold; + TF_ASSERT_OK_AND_ASSIGN(bool result, RunHloPass(&const_fold, module.get())); + EXPECT_FALSE(result); +} + } // namespace } // namespace xla