Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Snippets] Introduced BufferExpression #26413

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 45 additions & 5 deletions src/common/snippets/include/snippets/lowered/expression.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,18 @@ namespace ov {
namespace snippets {
namespace lowered {

class ExpressionFactory;
class LinearIR;
using ExpressionPtr = std::shared_ptr<Expression>;
using ExpressionMap = std::unordered_map<Expression*, ExpressionPtr>;
class Expression : public std::enable_shared_from_this<Expression> {
friend class LinearIR;
friend class ExpressionFactory;
friend class ExpressionPort;

public:
Expression() = default;
virtual ~Expression() = default;

std::shared_ptr<Node> get_node() const;
std::shared_ptr<Emitter> get_emitter() const;
Expand All @@ -50,7 +53,8 @@ class Expression : public std::enable_shared_from_this<Expression> {

void set_input_port_connector(size_t port, PortConnectorPtr to);

void validate() const;
// Attention! Cannot be called in ctor because this method validats port attributes (descs, connectors)
virtual void validate() const;

ExpressionPort get_input_port(size_t i);
ExpressionPort get_output_port(size_t i);
Expand All @@ -61,16 +65,52 @@ class Expression : public std::enable_shared_from_this<Expression> {
bool needShapeInfer() const { return m_need_shape_infer; }
const std::vector<size_t>& get_loop_ids() const;
void set_loop_ids(const std::vector<size_t>& loops);
ExpressionPtr clone_with_new_inputs(const std::vector<PortConnectorPtr>& new_inputs,
const std::shared_ptr<Node>& new_node) const;

/**
* @brief Clone Expression with new node and input port attributes/
* Output port descriptors will be cloned from the current expression.
* Output port connecters will be created.
* @param new_node new node
* @param new_inputs new input port connectors
* @param new_in_descs new input port descriptors. If this collection is empty,
* descriptors will be copied from the current expression
* @return the copy
*/
ExpressionPtr clone_with_new_inputs(const std::shared_ptr<Node>& new_node, const std::vector<PortConnectorPtr>& new_inputs,
const std::vector<PortDescriptorPtr>& new_in_descs = {}) const;
/**
* @brief Clone Expression with new node using `expr_map` to connect to new parent expressions.
* @param expr_map the map with the original and cloned expressions
* @param new_node new node
* @return the copy
*/
ExpressionPtr clone_with_new_inputs(const ExpressionMap& expr_map, const std::shared_ptr<Node>& new_node) const;

virtual bool visit_attributes(AttributeVisitor &visitor);

// Note that get_type_info_static and get_type_info are needed to mimic OPENVINO_RTTI interface,
// so the standard OPENVINO_RTTI(...) macros could be used in derived classes.
_OPENVINO_HIDDEN_METHOD static const ::ov::DiscreteTypeInfo& get_type_info_static() {
static ::ov::DiscreteTypeInfo type_info_static {"Expression"};
type_info_static.hash();
return type_info_static;
}

virtual const DiscreteTypeInfo& get_type_info() const {
return get_type_info_static();
}

const char* get_type_name() const {
return get_type_info().name;
}

protected:
Expression(const Expression& other);
// Note: The constructor initialization is private since an expression can be created only by Linear IR.
// The method must be used only by Linear IR builder of expressions!
Expression(const std::shared_ptr<Node>& n, const std::shared_ptr<IShapeInferSnippetsFactory>& factory, bool need_shape_infer = true);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Open question for discussion: can we move this ctor to public section to avoid writing friend class ExpressionFactory; in each SpecialExpression 🤔

void update_node_and_connectors(const std::vector<PortConnectorPtr>& new_inputs, const std::shared_ptr<Node>& new_node);

// Virtual clone method which is called in clone_with_new_inputs with common logic
virtual ExpressionPtr clone() const;

std::shared_ptr<Node> m_source_node{nullptr};
std::shared_ptr<Emitter> m_emitter{nullptr};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,65 +4,72 @@

#pragma once

#include "linear_ir.hpp"
#include "expression.hpp"
#include "expressions/buffer_expression.hpp"

#include "snippets/snippets_isa.hpp"
#include "snippets/op/loop.hpp"
#include "snippets/op/buffer.hpp"
#include "snippets/op/perf_count.hpp"

namespace ov {
namespace snippets {
namespace lowered {

class LinearIR::ExpressionFactory {
class ExpressionFactory {
public:
template<class... Args>
static ExpressionPtr build(const std::shared_ptr<Node>& n, Args&&... params) {
if (const auto par = ov::as_type_ptr<ov::op::v0::Parameter>(n)) {
return create(par, params...);
} else if (const auto res = ov::as_type_ptr<ov::op::v0::Result>(n)) {
return create(res, params...);
} else if (const auto loop_begin = ov::as_type_ptr<op::LoopBegin>(n)) {
return create(loop_begin, params...);
} else if (const auto loop_end = ov::as_type_ptr<op::LoopEnd>(n)) {
return create(loop_end, params...);
#ifdef SNIPPETS_DEBUG_CAPS
} else if (const auto perf_counter = ov::as_type_ptr<op::PerfCountBeginBase>(n)) {
return create(perf_counter, params...);
} else if (const auto perf_counter = ov::as_type_ptr<op::PerfCountEndBase>(n)) {
return create(perf_counter, params...);
#endif
}
return create(n, params...);
ExpressionFactory(std::shared_ptr<IShapeInferSnippetsFactory> shape_infer_factory)
: m_shape_infer_factory(std::move(shape_infer_factory)) {}

template <typename T = Expression, typename... Args,
v-Golubev marked this conversation as resolved.
Show resolved Hide resolved
typename std::enable_if<std::is_base_of<Expression, T>::value, bool>::type = true>
std::shared_ptr<T> build(const std::shared_ptr<Node>& n, const std::vector<PortConnectorPtr>& inputs, Args... args) {
return create<T>(n, inputs, m_shape_infer_factory, args...);
}

private:
/* -- Default Builders - initialize input port connectors from parents and create new output port connectors themselves */
static ExpressionPtr create(const std::shared_ptr<ov::op::v0::Parameter>& par, const LinearIR& linear_ir);
static ExpressionPtr create(const std::shared_ptr<ov::op::v0::Result>& res, const LinearIR& linear_ir);
static ExpressionPtr create(const std::shared_ptr<ov::Node>& n, const LinearIR& linear_ir);

/* -- Input Builders - get input port connectors from method parameters and create new output port connectors themselves */
static ExpressionPtr create(const std::shared_ptr<op::LoopBegin>& n, const std::vector<PortConnectorPtr>& inputs, const LinearIR& linear_ir);
static ExpressionPtr create(const std::shared_ptr<op::LoopEnd>& n, const std::vector<PortConnectorPtr>& inputs, const LinearIR& linear_ir);
static ExpressionPtr create(const std::shared_ptr<ov::Node>& n, const std::vector<PortConnectorPtr>& inputs, const LinearIR& linear_ir);
static ExpressionPtr create(const std::shared_ptr<ov::op::v0::Parameter>& par, const std::vector<PortConnectorPtr>& inputs,
const std::shared_ptr<IShapeInferSnippetsFactory>& shape_infer_factory);
static ExpressionPtr create(const std::shared_ptr<ov::op::v0::Result>& res, const std::vector<PortConnectorPtr>& inputs,
const std::shared_ptr<IShapeInferSnippetsFactory>& shape_infer_factory);
static ExpressionPtr create(const std::shared_ptr<op::LoopBegin>& n, const std::vector<PortConnectorPtr>& inputs,
const std::shared_ptr<IShapeInferSnippetsFactory>& shape_infer_factory);
static ExpressionPtr create(const std::shared_ptr<op::LoopEnd>& n, const std::vector<PortConnectorPtr>& inputs,
const std::shared_ptr<IShapeInferSnippetsFactory>& shape_infer_factory);

// Note: PerfCountBegin nodes have a PerfCountEnd ov::Output, but corresponding expression should not have any outputs to avoid register allocation
#ifdef SNIPPETS_DEBUG_CAPS
static ExpressionPtr create(const std::shared_ptr<op::PerfCountBeginBase>& n,
const std::vector<PortConnectorPtr>& inputs,
const LinearIR& linear_ir);
static ExpressionPtr create(const std::shared_ptr<op::PerfCountEndBase>& n,
const std::vector<PortConnectorPtr>& inputs,
const LinearIR& linear_ir);
static ExpressionPtr create_without_connections(const std::shared_ptr<ov::Node>& n, const LinearIR& linear_ir);
static ExpressionPtr create(const std::shared_ptr<op::PerfCountBeginBase>& n, const std::vector<PortConnectorPtr>& inputs,
const std::shared_ptr<IShapeInferSnippetsFactory>& shape_infer_factory);
static ExpressionPtr create(const std::shared_ptr<op::PerfCountEndBase>& n, const std::vector<PortConnectorPtr>& inputs,
const std::shared_ptr<IShapeInferSnippetsFactory>& shape_infer_factory);
static ExpressionPtr create_without_connections(const std::shared_ptr<ov::Node>& n, const std::shared_ptr<IShapeInferSnippetsFactory>& shape_infer_factory);
#endif

// Creates inputs for expression using parent output port connectors
static void create_expression_inputs(const LinearIR& linear_ir, const ExpressionPtr& expr);
template <typename T = Expression, typename... Args,
typename std::enable_if<std::is_base_of<Expression, T>::value, bool>::type = true>
static std::shared_ptr<T> create(const std::shared_ptr<ov::Node>& n, const std::vector<PortConnectorPtr>& inputs,
const std::shared_ptr<IShapeInferSnippetsFactory>& shape_infer_factory, Args... args) {
auto expr = std::shared_ptr<T>(new T(n, shape_infer_factory, args...));
init_expression_inputs(expr, inputs);
create_expression_outputs(expr);
expr->validate();
IvanNovoselov marked this conversation as resolved.
Show resolved Hide resolved
// todo: here we blindly synchronize input shapes from parent and child. Remove this when shapes will be stored in port connector itself
if (shape_infer_factory)
expr->updateShapes();
return expr;
}

// Creates new output port connectors
static void create_expression_outputs(const ExpressionPtr& expr);
// The method verifies of input port connectors to availability of the expression as consumer and add it if missed
static void init_expression_inputs(const ExpressionPtr& expr, const std::vector<PortConnectorPtr>& inputs);

const std::shared_ptr<IShapeInferSnippetsFactory> m_shape_infer_factory = nullptr;
};
using ExpressionFactoryPtr = std::shared_ptr<ExpressionFactory>;

template<>
std::shared_ptr<Expression> ExpressionFactory::build(const std::shared_ptr<Node>& n, const std::vector<PortConnectorPtr>& inputs);

} // namespace lowered
} // namespace snippets
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "snippets/lowered/expression.hpp"

#include "snippets/utils/utils.hpp"


namespace ov {
namespace snippets {
namespace lowered {

// To avoid cycle-dependancy of includes, we forward-declare LoopManager
class LoopManager;
/**
* @interface BufferExpression
* @brief This is a base class for memory storage.
* Note that Buffer should be a single consumer for operation output port
* @param m_allocation_size - memory size for allocation in bytes. Dynamic value means undefined size.
* @param m_offset - offset in common Buffer scratchpad
* @param m_reg_group - number of register group. The Buffers from the same register group will have the same GPR
* @param m_cluster_id - number of cluster. The Buffers from the same cluster shares memory between them and will have the same offset.
* @ingroup snippets
*/
class BufferExpression : public Expression {
friend class ExpressionFactory;
public:
OPENVINO_RTTI("BufferExpression", "0", Expression)
BufferExpression() = default;

bool visit_attributes(AttributeVisitor &visitor) override;

size_t get_reg_group() const { return m_reg_group; }
size_t get_cluster_id() const { return m_cluster_id; }
size_t get_offset() const { return m_offset; }
size_t get_allocation_size() const { return m_allocation_size; }
size_t get_byte_size() const;

void set_reg_group(size_t reg_group) { m_reg_group = reg_group; }
void set_cluster_id(size_t cluster) { m_cluster_id = cluster; }
void set_allocation_size(size_t size) { m_allocation_size = size; }
void set_offset(size_t offset) { m_offset = offset; }

virtual void init_allocation_size(const std::shared_ptr<LoopManager>& loop_manager, size_t allocation_rank);

// Returns True, if allocation size is known. Otherwise returns False - allocation size is undefined
bool is_defined() const;

// Returns True, if the memory is independent - expression doesn't have parents (source)
bool is_independent_memory() const { return get_input_count() == 0; }

protected:
BufferExpression(const std::shared_ptr<Node>& n, const std::shared_ptr<IShapeInferSnippetsFactory>& factory);

ExpressionPtr clone() const override;

size_t m_allocation_size = utils::get_dynamic_value<size_t>();
size_t m_reg_group = 0;
size_t m_cluster_id = 0;
size_t m_offset = utils::get_dynamic_value<size_t>();
};
using BufferExpressionPtr = std::shared_ptr<BufferExpression>;

} // namespace lowered
} // namespace snippets
} // namespace ov
47 changes: 33 additions & 14 deletions src/common/snippets/include/snippets/lowered/linear_ir.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
#include <list>

#include "snippets/lowered/expression.hpp"
#include "snippets/lowered/expression_factory.hpp"
#include "snippets/lowered/expressions/buffer_expression.hpp"
#include "snippets/target_machine.hpp"
#include "snippets/shape_inference/shape_inference.hpp"
#ifdef SNIPPETS_DEBUG_CAPS
Expand Down Expand Up @@ -51,7 +53,6 @@ using LoopManagerPtr = std::shared_ptr<LoopManager>;
*/
class LinearIR {
friend class LinearIRBuilder;
class ExpressionFactory;
public:
using container = std::list<ExpressionPtr>;
using exprIt = container::iterator;
Expand All @@ -62,12 +63,12 @@ class LinearIR {
LinearIR(Config config = {}, const std::shared_ptr<IShapeInferSnippetsFactory>& factory = {});
LinearIR(const std::shared_ptr<ov::Model>& m, const std::shared_ptr<IShapeInferSnippetsFactory>& factory, Config config = {});

ExpressionPtr create_expression(const std::shared_ptr<Node>& n, const std::vector<PortConnectorPtr>& inputs) const;
const ExpressionFactoryPtr& get_expr_factory() const;

const container& get_ops() const { return m_expressions; }
const container& get_buffers() const { return m_buffer_expressions; }
const container& get_parameters() const { return m_parameter_expressions; }
const container& get_results() const { return m_result_expressions; }
const std::vector<ExpressionPtr>& get_parameters() const { return m_parameter_expressions; }
const std::vector<ExpressionPtr>& get_results() const { return m_result_expressions; }
const std::vector<BufferExpressionPtr>& get_buffers() const { return m_buffer_expressions; }
const Config& get_config() const { return m_config; }
size_t get_static_buffer_scratchpad_size() const { return m_static_buffer_scratchpad_size; }

Expand Down Expand Up @@ -186,6 +187,20 @@ class LinearIR {
return std::make_pair(expr_it, node);
}

/**
* @brief Insert new Expression to LinearIR, sets `loops_ids` as loop identifiers and inserts the expression on the `place` in LinearIR.
* Also connects output ports to `consumers`
* @param new_expr the target expr which were created by ExpressionFactory
* @param loop_ids vector of loops ids that will be set for the expression
* @param update_loop_ports true - the helpers updates the corresponding loop ports after insertion otherwise - skip
* @param place before this place expression will be inserted
* @param consumers vector of expression port sets. These expression ports will be consumers of the expression.
* The vector may be empty or size of vector must be equal to output port count
* @return new expression iterator in LinearIR
*/
exprIt insert_expr(const ExpressionPtr& new_expr, const std::vector<size_t>& loop_ids,
bool update_loop_ports, const constExprIt& place, const std::vector<std::set<ExpressionPort>>& consumers);

/**
* @brief Replace the several existing expressions with the one new expression that contains `new_node`.
* Calls the helper `insert_node` and performs substitution: removes `old_exprs`.
Expand Down Expand Up @@ -248,21 +263,22 @@ class LinearIR {
private:
class LIRShapeInfer : public ShapeInferSnippetsNode {
public:
explicit LIRShapeInfer(const container& body_exprs, const container& param_exprs, const container& result_exprs);
explicit LIRShapeInfer(const container& body_exprs, const std::vector<ExpressionPtr>& param_exprs, const std::vector<ExpressionPtr>& result_exprs);
Result infer(const std::vector<VectorDimsRef>& input_shapes) override;

private:
const container& m_exprs;
const container& m_input_exprs;
const container& m_output_exprs;
const std::vector<ExpressionPtr>& m_input_exprs;
const std::vector<ExpressionPtr>& m_output_exprs;
};

static ov::NodeVector get_ordered_ops(const std::shared_ptr<ov::Model>& model);
// Default way: expr port connectors are constructed basing on ov::Node connection
ExpressionPtr create_expression(const std::shared_ptr<Node>& n);
ExpressionPtr create_expression(const std::shared_ptr<Node>& n, const std::vector<PortConnectorPtr>& new_inputs,
const std::vector<size_t>& loop_ids, bool update_loop_ports, const std::vector<std::set<ExpressionPort>>& consumers = {});

// Creates inputs for expression using parent output port connectors
std::vector<PortConnectorPtr> get_expression_inputs_by_node(const std::shared_ptr<Node>& n) const;

void register_expression(const ExpressionPtr& expr, bool io_allowed, double exec_num);
void unregister_expression(const ExpressionPtr& expr);

Expand All @@ -271,13 +287,16 @@ class LinearIR {

container m_expressions{};
std::unordered_map<std::shared_ptr<Node>, std::shared_ptr<Expression>> m_node2expression_map;
container m_parameter_expressions{};
container m_result_expressions{};
container m_buffer_expressions{};
// Note: Parameters and Results are stored in the order of Subgraph inputs/outputs
std::vector<ExpressionPtr> m_parameter_expressions{};
std::vector<ExpressionPtr> m_result_expressions{};
// Note: BufferExpressions are not stored in the order of execution numbers
std::vector<BufferExpressionPtr> m_buffer_expressions{};
Config m_config{};
LoopManagerPtr m_loop_manager;
std::shared_ptr<IShapeInferSnippetsFactory> m_shape_infer_factory;
std::shared_ptr<IShapeInferSnippetsFactory> m_shape_infer_factory = nullptr;
std::shared_ptr<ShapeInferSnippetsNode> m_shape_infer = nullptr;
std::shared_ptr<ExpressionFactory> m_expression_factory = nullptr;
bool m_is_dynamic = false;

// Size of static Buffer Scratchpad (Buffers with defined allocation size)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,9 @@ namespace pass {
class ComputeBufferAllocationSize : public RangedPass {
public:
OPENVINO_RTTI("ComputeBufferAllocationSize", "RangedPass")
ComputeBufferAllocationSize(size_t buffer_allocation_rank) : m_buffer_allocation_rank(buffer_allocation_rank) {}
ComputeBufferAllocationSize() = default;

bool run(LinearIR& linear_ir, lowered::LinearIR::constExprIt begin, lowered::LinearIR::constExprIt end) override;

static size_t get_allocation_size(const LoopManagerPtr& loop_manager, const ExpressionPtr& buffer_expr, size_t allocation_rank);

private:
const size_t m_buffer_allocation_rank = 0;
};

} // namespace pass
Expand Down
Loading
Loading