Skip to content

Commit

Permalink
[op]: S[SearchSorted]: Added new op definition with tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
pkowalc1 committed Oct 3, 2024
1 parent 187f1c4 commit ccb2134
Show file tree
Hide file tree
Showing 5 changed files with 232 additions and 0 deletions.
1 change: 1 addition & 0 deletions src/core/include/openvino/op/ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@
#include "openvino/op/scatter_elements_update.hpp"
#include "openvino/op/scatter_nd_update.hpp"
#include "openvino/op/scatter_update.hpp"
#include "openvino/op/search_sorted.hpp"
#include "openvino/op/select.hpp"
#include "openvino/op/selu.hpp"
#include "openvino/op/shape_of.hpp"
Expand Down
46 changes: 46 additions & 0 deletions src/core/include/openvino/op/search_sorted.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "openvino/op/op.hpp"

namespace ov {
namespace op {
namespace v15 {
/// \brief SearchSorted operation.
///
/// \ingroup ov_ops_cpp_api
class OPENVINO_API SearchSorted : public Op {
public:
OPENVINO_OP("SearchSorted", "opset15", Op);

SearchSorted() = default;
/// \brief Constructs a SearchSorted operation.
/// \param sorted_sequence Sorted sequence to search in.
/// \param values Values to search indexs for.
/// \param right_mode If False, return the first suitable index that is found for given value. If True, return
/// the last such index.
SearchSorted(const Output<Node>& sorted_sequence, const Output<Node>& values, bool right_mode = false);

void validate_and_infer_types() override;
bool visit_attributes(AttributeVisitor& visitor) override;
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;

bool get_right_mode() const {
return m_right_mode;
}

void set_right_mode(bool right_mode) {
m_right_mode = right_mode;
}

private:
void validate();
void infer_type();
bool m_right_mode;
};
} // namespace v15
} // namespace op
} // namespace ov
33 changes: 33 additions & 0 deletions src/core/shape_inference/include/search_sorted_shape_inference.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "openvino/op/search_sorted.hpp"
#include "utils.hpp"

namespace ov {
namespace op {
namespace v15 {
template <class TShape, class TRShape = result_shape_t<TShape>>
std::vector<TRShape> shape_infer(const SearchSorted* op, const std::vector<TShape>& input_shapes) {
const auto& sorted_shape = input_shapes[0];
const auto& values_shape = input_shapes[1];
auto output_shape = values_shape;
TShape::merge_into(output_shape, sorted_shape);

if (output_shape.rank().is_static()) {
auto last_it = output_shape.end() - 1;
if (values_shape.rank().is_static()) {
*last_it = *(input_shapes[1].end() - 1);
} else {
*last_it = Dimension::dynamic();
}
}

return {output_shape};
}
} // namespace v15
} // namespace op
} // namespace ov
68 changes: 68 additions & 0 deletions src/core/src/op/search_sorted.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include <openvino/op/search_sorted.hpp>

#include "itt.hpp"
#include "openvino/core/validation_util.hpp"
#include "search_sorted_shape_inference.hpp"

namespace ov {
namespace op {
namespace v15 {

SearchSorted::SearchSorted(const Output<Node>& sorted_sequence, const Output<Node>& values, bool right_mode)
: Op({sorted_sequence, values}),
m_right_mode(right_mode) {
constructor_validate_and_infer_types();
}

void SearchSorted::infer_type() {
const auto& output_shapes = shape_infer(this, ov::util::get_node_input_partial_shapes(*this));
set_output_type(0, ov::element::i64, output_shapes[0]);
}

void SearchSorted::validate() {
NODE_VALIDATION_CHECK(this,
get_input_element_type(0) == get_input_element_type(1),
"Sorted sequence and values must have the same element type.");

const auto& sorted_shape = get_input_partial_shape(0);
const auto& values_shape = get_input_partial_shape(1);

if (sorted_shape.rank().is_static() && values_shape.rank().is_static()) {
NODE_VALIDATION_CHECK(this,
sorted_shape.rank().get_length() == values_shape.rank().get_length(),
"Sorted sequence and values have different ranks.");

for (int64_t i = 0; i < sorted_shape.rank().get_length() - 1; ++i) {
NODE_VALIDATION_CHECK(this,
sorted_shape[i].compatible(values_shape[i]),
"Sorted sequence and values has different ",
i,
" dimension.");
}
}
}

void SearchSorted::validate_and_infer_types() {
OV_OP_SCOPE(v15_SearchSorted_validate_and_infer_types);
validate();
infer_type();
}

bool SearchSorted::visit_attributes(AttributeVisitor& visitor) {
OV_OP_SCOPE(v15_SearchSorted_visit_attributes);
visitor.on_attribute("right_mode", m_right_mode);
return true;
}

std::shared_ptr<Node> SearchSorted::clone_with_new_inputs(const OutputVector& new_args) const {
OV_OP_SCOPE(v15_SearchSorted_clone_with_new_inputs);
check_new_args_count(this, new_args);
return std::make_shared<SearchSorted>(new_args.at(0), new_args.at(1), get_right_mode());
}
} // namespace v15
} // namespace op
} // namespace ov
84 changes: 84 additions & 0 deletions src/core/tests/type_prop/search_sorted.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "openvino/op/search_sorted.hpp"

#include "common_test_utils/type_prop.hpp"

using namespace std;
using namespace ov;

#define EXPECT_THROW_SUBSTRING(STATEMENT, SUBSTRING) \
try { \
STATEMENT; \
FAIL() << "Exception not thrown"; \
} catch (const NodeValidationFailure& error) { \
EXPECT_THAT(error.what(), testing::HasSubstr(SUBSTRING)); \
} catch (...) { \
FAIL() << "Unexpected exception thrown"; \
}

static void PerformShapeTest(const PartialShape& sorted_shape,
const PartialShape& values_shape,
const PartialShape& expected_output_shape) {
auto sorted = make_shared<op::v0::Parameter>(element::i32, sorted_shape);
auto values = make_shared<op::v0::Parameter>(element::i32, values_shape);
auto search_sorted_op = make_shared<op::v15::SearchSorted>(sorted, values);
EXPECT_EQ(search_sorted_op->get_element_type(), element::i64);
EXPECT_EQ(search_sorted_op->get_output_partial_shape(0), expected_output_shape);
}

TEST(type_prop, search_sorted_shape_infer_equal_inputs) {
PerformShapeTest({1, 3, 6}, {1, 3, 6}, {1, 3, 6});
}

TEST(type_prop, search_sorted_shape_infer_sorted_dynamic) {
PerformShapeTest(PartialShape::dynamic(), {1, 3, 6}, {1, 3, 6});
}

TEST(type_prop, search_sorted_shape_infer_values_dynamic) {
PerformShapeTest({1, 3, 7, 5}, PartialShape::dynamic(), {1, 3, 7, -1});
}

TEST(type_prop, search_sorted_shape_infer_different_last_dim) {
PerformShapeTest({1, 3, 7, 100}, {1, 3, 7, 10}, {1, 3, 7, 10});
}

TEST(type_prop, search_sorted_shape_infer_both_dynamic_1) {
PerformShapeTest({1, -1, 7, -1}, {-1, 3, -1, 10}, {1, 3, 7, 10});
}

TEST(type_prop, search_sorted_shape_infer_both_dynamic_2) {
PerformShapeTest({1, -1, 7, 50}, {-1, 3, -1, -1}, {1, 3, 7, -1});
}

TEST(type_prop, search_sorted_shape_infer_both_dynamic_3) {
PerformShapeTest(PartialShape::dynamic(), PartialShape::dynamic(), PartialShape::dynamic());
}

TEST(type_prop, search_sorted_shape_infer_both_dynamic_4) {
PerformShapeTest({-1, -1, 50}, {-1, -1, 20}, {-1, -1, 20});
}

TEST(type_prop, search_sorted_shape_infer_different_types) {
auto sorted = make_shared<ov::op::v0::Parameter>(element::f32, Shape{1, 3, 6});
auto values = make_shared<ov::op::v0::Parameter>(element::i32, Shape{1, 3, 6});
EXPECT_THROW_SUBSTRING(make_shared<op::v15::SearchSorted>(values, sorted),
std::string("must have the same element type"));
}

TEST(type_prop, search_sorted_shape_infer_wrong_rank) {
auto sorted = make_shared<ov::op::v0::Parameter>(element::i32, Shape{1, 1, 3, 6});
auto values = make_shared<ov::op::v0::Parameter>(element::i32, Shape{1, 3, 6});
EXPECT_THROW_SUBSTRING(make_shared<op::v15::SearchSorted>(sorted, values),
std::string("Sorted sequence and values have different ranks"));
}

TEST(type_prop, search_sorted_shape_infer_wrong_dim) {
auto sorted = make_shared<ov::op::v0::Parameter>(element::i32, Shape{1, 1, 3, 6});
auto values = make_shared<ov::op::v0::Parameter>(element::i32, Shape{1, 1, 5, 6});
EXPECT_THROW_SUBSTRING(make_shared<op::v15::SearchSorted>(sorted, values), std::string(" different 2 dimension."));
}

#undef EXPECT_THROW_SUBSTRING

0 comments on commit ccb2134

Please sign in to comment.