Skip to content

Commit

Permalink
[PT FE]: support aten::take_along_dim (openvinotoolkit#21625)
Browse files Browse the repository at this point in the history
  • Loading branch information
eaidova authored Dec 14, 2023
1 parent 9f6c3e9 commit 27bf494
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 0 deletions.
56 changes: 56 additions & 0 deletions src/frontends/pytorch/src/op/take_along_dim.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/op/broadcast.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/convert.hpp"
#include "openvino/op/gather_elements.hpp"
#include "openvino/op/reshape.hpp"
#include "openvino/op/scatter_update.hpp"
#include "openvino/op/shape_of.hpp"
#include "utils.hpp"

namespace ov {
namespace frontend {
namespace pytorch {
namespace op {

OutputVector translate_take_along_dim(const NodeContext& context) {
// aten::take_along_dim(Tensor self, Tensor indices, int? dim=None) -> Tensor
// aten::take_along_dim.out(Tensor self, Tensor indices, int? dim=None, *, Tensor(a!) out) -> Tensor(a!)
num_inputs_check(context, 3, 4);
auto x = context.get_input(0);
auto index = context.get_input(1);
index = context.mark_node(std::make_shared<ov::op::v0::Convert>(index, element::i32));
int64_t axis = 0;

if (context.input_is_none(2)) {
// if dimension is not provided, flattenize input first
auto minus_1 = context.mark_node(ov::op::v0::Constant::create(element::i32, Shape{1}, {-1}));
x = context.mark_node(std::make_shared<ov::op::v1::Reshape>(x, minus_1, false));
} else {
axis = context.const_input<int64_t>(2);
// OpenVINO GatherElements requires to have equal dims between index and input except specified axis
// while PyTorch requires to have them broadcastable
auto axis_node = context.mark_node(ov::op::v0::Constant::create(element::i32, Shape{1}, {axis}));
auto const_1 = context.mark_node(ov::op::v0::Constant::create(element::i32, Shape{1}, {1}));
auto const_0 = context.mark_node(ov::op::v0::Constant::create(element::i32, Shape{1}, {0}));
auto x_shape = context.mark_node(std::make_shared<ov::op::v3::ShapeOf>(x, element::i32));
auto broadcast_shape =
context.mark_node(std::make_shared<ov::op::v3::ScatterUpdate>(x_shape, axis_node, const_1, const_0));
index = context.mark_node(
std::make_shared<ov::op::v3::Broadcast>(index, broadcast_shape, ov::op::BroadcastType::BIDIRECTIONAL));
}
auto gather_elements = context.mark_node(std::make_shared<ov::op::v6::GatherElements>(x, index, axis));
if (!context.input_is_none(3)) {
context.mutate_input(3, gather_elements);
}
return {gather_elements};
};

} // namespace op
} // namespace pytorch
} // namespace frontend
} // namespace ov
2 changes: 2 additions & 0 deletions src/frontends/pytorch/src/op_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ OP_CONVERTER(translate_sub);
OP_CONVERTER(translate_sub_);
OP_CONVERTER(translate_sum);
OP_CONVERTER(translate_t);
OP_CONVERTER(translate_take_along_dim);
OP_CONVERTER(translate_to);
OP_CONVERTER(translate_topk);
OP_CONVERTER(translate_transpose);
Expand Down Expand Up @@ -536,6 +537,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops_ts() {
{"aten::swapaxes", op::quantizable_op<op::translate_transpose>},
{"aten::t", op::translate_t},
{"aten::t_", op::inplace_op<op::translate_t>},
{"aten::take_along_dim", op::translate_take_along_dim},
{"aten::tan", op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Tan>},
{"aten::tan_", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::Tan>>},
{"aten::tanh", op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Tanh>},
Expand Down
56 changes: 56 additions & 0 deletions tests/layer_tests/pytorch_tests/test_take_along_dim.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import pytest

from pytorch_layer_test_class import PytorchLayerTest

class TestTakeAlongDim(PytorchLayerTest):
def _prepare_input(self, m, n, max_val, out=False, flattenize=False):
import numpy as np
index = np.random.randint(0, max_val, (m, n) if not flattenize else (m*n, ))
inp = np.random.randn(m, n).astype(np.float32)
if out:
axis = int(max_val == n)
if flattenize:
out = np.zeros_like(np.take(inp, index))
else:
out = np.zeros_like(np.take(inp, index, axis))
return (inp, index, out)
return (inp, index)

def create_model(self, axis, out):
import torch

class aten_take_along_dim(torch.nn.Module):
def __init__(self, axis, out=False):
super(aten_take_along_dim, self).__init__()
self.axis = axis
if self.axis is None:
self.forward = self.forward_no_dim
if out:
self.forward = self.forward_out if self.axis is not None else self.forward_no_dim_out

def forward(self, x, index):
return torch.take_along_dim(x, index, dim=self.axis)

def forward_out(self, x, index, out):
return torch.take_along_dim(x, index, dim=self.axis, out=out), out

def forward_no_dim(self, x, index):
return torch.take_along_dim(x, index)

def forward_no_dim_out(self, x, index, out):
return torch.take_along_dim(x, index, out=out)

ref_net = None

return aten_take_along_dim(axis, out), ref_net, "aten::take_along_dim"

@pytest.mark.nightly
@pytest.mark.precommit
@pytest.mark.parametrize("m", [2, 10, 100])
@pytest.mark.parametrize("n", [2, 10, 100])
@pytest.mark.parametrize("axis", [0, 1, None])
@pytest.mark.parametrize("out", [True, False])
def test_gather(self, m, n, axis, out, ie_device, precision, ir_version):
self._test(*self.create_model(axis, out), ie_device, precision, ir_version, kwargs_to_prepare_input={
"m": m, "n": n, "max_val": m if axis == 0 else n, "out": out, "flattenize": axis is None
})

0 comments on commit 27bf494

Please sign in to comment.