Skip to content

Commit

Permalink
[TF FE] Support StringHashBucketFast operation (#23946)
Browse files Browse the repository at this point in the history
**Details:** Support StringHashBucketFast operation. Needed for several
models from TF Hub. Merge after
openvinotoolkit/openvino_tokenizers#115

**Ticket:**

---------

Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
  • Loading branch information
rkazants authored Apr 10, 2024
1 parent 0b9f1dd commit 7a10b94
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 2 deletions.
2 changes: 1 addition & 1 deletion src/frontends/tensorflow/docs/supported_ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -1227,7 +1227,7 @@ A "supported operation" is one that TensorFlow Frontend can convert to the OpenV
| StringSplitV2<sup><mark style="background-color: #00FF00">NEW</mark></sup> | YES | openvino-tokenizers required |
| StringStrip | NO | |
| StringToHashBucket | NO | |
| StringToHashBucketFast | NO | |
| StringToHashBucketFast | YES | openvino-tokenizers required |
| StringToHashBucketStrong | NO | |
| StringToNumber | NO | |
| StringUpper | NO | |
Expand Down
7 changes: 6 additions & 1 deletion src/frontends/tensorflow/src/op_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,12 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
};

const std::vector<std::string> get_supported_ops_via_tokenizers() {
return {"RaggedTensorToSparse", "RaggedTensorToTensor", "StaticRegexReplace", "StringLower", "StringSplitV2"};
return {"RaggedTensorToSparse",
"RaggedTensorToTensor",
"StaticRegexReplace",
"StringLower",
"StringSplitV2",
"StringToHashBucketFast"};
}
} // namespace op
} // namespace tensorflow
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Copyright (C) 2018-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import platform

import numpy as np
import pytest
import tensorflow as tf
from common.tf_layer_test_class import CommonTFLayerTest
from common.utils.tf_utils import run_in_jenkins

rng = np.random.default_rng()


class TestStringToHashBucketFast(CommonTFLayerTest):
def _prepare_input(self, inputs_info):
assert 'input:0' in inputs_info
input_shape = inputs_info['input:0']
inputs_data = {}
sample_data = rng.choice(self.strings_dictionary, input_shape)
inputs_data['input:0'] = sample_data
return inputs_data

def create_string_to_hash_bucket_fast_net(self, input_shape, strings_dictionary, num_buckets):
self.strings_dictionary = strings_dictionary

tf.compat.v1.reset_default_graph()
with tf.compat.v1.Session() as sess:
input = tf.compat.v1.placeholder(tf.string, input_shape, 'input')
tf.raw_ops.StringToHashBucketFast(input=input, num_buckets=num_buckets)

tf.compat.v1.global_variables_initializer()
tf_net = sess.graph_def

ref_net = None

return tf_net, ref_net

@pytest.mark.parametrize("input_shape", [[], [2], [3, 4], [1, 3, 2]])
@pytest.mark.parametrize("num_buckets", [1, 4, 7, 11])
@pytest.mark.parametrize("strings_dictionary",
[['UPPER CASE SENTENCE', 'lower case sentence', ' UppEr LoweR CAse SENtence', ' '],
['Первое Предложение', 'второе предложение', ' ', ' ТРЕТЬЕ ПРЕДЛОЖЕНИЕ '],
['第一句話在這裡', '第二句話在這裡', '第三句話在這裡'],
['', ' ', '12345 ']])
@pytest.mark.precommit
@pytest.mark.nightly
@pytest.mark.xfail(condition=platform.system() in ('Darwin', 'Linux') and platform.machine() in ['arm', 'armv7l',
'aarch64',
'arm64', 'ARM64'],
reason='Ticket - 126314, 132699')
def test_string_to_hash_bucket_fast(self, input_shape, num_buckets, strings_dictionary, ie_device, precision,
ir_version, temp_dir,
use_legacy_frontend):
if ie_device == 'GPU' or run_in_jenkins():
pytest.skip("operation extension is not supported on GPU")
self._test(*self.create_string_to_hash_bucket_fast_net(input_shape=input_shape, num_buckets=num_buckets,
strings_dictionary=strings_dictionary),
ie_device, precision, ir_version, temp_dir=temp_dir,
use_legacy_frontend=use_legacy_frontend)

0 comments on commit 7a10b94

Please sign in to comment.