Skip to content

Commit

Permalink
[TF FE] Fix ArgMin and ArgMax translators and stabilize tests (openvi…
Browse files Browse the repository at this point in the history
…notoolkit#26725)

**Details:** Fix ArgMin and ArgMax translators and stabilize tests

**Ticket:** TBD

---------

Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
  • Loading branch information
rkazants authored Sep 23, 2024
1 parent 17f976d commit 6185b55
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 37 deletions.
17 changes: 14 additions & 3 deletions src/frontends/tensorflow_common/src/op/arg_min_max.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

#include "common_op_table.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/convert.hpp"
#include "openvino/op/squeeze.hpp"
#include "openvino/op/topk.hpp"

Expand Down Expand Up @@ -31,16 +32,26 @@ OutputVector translate_arg_min_max(const NodeContext& node, std::string mode) {
axis = axes[0];
}
auto output_type = node.get_attribute<element::Type>("output_type", element::i64);
auto topk_output_type = output_type;
if (topk_output_type != element::i32 && topk_output_type != element::i64) {
// OV TopK supports only element::i32 and element::i64
// so use temporarily element::i64
topk_output_type = element::i64;
}

// compute indices of max/min values using TopK
auto k = make_shared<v0::Constant>(element::i64, Shape{}, 1);
auto top_k_mode = (mode == "max" ? v11::TopK::Mode::MAX : v11::TopK::Mode::MIN);
auto sort_type = v11::TopK::SortType::SORT_VALUES;
auto top_k = make_shared<v11::TopK>(input, k, axis, top_k_mode, sort_type, output_type, true);
auto top_k = make_shared<v11::TopK>(input, k, axis, top_k_mode, sort_type, topk_output_type, true);

auto axis_to_remove = make_shared<v0::Constant>(element::i64, Shape{1}, vector<int64_t>({axis}));
auto res = make_shared<v0::Squeeze>(top_k->output(1), axis_to_remove);
set_node_name(node.get_name(), res);
auto res = make_shared<v0::Squeeze>(top_k->output(1), axis_to_remove)->output(0);
if (topk_output_type != output_type) {
// use required type for the output
res = make_shared<v0::Convert>(res, output_type);
}
set_node_name(node.get_name(), res.get_node_shared_ptr());
return {res};
}

Expand Down
79 changes: 45 additions & 34 deletions tests/layer_tests/tensorflow_tests/test_tf_ArgMinMax.py
Original file line number Diff line number Diff line change
@@ -1,69 +1,80 @@
# Copyright (C) 2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import platform

import numpy as np
import platform
import pytest
import tensorflow as tf
from common.tf_layer_test_class import CommonTFLayerTest


# Testing operation ArgMin, ArgMax (Initial Implementation)
# Documentation: https://www.tensorflow.org/api_docs/python/tf/raw_ops/ArgMin
# https://www.tensorflow.org/api_docs/python/tf/raw_ops/ArgMax

OPS = {
'tf.raw_ops.ArgMax': tf.raw_ops.ArgMax,
'tf.raw_ops.ArgMin': tf.raw_ops.ArgMin
}
rng = np.random.default_rng(2323534)


class TestArgMinMax(CommonTFLayerTest):
def _prepare_input(self, inputs_info):
assert 'input:0' in inputs_info
input_shape = inputs_info['input:0']
assert 'input:0' in inputs_info, "Test error: inputs_info must contain `input`"
x_shape = inputs_info['input:0']
inputs_data = {}
rng = np.random.default_rng()
inputs_data['input:0'] = rng.integers(-8, 8, input_shape).astype(self.input_type)
if np.issubdtype(self.input_type, np.floating):
inputs_data['input:0'] = rng.uniform(-5.0, 5.0, x_shape).astype(self.input_type)
elif np.issubdtype(self.input_type, np.signedinteger):
inputs_data['input:0'] = rng.integers(-8, 8, x_shape).astype(self.input_type)
else:
inputs_data['input:0'] = rng.integers(0, 8, x_shape).astype(self.input_type)
return inputs_data

def create_argmin_max_net(self, input_shape, dimension, input_type, output_type, op_type):
OPS = {
'tf.raw_ops.ArgMax': tf.raw_ops.ArgMax,
'tf.raw_ops.ArgMin': tf.raw_ops.ArgMin
}

self.input_type = input_type
tf.compat.v1.reset_default_graph()

# Create the graph and model
with tf.compat.v1.Session() as sess:
tf_input = tf.compat.v1.placeholder(input_type, input_shape, 'input')
tf_dimension = tf.constant(dimension)

op_type(input=tf_input, dimension=tf_dimension, output_type=output_type)

input = tf.compat.v1.placeholder(input_type, input_shape, 'input')
OPS[op_type](input=input, dimension=dimension, output_type=output_type)
tf.compat.v1.global_variables_initializer()
tf_net = sess.graph_def

ref_net = None

return tf_net, ref_net

test_data = [
[[20], 0],
[[20, 30], 1],
[[2, 30, 3, 4], 2],
]

@pytest.mark.parametrize("input_shape, dimension", test_data)
@pytest.mark.parametrize("input_type", [np.float32, np.int32])
@pytest.mark.parametrize("output_type", [tf.int32, tf.int64])
@pytest.mark.parametrize("op_type", ['tf.raw_ops.ArgMax', 'tf.raw_ops.ArgMin'])
@pytest.mark.parametrize('input_shape, dimension', [([10], 0),
([10, 15], 1),
([10, 15, 20], 2),
([10, 15, 20], -1)
])
@pytest.mark.parametrize('input_type', [np.float16, np.float32, np.float64,
np.int32, np.uint8, np.int16, np.int8, np.int64,
np.uint16, np.uint32, np.uint64])
@pytest.mark.parametrize('op_type, output_type', [('tf.raw_ops.ArgMax', np.int16),
('tf.raw_ops.ArgMax', np.uint16),
('tf.raw_ops.ArgMax', np.int32),
('tf.raw_ops.ArgMax', np.int64),
# tf.raw_ops.ArgMin supports only int32 and int64
('tf.raw_ops.ArgMin', np.int32),
('tf.raw_ops.ArgMin', np.int64)
])
@pytest.mark.precommit
@pytest.mark.nightly
@pytest.mark.xfail(condition=platform.system() in ('Darwin', 'Linux') and platform.machine() in ['arm', 'armv7l',
'aarch64',
'arm64', 'ARM64'],
reason='Ticket - 126314, 132699')
def test_argmin_max_net(self, input_shape, dimension, input_type, output_type, op_type, ie_device, precision, ir_version, temp_dir, use_legacy_frontend):
params = dict(input_shape=input_shape, dimension=dimension)
self._test(*self.create_argmin_max_net(**params, input_type=input_type,
output_type=output_type, op_type=OPS[op_type]),
def test_argmin_max_net(self, input_shape, dimension, input_type, output_type, op_type,
ie_device, precision, ir_version, temp_dir, use_legacy_frontend):
if platform.machine() in ['aarch64', 'arm64', 'ARM64']:
pytest.skip('153077: Segmentation fault on ARM')
if ie_device == 'GPU' and input_type == np.uint8:
pytest.skip('153078: No layout format available for topk')
if ie_device == 'GPU' and input_type == np.float32 and input_shape == [10, 15, 20]:
pytest.skip('153079: Accuracy error on GPU')
self._test(*self.create_argmin_max_net(input_shape=input_shape, dimension=dimension,
input_type=input_type, output_type=output_type,
op_type=op_type),
ie_device, precision, ir_version, temp_dir=temp_dir,
use_legacy_frontend=use_legacy_frontend)

0 comments on commit 6185b55

Please sign in to comment.