diff --git a/src/common/snippets/include/snippets/lowered/expression.hpp b/src/common/snippets/include/snippets/lowered/expression.hpp index befbaeb3c526d6..caf450ace5a925 100644 --- a/src/common/snippets/include/snippets/lowered/expression.hpp +++ b/src/common/snippets/include/snippets/lowered/expression.hpp @@ -53,7 +53,7 @@ class Expression : public std::enable_shared_from_this { void set_input_port_connector(size_t port, PortConnectorPtr to); - // Cannot be called in ctor because validate port attributes (descs, connectors) also + // Attention! Cannot be called in ctor because this method validats port attributes (descs, connectors) also virtual void validate() const; ExpressionPort get_input_port(size_t i); @@ -65,8 +65,25 @@ class Expression : public std::enable_shared_from_this { bool needShapeInfer() const { return m_need_shape_infer; } const std::vector& get_loop_ids() const; void set_loop_ids(const std::vector& loops); + + /** + * @brief Clone Expression with new node and input port attributes/ + * Output port descriptors will be cloned from the current expression. + * Output port connecters will be created. + * @param new_node new node + * @param new_inputs new input port connectors + * @param new_in_descs new input port descriptors. If this collection is empty, + * descriptors will be copied from the current expression + * @return the copy + */ ExpressionPtr clone_with_new_inputs(const std::shared_ptr& new_node, const std::vector& new_inputs, const std::vector& new_in_descs = {}) const; + /** + * @brief Clone Expression with new node using `expr_map` to connect to new parent expressions. + * @param expr_map the map with the original and cloned expressions + * @param new_node new node + * @return the copy + */ ExpressionPtr clone_with_new_inputs(const ExpressionMap& expr_map, const std::shared_ptr& new_node) const; virtual bool visit_attributes(AttributeVisitor &visitor); @@ -93,10 +110,8 @@ class Expression : public std::enable_shared_from_this { // The method must be used only by Linear IR builder of expressions! Expression(const std::shared_ptr& n, const std::shared_ptr& factory, bool need_shape_infer = true); - // Virtual clone method wich is called in clone_with_new_inputs with common logic + // Virtual clone method which is called in clone_with_new_inputs with common logic virtual ExpressionPtr clone() const; - // Called in ctors to validate expression attributes - virtual void validate_attributes() const; // used in clone_with_new_inputs. New output port descriptors were inited automatically void update_port_attributes(const std::shared_ptr& new_node, const std::vector& new_inputs, diff --git a/src/common/snippets/include/snippets/lowered/expressions/buffer_expression.hpp b/src/common/snippets/include/snippets/lowered/expressions/buffer_expression.hpp index b359c95faa06c5..cd096b5dfbe461 100644 --- a/src/common/snippets/include/snippets/lowered/expressions/buffer_expression.hpp +++ b/src/common/snippets/include/snippets/lowered/expressions/buffer_expression.hpp @@ -57,7 +57,6 @@ class BufferExpression : public Expression { BufferExpression(const std::shared_ptr& n, const std::shared_ptr& factory); ExpressionPtr clone() const override; - void validate_attributes() const override; size_t m_allocation_size = utils::get_dynamic_value(); size_t m_reg_group = 0; diff --git a/src/common/snippets/src/lowered/expression.cpp b/src/common/snippets/src/lowered/expression.cpp index b7fe496f66896a..b13574e36029e2 100644 --- a/src/common/snippets/src/lowered/expression.cpp +++ b/src/common/snippets/src/lowered/expression.cpp @@ -25,7 +25,6 @@ Expression::Expression(const std::shared_ptr& n, const std::shared_ptroutputs()) { m_output_port_descriptors.push_back(PortDescriptorUtils::get_port_descriptor_ptr(output)); } - validate_attributes(); } Expression::Expression(const Expression& other) : @@ -40,7 +39,6 @@ Expression::Expression(const Expression& other) : // input port connectors and input port descriptors - they must be consistent. m_input_port_descriptors = {}; m_output_port_descriptors = {}; - validate_attributes(); } const PortConnectorPtr& Expression::get_input_port_connector(size_t i) const { @@ -95,13 +93,9 @@ void Expression::set_reg_info(const RegInfo& rinfo) { } } - void Expression::validate_attributes() const { - OPENVINO_ASSERT(m_source_node != nullptr, - "The expression has null source node"); - } - void Expression::validate() const { - validate_attributes(); + OPENVINO_ASSERT(m_source_node != nullptr, + "The expression has null source node"); OPENVINO_ASSERT(m_input_port_descriptors.size() == m_input_port_connectors.size(), "The count of input ports and input port connectors must be equal"); OPENVINO_ASSERT(m_output_port_descriptors.size() == m_output_port_connectors.size(), diff --git a/src/common/snippets/src/lowered/expressions/buffer_expression.cpp b/src/common/snippets/src/lowered/expressions/buffer_expression.cpp index 7bf2b00da7d6ed..aa081f32a773e8 100644 --- a/src/common/snippets/src/lowered/expressions/buffer_expression.cpp +++ b/src/common/snippets/src/lowered/expressions/buffer_expression.cpp @@ -28,11 +28,6 @@ ExpressionPtr BufferExpression::clone() const { return std::shared_ptr(new BufferExpression(*this)); } -void BufferExpression::validate_attributes() const { - Expression::validate_attributes(); - OPENVINO_ASSERT(ov::is_type(get_node()), "BufferExpression expects Buffer op"); -} - bool BufferExpression::visit_attributes(AttributeVisitor &visitor) { auto allocation_size = utils::value2str(m_allocation_size); auto offset = utils::value2str(m_offset); diff --git a/src/common/snippets/src/op/buffer.cpp b/src/common/snippets/src/op/buffer.cpp index 4e20bbabd1fd1b..a99a75dc012f81 100644 --- a/src/common/snippets/src/op/buffer.cpp +++ b/src/common/snippets/src/op/buffer.cpp @@ -59,8 +59,8 @@ void Buffer::IntermediateMemoryImpl::validate_and_infer_types(Buffer* buffer) co OPENVINO_ASSERT(buffer, "Buffer is missed"); OPENVINO_ASSERT(buffer->get_input_size() != 0, "IntermediateMemory Buffer must have inputs"); const auto inputs = buffer->input_values(); - const auto inshape = buffer->get_input_partial_shape(0); - const auto intype = buffer->get_input_element_type(0); + const auto& inshape = buffer->get_input_partial_shape(0); + const auto& intype = buffer->get_input_element_type(0); OPENVINO_ASSERT(std::all_of(inputs.cbegin() + 1, inputs.cend(), [&](const ov::Output& in) { return in.get_partial_shape() == inshape && in.get_element_type() == intype; }), "All inputs of Buffers must have the same shape and element type");