Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Snippets] Support Fully Connected tokenization #26498

Merged
27 changes: 27 additions & 0 deletions src/common/snippets/include/snippets/pass/fc_tokenization.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "openvino/pass/matcher_pass.hpp"
#include "snippets/pass/tokenization.hpp"

namespace ov {
namespace snippets {
namespace pass {

/**
* @interface TokenizeFCSnippets
* @brief The pass tokenizes FullyConnected like (with constant path on B input) MatMuls
* @ingroup snippets
*/
class TokenizeFCSnippets: public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("TokenizeFCSnippets", "0");
TokenizeFCSnippets(const SnippetsTokenization::Config& config);
};

} // namespace pass
} // namespace snippets
} // namespace ov
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

/**
* @brief A file contains tokenization related utilities.
* @file tokenization_utils.hpp
*/
#pragma once

#include "snippets/op/subgraph.hpp"
#include "snippets/pass/tokenization.hpp"

namespace ov {
namespace snippets {
namespace utils {
/**
* @brief Tokenizes a node into Subgraph. 2 options are possible (depending on config's values and internal logic)L
* 1. The node is wrapped in a trivial Subgraph which contains only this node
* 2. The node is fused in parent's Subgraphs
* @param node node which should be tokenized
* @param config tokenization config which regulates
* @return whether the node was tokenized or not
*/
bool tokenize_node(const std::shared_ptr<ov::Node>& node, const ov::snippets::pass::SnippetsTokenization::Config& config);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably we need to create the static public method in the pass "TokenizeSnippets" instead of new file creation with one function. What do you think?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I decided to move this in a separate file because

  1. This is a common helper which can be reused in numerous tokenization passes in the future
  2. The helper is large, so it's a bit easier to have it in a separate file from readability perspective (At least that's what it seems to me 😄 )
  3. This file may be extended with bool tokenize_nodes(const ov::NodeVector& nodes) helper which tokenizes a nodes sequence. Now this is a part of TokenizeMHASnippets pass, but this code can be easily extracted and reused (e.g. for MLP tokenization)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can also move AppropriateForSubgraph here 😁
In general, whether we want it or not, we are moving from one universal tokenization pass towards several transformations responsible for particular pattern families (eltwise/MHA/MLP).
From this perspective, this attempt to derive and reuse some common logic seems like a step in the right direction.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should go even further and create a dedicated class for tokenization (SnippetsNodeTokenizer or smth). This class should have 3 methods: private bool check_necessary(node) public virtual bool check_sufficient(node) and public bool fuse(node).

  1. check_necessary will perform all the fundamental plugin-independent checks (like cyclic dependencies, num_results_children etc) and can potentially also gather some aggregated info (internal/external inputs/body parameters etc). If possible, we should separate the check from aggregated info gathering (another method analyze).
  2. check_sufficient will be overrided by plugin and will contain backend-specific checks possibly partially based on the gathered aggregated info (hidden_data_count, unique_biffer_count) etc
  3. Finally, fuse will call if(check_necessary() && check_sufficient()) and create subgraph based on aggregated info.
    So the idea is that the plugin creates its own instance of this SnippetsNodeTokenizer and passes it to all the tokenization transformations. And the transformations can use this class to do all the dirty work, plus they can impose their own limitations of course. For example, the eltwise tokenization pass will additionally check op types and precisions.

@a-sidorova, @v-Golubev what do you think? I'm nit saying we should refactor it know, but let's try to outline a more convenient and scalable design.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@IvanNovoselov your idea looks interesting, but I need some time to think if the proposed architecture is enough. Can we allocate a time slot in the next command meeting to discuss it?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like we really need to discuss how the current logic of the tokenization can be refactored and improved.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, sure let's discuss it on the sync

} // namespace utils
} // namespace snippets
} // namespace ov
432 changes: 5 additions & 427 deletions src/common/snippets/src/pass/collapse_subgraph.cpp

Large diffs are not rendered by default.

29 changes: 29 additions & 0 deletions src/common/snippets/src/pass/fc_tokenization.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "snippets/pass/fc_tokenization.hpp"

#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "snippets/itt.hpp"
#include "snippets/utils/tokenization_utils.hpp"

ov::snippets::pass::TokenizeFCSnippets::TokenizeFCSnippets(const SnippetsTokenization::Config& config) {
MATCHER_SCOPE(TokenizeFCSnippets);
// TODO: extend constant path coverage
// Ticket: 153480
auto constant = ov::pass::pattern::wrap_type<ov::op::v0::Constant>();
auto m_matmul = ov::pass::pattern::wrap_type<ov::opset1::MatMul>({ov::pass::pattern::any_input(), constant});

auto callback = [OV_CAPTURE_CPY_AND_THIS](ov::pass::pattern::Matcher &m) {
OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::op::TokenizeFCSnippets")
const auto matmul = m.get_match_root();
if (transformation_callback(matmul)) {
return false;
}
return ov::snippets::utils::tokenize_node(matmul, config);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we really need to extract the common tokenization logic to utilities just to tokenize one op?
Can we use the following code here?

        auto subgraph = op::Subgraph::wrap_node_as_subgraph(matmul);
        subgraph->get_rt_info()["originalLayersNames"] = matmul->get_friendly_name();
        ov::replace_node(matmul, subgraph);
        op::update_out_tensor_name(subgraph);

Probably I missed something 🤔 Could you please elaborate your decision with ov::snippets::utils::tokenize_node?

The only one thought is to connected to the existing Subgraphs on first input 🤔

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But why not?
It promotes code reusage and thus simplifies maintenance. We would also be able to extend this pass to cover MLP patterns more easily.
I agree that the tokenize_node itself is quite complex, but that's because we have a bunch of checks there (like the ones related to Converts or num registers) that shouldn't really be in this function.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only one thought is to connected to the existing Subgraphs on first input

You are right, that is one of the reasons why I reused the helper for FC tokenization. Thanks to that, we cover e.g. "Transposes on inputs + FC" cases by tests

Also, the moving of tokenization logic in a separate helper is a good prerequisite for Tokenization refactoring. Currently, the tokenization looks as follows (I cut the code a bit):

    auto tokenization_passes = manager.register_pass<ov::pass::GraphRewrite>();
    tokenization_passes->add_matcher<TokenizeGNSnippets>();
    tokenization_passes->add_matcher<TokenizeFCSnippets>(m_config);
    tokenization_passes->add_matcher<TokenizeSnippets>(m_config);

TokenizeSnippets here is responsible for tokenization of the ops which were not tokenized by the previous passes. The pass contain a lot of checks, which are related to different ops, in one place (is_supported_op lambda): as a result, the pass looks quite cumbersome.

Ideally, I would want to separate this large pass into sequence of small matcher passes, which will match only on the specific ops, and contain only the checks which are needed for the matched ops:

    auto tokenization_passes = manager.register_pass<ov::pass::GraphRewrite>();
    tokenization_passes->add_matcher<TokenizeGNSnippets>();
    tokenization_passes->add_matcher<TokenizeFCSnippets>(m_config);
    tokenization_passes->add_matcher<TokenizeUnaryEltwise>(m_config);
    tokenization_passes->add_matcher<TokenizeBinaryEltwise>(m_config);
    tokenization_passes->add_matcher<TokenizeTranspose>(m_config);
    tokenization_passes->add_matcher<TokenizeMatmul>(m_config);
    ...

I believe this will improve tokenization code readability, and could allow us to configure the pipeline more precisely (e.g. TokenizeTranspose and TokenizeMatmul are used only in tests, so they can be disabled by the plugin in product code).

And each of this new matchers can reuse ov::snippets::utils::tokenize_node -- that's another reason why I moved the tokenization code in a separate helper.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right, that is one of the reasons why I reused the helper for FC tokenization. Thanks to that, we cover e.g. "Transposes on inputs + FC" cases by tests

Cool! Thank you for the explanation! I just wanted to check that I fully got the goal of this PR 😊
Yeah, I saw some tests like FC->Eltwise->FC.

Ideally, I would want to separate this large pass into sequence of small matcher passes, which will match only on the specific ops, and contain only the checks which are needed for the matched ops

This is really good idea! 🤔

};

auto matcher = std::make_shared<ov::pass::pattern::Matcher>(m_matmul, matcher_name);
register_matcher(matcher, callback);
}
22 changes: 15 additions & 7 deletions src/common/snippets/src/pass/tokenization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,17 @@
// SPDX-License-Identifier: Apache-2.0
//

#include "snippets/itt.hpp"
#include "snippets/pass/tokenization.hpp"

#include "openvino/pass/graph_rewrite.hpp"
#include "openvino/pass/manager.hpp"
#include "snippets/pass/tokenization.hpp"
#include "snippets/itt.hpp"
#include "snippets/pass/collapse_subgraph.hpp"
#include "snippets/pass/common_optimizations.hpp"
#include "snippets/pass/extract_reshapes_from_mha.hpp"
#include "snippets/pass/mha_tokenization.hpp"
#include "snippets/pass/fc_tokenization.hpp"
#include "snippets/pass/gn_tokenization.hpp"
#include "snippets/pass/collapse_subgraph.hpp"

#include "snippets/pass/mha_tokenization.hpp"

namespace ov {
namespace snippets {
Expand Down Expand Up @@ -81,9 +82,16 @@ bool SnippetsTokenization::run_on_model(const std::shared_ptr<ov::Model>& m) {

manager.register_pass<EnumerateNodes>();
manager.register_pass<ExtractReshapesFromMHA>();
// This pass mustn't be registered in GraphRewrite with other tokenization passes because of 2 reasons:
// 1. It has higher priority than other tokenization passes
// 2. It changes the nodes after the matched root node
manager.register_pass<TokenizeMHASnippets>(m_config);
manager.register_pass<TokenizeGNSnippets>();
manager.register_pass<TokenizeSnippets>(m_config);

auto tokenization_passes = manager.register_pass<ov::pass::GraphRewrite>();
tokenization_passes->add_matcher<TokenizeGNSnippets>();
tokenization_passes->add_matcher<TokenizeFCSnippets>(m_config);
tokenization_passes->add_matcher<TokenizeSnippets>(m_config);

manager.register_pass<CommonOptimizations>(m_config);
manager.run_passes(m);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -197,16 +197,9 @@ Result BrgemmShapeInfer::infer(const std::vector<VectorDimsRef>& input_shapes) {
size_t max_rank = arg0_shape_tmp.size();
VectorDims output_shape(max_rank);
for (size_t i = 0; i < max_rank - 2; ++i) {
if (arg0_shape_tmp[i] == arg1_shape_tmp[i]) {
output_shape[i] = arg0_shape_tmp[i];
} else {
if (arg0_shape_tmp[i] == 1 || utils::is_dynamic_value(arg0_shape_tmp[i]))
output_shape[i] = arg1_shape_tmp[i];
else if (arg1_shape_tmp[i] == 1 || utils::is_dynamic_value(arg1_shape_tmp[i]))
output_shape[i] = arg0_shape_tmp[i];
else
OPENVINO_THROW("Incompatible Brgemm batch dimension");
}
if (!utils::broadcast_merge_dim(output_shape[i], arg0_shape_tmp[i], arg1_shape_tmp[i]))
OPENVINO_THROW("Incompatible MatMul batch dimension. Can't merge dim ", arg0_shape_tmp[i],
" with dim ", arg1_shape_tmp[i], " at index=", i);
}
output_shape[output_shape.size() - 2] = arg0_shape_tmp[arg0_shape_tmp.size() - 2]; // M
output_shape[output_shape.size() - 1] = arg1_shape_tmp[arg1_shape_tmp.size() - 1]; // N
Expand Down
Loading
Loading