Skip to content

Commit

Permalink
[TF FE][Tokenizers] Optimize TF FE extensions (openvinotoolkit#232)
Browse files Browse the repository at this point in the history
Avoid extra StringTensorPack and StringTensorUnpack operations in a graph

Signed-off-by: Kazantsev, Roman <[email protected]>
  • Loading branch information
rkazants authored Aug 26, 2024
1 parent f7cd828 commit 612cba8
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 4 deletions.
3 changes: 2 additions & 1 deletion src/ov_extension.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
std::make_shared<ov::frontend::ConversionExtension>("StringSplitV2", translate_string_split), \
std::make_shared<ov::frontend::ConversionExtension>("RaggedTensorToTensor", translate_ragged_tensor_to_tensor), \
std::make_shared<ov::frontend::ConversionExtension>("Equal", translate_equal), \
std::make_shared<ov::frontend::ConversionExtension>("StringToHashBucketFast", translate_string_to_hash_bucket_fast)
std::make_shared<ov::frontend::ConversionExtension>("StringToHashBucketFast", translate_string_to_hash_bucket_fast), \
std::make_shared<ov::frontend::ConversionExtension>("Squeeze", translate_squeeze_op)

// clang-format off
//! [ov_extension:entry_point]
Expand Down
69 changes: 66 additions & 3 deletions src/tensorflow_translators.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// SPDX-License-Identifier: Apache-2.0
//

#include "openvino/frontend/complex_type_mark.hpp"
#include "openvino/frontend/hash_table.hpp"

#include "openvino/op/util/framework_node.hpp"
Expand All @@ -28,6 +29,7 @@

using namespace TemplateExtension;
using namespace ov;
using namespace ov::op;
using namespace ov::frontend;
using namespace ov::opset13;

Expand All @@ -40,6 +42,17 @@ namespace {
FRONT_END_GENERAL_CHECK(const_value.size() == 1, "Conversion expects " + const_name + " to be a scalar.");
return const_value[0];
}

Output<Node> compute_subgraph_scalar_rank(const Output<Node>& output, element::Type output_type, bool as_scalar) {
auto shape_of = std::make_shared<ShapeOf>(output, output_type);
auto rank_of = std::make_shared<ShapeOf>(shape_of, output_type);

if (as_scalar) {
auto const_zero = std::make_shared<Constant>(element::i32, Shape{}, 0);
return std::make_shared<Squeeze>(rank_of, const_zero);
}
return rank_of;
}
} // namespace

OutputVector translate_sentencepiece_op(const NodeContext& node) {
Expand Down Expand Up @@ -310,6 +323,7 @@ NamedOutputVector translate_string_split(const ov::frontend::NodeContext& node)
FRONT_END_GENERAL_CHECK(node.get_input_size() == 2, "StringSplitV2 expects two inputs (1D input and separator)");
auto input = node.get_input(0);
ov::OutputVector unpacked_input = pre_translate_string_tensor_input(input);
auto begins = unpacked_input[0];
auto sep_const = ov::as_type_ptr<Constant>(node.get_input(1).get_node_shared_ptr());
FRONT_END_GENERAL_CHECK(sep_const, "[TensorFlow Frontend] internal error: only constant separator is supported for StringSplitV2");
auto sep_value = sep_const->cast_vector<std::string>();
Expand All @@ -331,7 +345,7 @@ NamedOutputVector translate_string_split(const ov::frontend::NodeContext& node)
auto maxsplit = node.get_attribute<int64_t>("maxsplit", -1);

// compute batch_dim to generate ragged_begins and ragged_ends for RegexSplit
auto input_shape = std::make_shared<ShapeOf>(input, element::i32);
auto input_shape = std::make_shared<ShapeOf>(begins, element::i32);
auto squeeze_axis = std::make_shared<Constant>(element::i32, Shape{ 1 }, std::vector<int32_t>{0});
auto batch_dim = std::make_shared<Squeeze>(input_shape, squeeze_axis);
auto zero_const = std::make_shared<Constant>(element::i32, Shape{}, std::vector<int32_t>{0});
Expand Down Expand Up @@ -486,8 +500,7 @@ ov::OutputVector translate_equal(const ov::frontend::NodeContext& node) {
inputs.insert(inputs.end(), unpacked_input2.begin(), unpacked_input2.end());

auto equal_str = std::make_shared<EqualStr>(inputs)->output(0);
auto const_one = std::make_shared<Constant>(ov::element::i32, ov::Shape{}, 1);
result = std::make_shared<Equal>(equal_str, const_one);
result = std::make_shared<Convert>(equal_str, element::boolean);
}
else {
result = std::make_shared<Equal>(input1, input2)->output(0);
Expand Down Expand Up @@ -516,3 +529,53 @@ ov::OutputVector translate_string_to_hash_bucket_fast(const ov::frontend::NodeCo
result.set_names({ node_name + ":0" });
return { result };
}

OutputVector translate_squeeze_op(const NodeContext& node) {
auto input = node.get_input(0);
auto complex_type_mark = as_type_ptr<ComplexTypeMark>(input.get_node_shared_ptr());
auto node_name = node.get_name();
std::vector<int64_t> axes;
if (node.has_attribute("axis")) {
axes = node.get_attribute<std::vector<int64_t>>("axis", {});
}
else {
// check deprecated name
axes = node.get_attribute<std::vector<int64_t>>("squeeze_dims", {});
}
auto axis_const = std::make_shared<Constant>(element::i32, Shape{ axes.size() }, axes);

if (complex_type_mark) {
element::Type complex_part_type = complex_type_mark->get_complex_part_type();
input = complex_type_mark->input_value(0);

auto input_rank = compute_subgraph_scalar_rank(input, element::i32, true);
auto const_one = std::make_shared<Constant>(element::i32, Shape{}, 1);
auto input_rank_minus_one = std::make_shared<Subtract>(input_rank, const_one)->output(0);

// adjust axis to make them non-negative
auto axis_complex = std::make_shared<FloorMod>(axis_const, input_rank_minus_one);

auto squeeze = std::make_shared<Squeeze>(input, axis_complex);
set_node_name(node_name, squeeze);
auto squeeze_complex = std::make_shared<ComplexTypeMark>(squeeze, complex_part_type);
return { squeeze_complex->output(0) };
}
else if (input.get_element_type() == element::string) {
ov::OutputVector unpacked_input = pre_translate_string_tensor_input(input);
auto begins = unpacked_input[0];
auto ends = unpacked_input[1];
auto chars = unpacked_input[2];

// squeeze begins and ends by given dimensions
begins = std::make_shared<Squeeze>(begins, axis_const);
ends = std::make_shared<Squeeze>(ends, axis_const);

auto string_pack_result = post_translate_string_tensor_output(OutputVector{ begins, ends, chars });
set_node_name(node_name, string_pack_result.get_node_shared_ptr());
return { string_pack_result };
}

auto squeeze = std::make_shared<Squeeze>(input, axis_const);
set_node_name(node_name, squeeze);
return { squeeze };
}
1 change: 1 addition & 0 deletions src/tensorflow_translators.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,4 @@ ov::OutputVector translate_string_lower(const ov::frontend::NodeContext& node);
ov::OutputVector translate_ragged_tensor_to_tensor(const ov::frontend::NodeContext& node);
ov::OutputVector translate_equal(const ov::frontend::NodeContext& node);
ov::OutputVector translate_string_to_hash_bucket_fast(const ov::frontend::NodeContext& node);
ov::OutputVector translate_squeeze_op(const ov::frontend::NodeContext& node);

0 comments on commit 612cba8

Please sign in to comment.