Skip to content

Commit

Permalink
[XLA] Eliminated constant folding for operations that have large numb…
Browse files Browse the repository at this point in the history
…er of elements in their operands.

PiperOrigin-RevId: 481504087
  • Loading branch information
tensorflower-gardener committed Oct 16, 2022
1 parent 80f0668 commit 736d0db
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 2 deletions.
5 changes: 3 additions & 2 deletions tensorflow/compiler/xla/service/hlo_constant_folding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ limitations under the License.

#include "tensorflow/compiler/xla/service/hlo_constant_folding.h"

#include <algorithm>
#include <memory>
#include <string>
#include <utility>
Expand Down Expand Up @@ -175,8 +176,8 @@ StatusOr<bool> HloConstantFolding::Run(
ShapeUtil::ElementsIn(instruction->shape());

static const int64_t kMaximumConstantSizeElements = 45 * 1000 * 1000;
if (elements_in_constant > elements_in_removed_operands &&
elements_in_constant > kMaximumConstantSizeElements) {
if (std::max(elements_in_constant, elements_in_removed_operands) >
kMaximumConstantSizeElements) {
continue;
}
}
Expand Down
22 changes: 22 additions & 0 deletions tensorflow/compiler/xla/service/hlo_constant_folding_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -370,5 +370,27 @@ TEST_F(HloConstantFoldingTest, BigReduceWindow) {
EXPECT_TRUE(result);
}

TEST_F(HloConstantFoldingTest, TimingConsumingTest) {
constexpr absl::string_view mod_str = R"(
HloModule jit_f, entry_computation_layout={()->f32[]}
region_0.4 {
Arg_0.5 = f32[] parameter(0)
Arg_1.6 = f32[] parameter(1)
ROOT add.7 = f32[] add(Arg_0.5, Arg_1.6)
}
ENTRY main.9 {
constant.1 = f32[] constant(1)
broadcast.2 = f32[32,999,40,512]{3,2,1,0} broadcast(constant.1), dimensions={}
constant.3 = f32[] constant(0)
ROOT reduce.8 = f32[] reduce(broadcast.2, constant.3), dimensions={0,1,2,3}, to_apply=region_0.4
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(mod_str));
HloConstantFolding const_fold;
TF_ASSERT_OK_AND_ASSIGN(bool result, RunHloPass(&const_fold, module.get()));
EXPECT_FALSE(result);
}

} // namespace
} // namespace xla

0 comments on commit 736d0db

Please sign in to comment.