Skip to content

Commit

Permalink
[Snippets] Applied Vladislav comments 2
Browse files Browse the repository at this point in the history
  • Loading branch information
a-sidorova committed Sep 12, 2024
1 parent 26ae711 commit 0432d44
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 20 deletions.
23 changes: 19 additions & 4 deletions src/common/snippets/include/snippets/lowered/expression.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class Expression : public std::enable_shared_from_this<Expression> {

void set_input_port_connector(size_t port, PortConnectorPtr to);

// Cannot be called in ctor because validate port attributes (descs, connectors) also
// Attention! Cannot be called in ctor because this method validats port attributes (descs, connectors) also
virtual void validate() const;

ExpressionPort get_input_port(size_t i);
Expand All @@ -65,8 +65,25 @@ class Expression : public std::enable_shared_from_this<Expression> {
bool needShapeInfer() const { return m_need_shape_infer; }
const std::vector<size_t>& get_loop_ids() const;
void set_loop_ids(const std::vector<size_t>& loops);

/**
* @brief Clone Expression with new node and input port attributes/
* Output port descriptors will be cloned from the current expression.
* Output port connecters will be created.
* @param new_node new node
* @param new_inputs new input port connectors
* @param new_in_descs new input port descriptors. If this collection is empty,
* descriptors will be copied from the current expression
* @return the copy
*/
ExpressionPtr clone_with_new_inputs(const std::shared_ptr<Node>& new_node, const std::vector<PortConnectorPtr>& new_inputs,
const std::vector<PortDescriptorPtr>& new_in_descs = {}) const;
/**
* @brief Clone Expression with new node using `expr_map` to connect to new parent expressions.
* @param expr_map the map with the original and cloned expressions
* @param new_node new node
* @return the copy
*/
ExpressionPtr clone_with_new_inputs(const ExpressionMap& expr_map, const std::shared_ptr<Node>& new_node) const;

virtual bool visit_attributes(AttributeVisitor &visitor);
Expand All @@ -93,10 +110,8 @@ class Expression : public std::enable_shared_from_this<Expression> {
// The method must be used only by Linear IR builder of expressions!
Expression(const std::shared_ptr<Node>& n, const std::shared_ptr<IShapeInferSnippetsFactory>& factory, bool need_shape_infer = true);

// Virtual clone method wich is called in clone_with_new_inputs with common logic
// Virtual clone method which is called in clone_with_new_inputs with common logic
virtual ExpressionPtr clone() const;
// Called in ctors to validate expression attributes
virtual void validate_attributes() const;

// used in clone_with_new_inputs. New output port descriptors were inited automatically
void update_port_attributes(const std::shared_ptr<Node>& new_node, const std::vector<PortConnectorPtr>& new_inputs,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ class BufferExpression : public Expression {
BufferExpression(const std::shared_ptr<Node>& n, const std::shared_ptr<IShapeInferSnippetsFactory>& factory);

ExpressionPtr clone() const override;
void validate_attributes() const override;

size_t m_allocation_size = utils::get_dynamic_value<size_t>();
size_t m_reg_group = 0;
Expand Down
10 changes: 2 additions & 8 deletions src/common/snippets/src/lowered/expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ Expression::Expression(const std::shared_ptr<Node>& n, const std::shared_ptr<ISh
for (const auto& output : n->outputs()) {
m_output_port_descriptors.push_back(PortDescriptorUtils::get_port_descriptor_ptr(output));
}
validate_attributes();
}

Expression::Expression(const Expression& other) :
Expand All @@ -40,7 +39,6 @@ Expression::Expression(const Expression& other) :
// input port connectors and input port descriptors - they must be consistent.
m_input_port_descriptors = {};
m_output_port_descriptors = {};
validate_attributes();
}

const PortConnectorPtr& Expression::get_input_port_connector(size_t i) const {
Expand Down Expand Up @@ -95,13 +93,9 @@ void Expression::set_reg_info(const RegInfo& rinfo) {
}
}

void Expression::validate_attributes() const {
OPENVINO_ASSERT(m_source_node != nullptr,
"The expression has null source node");
}

void Expression::validate() const {
validate_attributes();
OPENVINO_ASSERT(m_source_node != nullptr,
"The expression has null source node");
OPENVINO_ASSERT(m_input_port_descriptors.size() == m_input_port_connectors.size(),
"The count of input ports and input port connectors must be equal");
OPENVINO_ASSERT(m_output_port_descriptors.size() == m_output_port_connectors.size(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,6 @@ ExpressionPtr BufferExpression::clone() const {
return std::shared_ptr<BufferExpression>(new BufferExpression(*this));
}

void BufferExpression::validate_attributes() const {
Expression::validate_attributes();
OPENVINO_ASSERT(ov::is_type<op::Buffer>(get_node()), "BufferExpression expects Buffer op");
}

bool BufferExpression::visit_attributes(AttributeVisitor &visitor) {
auto allocation_size = utils::value2str(m_allocation_size);
auto offset = utils::value2str(m_offset);
Expand Down
4 changes: 2 additions & 2 deletions src/common/snippets/src/op/buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ void Buffer::IntermediateMemoryImpl::validate_and_infer_types(Buffer* buffer) co
OPENVINO_ASSERT(buffer, "Buffer is missed");
OPENVINO_ASSERT(buffer->get_input_size() != 0, "IntermediateMemory Buffer must have inputs");
const auto inputs = buffer->input_values();
const auto inshape = buffer->get_input_partial_shape(0);
const auto intype = buffer->get_input_element_type(0);
const auto& inshape = buffer->get_input_partial_shape(0);
const auto& intype = buffer->get_input_element_type(0);
OPENVINO_ASSERT(std::all_of(inputs.cbegin() + 1, inputs.cend(),
[&](const ov::Output<ov::Node>& in) { return in.get_partial_shape() == inshape && in.get_element_type() == intype; }),
"All inputs of Buffers must have the same shape and element type");
Expand Down

0 comments on commit 0432d44

Please sign in to comment.