diff --git a/src/common/transformations/src/transformations/common_optimizations/matmul_multiply_fusion.cpp b/src/common/transformations/src/transformations/common_optimizations/matmul_multiply_fusion.cpp index 2d6976d326a7a5..37d9888fc3f98f 100644 --- a/src/common/transformations/src/transformations/common_optimizations/matmul_multiply_fusion.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/matmul_multiply_fusion.cpp @@ -151,7 +151,7 @@ pass::MatMulMultiplyFusion::MatMulMultiplyFusion() { auto input_pattern = pattern::any_input(); auto weights_pattern = pattern::any_input(pattern::has_static_rank()); auto mul_const_pattern = pattern::wrap_type(); - auto matmul_pattern = pattern::wrap_type({input_pattern, weights_pattern}); + auto matmul_pattern = pattern::wrap_type({input_pattern, weights_pattern}, pattern::consumers_count(1)); auto mul_pattern = pattern::wrap_type({matmul_pattern, mul_const_pattern}); matcher_pass_callback callback = [OV_CAPTURE_CPY_AND_THIS](pattern::Matcher& m) { diff --git a/src/common/transformations/tests/common_optimizations/matmul_multiply_fusion.cpp b/src/common/transformations/tests/common_optimizations/matmul_multiply_fusion.cpp index 6bac9562bee0a5..ecb403d051d93d 100644 --- a/src/common/transformations/tests/common_optimizations/matmul_multiply_fusion.cpp +++ b/src/common/transformations/tests/common_optimizations/matmul_multiply_fusion.cpp @@ -99,6 +99,19 @@ TEST_F(TransformationTestsF, MatMulMultiplyFusionNonConstantTransposedWeightsNon comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES); } +TEST_F(TransformationTestsF, MatMulMultiplyFusionNonSingleConsumer) { + auto data = std::make_shared(element::f32, Shape{2, 3}); + auto weights = opset8::Constant::create(element::f32, Shape{2, 3}, {2, 6, 6, 12, 10, 18}); + auto matmul = std::make_shared(data, weights, false, true); + auto mul_const = opset8::Constant::create(element::f32, Shape{1, 2}, {4, 5}); + auto mul = std::make_shared(matmul, mul_const); + auto add = std::make_shared(matmul, mul); + model = std::make_shared(NodeVector{add}, ParameterVector{data}); + + manager.register_pass(); + comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES); +} + using MatMulMultiplyFusionParams = std::tuple; class MatMulMultiplyFusionDynamicShapes : public testing::WithParamInterface,