Skip to content

Commit

Permalink
[FIX] Remove copy mode, set it as default
Browse files Browse the repository at this point in the history
  • Loading branch information
PiotrKrzem committed Sep 30, 2024
1 parent 57d8fed commit 30daa6d
Show file tree
Hide file tree
Showing 8 changed files with 16 additions and 58 deletions.
13 changes: 2 additions & 11 deletions src/core/include/openvino/op/identity.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,13 @@ class OPENVINO_API Identity : public Op {
OPENVINO_OP("Identity", "opset15");
Identity() = default;
/**
* @brief Identity operation is used as a placeholder. It either passes the tensor down to the next layer,
* or copies the tensor to the output.
*
* @param copy Boolean that determines whether to copy the input to the output, or just return the output.
* @brief Identity operation is used as a placeholder. It copies the tensor data to the output.
*/
Identity(const Output<Node>& data, const bool copy = false);
Identity(const Output<Node>& data);

bool visit_attributes(AttributeVisitor& visitor) override;
void validate_and_infer_types() override;
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;

bool get_copy() const;
void set_copy(const bool copy);

private:
bool m_copy;
};
} // namespace v15
} // namespace op
Expand Down
13 changes: 2 additions & 11 deletions src/core/reference/include/openvino/reference/identity.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,9 @@ namespace identity {
*
* @param input Input matrix (matrices) pointer.
* @param output Output matrix (matrices) pointer.
* @param copy Boolean that determines whether to return the input as output or
* copy the input to a new memory address.
**/
template <typename T>
void identity(const T** input, T** output, const Shape& shape, const bool copy) {
const auto total_elements = shape_size<Shape>(shape);

if (!copy) {
*output = *input;
} else {
std::memcpy(*output, *input, total_elements * sizeof(T));
}
void identity(const char* input, char* output, const size_t size_in_bytes) {
std::memcpy(output, input, size_in_bytes);
}
} // namespace identity
} // namespace reference
Expand Down
13 changes: 2 additions & 11 deletions src/core/src/op/identity.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,12 @@ namespace ov {
namespace op {
namespace v15 {

Identity::Identity(const Output<Node>& data, const bool copy) : Op({data}), m_copy(copy) {
Identity::Identity(const Output<Node>& data) : Op({data}) {
constructor_validate_and_infer_types();
}

bool Identity::Identity::visit_attributes(AttributeVisitor& visitor) {
OV_OP_SCOPE(v15_Identity_visit_attributes);
visitor.on_attribute("copy", m_copy);
return true;
}

Expand All @@ -36,14 +35,6 @@ std::shared_ptr<Node> Identity::Identity::clone_with_new_inputs(const OutputVect
OV_OP_SCOPE(v15_Identity_clone_with_new_inputs);
check_new_args_count(this, new_args);

return std::make_shared<Identity>(new_args.at(0), m_copy);
}

bool Identity::get_copy() const {
return m_copy;
}

void Identity::set_copy(const bool copy) {
m_copy = copy;
return std::make_shared<Identity>(new_args.at(0));
}
} // namespace ov
1 change: 0 additions & 1 deletion src/core/tests/type_prop/identity.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,5 +24,4 @@ TEST_F(TypePropIdentityV15Test, default_ctor) {
EXPECT_EQ(op->get_output_size(), 1);
EXPECT_EQ(op->get_output_element_type(0), ov::element::f64);
EXPECT_EQ(op->get_output_partial_shape(0), ov::PartialShape({2, 2}));
EXPECT_EQ(op->get_copy(), false);
}
5 changes: 2 additions & 3 deletions src/core/tests/visitors/op/identity.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,10 @@ TEST(attributes, Identity) {
NodeBuilder::opset().insert<ov::op::v15::Identity>();
const auto data = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{2, 2});

const auto op = std::make_shared<ov::op::v15::Identity>(data, true);
const auto op = std::make_shared<ov::op::v15::Identity>(data);
NodeBuilder builder(op, {data});
auto g_identity = ov::as_type_ptr<ov::op::v15::Identity>(builder.create());

constexpr auto expected_attr_count = 1;
constexpr auto expected_attr_count = 0;
EXPECT_EQ(builder.get_value_map_size(), expected_attr_count);
EXPECT_EQ(op->get_copy(), g_identity->get_copy());
}
5 changes: 4 additions & 1 deletion src/plugins/template/backend/ops/identity.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,12 @@ inline bool evaluate(const std::shared_ptr<ov::op::v15::Identity>& op,
using T = typename ov::element_type_traits<ET>::value_type;

const std::vector<ov::PartialShape> input_shapes{op->get_input_shape(0)};
const auto total_size = get_shape_size(out_shape);
const auto total_size_in_bytes = total_size * inputs[0].get_dtype().get_element_size();

outputs[0].set_shape(input_shapes[0]);

ov::reference::Identity<T>(inputs[0].data<const T>(), outputs[0].data<T>(), out_shape, op->get_copy());
ov::reference::Identity(static_cast<const char*>(inputs[0].data()), static_cast<char*>(outputs[0].data()), total_size_in_bytes);
return true;
}

Expand Down
22 changes: 3 additions & 19 deletions src/plugins/template/tests/functional/op_reference/identity.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,9 @@ namespace {
struct IdentityParams {
IdentityParams(const reference_tests::Tensor& matrices, bool copy, std::string name)
: matrices{matrices},
copy(copy),
test_case_name{std::move(name)} {}

reference_tests::Tensor matrices;
bool copy;
std::string test_case_name;
};

Expand All @@ -27,7 +25,6 @@ class ReferenceIdentity : public testing::TestWithParam<IdentityParams>, public
function = CreateFunction(params);
inputData = {params.matrices.data};
refOutData = {params.matrices.data};
m_copy = params.copy;
}

static std::string getTestCaseName(const testing::TestParamInfo<IdentityParams>& obj) {
Expand All @@ -37,26 +34,15 @@ class ReferenceIdentity : public testing::TestWithParam<IdentityParams>, public
name << obj.param.matrices.type;
name << "_shape_";
name << obj.param.matrices.shape;
name << "_copy_";
name << obj.param.copy;
return name.str();
}

void Validate() {
CommonReferenceTest::Validate();

bool pointers_match = refOutData[0].data() == actualOutData[0].data();
ASSERT_EQ(pointers_match, !m_copy);
}

private:
static std::shared_ptr<ov::Model> CreateFunction(const IdentityParams& params) {
const auto in_matrices = std::make_shared<ov::op::v0::Parameter>(params.matrices.type, params.matrices.shape);
const auto identity = std::make_shared<ov::op::v15::Identity>(in_matrices, params.copy);
const auto identity = std::make_shared<ov::op::v15::Identity>(in_matrices);
return std::make_shared<ov::Model>(identity->outputs(), ov::ParameterVector{in_matrices});
}

bool m_copy;
};

template <ov::element::Type_t ET>
Expand Down Expand Up @@ -91,10 +77,8 @@ std::vector<IdentityParams> generateIdentityParams() {
0.0f});

std::vector<IdentityParams> params;
params.emplace_back(matrices_2_2, false, "single_simple");
params.emplace_back(matrices_2_2, true, "single_simple");
params.emplace_back(matrices_2_3_3, false, "many_simple");
params.emplace_back(matrices_2_3_3, true, "many_simple");
params.emplace_back(matrices_2_2, "single");
params.emplace_back(matrices_2_3_3, "many");

return params;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -527,7 +527,7 @@ std::shared_ptr<ov::Model> generate(const std::shared_ptr<ov::op::v0::Interpolat

std::shared_ptr<ov::Model> generate(const std::shared_ptr<ov::op::v15::Identity>& node) {
ov::ParameterVector params{std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{4, 4, 4})};
const auto identity = std::make_shared<ov::op::v15::Identity>(params[0], false);
const auto identity = std::make_shared<ov::op::v15::Identity>(params[0]);
ov::ResultVector results{std::make_shared<ov::op::v0::Result>(identity)};
return std::make_shared<ov::Model>(results, params, "Identity");
}
Expand Down

0 comments on commit 30daa6d

Please sign in to comment.