Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Xp/decompse matmul split or matmul gather #25196

Open
wants to merge 53 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
6d67fc9
init
ceciliapeng2011 May 14, 2024
2541834
support fp16
ceciliapeng2011 May 14, 2024
a39065f
support fp16
ceciliapeng2011 May 14, 2024
361b151
Add chrome trace
usstq Nov 20, 2023
9e2855f
fix accuracy issue, matmul transpose_B
ceciliapeng2011 May 17, 2024
6a7f280
fork initial version from cecilia
xipingyan Jun 12, 2024
ccadbc5
remove profilier
xipingyan Jun 12, 2024
6fbaf4e
tmp version, phi verify pass.
xipingyan Jun 18, 2024
6ea2baf
refactor test code.
xipingyan Jun 19, 2024
25c5e17
Add test for MatMulSplit
xipingyan Jun 19, 2024
43dea54
Improve test.
xipingyan Jun 20, 2024
23665d0
Support without bias
xipingyan Jun 21, 2024
7153b31
B=1,L=1 can't match pattern, why?
xipingyan Jun 21, 2024
b37ca44
scalar input test pass.
xipingyan Jun 21, 2024
f0b38a7
Test pass, ready to review.
xipingyan Jun 22, 2024
350f8a8
dynamic shape should be got from partial shape.
xipingyan Jun 24, 2024
93c3e8a
compitalbe int8 quantize model.
xipingyan Jun 25, 2024
57e8b56
fix replace fq node error.
xipingyan Jun 26, 2024
537bb39
clang format
xipingyan Jun 26, 2024
f35f8b3
Add enable_fq to test.
xipingyan Jun 26, 2024
218027b
clange issue.
xipingyan Jun 26, 2024
c7528d2
Merge branch 'master' into xp/decompse_matmul_split_or_matmul_gather
xipingyan Jul 4, 2024
efd55e7
remove debug code, and remvoe MatMulVariadicSplitDecomposition
xipingyan Jul 4, 2024
8760ced
fix CI clang issue.
xipingyan Jul 4, 2024
2e56d87
update based on comments
xipingyan Jul 5, 2024
3dc306e
Merge branch 'master' into xp/decompse_matmul_split_or_matmul_gather
xipingyan Jul 5, 2024
9290aa9
Fix transpose_b = false test fail.
xipingyan Jul 5, 2024
a154e26
try to fix CI arm test fail
xipingyan Jul 5, 2024
aefb31f
Baichun also match this pattern, but it only has rank 4, and have per…
xipingyan Jul 8, 2024
0393c50
Merge branch 'master' into xp/decompse_matmul_split_or_matmul_gather
xipingyan Jul 8, 2024
f574620
Updated based on comments.
xipingyan Jul 16, 2024
e0cc4eb
Merge branch 'master' into xp/decompse_matmul_split_or_matmul_gather
xipingyan Jul 16, 2024
eea39c7
Fix merge conflict.
xipingyan Jul 16, 2024
d094d70
Fix clang issue.
xipingyan Jul 16, 2024
8a45461
fix CI fail: because Diff: 1.75813e-05, so set abs_threshold = 1e-4
xipingyan Jul 16, 2024
1f68706
Merge remote-tracking branch 'origin/master' into xp/decompse_matmul_…
xipingyan Jul 22, 2024
54a5798
Merge branch 'master' into xp/decompse_matmul_split_or_matmul_gather
xipingyan Jul 23, 2024
461b708
Updated based on comments.
xipingyan Jul 23, 2024
53f9eca
Merge remote-tracking branch 'origin/master' into xp/decompse_matmul_…
xipingyan Jul 23, 2024
c2d71f4
Merge branch 'master' into xp/decompse_matmul_split_or_matmul_gather
xipingyan Jul 24, 2024
3dec73b
Merge branch 'master' into xp/decompse_matmul_split_or_matmul_gather
xipingyan Jul 29, 2024
e35702e
Updated based on comments.
xipingyan Aug 14, 2024
f24bbbf
Merge branch 'master' into xp/decompse_matmul_split_or_matmul_gather
xipingyan Aug 14, 2024
26f9314
Fix build error. Remove namespace CPUTestUtils::
xipingyan Aug 14, 2024
621a4dc
regist pass to CPU_REGISTER_PASS_COMMON, work for ARM
xipingyan Aug 15, 2024
c8a16f0
Merge commit '54f58b86' into xp/decompse_matmul_split_or_matmul_gather
xipingyan Aug 26, 2024
a7b978a
Move "CheckNumberOfNodesWithType" to src/tests/test_utils/functional_…
xipingyan Aug 26, 2024
cbb4fc2
Merge branch 'master' into xp/decompse_matmul_split_or_matmul_gather
xipingyan Sep 10, 2024
4d5f03d
Fix merge master conflict issue.
xipingyan Sep 10, 2024
121e6cb
1:Move decompose_num private
xipingyan Sep 24, 2024
cb15e8e
Move "matmul_split_decomposition.cpp" to CPU:
xipingyan Sep 24, 2024
cde2f32
Merge branch 'master' into xp/decompse_matmul_split_or_matmul_gather
xipingyan Sep 25, 2024
6f9ebad
Merge branch 'master' into xp/decompse_matmul_split_or_matmul_gather
xipingyan Sep 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "matmul_split_decomposition.hpp"

#include <transformations/utils/utils.hpp>

#include "itt.hpp"
#include "openvino/core/rt_info.hpp"
#include "openvino/op/add.hpp"
#include "openvino/op/fake_quantize.hpp"
#include "openvino/op/gather.hpp"
#include "openvino/op/matmul.hpp"
#include "openvino/op/reshape.hpp"
#include "openvino/op/split.hpp"
#include "openvino/op/transpose.hpp"
#include "openvino/pass/pattern/op/optional.hpp"
#include "openvino/pass/pattern/op/or.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"

using namespace ov;
using namespace ov::pass;
using namespace ov::pass::pattern;

bool intel_cpu::MatmulGatherDecomposition::split_weights(const Output<Node>& weights,
OutputVector& new_weights,
Output<Node>* bias,
OutputVector& new_bias,
const bool transpose_b) {
// weights is static
if (weights.get_partial_shape().size() != 2u) {
return false;
}

if (bias) {
const auto& bias_rank = bias->get_partial_shape().size();
if (bias_rank != 3 && bias_rank != 1) {
return false;
}
}

// Decompose weights
auto axis = register_new_node(op::v0::Constant::create(element::i32, Shape{}, {transpose_b ? 0 : 1}));
auto split = register_new_node<op::v1::Split>(weights, axis, decompose_num);
for (auto& out : split->outputs()) {
new_weights.emplace_back(out);
}

if (bias) {
// Decompose bias
auto axis2 = register_new_node(op::v0::Constant::create(element::i32, Shape{}, {-1})); // axis -1
auto split2 = register_new_node<op::v1::Split>(*bias, axis2, decompose_num);
for (auto& out : split2->outputs()) {
new_bias.emplace_back(out);
}
}
return true;
}

intel_cpu::MatmulGatherDecomposition::MatmulGatherDecomposition() {
MATCHER_SCOPE(MatmulGatherDecomposition);
auto input_pattern = any_input();
auto matmul_pattern = wrap_type<op::v0::MatMul>({input_pattern, any_input(pattern::has_static_shape())},
ov::pass::pattern::consumers_count(1));

auto bias_pattern = wrap_type<op::v0::Constant>();
auto add_pattern = wrap_type<op::v1::Add>({matmul_pattern, bias_pattern}, ov::pass::pattern::consumers_count(1));

auto reshape_productor_pattern = std::make_shared<pattern::op::Or>(OutputVector{matmul_pattern, add_pattern});

// Heuristics: Rank == 5, Baichun also match this pattern, but it only has rank 4, and have performance regression,
// so filter it out.
auto reshape_pattern =
wrap_type<op::v1::Reshape>({reshape_productor_pattern, any_input()}, ov::pass::pattern::rank_equals(5));

// Heuristics: there should be only decompose_num(3) gathers to split
auto transpose_pattern =
wrap_type<op::v1::Transpose>({reshape_pattern, ov::pass::pattern::wrap_type<op::v0::Constant>()},
ov::pass::pattern::consumers_count(decompose_num));

auto reshape2_pattern =
wrap_type<op::v1::Reshape>({reshape_pattern, any_input()}, ov::pass::pattern::consumers_count(decompose_num));

auto reshape_or_transpose_pattern =
std::make_shared<pattern::op::Or>(OutputVector{reshape2_pattern, transpose_pattern});

matcher_pass_callback callback = [OV_CAPTURE_CPY_AND_THIS](pattern::Matcher& m) {
const auto& pattern_map = m.get_pattern_value_map();
const auto root_node = m.get_match_root();

const auto matmul = pattern_map.at(matmul_pattern).get_node_shared_ptr();
const auto weights = matmul->input_value(1);
const std::shared_ptr<ov::Node> add =
pattern_map.count(add_pattern) ? pattern_map.at(add_pattern).get_node_shared_ptr() : nullptr;

if (as_type_ptr<op::v0::MatMul>(matmul)->get_transpose_a()) {
return false;
}
const bool& transpose_b = as_type_ptr<op::v0::MatMul>(matmul)->get_transpose_b();
const auto& reshape = pattern_map.at(reshape_pattern);
const auto reshape_input1 = reshape.get_node_shared_ptr()->input_value(1);

// Check transpose order[2,0,3,1,4]
if (pattern_map.count(transpose_pattern)) {
const auto transpose = pattern_map.at(transpose_pattern).get_node_shared_ptr();
const auto transpose_order = as_type_ptr<op::v0::Constant>(transpose->get_input_node_shared_ptr(1));
if (transpose_order) {
const std::vector<int32_t> expected_val = {2, 0, 3, 1, 4};
if (expected_val != transpose_order->cast_vector<int32_t>()) {
return false;
}
} else {
return false;
}
}

NodeVector gathers, fake_quantizes;
gathers.resize(decompose_num);
fake_quantizes.resize(decompose_num);
for (const auto& child : root_node->get_output_target_inputs(0)) {
std::shared_ptr<ov::Node> fq = nullptr;
auto gather = child.get_node()->shared_from_this();
if (ov::is_type<op::v0::FakeQuantize>(gather)) {
fq = gather;
if (fq->get_output_size() != 1u) {
return false;
}
gather = gather->get_output_target_inputs(0).begin()->get_node()->shared_from_this();
}
if (ov::is_type<ov::op::util::GatherBase>(gather)) {
const auto axis_node = as_type_ptr<op::v0::Constant>(gather->get_input_node_shared_ptr(2));
if (axis_node) {
const auto& axis_val = axis_node->cast_vector<int32_t>();
if (axis_val.size() != 1u || axis_val[0] != 0) {
return false;
}
} else {
return false;
}

const auto indices_node = as_type_ptr<op::v0::Constant>(gather->get_input_node_shared_ptr(1));
if (indices_node) {
const auto& indices_val = indices_node->cast_vector<int32_t>();
if (indices_val.size() != 1u) {
return false;
}
if (indices_val[0] < 0 || indices_val[0] >= 3) {
return false;
}
gathers[indices_val[0]] = gather;
fake_quantizes[indices_val[0]] = fq;
} else {
return false;
}
} else {
return false;
}
}

if (std::any_of(gathers.begin(), gathers.end(), [](const std::shared_ptr<Node> node_ptr) {
return !node_ptr || !is_type<ov::op::util::GatherBase>(node_ptr);
})) {
return false;
}

Output<Node> bias;
OutputVector new_weights, new_bias;
if (add) {
bias = pattern_map.at(bias_pattern);
}

if (!split_weights(weights, new_weights, (add != nullptr) ? &bias : nullptr, new_bias, transpose_b)) {
return false;
}

if (new_weights.size() != decompose_num || ((add != nullptr) && new_bias.size() != decompose_num)) {
return false;
}

// Heuristics: Split at axis 2, new Gahter should remove it.
const auto const_indices = register_new_node(op::v0::Constant::create(element::i32, Shape{4}, {0, 1, 3, 4}));
const auto const_axis = register_new_node(op::v0::Constant::create(element::i32, Shape{}, {0}));
const auto new_shape = register_new_node<op::v8::Gather>(reshape_input1, const_indices, const_axis);
const auto& input = pattern_map.at(input_pattern);
for (size_t i = 0; i < decompose_num; i++) {
const auto new_mm = register_new_node<op::v0::MatMul>(input, new_weights[i], false, transpose_b);
std::shared_ptr<ov::Node> reshape_productor = new_mm;
if (add) {
reshape_productor = register_new_node<op::v1::Add>(new_mm, new_bias[i]);
}
const auto new_reshape = register_new_node<op::v1::Reshape>(reshape_productor, new_shape, true);
ov::NodeVector from_nodes = {gathers[i], weights.get_node_shared_ptr(), matmul};
if (add) {
from_nodes.emplace_back(add);
from_nodes.emplace_back(pattern_map.at(bias_pattern).get_node_shared_ptr());
}
if (as_type<op::v1::Transpose>(root_node.get()))
from_nodes.emplace_back(root_node);

// Original transpose order[2,0,3,1,4], new order should be[0,2,1,3] after first axis is removed.
const auto transpose_order =
register_new_node(op::v0::Constant::create(element::i32, Shape{4}, {0, 2, 1, 3}));
const auto new_transpose = register_new_node<op::v1::Transpose>(new_reshape, transpose_order);
new_transpose->set_friendly_name(gathers[i]->get_friendly_name());
copy_runtime_info(from_nodes, get_new_nodes());

if (fake_quantizes[i]) {
fake_quantizes[i]->set_argument(0, new_transpose);
replace_node(gathers[i], fake_quantizes[i]);
} else {
replace_node(gathers[i], new_transpose);
}
}
return true;
};

auto m = std::make_shared<pattern::Matcher>(reshape_or_transpose_pattern, matcher_name);
this->register_matcher(m, callback);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include <openvino/pass/graph_rewrite.hpp>

namespace ov {
namespace intel_cpu {

/**
* @brief MatmulGatherDecomposition transformation matches following graph:
*
* +----------+
* | input |
* +----------+
* |
* v
* +----------+
* | MatMul |
* +----------+
* |
* v
* +------------+
* | Some nodes |
* +------------+
* |
* v
* +-----------------------+
* | Transpose |
* +-----------------------+
* | | |
* v v v
* +-------+ +-------+ +-------+
* |Gather | |Gather | |Gather |
* +-------+ +-------+ +-------+
* and replaces with:
*
* +-----------------------+
* | input |
* +-----------------------+
* | | |
* v v v
* +-------+ +-------+ +-------+
* |MatMul | |MatMul | |MatMul |
* +-------+ +-------+ +-------+
* | | |
* v v v
* +-------+ +-------+ +-------+
* |Nodes | |Nodes | |Nodes |
* +-------+ +-------+ +-------+
* | | |
* v v v
* +---------+ +---------+ +---------+
* |Transpose| |Transpose| |Transpose|
* +---------+ +---------+ +---------+
*/

class MatmulGatherDecomposition : public pass::MatcherPass {
public:
OPENVINO_RTTI("MatmulGatherDecomposition", "0");
MatmulGatherDecomposition();
bool split_weights(const Output<Node>& weights,
OutputVector& new_weights,
Output<Node>* bias,
OutputVector& new_bias,
const bool transpose_b);

private:
const size_t decompose_num = 3;
};

} // namespace intel_cpu
} // namespace ov
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@
#include "transformations/cpu_opset/common/pass/decompose_rms_norm.hpp"
#include "transformations/cpu_opset/common/pass/convert_fq_rnn_to_quantized_rnn.hpp"
#include "transformations/cpu_opset/common/pass/insert_convert_after_extension.hpp"
#include "transformations/cpu_opset/common/pass/matmul_split_decomposition.hpp"
#include "transformations/cpu_opset/common/pass/ngram_fusion.hpp"
#include "transformations/cpu_opset/common/pass/permute_slice_n_interpolation.hpp"
#include "transformations/cpu_opset/common/pass/swap_convert_transpose.hpp"
Expand Down Expand Up @@ -355,6 +356,11 @@ void Transformations::PreLpt(const std::vector<ov::element::Type>& defaultPrecis

ov::pass::Manager manager("Plugin:CPU");
manager.set_per_pass_validation(false);

// Decomposition
CPU_REGISTER_PASS_COMMON(manager, ov::intel_cpu::MatmulGatherDecomposition)
CPU_REGISTER_PASS_COMMON(manager, ov::pass::Validate);

if (useLpt)
CPU_REGISTER_PASS_COMMON(manager, ov::pass::MarkDequantizationSubgraph, defaultPrecisions);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ void RMSNormLayerCPUTest::SetUp() {

TEST_P(RMSNormLayerCPUTest, CompareWithRefs) {
run();
CheckNumberOfNodesWithType(compiledModel, "RMS", m_rms_decomposed ? 0 : 1);
utils::CheckNumberOfNodesWithType(compiledModel, "RMS", m_rms_decomposed ? 0 : 1);
}

} // namespace test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class ConvertToPluginSpecificNode : public testing::WithParamInterface<ConvertTo

TEST_P(ConvertToPluginSpecificNode, CompareWithRefs) {
run();
CheckNumberOfNodesWithType(compiledModel, "Const", constNodeNum);
utils::CheckNumberOfNodesWithType(compiledModel, "Const", constNodeNum);
}

namespace {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,9 @@ class GroupConvToConvTransformationCPUTest: public testing::WithParamInterface<g

TEST_P(GroupConvToConvTransformationCPUTest, CompareWithRefs) {
run();
CheckNumberOfNodesWithType(compiledModel, "Split", 1);
CheckNumberOfNodesWithType(compiledModel, "Convolution", numOfGroups);
CheckNumberOfNodesWithType(compiledModel, "Concatenation", 1);
utils::CheckNumberOfNodesWithType(compiledModel, "Split", 1);
utils::CheckNumberOfNodesWithType(compiledModel, "Convolution", numOfGroups);
utils::CheckNumberOfNodesWithType(compiledModel, "Concatenation", 1);
}

namespace {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ class Conv1dConvertTransformationCPUTest : public testing::WithParamInterface<co

TEST_P(Conv1dConvertTransformationCPUTest, CompareWithRefs) {
run();
CheckNumberOfNodesWithType(compiledModel, "Reshape", 2);
utils::CheckNumberOfNodesWithType(compiledModel, "Reshape", 2);
}

namespace {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class reduceTransformationCPUTest: public testing::WithParamInterface<reduceConv

TEST_P(reduceTransformationCPUTest, CompareWithRefs) {
run();
CheckNumberOfNodesWithType(compiledModel, "Reduce", numberOfExpectedReduce);
utils::CheckNumberOfNodesWithType(compiledModel, "Reduce", numberOfExpectedReduce);
}

namespace {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class ReorderDeconvNHWCTest : virtual public SubgraphBaseStaticTest {

TEST_F(ReorderDeconvNHWCTest, smoke_ReorderDeconvNHWC_CPU) {
run();
CheckNumberOfNodesWithType(compiledModel, "Reorder", 2);
utils::CheckNumberOfNodesWithType(compiledModel, "Reorder", 2);
}

} // namespace test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ namespace {
TEST_F(AddConvertToReorderTest, smoke_TestAddReorder_CPU) {
BuildGraph(ov::element::i8);
run();
CheckNumberOfNodesWithType(compiledModel, "Convert", 0);
CheckNumberOfNodesWithType(compiledModel, "Reorder", 1);
utils::CheckNumberOfNodesWithType(compiledModel, "Convert", 0);
utils::CheckNumberOfNodesWithType(compiledModel, "Reorder", 1);
}
} // namespace
} // namespace test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ class AlignMatMulInputRanksTest : public testing::WithParamInterface<AlignMatMul

TEST_P(AlignMatMulInputRanksTest, CompareWithRefs) {
run();
CheckNumberOfNodesWithType(compiledModel,
utils::CheckNumberOfNodesWithType(compiledModel,
"Reshape",
expectedNumOfReshapes); // Squeeze / Unsqueeze turns into Reshape
}
Expand Down
Loading
Loading