Skip to content

Commit

Permalink
[TF FE] Support Conj operation (openvinotoolkit#21947)
Browse files Browse the repository at this point in the history
* Conjugate

* Moved common logic into make_conj helper.

* Update src/frontends/tensorflow_common/src/op/conj_transpose.cpp

* Moved helper to conj_transpose

* Applied helper to both conj and conj_transpose

* Deleted conj.cpp

* Update src/frontends/tensorflow_common/src/op/conj_transpose.cpp

Co-authored-by: Roman Kazantsev <roman.kazantsev@intel.com>

* Update src/frontends/tensorflow_common/src/op/conj_transpose.cpp

Co-authored-by: Roman Kazantsev <roman.kazantsev@intel.com>

* Removed additional Shape:: scope resolution from get_conj helper

* Added Conj and ConjugateTranspose to supported ops

* Update src/frontends/tensorflow/src/op_table.cpp

Change "Conjugate" to "Conj"

Co-authored-by: Roman Kazantsev <roman.kazantsev@intel.com>

* Apply suggestions from code review

* Removed perms call and moved test data directly to parametrize macro

* Apply suggestions from code review

* Apply suggestions from code review

* Changed input types from float32 to complex64

* Changed input type back to np.float32 and removed real tensor test

* Update src/frontends/tensorflow_common/src/op/conj_transpose.cpp

---------

Co-authored-by: Roman Kazantsev <roman.kazantsev@intel.com>
Co-authored-by: Michal Lukaszewski <michal.lukaszewski@intel.com>
  • Loading branch information
3 people authored Jan 31, 2024
1 parent 4e302aa commit 79b4645
Show file tree
Hide file tree
Showing 5 changed files with 101 additions and 12 deletions.
4 changes: 2 additions & 2 deletions src/frontends/tensorflow/docs/supported_ops.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# TensorFlow Operations Supported by OpenVINO TensorFlow Frontend
# TensorFlow Operations Supported by OpenVINO TensorFlow Frontend

Here is a table of operations supported by the TensorFlow Frontend from [tf.raw_ops](https://www.tensorflow.org/api_docs/python/tf/raw_ops).
A "supported operation" is one that TensorFlow Frontend can convert to the OpenVINO representation.
Expand Down Expand Up @@ -232,7 +232,7 @@ A "supported operation" is one that TensorFlow Frontend can convert to the OpenV
| ConditionalAccumulator | NO | |
| ConfigureDistributedTPU | NO | |
| ConfigureTPUEmbedding | NO | |
| Conj | NO | |
| Conj | YES | |
| ConjugateTranspose | YES | |
| Const | YES | |
| ConsumeMutexLock | NO | |
Expand Down
1 change: 1 addition & 0 deletions src/frontends/tensorflow/src/op_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
{"ClipByValue", CreatorFunction(translate_clip_by_value_op)},
{"Complex", CreatorFunction(translate_complex_op)},
{"ComplexAbs", CreatorFunction(translate_complex_abs_op)},
{"Conj", CreatorFunction(translate_conj_op)},
{"ConjugateTranspose", CreatorFunction(translate_conj_transpose_op)},
{"Concat", CreatorFunction(translate_concat_op)},
{"ConcatV2", CreatorFunction(translate_concat_op)},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ OP_CONVERTER(translate_clip_by_value_op);
OP_CONVERTER(translate_complex_op);
OP_CONVERTER(translate_complex_abs_op);
OP_CONVERTER(translate_concat_op);
OP_CONVERTER(translate_conj_op);
OP_CONVERTER(translate_conj_transpose_op);
OP_CONVERTER(translate_const_op);
OP_CONVERTER(translate_conv_2d_op);
Expand Down
46 changes: 36 additions & 10 deletions src/frontends/tensorflow_common/src/op/conj_transpose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,33 +14,59 @@
using namespace std;
using namespace ov::op;

std::shared_ptr<ov::op::v0::Concat> get_conj_ptr(const ov::Output<ov::Node>& node) {
auto real_index = make_shared<ov::op::v0::Constant>(ov::element::i32, ov::Shape{1}, 0);
auto imag_index = make_shared<v0::Constant>(ov::element::i32, ov::Shape{1}, 1);

auto gather_axis = make_shared<v0::Constant>(ov::element::i32, ov::Shape{1}, -1);

auto real = make_shared<v8::Gather>(node, real_index, gather_axis)->output(0);
auto imag = make_shared<v8::Gather>(node, imag_index, gather_axis)->output(0);

imag = make_shared<v0::Negative>(imag);

auto conj = make_shared<v0::Concat>(ov::OutputVector{real, imag}, -1);
return conj;
}

namespace ov {
namespace frontend {
namespace tensorflow {
namespace op {

OutputVector translate_conj_transpose_op(const NodeContext& node) {
default_op_checks(node, 2, {"ConjugateTranspose"}, true);
OutputVector translate_conj_op(const NodeContext& node) {
default_op_checks(node, 1, {"Conj"}, true);

auto x = node.get_input(0);
auto perm = node.get_input(1);

auto complex_type_mark = as_type_ptr<ComplexTypeMark>(x.get_node_shared_ptr());

std::shared_ptr<Node> conj{x.get_node_shared_ptr()};
if (complex_type_mark) {
element::Type complex_part_type = complex_type_mark->get_complex_part_type();
auto x = complex_type_mark->input_value(0);
auto conj = get_conj_ptr(x);

auto real_index = make_shared<v0::Constant>(element::i32, Shape{1}, 0);
auto imag_index = make_shared<v0::Constant>(element::i32, Shape{1}, 1);
set_node_name(node.get_name(), conj);
auto complex_conj = make_shared<ComplexTypeMark>(conj, complex_part_type);
return {complex_conj->output(0)};
}

auto gather_axis = make_shared<v0::Constant>(element::i32, Shape{1}, -1);
set_node_name(node.get_name(), conj);
return {conj};
}

auto real = make_shared<v8::Gather>(x, real_index, gather_axis)->output(0);
auto imag = make_shared<v8::Gather>(x, imag_index, gather_axis)->output(0);
OutputVector translate_conj_transpose_op(const NodeContext& node) {
default_op_checks(node, 2, {"ConjugateTranspose"}, true);

imag = make_shared<v0::Negative>(imag);
auto x = node.get_input(0);
auto perm = node.get_input(1);

auto conj_tensor = make_shared<v0::Concat>(OutputVector{real, imag}, -1)->output(0);
auto complex_type_mark = as_type_ptr<ComplexTypeMark>(x.get_node_shared_ptr());
if (complex_type_mark) {
element::Type complex_part_type = complex_type_mark->get_complex_part_type();
auto x = complex_type_mark->input_value(0);
auto conj_tensor = get_conj_ptr(x);

OutputVector concat_inputs;
concat_inputs.push_back(perm);
Expand Down
61 changes: 61 additions & 0 deletions tests/layer_tests/tensorflow_tests/test_tf_Conj.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import pytest
import numpy as np
import tensorflow as tf
from common.tf_layer_test_class import CommonTFLayerTest

# Testing operation Conj
# Documentation: https://www.tensorflow.org/api_docs/python/tf/raw_ops/Conj

class TestComplexConjugate(CommonTFLayerTest):

def _prepare_input(self, inputs_info):

rng = np.random.default_rng()
assert 'real_part' in inputs_info
real_part_shape = inputs_info['real_part']
assert 'imag_part' in inputs_info
imag_part_shape = inputs_info['imag_part']

inputs_data = {}
inputs_data['real_part'] = 4 * rng.random(real_part_shape).astype(np.float32) - 2
inputs_data['imag_part'] = 4 * rng.random(imag_part_shape).astype(np.float32) - 2

return inputs_data
def create_complex_conjugate_net(self, input_shape):
"""
TensorFlow net IR net
Placeholder->Conjugate => Placeholder->Conjugate
"""

tf.compat.v1.reset_default_graph()

# Create the graph and model
with tf.compat.v1.Session() as sess:
real_part = tf.compat.v1.placeholder(np.float32, input_shape, 'real_part')
imag_part = tf.compat.v1.placeholder(np.float32, input_shape, 'imag_part')

complex_input = tf.raw_ops.Complex(real=real_part, imag=imag_part)

conj= tf.raw_ops.Conj(input=complex_input, name = "Operation")
real = tf.raw_ops.Real(input=conj)
img = tf.raw_ops.Imag(input=conj)

tf.compat.v1.global_variables_initializer()
tf_net = sess.graph_def

ref_net = None

return tf_net, ref_net


@pytest.mark.parametrize("input_shape", [[1,2], [1,2,3], [1,2,3,4], [1,2,3,4,5,6]])
@pytest.mark.precommit_tf_fe
@pytest.mark.nightly
def test_conjugate(self, input_shape, ie_device, precision, ir_version, temp_dir,
use_new_frontend):
self._test(*self.create_complex_conjugate_net(input_shape),
ie_device, precision, ir_version, temp_dir=temp_dir,
use_new_frontend=use_new_frontend)

0 comments on commit 79b4645

Please sign in to comment.