From 612cba809452c1dd9887640771d5e1ed382e7b21 Mon Sep 17 00:00:00 2001 From: Roman Kazantsev Date: Mon, 26 Aug 2024 12:28:07 +0400 Subject: [PATCH] [TF FE][Tokenizers] Optimize TF FE extensions (#232) Avoid extra StringTensorPack and StringTensorUnpack operations in a graph Signed-off-by: Kazantsev, Roman --- src/ov_extension.cpp | 3 +- src/tensorflow_translators.cpp | 69 ++++++++++++++++++++++++++++++++-- src/tensorflow_translators.hpp | 1 + 3 files changed, 69 insertions(+), 4 deletions(-) diff --git a/src/ov_extension.cpp b/src/ov_extension.cpp index a599c547c..60ea5ca3e 100644 --- a/src/ov_extension.cpp +++ b/src/ov_extension.cpp @@ -22,7 +22,8 @@ std::make_shared("StringSplitV2", translate_string_split), \ std::make_shared("RaggedTensorToTensor", translate_ragged_tensor_to_tensor), \ std::make_shared("Equal", translate_equal), \ - std::make_shared("StringToHashBucketFast", translate_string_to_hash_bucket_fast) + std::make_shared("StringToHashBucketFast", translate_string_to_hash_bucket_fast), \ + std::make_shared("Squeeze", translate_squeeze_op) // clang-format off //! [ov_extension:entry_point] diff --git a/src/tensorflow_translators.cpp b/src/tensorflow_translators.cpp index a0d64270c..a7aea2ad6 100644 --- a/src/tensorflow_translators.cpp +++ b/src/tensorflow_translators.cpp @@ -2,6 +2,7 @@ // SPDX-License-Identifier: Apache-2.0 // +#include "openvino/frontend/complex_type_mark.hpp" #include "openvino/frontend/hash_table.hpp" #include "openvino/op/util/framework_node.hpp" @@ -28,6 +29,7 @@ using namespace TemplateExtension; using namespace ov; +using namespace ov::op; using namespace ov::frontend; using namespace ov::opset13; @@ -40,6 +42,17 @@ namespace { FRONT_END_GENERAL_CHECK(const_value.size() == 1, "Conversion expects " + const_name + " to be a scalar."); return const_value[0]; } + + Output compute_subgraph_scalar_rank(const Output& output, element::Type output_type, bool as_scalar) { + auto shape_of = std::make_shared(output, output_type); + auto rank_of = std::make_shared(shape_of, output_type); + + if (as_scalar) { + auto const_zero = std::make_shared(element::i32, Shape{}, 0); + return std::make_shared(rank_of, const_zero); + } + return rank_of; + } } // namespace OutputVector translate_sentencepiece_op(const NodeContext& node) { @@ -310,6 +323,7 @@ NamedOutputVector translate_string_split(const ov::frontend::NodeContext& node) FRONT_END_GENERAL_CHECK(node.get_input_size() == 2, "StringSplitV2 expects two inputs (1D input and separator)"); auto input = node.get_input(0); ov::OutputVector unpacked_input = pre_translate_string_tensor_input(input); + auto begins = unpacked_input[0]; auto sep_const = ov::as_type_ptr(node.get_input(1).get_node_shared_ptr()); FRONT_END_GENERAL_CHECK(sep_const, "[TensorFlow Frontend] internal error: only constant separator is supported for StringSplitV2"); auto sep_value = sep_const->cast_vector(); @@ -331,7 +345,7 @@ NamedOutputVector translate_string_split(const ov::frontend::NodeContext& node) auto maxsplit = node.get_attribute("maxsplit", -1); // compute batch_dim to generate ragged_begins and ragged_ends for RegexSplit - auto input_shape = std::make_shared(input, element::i32); + auto input_shape = std::make_shared(begins, element::i32); auto squeeze_axis = std::make_shared(element::i32, Shape{ 1 }, std::vector{0}); auto batch_dim = std::make_shared(input_shape, squeeze_axis); auto zero_const = std::make_shared(element::i32, Shape{}, std::vector{0}); @@ -486,8 +500,7 @@ ov::OutputVector translate_equal(const ov::frontend::NodeContext& node) { inputs.insert(inputs.end(), unpacked_input2.begin(), unpacked_input2.end()); auto equal_str = std::make_shared(inputs)->output(0); - auto const_one = std::make_shared(ov::element::i32, ov::Shape{}, 1); - result = std::make_shared(equal_str, const_one); + result = std::make_shared(equal_str, element::boolean); } else { result = std::make_shared(input1, input2)->output(0); @@ -516,3 +529,53 @@ ov::OutputVector translate_string_to_hash_bucket_fast(const ov::frontend::NodeCo result.set_names({ node_name + ":0" }); return { result }; } + +OutputVector translate_squeeze_op(const NodeContext& node) { + auto input = node.get_input(0); + auto complex_type_mark = as_type_ptr(input.get_node_shared_ptr()); + auto node_name = node.get_name(); + std::vector axes; + if (node.has_attribute("axis")) { + axes = node.get_attribute>("axis", {}); + } + else { + // check deprecated name + axes = node.get_attribute>("squeeze_dims", {}); + } + auto axis_const = std::make_shared(element::i32, Shape{ axes.size() }, axes); + + if (complex_type_mark) { + element::Type complex_part_type = complex_type_mark->get_complex_part_type(); + input = complex_type_mark->input_value(0); + + auto input_rank = compute_subgraph_scalar_rank(input, element::i32, true); + auto const_one = std::make_shared(element::i32, Shape{}, 1); + auto input_rank_minus_one = std::make_shared(input_rank, const_one)->output(0); + + // adjust axis to make them non-negative + auto axis_complex = std::make_shared(axis_const, input_rank_minus_one); + + auto squeeze = std::make_shared(input, axis_complex); + set_node_name(node_name, squeeze); + auto squeeze_complex = std::make_shared(squeeze, complex_part_type); + return { squeeze_complex->output(0) }; + } + else if (input.get_element_type() == element::string) { + ov::OutputVector unpacked_input = pre_translate_string_tensor_input(input); + auto begins = unpacked_input[0]; + auto ends = unpacked_input[1]; + auto chars = unpacked_input[2]; + + // squeeze begins and ends by given dimensions + begins = std::make_shared(begins, axis_const); + ends = std::make_shared(ends, axis_const); + + auto string_pack_result = post_translate_string_tensor_output(OutputVector{ begins, ends, chars }); + set_node_name(node_name, string_pack_result.get_node_shared_ptr()); + return { string_pack_result }; + } + + auto squeeze = std::make_shared(input, axis_const); + set_node_name(node_name, squeeze); + return { squeeze }; +} diff --git a/src/tensorflow_translators.hpp b/src/tensorflow_translators.hpp index a9eca2f6b..bd00bfcfa 100644 --- a/src/tensorflow_translators.hpp +++ b/src/tensorflow_translators.hpp @@ -18,3 +18,4 @@ ov::OutputVector translate_string_lower(const ov::frontend::NodeContext& node); ov::OutputVector translate_ragged_tensor_to_tensor(const ov::frontend::NodeContext& node); ov::OutputVector translate_equal(const ov::frontend::NodeContext& node); ov::OutputVector translate_string_to_hash_bucket_fast(const ov::frontend::NodeContext& node); +ov::OutputVector translate_squeeze_op(const ov::frontend::NodeContext& node);