From 27bf49435539442a5ef75a8d87a6ec36837b6ae6 Mon Sep 17 00:00:00 2001 From: Ekaterina Aidova Date: Thu, 14 Dec 2023 15:42:32 +0400 Subject: [PATCH] [PT FE]: support aten::take_along_dim (#21625) --- .../pytorch/src/op/take_along_dim.cpp | 56 +++++++++++++++++++ src/frontends/pytorch/src/op_table.cpp | 2 + .../pytorch_tests/test_take_along_dim.py | 56 +++++++++++++++++++ 3 files changed, 114 insertions(+) create mode 100644 src/frontends/pytorch/src/op/take_along_dim.cpp create mode 100644 tests/layer_tests/pytorch_tests/test_take_along_dim.py diff --git a/src/frontends/pytorch/src/op/take_along_dim.cpp b/src/frontends/pytorch/src/op/take_along_dim.cpp new file mode 100644 index 00000000000000..b9bc20846e6a30 --- /dev/null +++ b/src/frontends/pytorch/src/op/take_along_dim.cpp @@ -0,0 +1,56 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "openvino/frontend/pytorch/node_context.hpp" +#include "openvino/op/broadcast.hpp" +#include "openvino/op/constant.hpp" +#include "openvino/op/convert.hpp" +#include "openvino/op/gather_elements.hpp" +#include "openvino/op/reshape.hpp" +#include "openvino/op/scatter_update.hpp" +#include "openvino/op/shape_of.hpp" +#include "utils.hpp" + +namespace ov { +namespace frontend { +namespace pytorch { +namespace op { + +OutputVector translate_take_along_dim(const NodeContext& context) { + // aten::take_along_dim(Tensor self, Tensor indices, int? dim=None) -> Tensor + // aten::take_along_dim.out(Tensor self, Tensor indices, int? dim=None, *, Tensor(a!) out) -> Tensor(a!) + num_inputs_check(context, 3, 4); + auto x = context.get_input(0); + auto index = context.get_input(1); + index = context.mark_node(std::make_shared(index, element::i32)); + int64_t axis = 0; + + if (context.input_is_none(2)) { + // if dimension is not provided, flattenize input first + auto minus_1 = context.mark_node(ov::op::v0::Constant::create(element::i32, Shape{1}, {-1})); + x = context.mark_node(std::make_shared(x, minus_1, false)); + } else { + axis = context.const_input(2); + // OpenVINO GatherElements requires to have equal dims between index and input except specified axis + // while PyTorch requires to have them broadcastable + auto axis_node = context.mark_node(ov::op::v0::Constant::create(element::i32, Shape{1}, {axis})); + auto const_1 = context.mark_node(ov::op::v0::Constant::create(element::i32, Shape{1}, {1})); + auto const_0 = context.mark_node(ov::op::v0::Constant::create(element::i32, Shape{1}, {0})); + auto x_shape = context.mark_node(std::make_shared(x, element::i32)); + auto broadcast_shape = + context.mark_node(std::make_shared(x_shape, axis_node, const_1, const_0)); + index = context.mark_node( + std::make_shared(index, broadcast_shape, ov::op::BroadcastType::BIDIRECTIONAL)); + } + auto gather_elements = context.mark_node(std::make_shared(x, index, axis)); + if (!context.input_is_none(3)) { + context.mutate_input(3, gather_elements); + } + return {gather_elements}; +}; + +} // namespace op +} // namespace pytorch +} // namespace frontend +} // namespace ov \ No newline at end of file diff --git a/src/frontends/pytorch/src/op_table.cpp b/src/frontends/pytorch/src/op_table.cpp index d42839a92fde35..4ddf3a5b68a9b5 100644 --- a/src/frontends/pytorch/src/op_table.cpp +++ b/src/frontends/pytorch/src/op_table.cpp @@ -197,6 +197,7 @@ OP_CONVERTER(translate_sub); OP_CONVERTER(translate_sub_); OP_CONVERTER(translate_sum); OP_CONVERTER(translate_t); +OP_CONVERTER(translate_take_along_dim); OP_CONVERTER(translate_to); OP_CONVERTER(translate_topk); OP_CONVERTER(translate_transpose); @@ -536,6 +537,7 @@ const std::map get_supported_ops_ts() { {"aten::swapaxes", op::quantizable_op}, {"aten::t", op::translate_t}, {"aten::t_", op::inplace_op}, + {"aten::take_along_dim", op::translate_take_along_dim}, {"aten::tan", op::translate_1to1_match_1_inputs_with_fp32_type_alignment}, {"aten::tan_", op::inplace_op>}, {"aten::tanh", op::translate_1to1_match_1_inputs_with_fp32_type_alignment}, diff --git a/tests/layer_tests/pytorch_tests/test_take_along_dim.py b/tests/layer_tests/pytorch_tests/test_take_along_dim.py new file mode 100644 index 00000000000000..81843750802f13 --- /dev/null +++ b/tests/layer_tests/pytorch_tests/test_take_along_dim.py @@ -0,0 +1,56 @@ +import pytest + +from pytorch_layer_test_class import PytorchLayerTest + +class TestTakeAlongDim(PytorchLayerTest): + def _prepare_input(self, m, n, max_val, out=False, flattenize=False): + import numpy as np + index = np.random.randint(0, max_val, (m, n) if not flattenize else (m*n, )) + inp = np.random.randn(m, n).astype(np.float32) + if out: + axis = int(max_val == n) + if flattenize: + out = np.zeros_like(np.take(inp, index)) + else: + out = np.zeros_like(np.take(inp, index, axis)) + return (inp, index, out) + return (inp, index) + + def create_model(self, axis, out): + import torch + + class aten_take_along_dim(torch.nn.Module): + def __init__(self, axis, out=False): + super(aten_take_along_dim, self).__init__() + self.axis = axis + if self.axis is None: + self.forward = self.forward_no_dim + if out: + self.forward = self.forward_out if self.axis is not None else self.forward_no_dim_out + + def forward(self, x, index): + return torch.take_along_dim(x, index, dim=self.axis) + + def forward_out(self, x, index, out): + return torch.take_along_dim(x, index, dim=self.axis, out=out), out + + def forward_no_dim(self, x, index): + return torch.take_along_dim(x, index) + + def forward_no_dim_out(self, x, index, out): + return torch.take_along_dim(x, index, out=out) + + ref_net = None + + return aten_take_along_dim(axis, out), ref_net, "aten::take_along_dim" + + @pytest.mark.nightly + @pytest.mark.precommit + @pytest.mark.parametrize("m", [2, 10, 100]) + @pytest.mark.parametrize("n", [2, 10, 100]) + @pytest.mark.parametrize("axis", [0, 1, None]) + @pytest.mark.parametrize("out", [True, False]) + def test_gather(self, m, n, axis, out, ie_device, precision, ir_version): + self._test(*self.create_model(axis, out), ie_device, precision, ir_version, kwargs_to_prepare_input={ + "m": m, "n": n, "max_val": m if axis == 0 else n, "out": out, "flattenize": axis is None + }) \ No newline at end of file