diff --git a/src/ngraph/op/convolution.cpp b/src/ngraph/op/convolution.cpp index d7121d4b121..c36f901c405 100644 --- a/src/ngraph/op/convolution.cpp +++ b/src/ngraph/op/convolution.cpp @@ -205,13 +205,28 @@ op::v1::ConvolutionBackpropData::ConvolutionBackpropData(const Output& dat const PartialShape op::v1::ConvolutionBackpropData::get_output_shape() const { - PartialShape shape{PartialShape::dynamic()}; + PartialShape shape{vector(m_strides.size() + 2)}; + auto data_pshape = get_input_partial_shape(0); + if (data_pshape.rank().is_static()) + { + shape[0] = data_pshape[0]; // N + } + auto filters_pshape = get_input_partial_shape(1); + if (filters_pshape.rank().is_static()) + { + shape[1] = filters_pshape[1]; // C + } bool is_output_shape_present = get_inputs().size() == 3; if (is_output_shape_present) { if (auto const_op = as_type(input_value(2).get_node())) { - shape = const_op->get_shape_val(); + auto output_shape = const_op->get_shape_val(); + // Populate spatials + for (int i = 0; i < output_shape.size(); ++i) + { + shape[i + 2] = output_shape[i]; + } } } return shape; @@ -270,13 +285,6 @@ void op::v1::ConvolutionBackpropData::validate_and_infer_types() if (is_output_shape_present) { set_input_is_relevant_to_shape(2); - if (output_pshape.is_static() && data_pshape.is_static()) - { - auto data_shape = data_pshape.to_shape(); - auto output_shape = output_pshape.to_shape(); - output_shape.insert(output_shape.begin(), data_shape.begin(), data_shape.begin() + 1); - output_pshape = output_shape; - } } else { @@ -295,12 +303,13 @@ void op::v1::ConvolutionBackpropData::validate_and_infer_types() for (size_t i = 0; i < data_spatial_rank; ++i) { size_t tmp = m_strides[i] * (data_shape[i + 2] - 1) + - ((filters_shape[i] + 2 - 1) * m_dilations[i] + 1) - m_pads_begin[i] - + ((filters_shape[i + 2] - 1) * m_dilations[i] + 1) - m_pads_begin[i] - m_pads_end[i] + output_padding[i]; output_shape.push_back(tmp); - output_pshape = output_shape; } - output_shape.insert(output_shape.begin(), data_shape.begin(), data_shape.begin() + 1); + output_shape.insert(output_shape.begin(), filters_shape.at(1)); + output_shape.insert(output_shape.begin(), data_shape.at(0)); + output_pshape = output_shape; } } diff --git a/src/ngraph/op/fused/group_conv.cpp b/src/ngraph/op/fused/group_conv.cpp index 39246022649..38717017382 100644 --- a/src/ngraph/op/fused/group_conv.cpp +++ b/src/ngraph/op/fused/group_conv.cpp @@ -192,13 +192,28 @@ op::v1::GroupConvolutionBackpropData::GroupConvolutionBackpropData( const PartialShape op::v1::GroupConvolutionBackpropData::get_output_shape() const { - PartialShape shape{PartialShape::dynamic()}; + PartialShape shape{vector(m_strides.size() + 2)}; + auto data_pshape = get_input_partial_shape(0); + if (data_pshape.rank().is_static()) + { + shape[0] = data_pshape[0]; // N + } + auto filters_pshape = get_input_partial_shape(1); + if (filters_pshape.rank().is_static()) + { + shape[1] = filters_pshape[1]; // C + } bool is_output_shape_present = get_inputs().size() == 3; if (is_output_shape_present) { if (auto const_op = as_type(input_value(2).get_node())) { - shape = const_op->get_shape_val(); + auto output_shape = const_op->get_shape_val(); + // Populate spatials + for (int i = 0; i < output_shape.size(); ++i) + { + shape[i + 2] = output_shape[i]; + } } } return shape; @@ -257,26 +272,16 @@ void op::v1::GroupConvolutionBackpropData::validate_and_infer_types() if (is_output_shape_present) { set_input_is_relevant_to_shape(2); - if (output_pshape.is_static() && data_pshape.is_static()) - { - auto data_shape = data_pshape.to_shape(); - auto output_shape = output_pshape.to_shape(); - output_shape.insert(output_shape.begin(), data_shape.begin(), data_shape.begin() + 1); - output_pshape = output_shape; - } } else { if (filters_pshape.is_static() && data_pshape.is_static()) { auto filters_shape = filters_pshape.to_shape(); - filters_shape.erase(filters_shape.begin(), - filters_shape.begin() + 3); // remove {G, O, I} auto data_shape = data_pshape.to_shape(); - data_shape.erase(data_shape.begin(), data_shape.begin() + 2); // remove {N, C} Shape output_shape; - auto data_spatial_rank = data_shape.size(); + auto data_spatial_rank = data_shape.size() - 2; auto output_padding = get_output_padding(); if (output_padding.size() == 0) { @@ -284,13 +289,15 @@ void op::v1::GroupConvolutionBackpropData::validate_and_infer_types() } for (size_t i = 0; i < data_spatial_rank; ++i) { - size_t tmp = m_strides[i] * (data_shape[i] - 1) + - ((filters_shape[i] - 1) * m_dilations[i] + 1) - m_pads_begin[i] - + size_t tmp = m_strides[i] * (data_shape[i + 2] - 1) + + ((filters_shape[i + 3] - 1) * m_dilations[i] + 1) - m_pads_begin[i] - m_pads_end[i] + output_padding[i]; output_shape.push_back(tmp); - output_pshape = output_shape; } - output_shape.insert(output_shape.begin(), data_shape.begin(), data_shape.begin() + 1); + output_shape.insert(output_shape.begin(), + filters_shape.at(0) * filters_shape.at(2)); // GROUP * C_OUTPUT + output_shape.insert(output_shape.begin(), data_shape.at(0)); + output_pshape = output_shape; } } diff --git a/src/ngraph/pass/opset0_downgrade.cpp b/src/ngraph/pass/opset0_downgrade.cpp index ecdbacd8e5b..71c2098d6b1 100644 --- a/src/ngraph/pass/opset0_downgrade.cpp +++ b/src/ngraph/pass/opset0_downgrade.cpp @@ -135,11 +135,12 @@ namespace bool op_cast(shared_ptr node) { - auto output_shape = as_type_ptr(node->input_value(2).get_node_shared_ptr()); + auto output_shape_node = + as_type_ptr(node->input_value(2).get_node_shared_ptr()); const auto data_arg = node->input(0).get_source_output(); const auto filters_arg = node->input(1).get_source_output(); const auto strides = node->get_strides(); - NGRAPH_CHECK(output_shape, + NGRAPH_CHECK(output_shape_node, "Unable to convert ConvolutionBackpropData:v1 to ConvolutionBackpropData:v0 " "if output_shape is not constant. Node: ", *node); @@ -155,8 +156,22 @@ namespace "with output padding other than `0`. Node: ", *node); + auto data_pshape = data_arg.get_partial_shape(); + auto filters_pshape = filters_arg.get_partial_shape(); + + NGRAPH_CHECK(data_pshape.rank().is_static() && data_pshape[0].is_static() && + filters_pshape.rank().is_static() && filters_pshape[1].is_static(), + "Unable to convert ConvolutionBackpropData:v1 to ConvolutionBackpropData:v0 " + "if data shape N and filters shape C dimensions are not static. Node: ", + *node); + + // Add N and C dimenstions to output_shape + auto output_shape = output_shape_node->get_shape_val(); + output_shape.insert(output_shape.begin(), static_cast(filters_pshape[1])); + output_shape.insert(output_shape.begin(), static_cast(data_pshape[0])); + auto replacement_node = - make_shared(output_shape->get_shape_val(), + make_shared(output_shape, filters_arg, data_arg, node->get_strides(), diff --git a/src/ngraph/pass/opset1_upgrade.cpp b/src/ngraph/pass/opset1_upgrade.cpp index 391d1bbff40..11a0a14c5b2 100644 --- a/src/ngraph/pass/opset1_upgrade.cpp +++ b/src/ngraph/pass/opset1_upgrade.cpp @@ -179,7 +179,10 @@ namespace auto replacement_node = make_shared( node->input_value(1), // data node->input_value(0), // filters - op::Constant::create(element::i64, Shape{data_batch_shape.size()}, data_batch_shape), + op::Constant::create( + element::i64, + Shape{data_batch_shape.size() - 2}, + vector(data_batch_shape.begin() + 2, data_batch_shape.end())), strides, pads_begin, pads_end, diff --git a/test/backend/convolution.in.cpp b/test/backend/convolution.in.cpp index 2ae3954b22a..771b0c9b2ff 100644 --- a/test/backend/convolution.in.cpp +++ b/test/backend/convolution.in.cpp @@ -175,7 +175,7 @@ NGRAPH_TEST(${BACKEND_NAME}, dyn_convolution_backprop_data) for (int i = 0; i < 2 * 3 * 5 * 5; i++) expected_result.emplace_back(i); - vector shapes = {2, 3, 5, 5}; + vector shapes = {5, 5}; // Create some tensors for input/output auto a = backend->create_tensor(element::f32, shape_delta); diff --git a/test/opset_pass/convolution_opset_pass.cpp b/test/opset_pass/convolution_opset_pass.cpp index acd2998e0a8..8832ec96ca7 100644 --- a/test/opset_pass/convolution_opset_pass.cpp +++ b/test/opset_pass/convolution_opset_pass.cpp @@ -78,7 +78,7 @@ TEST(opset_transform, opset1_convolution_downgrade_pass) TEST(opset_transform, opset1_convolution_backprop_data_downgrade_pass) { - auto data_batch_shape = op::Constant::create(element::i64, Shape{3}, {64, 3, 100}); + auto data_batch_shape = op::Constant::create(element::i64, Shape{1}, {100}); auto filters = make_shared(element::f32, Shape{128, 3, 10}); auto delta = make_shared(element::f32, Shape{64, 128, 96}); auto strides = Strides{1};