Skip to content
This repository has been archived by the owner on Jan 3, 2023. It is now read-only.

Commit

Permalink
[ONNX] Add support for negative axes (#3643)
Browse files Browse the repository at this point in the history
  • Loading branch information
Ewa Tusień authored and diyessi committed Sep 20, 2019
1 parent 1a5288a commit 0c181e9
Show file tree
Hide file tree
Showing 16 changed files with 163 additions and 99 deletions.
6 changes: 1 addition & 5 deletions src/ngraph/frontend/onnx_import/op/concat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,8 @@ namespace ngraph
{
NodeVector inputs{node.get_ng_inputs()};
std::int64_t axis = node.get_attribute_value<std::int64_t>("axis");

size_t valid_axis =
common::convert_negative_axis(axis, inputs.at(0)->get_shape().size());

ASSERT_VALID_ARGUMENT(node, valid_axis >= 0)
<< "Incorrect value of axis attribute: " << axis;
common::validate_axis(node, axis, inputs.at(0)->get_shape().size());

return {std::make_shared<ngraph::op::Concat>(inputs, valid_axis)};
}
Expand Down
11 changes: 6 additions & 5 deletions src/ngraph/frontend/onnx_import/op/flatten.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
#include "exceptions.hpp"
#include "flatten.hpp"
#include "ngraph/builder/reshape.hpp"

#include "utils/common.hpp"
namespace ngraph
{
namespace onnx_import
Expand All @@ -33,11 +33,12 @@ namespace ngraph
NodeVector inputs{node.get_ng_inputs()};
auto data = inputs.at(0);
auto axis = node.get_attribute_value<std::int64_t>("axis", 1);
auto data_rank = data->get_shape().size();
// Accepted range is [-r, r] where r = rank(input).
auto valid_axis =
common::validate_axis(node, axis, data_rank, -data_rank, data_rank);

ASSERT_VALID_ARGUMENT(node, (axis >= 0) && (axis <= data->get_shape().size()))
<< "provided 'axis' attribute is not valid.";

return {ngraph::builder::flatten(data, axis)};
return {ngraph::builder::flatten(data, valid_axis)};
}

} // namespace set_1
Expand Down
8 changes: 3 additions & 5 deletions src/ngraph/frontend/onnx_import/op/gather.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "core/node.hpp"
#include "ngraph/node.hpp"
#include "ngraph/op/gather.hpp"
#include "utils/common.hpp"

namespace ngraph
{
Expand All @@ -34,12 +35,9 @@ namespace ngraph
auto data = ng_inputs.at(0);
auto indices = ng_inputs.at(1);
auto axis = node.get_attribute_value<int64_t>("axis", 0);
if (axis < 0)
{
axis += data->get_shape().size();
}
auto valid_axis = common::validate_axis(node, axis, data->get_shape().size());

return {std::make_shared<ngraph::op::Gather>(data, indices, axis)};
return {std::make_shared<ngraph::op::Gather>(data, indices, valid_axis)};
}

} // namespace set_1
Expand Down
6 changes: 2 additions & 4 deletions src/ngraph/frontend/onnx_import/op/hardmax.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,10 @@ namespace ngraph
const auto& input_shape = input->get_shape();
auto axis = node.get_attribute_value<std::int64_t>("axis", 1);

ASSERT_VALID_ARGUMENT(node, axis >= 0 && axis < input_shape.size())
<< "The provided axis value " << axis
<< " does not match the input tensor dimensions";
auto valid_axis = common::validate_axis(node, axis, input_shape.size());

// reshape to 2D - "batch size" x "input feature dimensions" (NxD)
const auto coerced_tensor = ngraph::builder::flatten(input, axis);
const auto coerced_tensor = ngraph::builder::flatten(input, valid_axis);
const auto& coerced_shape = coerced_tensor->get_shape();

const std::shared_ptr<ngraph::Node> argmax_2d =
Expand Down
10 changes: 4 additions & 6 deletions src/ngraph/frontend/onnx_import/op/lp_norm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "ngraph/builder/norm.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/divide.hpp"
#include "utils/common.hpp"

namespace ngraph
{
Expand All @@ -37,17 +38,14 @@ namespace ngraph
const std::shared_ptr<ngraph::Node> data{node.get_ng_inputs().at(0)};
std::int64_t axis{node.get_attribute_value<std::int64_t>("axis", -1)};
const std::int64_t p_norm{node.get_attribute_value<std::int64_t>("p", 2)};

if (axis < 0)
{
axis += data->get_shape().size();
}
std::size_t valid_axis =
common::validate_axis(node, axis, data->get_shape().size());

ASSERT_VALID_ARGUMENT(node, p_norm == 1 || p_norm == 2)
<< "Invalid `p` attribute value: " << p_norm
<< "Only normalization of 1st or 2nd order is supported.";

const AxisSet reduction_axes{static_cast<std::size_t>(axis)};
const AxisSet reduction_axes{valid_axis};
std::shared_ptr<ngraph::Node> norm = ngraph::builder::lp_norm(
data, reduction_axes, static_cast<std::size_t>(p_norm));
norm = std::make_shared<ngraph::op::Broadcast>(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "mean_variance_normalization.hpp"
#include "ngraph/axis_set.hpp"
#include "ngraph/op/fused/mvn.hpp"
#include "utils/common.hpp"

namespace ngraph
{
Expand Down Expand Up @@ -47,9 +48,11 @@ namespace ngraph
NodeVector mean_variance_normalization(const Node& node)
{
auto data = node.get_ng_inputs().at(0);
auto axes = node.get_attribute_value<std::vector<size_t>>("axes", {0, 2, 3});
auto axes = node.get_attribute_value<std::vector<int64_t>>("axes", {0, 2, 3});
std::vector<std::size_t> valid_axes =
common::validate_axes(node, axes, data->get_shape().size());

return {std::make_shared<ngraph::op::MVN>(data, AxisSet(axes))};
return {std::make_shared<ngraph::op::MVN>(data, AxisSet(valid_axes))};
}

} // namespace set_9
Expand Down
21 changes: 11 additions & 10 deletions src/ngraph/frontend/onnx_import/op/onehot.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include "ngraph/op/subtract.hpp"
#include "ngraph/op/util/broadcasting.hpp"
#include "onehot.hpp"
#include "utils/common.hpp"

namespace ngraph
{
Expand All @@ -51,14 +52,13 @@ namespace ngraph
std::make_shared<ngraph::op::Slice>(values, Coordinate{1}, Coordinate{2});
auto axis = node.get_attribute_value<std::int64_t>("axis", -1);

if (axis < 0)
{
axis += indices_shape.size() + 1;
}

ASSERT_VALID_ARGUMENT(node, (axis >= 0) && (axis <= indices_shape.size()))
<< "invalid 'axis' attribute: "
<< node.get_attribute_value<std::int64_t>("axis", -1);
// Accepted range for axis is [-r-1, r] where r = rank(indices). Validate
// against rank+1.
std::size_t valid_axis = common::validate_axis(node,
axis,
indices_shape.size() + 1,
-indices_shape.size() - 1,
indices_shape.size());

auto constant_depth = std::dynamic_pointer_cast<ngraph::op::Constant>(depth);

Expand All @@ -74,10 +74,11 @@ namespace ngraph
// axis = 1
// depth = 10
// output_shape = (2, 10, 2)
output_shape.insert(std::next(std::begin(output_shape), axis), depth_value);
output_shape.insert(std::next(std::begin(output_shape), valid_axis),
depth_value);

std::shared_ptr<ngraph::Node> one_hot = std::make_shared<ngraph::op::Convert>(
std::make_shared<ngraph::op::OneHot>(indices, output_shape, axis),
std::make_shared<ngraph::op::OneHot>(indices, output_shape, valid_axis),
values->get_element_type());
auto broadcasted_values =
ngraph::op::numpy_style_broadcast({one_hot, on_value, off_value});
Expand Down
13 changes: 9 additions & 4 deletions src/ngraph/frontend/onnx_import/op/reverse_sequence.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "ngraph/op/convert.hpp"
#include "ngraph/op/reverse_sequence.hpp"
#include "ngraph/type/element_type.hpp"
#include "utils/common.hpp"

namespace ngraph
{
Expand All @@ -40,22 +41,26 @@ namespace ngraph
node.get_ng_inputs().at(1), element::i32);

const auto batch_axis = node.get_attribute_value<int64_t>("batch_axis", 1);
std::size_t valid_batch_axis =
common::validate_axis(node, batch_axis, data->get_shape().size());
const auto time_axis = node.get_attribute_value<int64_t>("time_axis", 0);
std::size_t valid_time_axis =
common::validate_axis(node, time_axis, data->get_shape().size());

NGRAPH_CHECK(batch_axis == 0 || batch_axis == 1,
NGRAPH_CHECK(valid_batch_axis == 0 || valid_batch_axis == 1,
"Allowed values of the 'batch_axis' attribute for ReverseSequence "
"operator are 0 and 1");

NGRAPH_CHECK(time_axis == 0 || time_axis == 1,
NGRAPH_CHECK(valid_time_axis == 0 || valid_time_axis == 1,
"Allowed values of the 'time_axis' attribute for ReverseSequence "
"operator are 0 and 1");

NGRAPH_CHECK(batch_axis != time_axis,
NGRAPH_CHECK(valid_batch_axis != valid_time_axis,
"'batch_axis' and 'time_axis' attributes of the ReverseSequence "
"operator can't point to the same dimension");

return {std::make_shared<ngraph::op::ReverseSequence>(
data, sequence_lengths_i32, batch_axis, time_axis)};
data, sequence_lengths_i32, valid_batch_axis, valid_time_axis)};
}

} // namespace set_1
Expand Down
16 changes: 4 additions & 12 deletions src/ngraph/frontend/onnx_import/op/softmax.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "exceptions.hpp"
#include "ngraph/op/softmax.hpp"
#include "softmax.hpp"
#include "utils/common.hpp"

namespace ngraph
{
Expand All @@ -35,22 +36,13 @@ namespace ngraph
auto data_shape = data->get_shape();

int axis = node.get_attribute_value<int64_t>("axis", 1);

if (axis < 0)
{
axis = data_shape.size() + axis;
}

ASSERT_VALID_ARGUMENT(node, axis < data_shape.size())
<< "provided 'axis' value:" << axis
<< " is out of input tensor dimensions range.";
std::size_t valid_axis = common::validate_axis(node, axis, data_shape.size());

// create vector of capacity data_dimensions - axis_divider position
std::vector<size_t> axes(data_shape.size() - axis);
std::iota(std::begin(axes), std::end(axes), axis);
std::vector<size_t> axes(data_shape.size() - valid_axis);
std::iota(std::begin(axes), std::end(axes), valid_axis);
return {std::make_shared<ngraph::op::Softmax>(data, axes)};
}

} // namespace set_1

} // namespace op
Expand Down
7 changes: 5 additions & 2 deletions src/ngraph/frontend/onnx_import/op/split.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

#include "ngraph/op/fused/split.hpp"
#include "op/split.hpp"
#include "utils/common.hpp"

namespace ngraph
{
Expand All @@ -33,13 +34,15 @@ namespace ngraph
const auto input = node.get_ng_inputs().at(0);
const auto outputs_number = node.get_output_names().size();
const auto axis = node.get_attribute_value<int64_t>("axis", 0);
std::size_t valid_axis =
common::validate_axis(node, axis, input->get_shape().size());

try
{
const auto length_parts =
node.get_attribute_value<std::vector<std::size_t>>("split");
const auto fused_split =
std::make_shared<ngraph::op::Split>(input, axis, length_parts);
std::make_shared<ngraph::op::Split>(input, valid_axis, length_parts);

return fused_split->decompose_op();
}
Expand All @@ -49,7 +52,7 @@ namespace ngraph
// the 'split' attribute - this means we should split the input tensor
// into same-length parts equal to the number of node outputs
const auto fused_split =
std::make_shared<ngraph::op::Split>(input, axis, outputs_number);
std::make_shared<ngraph::op::Split>(input, valid_axis, outputs_number);

return fused_split->decompose_op();
}
Expand Down
16 changes: 6 additions & 10 deletions src/ngraph/frontend/onnx_import/op/squeeze.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "ngraph/op/constant.hpp"
#include "ngraph/op/fused/squeeze.hpp"
#include "squeeze.hpp"
#include "utils/common.hpp"

namespace ngraph
{
Expand All @@ -32,17 +33,12 @@ namespace ngraph
NodeVector squeeze(const Node& node)
{
auto data = node.get_ng_inputs().at(0);
auto axes = node.get_attribute_value<std::vector<std::int64_t>>("axes", {});

for (auto axis : axes)
{
ASSERT_VALID_ARGUMENT(node, axis >= 0)
<< "provided axes attribute is invalid. Only non-negative "
<< "integers are allowed, got " << axis << ".";
}

std::vector<std::int64_t> axes =
node.get_attribute_value<std::vector<std::int64_t>>("axes", {});
std::vector<std::size_t> valid_axes =
common::validate_axes(node, axes, data->get_shape().size());
auto axes_node = std::make_shared<ngraph::op::Constant>(
element::u64, Shape{axes.size()}, axes);
element::u64, Shape{valid_axes.size()}, valid_axes);
return {std::make_shared<ngraph::op::Squeeze>(data, axes_node)};
}

Expand Down
14 changes: 4 additions & 10 deletions src/ngraph/frontend/onnx_import/op/topk.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "ngraph/op/topk.hpp"
#include "ngraph/type/element_type.hpp"
#include "topk.hpp"
#include "utils/common.hpp"

namespace ngraph
{
Expand All @@ -35,21 +36,14 @@ namespace ngraph
NodeVector topk(const Node& node)
{
auto data = node.get_ng_inputs().at(0);
std::int64_t axis{node.get_attribute_value<std::int64_t>("axis", -1)};
std::int64_t k{node.get_attribute_value<std::int64_t>("k")};

auto num_dimensions = data->get_shape().size();

if (axis < 0)
{
axis += num_dimensions;
}

ASSERT_VALID_ARGUMENT(node, axis < num_dimensions)
<< "`axis` parameter is out of range: " << axis;
std::int64_t axis{node.get_attribute_value<std::int64_t>("axis", -1)};
std::int64_t valid_axis = common::validate_axis(node, axis, num_dimensions);

std::shared_ptr<ngraph::Node> top_k =
std::make_shared<ngraph::op::TopK>(data, axis, element::i64, k);
std::make_shared<ngraph::op::TopK>(data, valid_axis, element::i64, k);

std::shared_ptr<ngraph::Node> indices =
std::make_shared<ngraph::op::GetOutputElement>(top_k, 0);
Expand Down
47 changes: 47 additions & 0 deletions src/ngraph/frontend/onnx_import/utils/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,53 @@ namespace ngraph
static_cast<onnx::TensorProto_DataType>(onnx_type)));
}

std::size_t validate_axis(const ngraph::onnx_import::Node& node,
std::int64_t axis,
std::int64_t tensor_rank)
{
// Accepted range of value for axis is [-tensor_rank, tensor_rank-1].
return validate_axis(node, axis, tensor_rank, -tensor_rank, tensor_rank - 1);
}

std::size_t validate_axis(const ngraph::onnx_import::Node& node,
std::int64_t axis,
std::int64_t tensor_rank,
std::int64_t axis_range_min,
std::int64_t axis_range_max)
{
// Accepted range of value for axis is [axis_range_min, axis_range_max].
NGRAPH_CHECK(((axis >= axis_range_min) && (axis <= axis_range_max)),
node.get_description(),
"Parameter axis ",
axis,
" out of the tensor rank [-",
axis_range_min,
", ",
axis_range_max,
"].");

if (axis < 0)
{
axis = axis + tensor_rank;
}

return static_cast<size_t>(axis);
}

std::vector<std::size_t> validate_axes(const ngraph::onnx_import::Node& node,
std::vector<std::int64_t> axes,
std::int64_t tensor_rank)
{
std::vector<std::size_t> new_axes;

for (auto a : axes)
{
new_axes.push_back(validate_axis(node, a, tensor_rank));
}

return new_axes;
}

} // namespace common
} // namespace onnx_import
} // namespace ngraph
Loading

0 comments on commit 0c181e9

Please sign in to comment.