Skip to content

Commit

Permalink
fix ov rtti, extend the transformation to support logical reduce, add…
Browse files Browse the repository at this point in the history
…ed a new test
  • Loading branch information
itikhono committed Sep 10, 2024
1 parent fa03ba8 commit 214bcdc
Show file tree
Hide file tree
Showing 7 changed files with 23 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,7 @@ pass::EliminateReduceReshape::EliminateReduceReshape() {
MATCHER_SCOPE(EliminateReduceReshape);
using namespace pass::pattern;
auto axes = wrap_type<ov::op::v0::Constant>();
auto reduce_pattern = wrap_type<ov::op::util::ArithmeticReductionKeepDims>({any_input(), axes});
auto reduce_pattern = wrap_type<ov::op::util::ReductionBase>({any_input(), axes});
auto requested_shape_pattern = wrap_type<ov::op::v0::Constant>();
auto reshape_pattern =
wrap_type<ov::op::v1::Reshape>({reduce_pattern, requested_shape_pattern}, consumers_count(1));
Expand All @@ -361,7 +361,7 @@ pass::EliminateReduceReshape::EliminateReduceReshape() {
auto reduce_node = pattern_map.at(reduce_pattern);

auto reshape = ov::as_type_ptr<ov::op::v1::Reshape>(reshape_node);
auto reduce = ov::as_type_ptr<ov::op::util::ArithmeticReductionKeepDims>(reduce_node);
auto reduce = ov::as_type_ptr<ov::op::util::ReductionBase>(reduce_node);
if (!reshape || !reduce) {
return false;
}
Expand All @@ -387,6 +387,7 @@ pass::EliminateReduceReshape::EliminateReduceReshape() {
(axes.count(i) && requested_shape_vec[i] == 1));
}

// if the number of dyn dims here is equal to 0 or 1, we can unambiguously define output shape
if (cnt_dyn <= 1) {
return replace_output_update_name(reshape->output(0), reshape->input_value(0));
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@ ov::pass::ReduceReshapeFusion::ReduceReshapeFusion() {
const auto reduce = pattern::wrap_type<op::util::ArithmeticReductionKeepDims, op::util::LogicalReductionKeepDims>(
{pattern::any_input(), reduce_axes},
pattern::consumers_count(1));
const auto reshape =
pattern::wrap_type<ov::op::v1::Reshape>({reduce, pattern::any_input()}, pattern::has_static_shape());
const auto reshape = pattern::wrap_type<ov::op::v1::Reshape>({reduce, pattern::any_input()});

matcher_pass_callback callback = [=](pattern::Matcher& m) {
auto& pattern_map = m.get_pattern_value_map();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ TEST(nop_elimination, reshape_elimination_v1_dynamic_negative) {
ASSERT_TRUE(count_ops_of_type<op::v1::Reshape>(f) == 1);
}

TEST(nop_elimination, reshape_elimination_v1_dynamic) {
TEST(nop_elimination, reshape_arithmetical_reduce_elimination_dynamic) {
auto arg = std::make_shared<op::v0::Parameter>(element::i64, PartialShape({-1, 96, 100, 100}));
auto reduce_axes = ov::op::v0::Constant::create(element::i64, Shape{2}, {2, 3});
auto reduce = std::make_shared<op::v1::ReduceMean>(arg, reduce_axes, true);
Expand All @@ -259,6 +259,20 @@ TEST(nop_elimination, reshape_elimination_v1_dynamic) {
ASSERT_TRUE(count_ops_of_type<op::v1::Reshape>(f) == 0);
}

TEST(nop_elimination, reshape_logical_reduce_elimination_dynamic) {
auto arg = std::make_shared<op::v0::Parameter>(element::boolean, PartialShape({-1, 96, 100, 100}));
auto reduce_axes = ov::op::v0::Constant::create(element::i64, Shape{2}, {2, 3});
auto reduce = std::make_shared<op::v1::ReduceLogicalAnd>(arg, reduce_axes, true);
auto pattern = op::v0::Constant::create(element::i64, Shape{4}, {0, 96, 1, 1});
auto reshape_v1 = std::make_shared<op::v1::Reshape>(reduce, pattern, true);
auto nz = std::make_shared<op::v3::NonZero>(reshape_v1);
auto f = std::make_shared<ov::Model>(NodeVector{nz}, ParameterVector{arg});
pass::Manager pass_manager;
pass_manager.register_pass<ov::pass::NopElimination>(false);
pass_manager.run_passes(f);
ASSERT_TRUE(count_ops_of_type<op::v1::Reshape>(f) == 0);
}

TEST(nop_elimination, reshape_elimination_v1_check_consumer_count) {
std::shared_ptr<ov::Model> f;
{
Expand Down
2 changes: 1 addition & 1 deletion src/core/include/openvino/op/util/arithmetic_reduction.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class OPENVINO_API ArithmeticReduction : public ReductionBase {
ArithmeticReduction(const Output<Node>& arg, const Output<Node>& reduction_axes);

public:
OPENVINO_OP("ArithmeticReduction", "util");
OPENVINO_OP("ArithmeticReduction", "util", ReductionBase);
void validate_and_infer_types() override;
};
} // namespace util
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class OPENVINO_API ArithmeticReductionKeepDims : public util::ArithmeticReductio
bool visit_attributes(AttributeVisitor& visitor) override;

public:
OPENVINO_OP("ArithmeticReductionKeepDims", "util");
OPENVINO_OP("ArithmeticReductionKeepDims", "util", util::ArithmeticReduction);
void validate_and_infer_types() override;

/// \return If set to 1 it holds axes that are used for reduction.
Expand Down
2 changes: 1 addition & 1 deletion src/core/include/openvino/op/util/logical_reduction.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class OPENVINO_API LogicalReduction : public ReductionBase {
LogicalReduction(const Output<Node>& arg, const Output<Node>& reduction_axes);

public:
OPENVINO_OP("LogicalReduction", "util");
OPENVINO_OP("LogicalReduction", "util", ReductionBase);
void validate_and_infer_types() override;
};
} // namespace util
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class OPENVINO_API LogicalReductionKeepDims : public util::LogicalReduction {
bool visit_attributes(AttributeVisitor& visitor) override;

public:
OPENVINO_OP("LogicalReductionKeepDims", "util");
OPENVINO_OP("LogicalReductionKeepDims", "util", util::LogicalReduction);
void validate_and_infer_types() override;

/// \return If set to 1 it holds axes that are used for reduction.
Expand Down

0 comments on commit 214bcdc

Please sign in to comment.