From 214bcdc92a664fbe6ab9380c09cdb594f9178823 Mon Sep 17 00:00:00 2001 From: Tikhonov Ivan Date: Tue, 10 Sep 2024 12:25:25 +0400 Subject: [PATCH] fix ov rtti, extend the transformation to support logical reduce, added a new test --- .../common_optimizations/nop_elimination.cpp | 5 +++-- .../reduce_reshape_fusion.cpp | 3 +-- .../common_optimizations/nop_elimination.cpp | 16 +++++++++++++++- .../openvino/op/util/arithmetic_reduction.hpp | 2 +- .../op/util/arithmetic_reductions_keep_dims.hpp | 2 +- .../openvino/op/util/logical_reduction.hpp | 2 +- .../op/util/logical_reduction_keep_dims.hpp | 2 +- 7 files changed, 23 insertions(+), 9 deletions(-) diff --git a/src/common/transformations/src/transformations/common_optimizations/nop_elimination.cpp b/src/common/transformations/src/transformations/common_optimizations/nop_elimination.cpp index f07aefc216b1a1..b5527d8672effe 100644 --- a/src/common/transformations/src/transformations/common_optimizations/nop_elimination.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/nop_elimination.cpp @@ -350,7 +350,7 @@ pass::EliminateReduceReshape::EliminateReduceReshape() { MATCHER_SCOPE(EliminateReduceReshape); using namespace pass::pattern; auto axes = wrap_type(); - auto reduce_pattern = wrap_type({any_input(), axes}); + auto reduce_pattern = wrap_type({any_input(), axes}); auto requested_shape_pattern = wrap_type(); auto reshape_pattern = wrap_type({reduce_pattern, requested_shape_pattern}, consumers_count(1)); @@ -361,7 +361,7 @@ pass::EliminateReduceReshape::EliminateReduceReshape() { auto reduce_node = pattern_map.at(reduce_pattern); auto reshape = ov::as_type_ptr(reshape_node); - auto reduce = ov::as_type_ptr(reduce_node); + auto reduce = ov::as_type_ptr(reduce_node); if (!reshape || !reduce) { return false; } @@ -387,6 +387,7 @@ pass::EliminateReduceReshape::EliminateReduceReshape() { (axes.count(i) && requested_shape_vec[i] == 1)); } + // if the number of dyn dims here is equal to 0 or 1, we can unambiguously define output shape if (cnt_dyn <= 1) { return replace_output_update_name(reshape->output(0), reshape->input_value(0)); } else { diff --git a/src/common/transformations/src/transformations/common_optimizations/reduce_reshape_fusion.cpp b/src/common/transformations/src/transformations/common_optimizations/reduce_reshape_fusion.cpp index b2f8d98e715a30..85342ee6bbd9d8 100644 --- a/src/common/transformations/src/transformations/common_optimizations/reduce_reshape_fusion.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/reduce_reshape_fusion.cpp @@ -25,8 +25,7 @@ ov::pass::ReduceReshapeFusion::ReduceReshapeFusion() { const auto reduce = pattern::wrap_type( {pattern::any_input(), reduce_axes}, pattern::consumers_count(1)); - const auto reshape = - pattern::wrap_type({reduce, pattern::any_input()}, pattern::has_static_shape()); + const auto reshape = pattern::wrap_type({reduce, pattern::any_input()}); matcher_pass_callback callback = [=](pattern::Matcher& m) { auto& pattern_map = m.get_pattern_value_map(); diff --git a/src/common/transformations/tests/common_optimizations/nop_elimination.cpp b/src/common/transformations/tests/common_optimizations/nop_elimination.cpp index f5ebcde9cc18e9..19b5fefd79b9b0 100644 --- a/src/common/transformations/tests/common_optimizations/nop_elimination.cpp +++ b/src/common/transformations/tests/common_optimizations/nop_elimination.cpp @@ -245,7 +245,7 @@ TEST(nop_elimination, reshape_elimination_v1_dynamic_negative) { ASSERT_TRUE(count_ops_of_type(f) == 1); } -TEST(nop_elimination, reshape_elimination_v1_dynamic) { +TEST(nop_elimination, reshape_arithmetical_reduce_elimination_dynamic) { auto arg = std::make_shared(element::i64, PartialShape({-1, 96, 100, 100})); auto reduce_axes = ov::op::v0::Constant::create(element::i64, Shape{2}, {2, 3}); auto reduce = std::make_shared(arg, reduce_axes, true); @@ -259,6 +259,20 @@ TEST(nop_elimination, reshape_elimination_v1_dynamic) { ASSERT_TRUE(count_ops_of_type(f) == 0); } +TEST(nop_elimination, reshape_logical_reduce_elimination_dynamic) { + auto arg = std::make_shared(element::boolean, PartialShape({-1, 96, 100, 100})); + auto reduce_axes = ov::op::v0::Constant::create(element::i64, Shape{2}, {2, 3}); + auto reduce = std::make_shared(arg, reduce_axes, true); + auto pattern = op::v0::Constant::create(element::i64, Shape{4}, {0, 96, 1, 1}); + auto reshape_v1 = std::make_shared(reduce, pattern, true); + auto nz = std::make_shared(reshape_v1); + auto f = std::make_shared(NodeVector{nz}, ParameterVector{arg}); + pass::Manager pass_manager; + pass_manager.register_pass(false); + pass_manager.run_passes(f); + ASSERT_TRUE(count_ops_of_type(f) == 0); +} + TEST(nop_elimination, reshape_elimination_v1_check_consumer_count) { std::shared_ptr f; { diff --git a/src/core/include/openvino/op/util/arithmetic_reduction.hpp b/src/core/include/openvino/op/util/arithmetic_reduction.hpp index 365444418429b8..1def85392ebfc2 100644 --- a/src/core/include/openvino/op/util/arithmetic_reduction.hpp +++ b/src/core/include/openvino/op/util/arithmetic_reduction.hpp @@ -25,7 +25,7 @@ class OPENVINO_API ArithmeticReduction : public ReductionBase { ArithmeticReduction(const Output& arg, const Output& reduction_axes); public: - OPENVINO_OP("ArithmeticReduction", "util"); + OPENVINO_OP("ArithmeticReduction", "util", ReductionBase); void validate_and_infer_types() override; }; } // namespace util diff --git a/src/core/include/openvino/op/util/arithmetic_reductions_keep_dims.hpp b/src/core/include/openvino/op/util/arithmetic_reductions_keep_dims.hpp index 7eca3778f2760b..ced3382787f02d 100644 --- a/src/core/include/openvino/op/util/arithmetic_reductions_keep_dims.hpp +++ b/src/core/include/openvino/op/util/arithmetic_reductions_keep_dims.hpp @@ -21,7 +21,7 @@ class OPENVINO_API ArithmeticReductionKeepDims : public util::ArithmeticReductio bool visit_attributes(AttributeVisitor& visitor) override; public: - OPENVINO_OP("ArithmeticReductionKeepDims", "util"); + OPENVINO_OP("ArithmeticReductionKeepDims", "util", util::ArithmeticReduction); void validate_and_infer_types() override; /// \return If set to 1 it holds axes that are used for reduction. diff --git a/src/core/include/openvino/op/util/logical_reduction.hpp b/src/core/include/openvino/op/util/logical_reduction.hpp index 1059dd922fdc68..747a1af6401ee0 100644 --- a/src/core/include/openvino/op/util/logical_reduction.hpp +++ b/src/core/include/openvino/op/util/logical_reduction.hpp @@ -29,7 +29,7 @@ class OPENVINO_API LogicalReduction : public ReductionBase { LogicalReduction(const Output& arg, const Output& reduction_axes); public: - OPENVINO_OP("LogicalReduction", "util"); + OPENVINO_OP("LogicalReduction", "util", ReductionBase); void validate_and_infer_types() override; }; } // namespace util diff --git a/src/core/include/openvino/op/util/logical_reduction_keep_dims.hpp b/src/core/include/openvino/op/util/logical_reduction_keep_dims.hpp index cfa03130d44031..85ffae909a91f4 100644 --- a/src/core/include/openvino/op/util/logical_reduction_keep_dims.hpp +++ b/src/core/include/openvino/op/util/logical_reduction_keep_dims.hpp @@ -22,7 +22,7 @@ class OPENVINO_API LogicalReductionKeepDims : public util::LogicalReduction { bool visit_attributes(AttributeVisitor& visitor) override; public: - OPENVINO_OP("LogicalReductionKeepDims", "util"); + OPENVINO_OP("LogicalReductionKeepDims", "util", util::LogicalReduction); void validate_and_infer_types() override; /// \return If set to 1 it holds axes that are used for reduction.