diff --git a/src/core/include/openvino/op/constant.hpp b/src/core/include/openvino/op/constant.hpp index df630e96576db7..9f4e313d0b51a6 100644 --- a/src/core/include/openvino/op/constant.hpp +++ b/src/core/include/openvino/op/constant.hpp @@ -186,6 +186,23 @@ class OPENVINO_API Constant : public Op { /// \param data A void* to constant data. Constant(const element::Type& type, const Shape& shape, const void* data); + /// \brief Construct a tensor constant from shared memory. + /// + /// The constant byte size (defined by type and shape) must be lower than shared memory size, otherwise throw + /// exception. + /// The Constant can take ownership of shared memory if provided shared object is not null manges memory lifetime. + /// + /// \param type The element type of the tensor constant. + /// \param shape The shape of the tensor constant. + /// \param data The pointer to shared memory. + /// \param size The byte size of shared memory. + /// \param so The shared object to take it ownership (default: nullptr). + Constant(const element::Type& type, + const Shape& shape, + const void* data, + size_t size, + std::shared_ptr so = nullptr); + Constant(const element::Type& type, const Shape& shape, const std::shared_ptr& data); Constant(const Constant& other); diff --git a/src/core/src/op/constant.cpp b/src/core/src/op/constant.cpp index 90b293fb564672..4eed7740fe08d1 100644 --- a/src/core/src/op/constant.cpp +++ b/src/core/src/op/constant.cpp @@ -340,6 +340,27 @@ Constant::Constant(const Constant& other, const Shape& new_shape) constructor_validate_and_infer_types(); } +Constant::Constant(const element::Type& type, + const Shape& shape, + const void* data, + size_t size, + std::shared_ptr so) + : Constant( + type, + shape, + // Note: const_cast used to store pointer only + std::make_shared>>(reinterpret_cast(const_cast(data)), + element::get_memory_size(type, shape_size(shape)), + so)) { + const auto const_size = get_byte_size(); + NODE_VALIDATION_CHECK(this, + const_size <= size, + "The given precision and shape has size larger than the memory size: ", + const_size, + " > ", + size); +} + Constant::~Constant() = default; struct ValueToString : ov::element::NotSupported { diff --git a/src/core/tests/constant.cpp b/src/core/tests/constant.cpp index 012a2aab6cfa13..f54b269c1b06b8 100644 --- a/src/core/tests/constant.cpp +++ b/src/core/tests/constant.cpp @@ -36,6 +36,11 @@ struct TestDType { float value; }; +template +size_t container_byte_size(const Container& c) { + return c.size() * sizeof(typename Container::value_type); +} + using std::string; using std::vector; @@ -2625,6 +2630,116 @@ TEST(constant, hold_tensor_custom_strides_revalidate) { EXPECT_EQ(const_op->get_tensor_view().get_strides(), strides); } +TEST(constant, hold_shared_memory_invalid_shape) { + auto storage = std::vector{1, 2, 3, 4, 5, 6, 5, 4, 3, 2, 1}; + + OV_EXPECT_THROW( + std::ignore = op::v0::Constant(element::i32, Shape{2, 3, 3}, storage.data(), container_byte_size(storage)), + AssertFailure, + AllOf(HasSubstr("i32[2,3,3]"), + HasSubstr("The given precision and shape has size larger than the memory size: 72 > 44"))); +} + +TEST(constant, hold_shared_memory_invalid_precision) { + auto storage = std::vector{1, 2, 3, 4, 5, 6, 5, 4, 3, 2, 1}; + + OV_EXPECT_THROW( + std::ignore = op::v0::Constant(element::i64, Shape{2, 3}, storage.data(), container_byte_size(storage)), + AssertFailure, + AllOf(HasSubstr("i64[2,3]"), + HasSubstr("The given precision and shape has size larger than the memory size: 48 > 44"))); +} + +TEST(constant, hold_shared_memory_same_size) { + auto storage = + std::make_shared>(std::initializer_list{1, 2, 3, 4, 5, 6, 5, 4, 3, 2, 1}); + { + auto c = op::v0::Constant(element::i32, Shape{11}, storage->data(), container_byte_size(*storage)); + std::fill_n(storage->begin() + 3, 4, 0); + + EXPECT_EQ(storage.use_count(), 1); + EXPECT_EQ(c.get_data_ptr(), storage->data()); + EXPECT_EQ(c.get_vector(), std::vector({1, 2, 3, 0, 0, 0, 0, 4, 3, 2, 1})); + EXPECT_EQ(c.cast_vector(), std::vector({1, 2, 3, 0, 0, 0, 0, 4, 3, 2, 1})); + } + EXPECT_EQ(storage.use_count(), 1); +} + +TEST(constant, hold_shared_memory_shape_within_memory_size) { + auto storage = std::vector{1, 2, 3, 4, 5, 6, 5, 4, 3, 2, 1}; + auto c = op::v0::Constant(element::u8, Shape{2, 3}, storage.data(), container_byte_size(storage)); + + EXPECT_EQ(c.get_data_ptr(), storage.data()); + EXPECT_EQ(c.get_vector(), std::vector({1, 2, 3, 4, 5, 6})); + EXPECT_EQ(c.cast_vector(), std::vector({1, 2, 3, 4, 5, 6})); +} + +TEST(constant, hold_shared_memory_different_precision) { + auto storage = std::vector{1, 2, 3, 4, 5, 6, 5, 4, 3, 2, 1}; + auto c = op::v0::Constant(element::u8, Shape{2, 3}, storage.data(), container_byte_size(storage)); + + EXPECT_EQ(c.get_data_ptr(), storage.data()); + EXPECT_EQ(c.get_vector(), std::vector({1, 0, 0, 0, 2, 0})); + EXPECT_EQ(c.cast_vector(), std::vector({1, 0, 0, 0, 2, 0})); +} + +TEST(constant, own_shared_memory) { + struct CustomStorage { + CustomStorage(std::initializer_list values) : values{std::move(values)} { + ON_CALL(*this, dtor_impl).WillByDefault(testing::Return()); + } + + ~CustomStorage() { + dtor_impl(); + } + + MOCK_METHOD(void, dtor_impl, ()); + + size_t byte_size() const { + return container_byte_size(values); + } + + constexpr ov::element::Type get_element_type() const { + return ov::element::i16; + } + + std::vector values{}; + }; + + { + auto storage = std::make_shared(std::initializer_list{1, 2, 3, 4, 5, 6, 5, 4, 3, 2, 1}); + auto c = std::make_shared(storage->get_element_type(), + Shape{2}, + storage->values.data(), + container_byte_size(storage->values), + storage); + + EXPECT_EQ(storage.use_count(), 2); + + c = nullptr; + EXPECT_EQ(storage.use_count(), 1); + EXPECT_CALL(*storage, dtor_impl).Times(1); + } + + { + std::shared_ptr c; + CustomStorage* s_ptr; + { + auto storage = std::make_shared>( + std::initializer_list{1, 2, 3, 4, 5, 6, 5, 4, 3, 2, 1}); + s_ptr = storage.get(); + c = std::make_shared(storage->get_element_type(), + Shape{2}, + storage->values.data(), + container_byte_size(storage->values), + storage); + } + + EXPECT_CALL(*s_ptr, dtor_impl).Times(1); + c = nullptr; + } +} + // Test verifies 2 things: // a) Checks that bitwise comparison happens on first call of 'get_all_data_elements_bitwise_identical' // b) Next call of 'get_all_data_elements_bitwise_identical' takes already calculated value