Skip to content

Commit

Permalink
Result tensor uses shared tensor names
Browse files Browse the repository at this point in the history
  • Loading branch information
praasz authored and olpipi committed Aug 28, 2024
1 parent 1f876d6 commit 1b7eedf
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 14 deletions.
4 changes: 2 additions & 2 deletions src/core/src/op/result.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,11 @@ class SharedTensor : public ITensorDescriptor {
}

const std::unordered_set<std::string>& get_names() const override {
return m_output_names.empty() ? m_shared_tensor->get_names() : m_output_names;
return m_shared_tensor->get_names();
}

const std::string& get_any_name() const override {
return m_output_names.empty() ? m_shared_tensor->get_any_name() : *m_output_names.begin();
return m_shared_tensor->get_any_name();
}

/** @brief Gets runtime map from shared tensor. */
Expand Down
25 changes: 13 additions & 12 deletions src/core/tests/type_prop/result.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@ TEST_F(TypePropResultV0Test, set_specific_output_name_by_output) {
EXPECT_THAT(result->get_output_tensor(0).get_names(), UnorderedElementsAre("input"));

result->output(0).set_names({"out"});
EXPECT_THAT(result->output(0).get_names(), UnorderedElementsAre("out"));
EXPECT_THAT(result->get_output_tensor(0).get_names(), UnorderedElementsAre("out"));
EXPECT_THAT(result->output(0).get_names(), UnorderedElementsAre("input", "out"));
EXPECT_THAT(result->get_output_tensor(0).get_names(), UnorderedElementsAre("input", "out"));
EXPECT_THAT(a->output(0).get_names(), UnorderedElementsAre("input", "out"));
EXPECT_THAT(a->get_output_tensor(0).get_names(), UnorderedElementsAre("input", "out"));
}
Expand All @@ -83,8 +83,8 @@ TEST_F(TypePropResultV0Test, set_specific_output_name_by_tensor_desc) {
EXPECT_THAT(result->get_output_tensor(0).get_names(), UnorderedElementsAre("input"));

result->get_output_tensor(0).set_names({"out"});
EXPECT_THAT(result->output(0).get_names(), UnorderedElementsAre("out"));
EXPECT_THAT(result->get_output_tensor(0).get_names(), UnorderedElementsAre("out"));
EXPECT_THAT(result->output(0).get_names(), UnorderedElementsAre("input", "out"));
EXPECT_THAT(result->get_output_tensor(0).get_names(), UnorderedElementsAre("input", "out"));
EXPECT_THAT(a->output(0).get_names(), UnorderedElementsAre("input", "out"));
EXPECT_THAT(a->get_output_tensor(0).get_names(), UnorderedElementsAre("input", "out"));
}
Expand All @@ -99,15 +99,15 @@ TEST_F(TypePropResultV0Test, change_specific_output_name) {

result->get_output_tensor(0).set_names({"out"});

EXPECT_THAT(result->output(0).get_names(), UnorderedElementsAre("out"));
EXPECT_THAT(result->get_output_tensor(0).get_names(), UnorderedElementsAre("out"));
EXPECT_THAT(result->output(0).get_names(), UnorderedElementsAre("input", "out"));
EXPECT_THAT(result->get_output_tensor(0).get_names(), UnorderedElementsAre("input", "out"));
EXPECT_THAT(a->output(0).get_names(), UnorderedElementsAre("input", "out"));
EXPECT_THAT(a->get_output_tensor(0).get_names(), UnorderedElementsAre("input", "out"));

result->output(0).set_names({"new output"});

EXPECT_THAT(result->output(0).get_names(), UnorderedElementsAre("new output"));
EXPECT_THAT(result->get_output_tensor(0).get_names(), UnorderedElementsAre("new output"));
EXPECT_THAT(result->output(0).get_names(), UnorderedElementsAre("input", "new output"));
EXPECT_THAT(result->get_output_tensor(0).get_names(), UnorderedElementsAre("input", "new output"));
EXPECT_THAT(a->output(0).get_names(), UnorderedElementsAre("input", "new output"));
EXPECT_THAT(a->get_output_tensor(0).get_names(), UnorderedElementsAre("input", "new output"));
}
Expand All @@ -124,8 +124,9 @@ TEST_F(TypePropResultV0Test, add_specific_output_name) {
result->get_output_tensor(0).add_names({"extra output name", "o1"});
result->output(0).add_names({"extra output name", "o2"});

EXPECT_THAT(result->output(0).get_names(), UnorderedElementsAre("out", "extra output name", "o1", "o2"));
EXPECT_THAT(result->get_output_tensor(0).get_names(), UnorderedElementsAre("out", "extra output name", "o1", "o2"));
EXPECT_THAT(result->output(0).get_names(), UnorderedElementsAre("input", "out", "extra output name", "o1", "o2"));
EXPECT_THAT(result->get_output_tensor(0).get_names(),
UnorderedElementsAre("input", "out", "extra output name", "o1", "o2"));
EXPECT_THAT(a->output(0).get_names(), UnorderedElementsAre("input", "out", "extra output name", "o1", "o2"));
EXPECT_THAT(a->get_output_tensor(0).get_names(),
UnorderedElementsAre("input", "out", "extra output name", "o1", "o2"));
Expand All @@ -138,15 +139,15 @@ TEST_F(TypePropResultV0Test, preserve_specific_name_on_input_replace) {
const auto result = make_op(a);
result->output(0).set_names({"out"});

EXPECT_THAT(result->output(0).get_names(), UnorderedElementsAre("out"));
EXPECT_THAT(result->output(0).get_names(), UnorderedElementsAre("input a", "out"));

const auto b = std::make_shared<Parameter>(element::f32, PartialShape::dynamic());
b->get_output_tensor(0).set_names({"input b"});

result->input(0).replace_source_output(b);
result->validate_and_infer_types();

EXPECT_THAT(result->output(0).get_names(), UnorderedElementsAre("out"));
EXPECT_THAT(result->output(0).get_names(), UnorderedElementsAre("input b", "out"));
EXPECT_THAT(a->output(0).get_names(), UnorderedElementsAre("input a"));
}
} // namespace test
Expand Down

0 comments on commit 1b7eedf

Please sign in to comment.