forked from openvinotoolkit/openvino
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[PT FE]: support aten::take_along_dim (openvinotoolkit#21625)
- Loading branch information
Showing
3 changed files
with
114 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
// Copyright (C) 2018-2023 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#include "openvino/frontend/pytorch/node_context.hpp" | ||
#include "openvino/op/broadcast.hpp" | ||
#include "openvino/op/constant.hpp" | ||
#include "openvino/op/convert.hpp" | ||
#include "openvino/op/gather_elements.hpp" | ||
#include "openvino/op/reshape.hpp" | ||
#include "openvino/op/scatter_update.hpp" | ||
#include "openvino/op/shape_of.hpp" | ||
#include "utils.hpp" | ||
|
||
namespace ov { | ||
namespace frontend { | ||
namespace pytorch { | ||
namespace op { | ||
|
||
OutputVector translate_take_along_dim(const NodeContext& context) { | ||
// aten::take_along_dim(Tensor self, Tensor indices, int? dim=None) -> Tensor | ||
// aten::take_along_dim.out(Tensor self, Tensor indices, int? dim=None, *, Tensor(a!) out) -> Tensor(a!) | ||
num_inputs_check(context, 3, 4); | ||
auto x = context.get_input(0); | ||
auto index = context.get_input(1); | ||
index = context.mark_node(std::make_shared<ov::op::v0::Convert>(index, element::i32)); | ||
int64_t axis = 0; | ||
|
||
if (context.input_is_none(2)) { | ||
// if dimension is not provided, flattenize input first | ||
auto minus_1 = context.mark_node(ov::op::v0::Constant::create(element::i32, Shape{1}, {-1})); | ||
x = context.mark_node(std::make_shared<ov::op::v1::Reshape>(x, minus_1, false)); | ||
} else { | ||
axis = context.const_input<int64_t>(2); | ||
// OpenVINO GatherElements requires to have equal dims between index and input except specified axis | ||
// while PyTorch requires to have them broadcastable | ||
auto axis_node = context.mark_node(ov::op::v0::Constant::create(element::i32, Shape{1}, {axis})); | ||
auto const_1 = context.mark_node(ov::op::v0::Constant::create(element::i32, Shape{1}, {1})); | ||
auto const_0 = context.mark_node(ov::op::v0::Constant::create(element::i32, Shape{1}, {0})); | ||
auto x_shape = context.mark_node(std::make_shared<ov::op::v3::ShapeOf>(x, element::i32)); | ||
auto broadcast_shape = | ||
context.mark_node(std::make_shared<ov::op::v3::ScatterUpdate>(x_shape, axis_node, const_1, const_0)); | ||
index = context.mark_node( | ||
std::make_shared<ov::op::v3::Broadcast>(index, broadcast_shape, ov::op::BroadcastType::BIDIRECTIONAL)); | ||
} | ||
auto gather_elements = context.mark_node(std::make_shared<ov::op::v6::GatherElements>(x, index, axis)); | ||
if (!context.input_is_none(3)) { | ||
context.mutate_input(3, gather_elements); | ||
} | ||
return {gather_elements}; | ||
}; | ||
|
||
} // namespace op | ||
} // namespace pytorch | ||
} // namespace frontend | ||
} // namespace ov |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
import pytest | ||
|
||
from pytorch_layer_test_class import PytorchLayerTest | ||
|
||
class TestTakeAlongDim(PytorchLayerTest): | ||
def _prepare_input(self, m, n, max_val, out=False, flattenize=False): | ||
import numpy as np | ||
index = np.random.randint(0, max_val, (m, n) if not flattenize else (m*n, )) | ||
inp = np.random.randn(m, n).astype(np.float32) | ||
if out: | ||
axis = int(max_val == n) | ||
if flattenize: | ||
out = np.zeros_like(np.take(inp, index)) | ||
else: | ||
out = np.zeros_like(np.take(inp, index, axis)) | ||
return (inp, index, out) | ||
return (inp, index) | ||
|
||
def create_model(self, axis, out): | ||
import torch | ||
|
||
class aten_take_along_dim(torch.nn.Module): | ||
def __init__(self, axis, out=False): | ||
super(aten_take_along_dim, self).__init__() | ||
self.axis = axis | ||
if self.axis is None: | ||
self.forward = self.forward_no_dim | ||
if out: | ||
self.forward = self.forward_out if self.axis is not None else self.forward_no_dim_out | ||
|
||
def forward(self, x, index): | ||
return torch.take_along_dim(x, index, dim=self.axis) | ||
|
||
def forward_out(self, x, index, out): | ||
return torch.take_along_dim(x, index, dim=self.axis, out=out), out | ||
|
||
def forward_no_dim(self, x, index): | ||
return torch.take_along_dim(x, index) | ||
|
||
def forward_no_dim_out(self, x, index, out): | ||
return torch.take_along_dim(x, index, out=out) | ||
|
||
ref_net = None | ||
|
||
return aten_take_along_dim(axis, out), ref_net, "aten::take_along_dim" | ||
|
||
@pytest.mark.nightly | ||
@pytest.mark.precommit | ||
@pytest.mark.parametrize("m", [2, 10, 100]) | ||
@pytest.mark.parametrize("n", [2, 10, 100]) | ||
@pytest.mark.parametrize("axis", [0, 1, None]) | ||
@pytest.mark.parametrize("out", [True, False]) | ||
def test_gather(self, m, n, axis, out, ie_device, precision, ir_version): | ||
self._test(*self.create_model(axis, out), ie_device, precision, ir_version, kwargs_to_prepare_input={ | ||
"m": m, "n": n, "max_val": m if axis == 0 else n, "out": out, "flattenize": axis is None | ||
}) |