Skip to content
This repository has been archived by the owner on Jan 3, 2023. It is now read-only.

Commit

Permalink
[SPEC] Fix output_shape input in (Group)ConvolutionBackpropData ops (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
tsocha authored and postrational committed Dec 6, 2019
1 parent 206bc65 commit eb2ce8e
Show file tree
Hide file tree
Showing 6 changed files with 69 additions and 35 deletions.
33 changes: 21 additions & 12 deletions src/ngraph/op/convolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -205,13 +205,28 @@ op::v1::ConvolutionBackpropData::ConvolutionBackpropData(const Output<Node>& dat

const PartialShape op::v1::ConvolutionBackpropData::get_output_shape() const
{
PartialShape shape{PartialShape::dynamic()};
PartialShape shape{vector<Dimension>(m_strides.size() + 2)};
auto data_pshape = get_input_partial_shape(0);
if (data_pshape.rank().is_static())
{
shape[0] = data_pshape[0]; // N
}
auto filters_pshape = get_input_partial_shape(1);
if (filters_pshape.rank().is_static())
{
shape[1] = filters_pshape[1]; // C
}
bool is_output_shape_present = get_inputs().size() == 3;
if (is_output_shape_present)
{
if (auto const_op = as_type<op::Constant>(input_value(2).get_node()))
{
shape = const_op->get_shape_val();
auto output_shape = const_op->get_shape_val();
// Populate spatials
for (int i = 0; i < output_shape.size(); ++i)
{
shape[i + 2] = output_shape[i];
}
}
}
return shape;
Expand Down Expand Up @@ -270,13 +285,6 @@ void op::v1::ConvolutionBackpropData::validate_and_infer_types()
if (is_output_shape_present)
{
set_input_is_relevant_to_shape(2);
if (output_pshape.is_static() && data_pshape.is_static())
{
auto data_shape = data_pshape.to_shape();
auto output_shape = output_pshape.to_shape();
output_shape.insert(output_shape.begin(), data_shape.begin(), data_shape.begin() + 1);
output_pshape = output_shape;
}
}
else
{
Expand All @@ -295,12 +303,13 @@ void op::v1::ConvolutionBackpropData::validate_and_infer_types()
for (size_t i = 0; i < data_spatial_rank; ++i)
{
size_t tmp = m_strides[i] * (data_shape[i + 2] - 1) +
((filters_shape[i] + 2 - 1) * m_dilations[i] + 1) - m_pads_begin[i] -
((filters_shape[i + 2] - 1) * m_dilations[i] + 1) - m_pads_begin[i] -
m_pads_end[i] + output_padding[i];
output_shape.push_back(tmp);
output_pshape = output_shape;
}
output_shape.insert(output_shape.begin(), data_shape.begin(), data_shape.begin() + 1);
output_shape.insert(output_shape.begin(), filters_shape.at(1));
output_shape.insert(output_shape.begin(), data_shape.at(0));
output_pshape = output_shape;
}
}

Expand Down
41 changes: 24 additions & 17 deletions src/ngraph/op/fused/group_conv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -192,13 +192,28 @@ op::v1::GroupConvolutionBackpropData::GroupConvolutionBackpropData(

const PartialShape op::v1::GroupConvolutionBackpropData::get_output_shape() const
{
PartialShape shape{PartialShape::dynamic()};
PartialShape shape{vector<Dimension>(m_strides.size() + 2)};
auto data_pshape = get_input_partial_shape(0);
if (data_pshape.rank().is_static())
{
shape[0] = data_pshape[0]; // N
}
auto filters_pshape = get_input_partial_shape(1);
if (filters_pshape.rank().is_static())
{
shape[1] = filters_pshape[1]; // C
}
bool is_output_shape_present = get_inputs().size() == 3;
if (is_output_shape_present)
{
if (auto const_op = as_type<op::Constant>(input_value(2).get_node()))
{
shape = const_op->get_shape_val();
auto output_shape = const_op->get_shape_val();
// Populate spatials
for (int i = 0; i < output_shape.size(); ++i)
{
shape[i + 2] = output_shape[i];
}
}
}
return shape;
Expand Down Expand Up @@ -257,40 +272,32 @@ void op::v1::GroupConvolutionBackpropData::validate_and_infer_types()
if (is_output_shape_present)
{
set_input_is_relevant_to_shape(2);
if (output_pshape.is_static() && data_pshape.is_static())
{
auto data_shape = data_pshape.to_shape();
auto output_shape = output_pshape.to_shape();
output_shape.insert(output_shape.begin(), data_shape.begin(), data_shape.begin() + 1);
output_pshape = output_shape;
}
}
else
{
if (filters_pshape.is_static() && data_pshape.is_static())
{
auto filters_shape = filters_pshape.to_shape();
filters_shape.erase(filters_shape.begin(),
filters_shape.begin() + 3); // remove {G, O, I}
auto data_shape = data_pshape.to_shape();
data_shape.erase(data_shape.begin(), data_shape.begin() + 2); // remove {N, C}

Shape output_shape;
auto data_spatial_rank = data_shape.size();
auto data_spatial_rank = data_shape.size() - 2;
auto output_padding = get_output_padding();
if (output_padding.size() == 0)
{
output_padding.insert(output_padding.begin(), data_spatial_rank, 0);
}
for (size_t i = 0; i < data_spatial_rank; ++i)
{
size_t tmp = m_strides[i] * (data_shape[i] - 1) +
((filters_shape[i] - 1) * m_dilations[i] + 1) - m_pads_begin[i] -
size_t tmp = m_strides[i] * (data_shape[i + 2] - 1) +
((filters_shape[i + 3] - 1) * m_dilations[i] + 1) - m_pads_begin[i] -
m_pads_end[i] + output_padding[i];
output_shape.push_back(tmp);
output_pshape = output_shape;
}
output_shape.insert(output_shape.begin(), data_shape.begin(), data_shape.begin() + 1);
output_shape.insert(output_shape.begin(),
filters_shape.at(0) * filters_shape.at(2)); // GROUP * C_OUTPUT
output_shape.insert(output_shape.begin(), data_shape.at(0));
output_pshape = output_shape;
}
}

Expand Down
21 changes: 18 additions & 3 deletions src/ngraph/pass/opset0_downgrade.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,11 +135,12 @@ namespace

bool op_cast(shared_ptr<op::v1::ConvolutionBackpropData> node)
{
auto output_shape = as_type_ptr<op::Constant>(node->input_value(2).get_node_shared_ptr());
auto output_shape_node =
as_type_ptr<op::Constant>(node->input_value(2).get_node_shared_ptr());
const auto data_arg = node->input(0).get_source_output();
const auto filters_arg = node->input(1).get_source_output();
const auto strides = node->get_strides();
NGRAPH_CHECK(output_shape,
NGRAPH_CHECK(output_shape_node,
"Unable to convert ConvolutionBackpropData:v1 to ConvolutionBackpropData:v0 "
"if output_shape is not constant. Node: ",
*node);
Expand All @@ -155,8 +156,22 @@ namespace
"with output padding other than `0`. Node: ",
*node);

auto data_pshape = data_arg.get_partial_shape();
auto filters_pshape = filters_arg.get_partial_shape();

NGRAPH_CHECK(data_pshape.rank().is_static() && data_pshape[0].is_static() &&
filters_pshape.rank().is_static() && filters_pshape[1].is_static(),
"Unable to convert ConvolutionBackpropData:v1 to ConvolutionBackpropData:v0 "
"if data shape N and filters shape C dimensions are not static. Node: ",
*node);

// Add N and C dimenstions to output_shape
auto output_shape = output_shape_node->get_shape_val();
output_shape.insert(output_shape.begin(), static_cast<size_t>(filters_pshape[1]));
output_shape.insert(output_shape.begin(), static_cast<size_t>(data_pshape[0]));

auto replacement_node =
make_shared<op::v0::ConvolutionBackpropData>(output_shape->get_shape_val(),
make_shared<op::v0::ConvolutionBackpropData>(output_shape,
filters_arg,
data_arg,
node->get_strides(),
Expand Down
5 changes: 4 additions & 1 deletion src/ngraph/pass/opset1_upgrade.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,10 @@ namespace
auto replacement_node = make_shared<op::v1::ConvolutionBackpropData>(
node->input_value(1), // data
node->input_value(0), // filters
op::Constant::create(element::i64, Shape{data_batch_shape.size()}, data_batch_shape),
op::Constant::create(
element::i64,
Shape{data_batch_shape.size() - 2},
vector<size_t>(data_batch_shape.begin() + 2, data_batch_shape.end())),
strides,
pads_begin,
pads_end,
Expand Down
2 changes: 1 addition & 1 deletion test/backend/convolution.in.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ NGRAPH_TEST(${BACKEND_NAME}, dyn_convolution_backprop_data)
for (int i = 0; i < 2 * 3 * 5 * 5; i++)
expected_result.emplace_back(i);

vector<int64_t> shapes = {2, 3, 5, 5};
vector<int64_t> shapes = {5, 5};

// Create some tensors for input/output
auto a = backend->create_tensor(element::f32, shape_delta);
Expand Down
2 changes: 1 addition & 1 deletion test/opset_pass/convolution_opset_pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ TEST(opset_transform, opset1_convolution_downgrade_pass)

TEST(opset_transform, opset1_convolution_backprop_data_downgrade_pass)
{
auto data_batch_shape = op::Constant::create<int64_t>(element::i64, Shape{3}, {64, 3, 100});
auto data_batch_shape = op::Constant::create<int64_t>(element::i64, Shape{1}, {100});
auto filters = make_shared<op::Parameter>(element::f32, Shape{128, 3, 10});
auto delta = make_shared<op::Parameter>(element::f32, Shape{64, 128, 96});
auto strides = Strides{1};
Expand Down

0 comments on commit eb2ce8e

Please sign in to comment.