From 6185b55d7adf61125f21fe712aede343472e7a34 Mon Sep 17 00:00:00 2001 From: Roman Kazantsev Date: Mon, 23 Sep 2024 15:00:08 +0400 Subject: [PATCH] [TF FE] Fix ArgMin and ArgMax translators and stabilize tests (#26725) **Details:** Fix ArgMin and ArgMax translators and stabilize tests **Ticket:** TBD --------- Signed-off-by: Kazantsev, Roman --- .../tensorflow_common/src/op/arg_min_max.cpp | 17 +++- .../tensorflow_tests/test_tf_ArgMinMax.py | 79 +++++++++++-------- 2 files changed, 59 insertions(+), 37 deletions(-) diff --git a/src/frontends/tensorflow_common/src/op/arg_min_max.cpp b/src/frontends/tensorflow_common/src/op/arg_min_max.cpp index af7d2986524492..e2b53a8dbcceaa 100644 --- a/src/frontends/tensorflow_common/src/op/arg_min_max.cpp +++ b/src/frontends/tensorflow_common/src/op/arg_min_max.cpp @@ -4,6 +4,7 @@ #include "common_op_table.hpp" #include "openvino/op/constant.hpp" +#include "openvino/op/convert.hpp" #include "openvino/op/squeeze.hpp" #include "openvino/op/topk.hpp" @@ -31,16 +32,26 @@ OutputVector translate_arg_min_max(const NodeContext& node, std::string mode) { axis = axes[0]; } auto output_type = node.get_attribute("output_type", element::i64); + auto topk_output_type = output_type; + if (topk_output_type != element::i32 && topk_output_type != element::i64) { + // OV TopK supports only element::i32 and element::i64 + // so use temporarily element::i64 + topk_output_type = element::i64; + } // compute indices of max/min values using TopK auto k = make_shared(element::i64, Shape{}, 1); auto top_k_mode = (mode == "max" ? v11::TopK::Mode::MAX : v11::TopK::Mode::MIN); auto sort_type = v11::TopK::SortType::SORT_VALUES; - auto top_k = make_shared(input, k, axis, top_k_mode, sort_type, output_type, true); + auto top_k = make_shared(input, k, axis, top_k_mode, sort_type, topk_output_type, true); auto axis_to_remove = make_shared(element::i64, Shape{1}, vector({axis})); - auto res = make_shared(top_k->output(1), axis_to_remove); - set_node_name(node.get_name(), res); + auto res = make_shared(top_k->output(1), axis_to_remove)->output(0); + if (topk_output_type != output_type) { + // use required type for the output + res = make_shared(res, output_type); + } + set_node_name(node.get_name(), res.get_node_shared_ptr()); return {res}; } diff --git a/tests/layer_tests/tensorflow_tests/test_tf_ArgMinMax.py b/tests/layer_tests/tensorflow_tests/test_tf_ArgMinMax.py index b9cec257e89e0b..785ef72a60f3a1 100644 --- a/tests/layer_tests/tensorflow_tests/test_tf_ArgMinMax.py +++ b/tests/layer_tests/tensorflow_tests/test_tf_ArgMinMax.py @@ -1,43 +1,45 @@ # Copyright (C) 2022 Intel Corporation # SPDX-License-Identifier: Apache-2.0 -import platform - import numpy as np +import platform import pytest import tensorflow as tf from common.tf_layer_test_class import CommonTFLayerTest - # Testing operation ArgMin, ArgMax (Initial Implementation) # Documentation: https://www.tensorflow.org/api_docs/python/tf/raw_ops/ArgMin # https://www.tensorflow.org/api_docs/python/tf/raw_ops/ArgMax -OPS = { - 'tf.raw_ops.ArgMax': tf.raw_ops.ArgMax, - 'tf.raw_ops.ArgMin': tf.raw_ops.ArgMin -} +rng = np.random.default_rng(2323534) + class TestArgMinMax(CommonTFLayerTest): def _prepare_input(self, inputs_info): - assert 'input:0' in inputs_info - input_shape = inputs_info['input:0'] + assert 'input:0' in inputs_info, "Test error: inputs_info must contain `input`" + x_shape = inputs_info['input:0'] inputs_data = {} - rng = np.random.default_rng() - inputs_data['input:0'] = rng.integers(-8, 8, input_shape).astype(self.input_type) + if np.issubdtype(self.input_type, np.floating): + inputs_data['input:0'] = rng.uniform(-5.0, 5.0, x_shape).astype(self.input_type) + elif np.issubdtype(self.input_type, np.signedinteger): + inputs_data['input:0'] = rng.integers(-8, 8, x_shape).astype(self.input_type) + else: + inputs_data['input:0'] = rng.integers(0, 8, x_shape).astype(self.input_type) return inputs_data def create_argmin_max_net(self, input_shape, dimension, input_type, output_type, op_type): + OPS = { + 'tf.raw_ops.ArgMax': tf.raw_ops.ArgMax, + 'tf.raw_ops.ArgMin': tf.raw_ops.ArgMin + } + self.input_type = input_type tf.compat.v1.reset_default_graph() # Create the graph and model with tf.compat.v1.Session() as sess: - tf_input = tf.compat.v1.placeholder(input_type, input_shape, 'input') - tf_dimension = tf.constant(dimension) - - op_type(input=tf_input, dimension=tf_dimension, output_type=output_type) - + input = tf.compat.v1.placeholder(input_type, input_shape, 'input') + OPS[op_type](input=input, dimension=dimension, output_type=output_type) tf.compat.v1.global_variables_initializer() tf_net = sess.graph_def @@ -45,25 +47,34 @@ def create_argmin_max_net(self, input_shape, dimension, input_type, output_type, return tf_net, ref_net - test_data = [ - [[20], 0], - [[20, 30], 1], - [[2, 30, 3, 4], 2], - ] - - @pytest.mark.parametrize("input_shape, dimension", test_data) - @pytest.mark.parametrize("input_type", [np.float32, np.int32]) - @pytest.mark.parametrize("output_type", [tf.int32, tf.int64]) - @pytest.mark.parametrize("op_type", ['tf.raw_ops.ArgMax', 'tf.raw_ops.ArgMin']) + @pytest.mark.parametrize('input_shape, dimension', [([10], 0), + ([10, 15], 1), + ([10, 15, 20], 2), + ([10, 15, 20], -1) + ]) + @pytest.mark.parametrize('input_type', [np.float16, np.float32, np.float64, + np.int32, np.uint8, np.int16, np.int8, np.int64, + np.uint16, np.uint32, np.uint64]) + @pytest.mark.parametrize('op_type, output_type', [('tf.raw_ops.ArgMax', np.int16), + ('tf.raw_ops.ArgMax', np.uint16), + ('tf.raw_ops.ArgMax', np.int32), + ('tf.raw_ops.ArgMax', np.int64), + # tf.raw_ops.ArgMin supports only int32 and int64 + ('tf.raw_ops.ArgMin', np.int32), + ('tf.raw_ops.ArgMin', np.int64) + ]) @pytest.mark.precommit @pytest.mark.nightly - @pytest.mark.xfail(condition=platform.system() in ('Darwin', 'Linux') and platform.machine() in ['arm', 'armv7l', - 'aarch64', - 'arm64', 'ARM64'], - reason='Ticket - 126314, 132699') - def test_argmin_max_net(self, input_shape, dimension, input_type, output_type, op_type, ie_device, precision, ir_version, temp_dir, use_legacy_frontend): - params = dict(input_shape=input_shape, dimension=dimension) - self._test(*self.create_argmin_max_net(**params, input_type=input_type, - output_type=output_type, op_type=OPS[op_type]), + def test_argmin_max_net(self, input_shape, dimension, input_type, output_type, op_type, + ie_device, precision, ir_version, temp_dir, use_legacy_frontend): + if platform.machine() in ['aarch64', 'arm64', 'ARM64']: + pytest.skip('153077: Segmentation fault on ARM') + if ie_device == 'GPU' and input_type == np.uint8: + pytest.skip('153078: No layout format available for topk') + if ie_device == 'GPU' and input_type == np.float32 and input_shape == [10, 15, 20]: + pytest.skip('153079: Accuracy error on GPU') + self._test(*self.create_argmin_max_net(input_shape=input_shape, dimension=dimension, + input_type=input_type, output_type=output_type, + op_type=op_type), ie_device, precision, ir_version, temp_dir=temp_dir, use_legacy_frontend=use_legacy_frontend)