Skip to content

Commit

Permalink
FCTokenization pass. TODO: enable tokenization in product code = reve…
Browse files Browse the repository at this point in the history
…rt changes in transformation pipeline
  • Loading branch information
v-Golubev committed Sep 11, 2024
1 parent fa84133 commit 39805e0
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 7 deletions.
27 changes: 27 additions & 0 deletions src/common/snippets/include/snippets/pass/fc_tokenization.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "openvino/pass/matcher_pass.hpp"
#include "snippets/pass/tokenization.hpp"

namespace ov {
namespace snippets {
namespace pass {

/**
* @interface TokenizeFCSnippets
* @brief The pass tokenizes FullyConnected like (with constant path on B input) MatMuls
* @ingroup snippets
*/
class TokenizeFCSnippets: public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("TokenizeFCSnippets", "0");
TokenizeFCSnippets(const SnippetsTokenization::Config& config);
};

} // namespace pass
} // namespace snippets
} // namespace ov
35 changes: 35 additions & 0 deletions src/common/snippets/src/pass/fc_tokenization.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "snippets/pass/fc_tokenization.hpp"

#include "openvino/pass/graph_rewrite.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "snippets/itt.hpp"
#include "snippets/op/subgraph.hpp"

ov::snippets::pass::TokenizeFCSnippets::TokenizeFCSnippets(const SnippetsTokenization::Config& config) {
MATCHER_SCOPE(TokenizeFCSnippets);

// TODO: extend constant path coverage:
// 1. Add u8/i8/bf16 precisions
// 2. Add subgraphs (Transpose/Convert)
// 3. Add Decompression subgraphs support (and all the possible compressed weights related precisions)
auto constant = ov::pass::pattern::wrap_type<ov::op::v0::Constant>(ov::pass::pattern::type_matches(ov::element::f32));
auto m_matmul = ov::pass::pattern::wrap_type<ov::opset1::MatMul>({ov::pass::pattern::any_input(), constant});

register_matcher(std::make_shared<ov::pass::pattern::Matcher>(m_matmul, matcher_name),
[OV_CAPTURE_CPY_AND_THIS](ov::pass::pattern::Matcher &m) {
OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::op::TokenizeFCSnippets")
const auto& pattern_map = m.get_pattern_value_map();
const auto matmul = pattern_map.at(m_matmul).get_node_shared_ptr();
const auto subgraph = op::Subgraph::wrap_node_as_subgraph(matmul);
subgraph->get_rt_info()["originalLayersNames"] = matmul->get_friendly_name();
// MatMul weights are stored outside the subgraph
subgraph->set_virtual_port_count(1);
op::update_out_tensor_name(subgraph);
ov::replace_node(matmul, subgraph);
return true;
});
}
21 changes: 14 additions & 7 deletions src/common/snippets/src/pass/tokenization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,17 @@
// SPDX-License-Identifier: Apache-2.0
//

#include "snippets/itt.hpp"
#include "snippets/pass/tokenization.hpp"

#include "openvino/pass/graph_rewrite.hpp"
#include "openvino/pass/manager.hpp"
#include "snippets/pass/tokenization.hpp"
#include "snippets/itt.hpp"
#include "snippets/pass/collapse_subgraph.hpp"
#include "snippets/pass/common_optimizations.hpp"
#include "snippets/pass/extract_reshapes_from_mha.hpp"
#include "snippets/pass/mha_tokenization.hpp"
#include "snippets/pass/fc_tokenization.hpp"
#include "snippets/pass/gn_tokenization.hpp"
#include "snippets/pass/collapse_subgraph.hpp"

#include "snippets/pass/mha_tokenization.hpp"

namespace ov {
namespace snippets {
Expand Down Expand Up @@ -81,9 +82,15 @@ bool SnippetsTokenization::run_on_model(const std::shared_ptr<ov::Model>& m) {

manager.register_pass<EnumerateNodes>();
manager.register_pass<ExtractReshapesFromMHA>();
// This pass mustn't be registered in GraphRewrite with other tokenization passes
// since it changes the nodes after the matched root node
manager.register_pass<TokenizeMHASnippets>(m_config);
manager.register_pass<TokenizeGNSnippets>();
manager.register_pass<TokenizeSnippets>(m_config);

auto tokenization_passes = manager.register_pass<ov::pass::GraphRewrite>();
tokenization_passes->add_matcher<TokenizeGNSnippets>();
tokenization_passes->add_matcher<TokenizeFCSnippets>(m_config);
tokenization_passes->add_matcher<TokenizeSnippets>(m_config);

manager.register_pass<CommonOptimizations>(m_config);
manager.run_passes(m);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@
// Snippets
#include "snippets/pass/tokenization.hpp"
#include "snippets/pass/mha_tokenization.hpp"
#include "snippets/pass/fc_tokenization.hpp"
#include "snippets/pass/collapse_subgraph.hpp"
#include "snippets/pass/common_optimizations.hpp"
#include "snippets/pass/split_dimension_m.hpp"
Expand Down Expand Up @@ -928,6 +929,8 @@ void Transformations::MainSnippets(void) {
CPU_REGISTER_PASS_ARM(snippetsManager, SnippetsMarkSkipped);
#else
CPU_REGISTER_PASS_X64(snippetsManager, SnippetsMarkSkipped, inferencePrecision == ov::element::bf16);
// TODO: remove
CPU_DISABLE_PASS_COMMON(snippetsManager, snippets::pass::TokenizeFCSnippets);
#endif
}
CPU_REGISTER_PASS_X64(snippetsManager, snippets::pass::SnippetsTokenization, tokenization_config);
Expand Down

0 comments on commit 39805e0

Please sign in to comment.