Skip to content

Commit

Permalink
[TF FE] Extend RaggedTensorToSparse conversion extension (openvinotoo…
Browse files Browse the repository at this point in the history
…lkit#103)

* [TF FE] Extend RaggedTensorToSparse conversion extension

Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>

* Remove check for initialized variable

Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>

* Fix tensor name

---------

Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
  • Loading branch information
rkazants authored Apr 4, 2024
1 parent 4dba76d commit 01f15a0
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 33 deletions.
2 changes: 1 addition & 1 deletion src/ov_extension.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
std::make_shared<ov::frontend::tensorflow::ConversionExtension>("NormalizeUTF8", translate_normalize_utf8), \
std::make_shared<ov::frontend::tensorflow::ConversionExtension>("CaseFoldUTF8", translate_case_fold_utf8), \
std::make_shared<ov::frontend::tensorflow::ConversionExtension>("SentencepieceOp", translate_sentencepiece_op), \
std::make_shared<ov::frontend::tensorflow::ConversionExtension>("RaggedTensorToSparse", translate_sentencepiece_tokenizer), \
std::make_shared<ov::frontend::tensorflow::ConversionExtension>("RaggedTensorToSparse", translate_ragged_tensor_to_sparse), \
std::make_shared<ov::frontend::tensorflow::ConversionExtension>("StringLower", translate_string_lower), \
std::make_shared<ov::frontend::tensorflow::ConversionExtension>("StaticRegexReplace", translate_static_regex_replace), \
std::make_shared<ov::frontend::tensorflow::ConversionExtension>("LookupTableFind", translate_lookup_table_find_op), \
Expand Down
99 changes: 68 additions & 31 deletions src/tensorflow_translators.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,49 +51,90 @@ OutputVector translate_sentencepiece_op(const NodeContext& node) {
return { sp_model_const };
}

NamedOutputVector translate_sentencepiece_tokenizer(const NodeContext& node) {
NamedOutputVector translate_ragged_tensor_to_sparse(const NodeContext& node) {
// this is custom translator that converts a sub-graph with SentencePieceOp, SentencePieceTokenizer,
// and RaggedTensorToSparse operation- into a custom operation SentencepieceTokenizerExtensionOp
FRONT_END_GENERAL_CHECK(node.get_input_size() > 0, "RaggedTensorToSparse expects at least one input.");
auto node_name = node.get_name();

// check that producers of RaggedTensorToSparse is SentencePieceTokenizer
auto sp_tokenize_op = node.get_input(0).get_node_shared_ptr();
FRONT_END_GENERAL_CHECK(sp_tokenize_op->get_input_size() > 6,
"SentencepieceTokenizeOp expects at least six inputs");
ov::Output<ov::Node> sparse_indices, sparse_values, sparse_dense_shape;
if (ov::as_type_ptr<ov::op::util::FrameworkNode>(node.get_input(0).get_node_shared_ptr())) {
auto sp_tokenize_op = node.get_input(0).get_node_shared_ptr();

FRONT_END_GENERAL_CHECK(sp_tokenize_op->get_input_size() > 6,
"SentencepieceTokenizeOp expects at least six inputs");

// prepare inputs that go to custom operation
// prepare input 0 - SentencePieceTokenizer configuration model
auto sp_model_const = as_type_ptr<Constant>(sp_tokenize_op->input_value(0).get_node_shared_ptr());
FRONT_END_GENERAL_CHECK(sp_model_const, "Conversion expects SentencePiece model to be constant.");

// prepare input
auto inputs = sp_tokenize_op->input_value(1);

// extract values for nbest_size, alpha, add_bos, add_eos, reverse attributes
auto nbest_size = extract_scalar_const_value<int32_t>(sp_tokenize_op->input_value(2).get_node_shared_ptr(), "nbest_size");
auto alpha = extract_scalar_const_value<float>(sp_tokenize_op->input_value(3).get_node_shared_ptr(), "alpha");
auto add_bos = extract_scalar_const_value<bool>(sp_tokenize_op->input_value(4).get_node_shared_ptr(), "add_bos");
auto add_eos = extract_scalar_const_value<bool>(sp_tokenize_op->input_value(5).get_node_shared_ptr(), "add_eos");
auto reverse = extract_scalar_const_value<bool>(sp_tokenize_op->input_value(6).get_node_shared_ptr(), "reverse");

OutputVector inputs_vector = OutputVector{ sp_model_const, inputs };

// create a node with custom operation
auto sp_tokenizer_ext = std::make_shared<SentencepieceTokenizer>(inputs_vector, nbest_size, alpha, add_bos, add_eos, reverse);
FRONT_END_GENERAL_CHECK(sp_tokenizer_ext->get_output_size() == 3,
"Internal error: SentencepieceTokenizer operation extension must have three outputs.");
sparse_indices = sp_tokenizer_ext->output(0);
sparse_values = sp_tokenizer_ext->output(1);
sparse_dense_shape = sp_tokenizer_ext->output(2);
}
else {
FRONT_END_GENERAL_CHECK(node.get_input_size() == 2, "RaggedTensorToSparse is supported only for one dimension raggedness");
auto rt_nested_splits = node.get_input(0);
auto rt_dense_values = node.get_input(1);

// prepare inputs that go to custom operation
// prepare input 0 - SentencePieceTokenizer configuration model
auto sp_model_const = as_type_ptr<Constant>(sp_tokenize_op->input_value(0).get_node_shared_ptr());
FRONT_END_GENERAL_CHECK(sp_model_const, "Conversion expects SentencePiece model to be constant.");
rt_nested_splits = std::make_shared<Convert>(rt_nested_splits, ov::element::i32);

// prepare input
auto inputs = sp_tokenize_op->input_value(1);
// compute vectors of begins and ends
auto rpt_shape = std::make_shared<ShapeOf>(rt_nested_splits, ov::element::i32)->output(0);
auto const_one = std::make_shared<Constant>(ov::element::i32, Shape{}, 1);
auto rpt_shape_minus_one = std::make_shared<Subtract>(rpt_shape, const_one)->output(0);
auto begins_start = std::make_shared<Constant>(ov::element::i32, Shape{ 1 }, 0);
auto ends_start = std::make_shared<Constant>(ov::element::i32, Shape{ 1 }, 1);
auto step = std::make_shared<Constant>(ov::element::i32, Shape{ 1 }, 1);
auto begins = std::make_shared<Slice>(rt_nested_splits, begins_start, rpt_shape_minus_one, step);
auto ends = std::make_shared<Slice>(rt_nested_splits, ends_start, rpt_shape, step);
auto longest_batch = rpt_shape_minus_one;

// extract values for nbest_size, alpha, add_bos, add_eos, reverse attributes
auto nbest_size = extract_scalar_const_value<int32_t>(sp_tokenize_op->input_value(2).get_node_shared_ptr(), "nbest_size");
auto alpha = extract_scalar_const_value<float>(sp_tokenize_op->input_value(3).get_node_shared_ptr(), "alpha");
auto add_bos = extract_scalar_const_value<bool>(sp_tokenize_op->input_value(4).get_node_shared_ptr(), "add_bos");
auto add_eos = extract_scalar_const_value<bool>(sp_tokenize_op->input_value(5).get_node_shared_ptr(), "add_eos");
auto reverse = extract_scalar_const_value<bool>(sp_tokenize_op->input_value(6).get_node_shared_ptr(), "reverse");
// compute the longest row in a tensor
auto longest_row_size = std::make_shared<Subtract>(ends, begins)->output(0);
auto reduce_axis = std::make_shared<Constant>(ov::element::i32, Shape{ 1 }, 0);
longest_row_size = std::make_shared<ReduceMax>(longest_row_size, reduce_axis, true);

OutputVector inputs_vector = OutputVector{ sp_model_const, inputs };
sparse_dense_shape = std::make_shared<Concat>(ov::OutputVector{ longest_batch, longest_row_size }, 0);
sparse_indices = std::make_shared<RaggedToSparse>(ov::OutputVector{ begins, ends })->output(0);
sparse_values = rt_dense_values;

// create a node with custom operation
auto sp_tokenizer_ext = std::make_shared<SentencepieceTokenizer>(inputs_vector, nbest_size, alpha, add_bos, add_eos, reverse);
FRONT_END_GENERAL_CHECK(sp_tokenizer_ext->get_output_size() == 3,
"Internal error: SentencepieceTokenizer operation extension must have three outputs.");
sparse_indices = std::make_shared<Convert>(sparse_indices, ov::element::i64);
sparse_dense_shape = std::make_shared<Convert>(sparse_dense_shape, ov::element::i64);
}

// set tensor names
sp_tokenizer_ext->output(0).add_names({ node_name + ":0" });
sp_tokenizer_ext->output(1).add_names({ node_name + ":1" });
sp_tokenizer_ext->output(2).add_names({ node_name + ":2" });
sparse_indices.add_names({ node_name + ":0" });
if (!ov::as_type_ptr<Parameter>(sparse_values.get_node_shared_ptr())) {
// for a case without SentencePiece tokenizer
// we must not corrupt input tensor name due to skip connection
sparse_values.add_names({ node_name + ":1" });
}
sparse_dense_shape.add_names({ node_name + ":2" });

// create named outputs for the conversion extension
NamedOutputVector named_results;
named_results.push_back({ "sparse_indices", sp_tokenizer_ext->output(0) });
named_results.push_back({ "sparse_values", sp_tokenizer_ext->output(1) });
named_results.push_back({ "sparse_dense_shape", sp_tokenizer_ext->output(2) });
named_results.push_back({ "sparse_indices", sparse_indices });
named_results.push_back({ "sparse_values", sparse_values });
named_results.push_back({ "sparse_dense_shape", sparse_dense_shape });

return named_results;
}
Expand Down Expand Up @@ -176,10 +217,6 @@ OutputVector translate_lookup_table_find_op(const ov::frontend::tensorflow::Node
node,
table_handle,
"[TensorFlow Frontend] internal error: LookupTableFind operation expects table_handle by the first input");
TENSORFLOW_OP_VALIDATION(
node,
table_handle->is_initialized(),
"[TensorFlow Frontend] internal error: LookupTableFind operation expects initialized table_handle");
auto keys = node.get_input(1);
auto default_value = node.get_input(2);

Expand Down
2 changes: 1 addition & 1 deletion src/tensorflow_translators.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ ov::frontend::NamedOutputVector translate_string_split(const ov::frontend::NodeC
#endif

ov::OutputVector translate_sentencepiece_op(const ov::frontend::NodeContext& node);
ov::frontend::NamedOutputVector translate_sentencepiece_tokenizer(const ov::frontend::NodeContext& node);
ov::frontend::NamedOutputVector translate_ragged_tensor_to_sparse(const ov::frontend::NodeContext& node);
ov::OutputVector translate_case_fold_utf8(const ov::frontend::NodeContext& node);
ov::OutputVector translate_normalize_utf8(const ov::frontend::NodeContext& node);
ov::OutputVector translate_static_regex_replace(const ov::frontend::NodeContext& node);
Expand Down

0 comments on commit 01f15a0

Please sign in to comment.