Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Xp/decompse matmul split or matmul gather #25196

Open
wants to merge 53 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 27 commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
6d67fc9
init
ceciliapeng2011 May 14, 2024
2541834
support fp16
ceciliapeng2011 May 14, 2024
a39065f
support fp16
ceciliapeng2011 May 14, 2024
361b151
Add chrome trace
usstq Nov 20, 2023
9e2855f
fix accuracy issue, matmul transpose_B
ceciliapeng2011 May 17, 2024
6a7f280
fork initial version from cecilia
xipingyan Jun 12, 2024
ccadbc5
remove profilier
xipingyan Jun 12, 2024
6fbaf4e
tmp version, phi verify pass.
xipingyan Jun 18, 2024
6ea2baf
refactor test code.
xipingyan Jun 19, 2024
25c5e17
Add test for MatMulSplit
xipingyan Jun 19, 2024
43dea54
Improve test.
xipingyan Jun 20, 2024
23665d0
Support without bias
xipingyan Jun 21, 2024
7153b31
B=1,L=1 can't match pattern, why?
xipingyan Jun 21, 2024
b37ca44
scalar input test pass.
xipingyan Jun 21, 2024
f0b38a7
Test pass, ready to review.
xipingyan Jun 22, 2024
350f8a8
dynamic shape should be got from partial shape.
xipingyan Jun 24, 2024
93c3e8a
compitalbe int8 quantize model.
xipingyan Jun 25, 2024
57e8b56
fix replace fq node error.
xipingyan Jun 26, 2024
537bb39
clang format
xipingyan Jun 26, 2024
f35f8b3
Add enable_fq to test.
xipingyan Jun 26, 2024
218027b
clange issue.
xipingyan Jun 26, 2024
c7528d2
Merge branch 'master' into xp/decompse_matmul_split_or_matmul_gather
xipingyan Jul 4, 2024
efd55e7
remove debug code, and remvoe MatMulVariadicSplitDecomposition
xipingyan Jul 4, 2024
8760ced
fix CI clang issue.
xipingyan Jul 4, 2024
2e56d87
update based on comments
xipingyan Jul 5, 2024
3dc306e
Merge branch 'master' into xp/decompse_matmul_split_or_matmul_gather
xipingyan Jul 5, 2024
9290aa9
Fix transpose_b = false test fail.
xipingyan Jul 5, 2024
a154e26
try to fix CI arm test fail
xipingyan Jul 5, 2024
aefb31f
Baichun also match this pattern, but it only has rank 4, and have per…
xipingyan Jul 8, 2024
0393c50
Merge branch 'master' into xp/decompse_matmul_split_or_matmul_gather
xipingyan Jul 8, 2024
f574620
Updated based on comments.
xipingyan Jul 16, 2024
e0cc4eb
Merge branch 'master' into xp/decompse_matmul_split_or_matmul_gather
xipingyan Jul 16, 2024
eea39c7
Fix merge conflict.
xipingyan Jul 16, 2024
d094d70
Fix clang issue.
xipingyan Jul 16, 2024
8a45461
fix CI fail: because Diff: 1.75813e-05, so set abs_threshold = 1e-4
xipingyan Jul 16, 2024
1f68706
Merge remote-tracking branch 'origin/master' into xp/decompse_matmul_…
xipingyan Jul 22, 2024
54a5798
Merge branch 'master' into xp/decompse_matmul_split_or_matmul_gather
xipingyan Jul 23, 2024
461b708
Updated based on comments.
xipingyan Jul 23, 2024
53f9eca
Merge remote-tracking branch 'origin/master' into xp/decompse_matmul_…
xipingyan Jul 23, 2024
c2d71f4
Merge branch 'master' into xp/decompse_matmul_split_or_matmul_gather
xipingyan Jul 24, 2024
3dec73b
Merge branch 'master' into xp/decompse_matmul_split_or_matmul_gather
xipingyan Jul 29, 2024
e35702e
Updated based on comments.
xipingyan Aug 14, 2024
f24bbbf
Merge branch 'master' into xp/decompse_matmul_split_or_matmul_gather
xipingyan Aug 14, 2024
26f9314
Fix build error. Remove namespace CPUTestUtils::
xipingyan Aug 14, 2024
621a4dc
regist pass to CPU_REGISTER_PASS_COMMON, work for ARM
xipingyan Aug 15, 2024
c8a16f0
Merge commit '54f58b86' into xp/decompse_matmul_split_or_matmul_gather
xipingyan Aug 26, 2024
a7b978a
Move "CheckNumberOfNodesWithType" to src/tests/test_utils/functional_…
xipingyan Aug 26, 2024
cbb4fc2
Merge branch 'master' into xp/decompse_matmul_split_or_matmul_gather
xipingyan Sep 10, 2024
4d5f03d
Fix merge master conflict issue.
xipingyan Sep 10, 2024
121e6cb
1:Move decompose_num private
xipingyan Sep 24, 2024
cb15e8e
Move "matmul_split_decomposition.cpp" to CPU:
xipingyan Sep 24, 2024
cde2f32
Merge branch 'master' into xp/decompse_matmul_split_or_matmul_gather
xipingyan Sep 25, 2024
6f9ebad
Merge branch 'master' into xp/decompse_matmul_split_or_matmul_gather
xipingyan Sep 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "openvino/pass/graph_rewrite.hpp"
#include "transformations_visibility.hpp"

namespace ov {
namespace pass {

class TRANSFORMATIONS_API MatmulGatherDecomposition;

} // namespace pass
} // namespace ov

/**
* @ingroup ov_transformation_common_api
* @brief MatmulGatherDecomposition transformation matches following graph:
*
* +----------+
* | input |
* +----------+
* |
* v
* +----------+
* | MatMul |
* +----------+
* |
* v
* +------------+
* | Some nodes |
* +------------+
* |
* v
* +-----------------------+
* | Transpose |
* +-----------------------+
* | | |
* v v v
* +-------+ +-------+ +-------+
* |Gather | |Gather | |Gather |
* +-------+ +-------+ +-------+
* and replaces with:
*
* +-----------------------+
* | input |
* +-----------------------+
* | | |
* v v v
* +-------+ +-------+ +-------+
* |MatMul | |MatMul | |MatMul |
* +-------+ +-------+ +-------+
* | | |
* v v v
* +-------+ +-------+ +-------+
* |Nodes | |Nodes | |Nodes |
* +-------+ +-------+ +-------+
* | | |
* v v v
* +---------+ +---------+ +---------+
* |Transpose| |Transpose| |Transpose|
* +---------+ +---------+ +---------+
*/
class ov::pass::MatmulGatherDecomposition : public ov::pass::MatcherPass {
xipingyan marked this conversation as resolved.
Show resolved Hide resolved
public:
OPENVINO_RTTI("MatmulGatherDecomposition", "0");
MatmulGatherDecomposition();
void split_weights(const Output<Node>& weights,
OutputVector& new_weights,
Output<Node>* bias,
OutputVector& new_bias,
const bool& transpos_b);
xipingyan marked this conversation as resolved.
Show resolved Hide resolved
};
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
#include "transformations/common_optimizations/lin_op_sequence_fusion.hpp"
#include "transformations/common_optimizations/mark_precision_sensitive_shapeof_subgraphs.hpp"
#include "transformations/common_optimizations/matmul_multiply_fusion.hpp"
#include "transformations/common_optimizations/matmul_split_decomposition.hpp"
xipingyan marked this conversation as resolved.
Show resolved Hide resolved
#include "transformations/common_optimizations/moc_transformations.hpp"
#include "transformations/common_optimizations/mul_conv_fusion.hpp"
#include "transformations/common_optimizations/mul_fake_quantize_fusion.hpp"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "transformations/common_optimizations/matmul_split_decomposition.hpp"

#include <cstdint>
#include <limits>
#include <memory>
#include <openvino/core/rt_info.hpp>
#include <openvino/opsets/opset13.hpp>
#include <openvino/opsets/opset6.hpp>
#include <openvino/opsets/opset8.hpp>
xipingyan marked this conversation as resolved.
Show resolved Hide resolved
#include <openvino/pass/pattern/op/or.hpp>
#include <openvino/pass/pattern/op/wrap_type.hpp>
#include <transformations/utils/utils.hpp>
#include <vector>

#include "itt.hpp"
#include "openvino/core/rt_info.hpp"
#include "openvino/core/validation_util.hpp"
#include "openvino/op/add.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/gather.hpp"
#include "openvino/op/matmul.hpp"
#include "openvino/op/multiply.hpp"
#include "openvino/op/reshape.hpp"
#include "openvino/op/transpose.hpp"
#include "openvino/op/util/gather_base.hpp"
#include "openvino/opsets/opset1.hpp"
xipingyan marked this conversation as resolved.
Show resolved Hide resolved
xipingyan marked this conversation as resolved.
Show resolved Hide resolved
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "ov_ops/type_relaxed.hpp"
#include "transformations/utils/utils.hpp"

using namespace ov::op;
using namespace ov;
using namespace ov::pass::pattern;

void pass::MatmulGatherDecomposition::split_weights(const Output<Node>& weights,
OutputVector& new_weights,
Output<Node>* bias,
OutputVector& new_bias,
const bool& transpos_b) {
xipingyan marked this conversation as resolved.
Show resolved Hide resolved
const auto& weights_shape = weights.get_partial_shape();
int64_t weights_rank = static_cast<int64_t>(weights_shape.rank().get_length());
xipingyan marked this conversation as resolved.
Show resolved Hide resolved

if (bias) {
const auto& bias_shape = bias->get_partial_shape();
int64_t bias_rank = static_cast<int64_t>(bias_shape.rank().get_length());
xipingyan marked this conversation as resolved.
Show resolved Hide resolved
if (weights_rank != 2 || (bias_rank != 3 && bias_rank != 1)) {
xipingyan marked this conversation as resolved.
Show resolved Hide resolved
return;
}
}

// Decompose weights
auto axis = register_new_node(v0::Constant::create(element::i32, Shape{}, {transpos_b ? 0 : 1}));
auto split = register_new_node<opset1::Split>(weights, axis, 3);
xipingyan marked this conversation as resolved.
Show resolved Hide resolved
for (auto& out : split->outputs()) {
new_weights.emplace_back(out);
}

if (bias) {
// Decompose bias
auto axis2 = register_new_node(v0::Constant::create(element::i32, Shape{}, {-1})); // axis -1
auto split2 = register_new_node<opset1::Split>(*bias, axis2, 3);
for (auto& out : split2->outputs()) {
new_bias.emplace_back(out);
}
}
}

pass::MatmulGatherDecomposition::MatmulGatherDecomposition() {
xipingyan marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Contributor

@EgorDuplensky EgorDuplensky Sep 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand that this transformation is trying to match a very specific pattern from llm models but shouldn't we have some heuristic for the weights size or something?
I mean do we expect any model with any weights sizes to benefit from this transformation?
Also, please describe in the commit message / PR description the motivation of having this transformation, why we expect it to speed up llms in the first place.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I want to match VIT similar structure model, and I also add some heuristics, check Rank, decompose_num, and specific transpose order, do you think these are not enough?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably this will be enough most of the times.
But, again, this is mostly about the reason we are getting the speed-ups.
I assume we observe speed-ups not because of the ranks, decompose_num and transpose order, but because we become less memory bound. But maybe I am wrong.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think it is related to input data size, because it is dynamic shape, so it is hard to custom describe it, just try to best.

MATCHER_SCOPE(MatmulGatherDecomposition);
auto input_pattern = any_input();
auto matmul_pattern = wrap_type<opset1::MatMul>({input_pattern, any_input()});

auto bias_pattern = wrap_type<opset1::Constant>();
auto add_pattern = wrap_type<opset1::Add>({matmul_pattern, bias_pattern});
xipingyan marked this conversation as resolved.
Show resolved Hide resolved

auto reshape_productor_pattern = std::make_shared<pattern::op::Or>(OutputVector{matmul_pattern, add_pattern});

auto reshape_pattern = wrap_type<opset1::Reshape>({reshape_productor_pattern, any_input()});
auto transpose_pattern = wrap_type<opset6::Transpose>({reshape_pattern, any_input()});
auto reshape2_pattern = wrap_type<opset1::Reshape>({reshape_pattern, any_input()});

auto reshape_or_transpose_pattern =
std::make_shared<pattern::op::Or>(OutputVector{reshape2_pattern, transpose_pattern});

matcher_pass_callback callback = [OV_CAPTURE_CPY_AND_THIS](pattern::Matcher& m) {
const auto& pattern_map = m.get_pattern_value_map();

// Heuristics: there should be only 3 gathers to split
xipingyan marked this conversation as resolved.
Show resolved Hide resolved
auto root_node = m.get_match_root();
xipingyan marked this conversation as resolved.
Show resolved Hide resolved
bool have_transpose = as_type<opset1::Transpose>(root_node.get()) != nullptr;
xipingyan marked this conversation as resolved.
Show resolved Hide resolved
auto children = root_node->get_output_target_inputs(0);
if (children.size() != 3u) {
return false;
}

auto matmul = pattern_map.at(matmul_pattern).get_node_shared_ptr();
auto weights = matmul->input_value(1);
std::shared_ptr<ov::Node> add = nullptr;
xipingyan marked this conversation as resolved.
Show resolved Hide resolved
bool have_bias = false;
xipingyan marked this conversation as resolved.
Show resolved Hide resolved
for (auto& consumer : matmul->get_output_target_inputs(0)) {
if (ov::is_type<opset1::Add>(consumer.get_node()->shared_from_this())) {
add = pattern_map.at(add_pattern).get_node_shared_ptr();
have_bias = true;
break;
}
}
xipingyan marked this conversation as resolved.
Show resolved Hide resolved
const bool& transpose_b = as_type_ptr<opset1::MatMul>(matmul)->get_transpose_b();
const auto& reshape = pattern_map.at(reshape_pattern);
auto concat = reshape.get_node_shared_ptr()->input_value(1);
xipingyan marked this conversation as resolved.
Show resolved Hide resolved

NodeVector gathers, fake_quantizes;
gathers.resize(3);
fake_quantizes.resize(3);
for (auto& child : children) {
std::shared_ptr<ov::Node> fq = nullptr;
auto gather = child.get_node()->shared_from_this();
if (ov::is_type<opset1::FakeQuantize>(gather)) {
fq = gather;
gather = gather->get_output_target_inputs(0).begin()->get_node()->shared_from_this();
xipingyan marked this conversation as resolved.
Show resolved Hide resolved
}
if (ov::is_type<ov::op::util::GatherBase>(gather)) {
const auto axis_node = as_type_ptr<opset6::Constant>(gather->input_value(2).get_node_shared_ptr());
xipingyan marked this conversation as resolved.
Show resolved Hide resolved
if (axis_node) {
const auto& axis_val = axis_node->cast_vector<int32_t>();
if (axis_val.size() != 1u || axis_val[0] != 0) {
return false;
}
} else {
return false;
}

const auto indices_node = as_type_ptr<opset6::Constant>(gather->input_value(1).get_node_shared_ptr());
if (indices_node) {
const auto& indices_val = indices_node->cast_vector<int32_t>();
if (indices_val.size() != 1) {
return false;
}
if (indices_val[0] < 0 || indices_val[0] >= 3) {
return false;
}
gathers[indices_val[0]] = gather;
fake_quantizes[indices_val[0]] = fq;
} else {
return false;
}
} else {
return false;
}
}

if (std::any_of(gathers.begin(), gathers.end(), [](const std::shared_ptr<Node> node_ptr) {
return !node_ptr || !is_type<ov::op::util::GatherBase>(node_ptr);
})) {
return false;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like unnecessary check: we already check that while filled the gathers vector

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, it also check "indices_val[0]" cover 3 cases.(0,1,2)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still don't understand this check: I don't see any indices related check in this any_of, and as for type check: this was already done a few lines upper


Output<Node> bias;
OutputVector new_weights, new_bias;
if (have_bias) {
bias = pattern_map.at(bias_pattern);
}
split_weights(weights, new_weights, have_bias ? &bias : nullptr, new_bias, transpose_b);
if (new_weights.size() != 3u || (have_bias && new_bias.size() != 3u)) {
return false;
}

auto const_indices = register_new_node(v0::Constant::create(element::i32, Shape{4}, {0, 1, 3, 4}));
xipingyan marked this conversation as resolved.
Show resolved Hide resolved
auto const_axis = register_new_node(v0::Constant::create(element::i32, Shape{}, {0}));
auto new_shape = register_new_node<v1::Gather>(concat, const_indices, const_axis);
xipingyan marked this conversation as resolved.
Show resolved Hide resolved
const auto& input = pattern_map.at(input_pattern);
for (size_t i = 0; i < 3u; i++) {
auto new_mm = register_new_node<v0::MatMul>(input, new_weights[i], false, transpose_b);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to check somewhere upper that matmul->get_transpose_a() == false

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so.
1: The real case does not include tranpose_a being False.
2: It makes the code more complicated.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Although there are no known cases where such patterns have transpose_a==true, it doesn't mean that there will be no such cases in future (not necessary the cases from real models, we should also take into account synthetic cases, e.g. from tests).
And I would prefer to avoid debugging MM shape inference failure if such patterns will appear. Moreover, this is one-line check that can be done right after matmul variable creation: almost no code is spent on that. So please do that

std::shared_ptr<ov::Node> reshape_productor = new_mm;
if (have_bias) {
reshape_productor = register_new_node<v1::Add>(new_mm, new_bias[i]);
}
auto new_reshape = register_new_node<v1::Reshape>(reshape_productor, new_shape, true);
ov::NodeVector from_nodes = {gathers[i], weights.get_node_shared_ptr(), matmul};
if (have_bias) {
from_nodes.emplace_back(add);
from_nodes.emplace_back(pattern_map.at(bias_pattern).get_node_shared_ptr());
}
if (have_transpose)
from_nodes.emplace_back(root_node);

copy_runtime_info(from_nodes, get_new_nodes());
xipingyan marked this conversation as resolved.
Show resolved Hide resolved
auto transpose_order = register_new_node(v0::Constant::create(element::i32, Shape{4}, {0, 2, 1, 3}));
xipingyan marked this conversation as resolved.
Show resolved Hide resolved
auto new_transpose = register_new_node<v1::Transpose>(new_reshape, transpose_order);
new_transpose->set_friendly_name(gathers[i]->get_friendly_name());

if (fake_quantizes[i]) {
fake_quantizes[i]->set_argument(0, new_transpose);
replace_node(gathers[i], fake_quantizes[i]);
} else {
replace_node(gathers[i], new_transpose);
}
}
return true;
};

auto m = std::make_shared<pattern::Matcher>(reshape_or_transpose_pattern, matcher_name);
this->register_matcher(m, callback);
}
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
#include "transformations/common_optimizations/lstm_cell_fusion.hpp"
#include "transformations/common_optimizations/matmul_const_transposes_extraction.hpp"
#include "transformations/common_optimizations/matmul_multiply_fusion.hpp"
#include "transformations/common_optimizations/matmul_split_decomposition.hpp"
#include "transformations/common_optimizations/mul_conv_fusion.hpp"
#include "transformations/common_optimizations/mul_fake_quantize_fusion.hpp"
#include "transformations/common_optimizations/mvn_fusion.hpp"
Expand Down Expand Up @@ -249,6 +250,7 @@ bool ov::pass::MOCTransformations::run_on_model(const std::shared_ptr<ov::Model>
REGISTER_PASS(manager, ConvToBinaryConv)

auto decomp = manager.register_pass<ov::pass::GraphRewrite>();
ADD_MATCHER(decomp, MatmulGatherDecomposition)
xipingyan marked this conversation as resolved.
Show resolved Hide resolved
ADD_MATCHER(decomp, BatchNormDecomposition)
ADD_MATCHER(decomp, ConvertDivideWithConstant)
ADD_MATCHER(decomp, ConvertSubtractWithConstant)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "subgraph_tests/matmul_split_decompose.hpp"

using namespace ov::test;
namespace {

std::vector<MatMulGatherDecomposeShapeParams> mm_gather_shape_params = {
{{2, 5, 8}, {24, 8}, true, true, {1, 1, 24}, {2, 5, 3, 2, 4}},
{{1, 1, 8}, {24, 8}, true, false, {1, 1, 24}, {1, 1, 3, 2, 4}},
xipingyan marked this conversation as resolved.
Show resolved Hide resolved
{{1, 2, 4}, {4, 12}, false, true, {1, 1, 12}, {1, 2, 3, 2, 2}},
};

INSTANTIATE_TEST_SUITE_P(smoke_MatMulGatherDecompose,
MatMulGatherDecompose,
::testing::Combine(::testing::ValuesIn(mm_gather_shape_params),
::testing::Values(ov::test::utils::DEVICE_CPU),
::testing::Values(false, true)),
MatMulGatherDecompose::getTestCaseName);
} // namespace
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "functional_test_utils/skip_tests_config.hpp"
#include "shared_test_classes/subgraph/matmul_split_decompose.hpp"

namespace ov {
namespace test {

TEST_P(MatMulGatherDecompose, CompareWithRefs) {
SKIP_IF_CURRENT_TEST_IS_DISABLED();
run();
check_results();
}

} // namespace test
} // namespace ov
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include <string>
#include <tuple>

#include "shared_test_classes/base/ov_subgraph.hpp"

namespace ov {
namespace test {

struct MatMulGatherDecomposeShapeParams {
ov::Shape input_shape;
ov::Shape weights_shape;
bool trans_b;
bool have_bias;
ov::Shape bias_shape;
ov::Shape reshape_shape;
};

typedef std::tuple<MatMulGatherDecomposeShapeParams,
std::string, // Device name
bool // Enable FakeQuantize
>
MatMulGatherDecomposeParams;

class MatMulGatherDecompose : public testing::WithParamInterface<MatMulGatherDecomposeParams>,
virtual public ov::test::SubgraphBaseStaticTest {
public:
static std::string getTestCaseName(const testing::TestParamInfo<MatMulGatherDecomposeParams>& obj);

protected:
void SetUp() override;
void check_results();
};

} // namespace test
} // namespace ov
Loading
Loading