diff --git a/src/plugins/intel_cpu/src/transformations/cpu_opset/common/pass/split_fc.cpp b/src/plugins/intel_cpu/src/transformations/cpu_opset/common/pass/split_fc.cpp index 1e9bd271988cd8..9c1b4d82c79719 100644 --- a/src/plugins/intel_cpu/src/transformations/cpu_opset/common/pass/split_fc.cpp +++ b/src/plugins/intel_cpu/src/transformations/cpu_opset/common/pass/split_fc.cpp @@ -92,10 +92,10 @@ ov::intel_cpu::SplitFC::SplitFC(int sub_stream_num) { bool need_to_split_convert = ov::shape_size(convert_node1_shape) > 1 && split_dim < convert_node1_shape.size() && convert_node1_shape[split_dim] == split_dim_range; - auto weights = wgt_item->get_shape(); + const auto& wgt_shape = wgt_item->get_shape(); // needn't to split fc when the dim is 0. - if (split_dim_range <= 1 || weights[0] * weights[1] < 6600000) { + if (split_dim_range <= 1 || ov::shape_size(wgt_shape) < 6600000) { return false; } diff --git a/src/plugins/intel_cpu/tests/unit/transformations/split_fc_test.cpp b/src/plugins/intel_cpu/tests/unit/transformations/split_fc_test.cpp index d2a5fa66f00a09..6eef34867ffe1d 100644 --- a/src/plugins/intel_cpu/tests/unit/transformations/split_fc_test.cpp +++ b/src/plugins/intel_cpu/tests/unit/transformations/split_fc_test.cpp @@ -26,25 +26,25 @@ TEST_F(TransformationTestsF, SplitFCTest) { disable_result_friendly_names_check(); disable_rt_info_check(); { - auto src = std::make_shared(ov::element::f32, ov::Shape{ 3, 16, 1 }); + auto src = std::make_shared(ov::element::f32, ov::Shape{ 3, 4096, 1 }); auto transpose_constant_src = ov::opset1::Constant::create(ov::element::i32, ov::Shape{ 3 }, { 0, 2, 1 }); auto transpose_src = std::make_shared(src, transpose_constant_src); - auto wgt = ov::opset1::Constant::create(ov::element::f32, ov::Shape{ 2, 16 }, { 12.34 }); + auto wgt = ov::opset1::Constant::create(ov::element::f32, ov::Shape{ 2048, 4096 }, { 12.34 }); auto fc = std::make_shared(transpose_src, wgt, ov::Rank(3)); model = std::make_shared(ov::NodeVector{fc}, ov::ParameterVector{src}); manager.register_pass(1); } { - auto src = std::make_shared(ov::element::f32, ov::Shape{ 3, 16, 1 }); + auto src = std::make_shared(ov::element::f32, ov::Shape{ 3, 4096, 1 }); auto transpose_constant_src = ov::opset1::Constant::create(ov::element::i32, ov::Shape{ 3 }, { 0, 2, 1 }); auto transpose_src = std::make_shared(src, transpose_constant_src); - auto wgt = ov::opset1::Constant::create(ov::element::f32, ov::Shape{ 2, 16 }, { 12.34 }); + auto wgt = ov::opset1::Constant::create(ov::element::f32, ov::Shape{ 2048, 4096 }, { 12.34 }); auto split_dim_node = std::make_shared(ov::element::i32, ov::Shape{}, 0); - auto split_length = ov::opset1::Constant::create(ov::element::i32, ov::Shape{2}, {1, 1}); + auto split_length = ov::opset1::Constant::create(ov::element::i32, ov::Shape{2}, {1024, 1024}); auto split_wgts = std::make_shared(wgt, split_dim_node, split_length); auto fc0 = std::make_shared(transpose_src, split_wgts->output(0), ov::Rank(3)); @@ -61,19 +61,19 @@ TEST_F(TransformationTestsF, SplitFCTest_int8_weight) { disable_result_friendly_names_check(); disable_rt_info_check(); { - auto src = std::make_shared(ov::element::f32, ov::Shape{3, 16, 1}); + auto src = std::make_shared(ov::element::f32, ov::Shape{3, 4096, 1}); auto transpose_constant_src = ov::opset1::Constant::create(ov::element::i32, ov::Shape{3}, {0, 2, 1}); auto transpose_src = std::make_shared(src, transpose_constant_src); - auto wgt = ov::opset1::Constant::create(ov::element::u8, ov::Shape{2, 16}, {123}); + auto wgt = ov::opset1::Constant::create(ov::element::u8, ov::Shape{2048, 4096}, {123}); auto cvt_wgt = std::make_shared(wgt, ov::element::f32); - auto zp = ov::opset1::Constant::create(ov::element::u8, ov::Shape{2, 1}, {1}); + auto zp = ov::opset1::Constant::create(ov::element::u8, ov::Shape{2048, 1}, {1}); auto cvt_zp = std::make_shared(zp, ov::element::f32); auto sub = std::make_shared(cvt_wgt, cvt_zp); - auto mul_const = ov::opset1::Constant::create(ov::element::f32, ov::Shape{2, 1}, {0.2}); + auto mul_const = ov::opset1::Constant::create(ov::element::f32, ov::Shape{2048, 1}, {0.2}); auto mul = std::make_shared(sub, mul_const); auto fc = std::make_shared(transpose_src, mul, ov::Rank(3)); @@ -81,21 +81,21 @@ TEST_F(TransformationTestsF, SplitFCTest_int8_weight) { manager.register_pass(1); } { - auto src = std::make_shared(ov::element::f32, ov::Shape{ 3, 16, 1 }); + auto src = std::make_shared(ov::element::f32, ov::Shape{ 3, 4096, 1 }); auto transpose_constant_src = ov::opset1::Constant::create(ov::element::i32, ov::Shape{ 3 }, { 0, 2, 1 }); auto transpose_src = std::make_shared(src, transpose_constant_src); - auto wgt = ov::opset1::Constant::create(ov::element::u8, ov::Shape{ 2, 16 }, { 123 }); + auto wgt = ov::opset1::Constant::create(ov::element::u8, ov::Shape{ 2048, 4096 }, { 123 }); auto cvt_wgt = std::make_shared(wgt, ov::element::f32); auto split_dim_node = std::make_shared(ov::element::i32, ov::Shape{}, 0); - auto split_length = ov::opset1::Constant::create(ov::element::i32, ov::Shape{2}, {1, 1}); + auto split_length = ov::opset1::Constant::create(ov::element::i32, ov::Shape{2}, {1024, 1024}); auto split_wgts = std::make_shared(wgt, split_dim_node, split_length); auto cvt_wgt0 = std::make_shared(split_wgts->output(0), ov::element::f32); auto cvt_wgt1 = std::make_shared(split_wgts->output(1), ov::element::f32); - auto zp = ov::opset1::Constant::create(ov::element::u8, ov::Shape{2, 1}, {1}); + auto zp = ov::opset1::Constant::create(ov::element::u8, ov::Shape{2048, 1}, {1}); auto split_zp = std::make_shared(zp, split_dim_node, split_length); auto cvt_zp0 = std::make_shared(split_zp->output(0), ov::element::f32); @@ -104,7 +104,7 @@ TEST_F(TransformationTestsF, SplitFCTest_int8_weight) { auto sub0 = std::make_shared(cvt_wgt0, cvt_zp0); auto sub1 = std::make_shared(cvt_wgt1, cvt_zp1); - auto mul_const = ov::opset1::Constant::create(ov::element::f32, ov::Shape{2, 1}, {0.2}); + auto mul_const = ov::opset1::Constant::create(ov::element::f32, ov::Shape{2048, 1}, {0.2}); auto split_mul_const = std::make_shared(mul_const, split_dim_node, split_length); auto mul0 = std::make_shared(sub0, split_mul_const->output(0)); @@ -124,19 +124,19 @@ TEST_F(TransformationTestsF, SplitFCTest_int4_weight) { disable_result_friendly_names_check(); disable_rt_info_check(); { - auto src = std::make_shared(ov::element::f32, ov::Shape{3, 16, 1}); + auto src = std::make_shared(ov::element::f32, ov::Shape{3, 4096, 1}); auto transpose_constant_src = ov::opset1::Constant::create(ov::element::i32, ov::Shape{3}, {0, 2, 1}); auto transpose_src = std::make_shared(src, transpose_constant_src); - auto wgt = ov::opset1::Constant::create(ov::element::u4, ov::Shape{2, 16}, {12}); + auto wgt = ov::opset1::Constant::create(ov::element::u4, ov::Shape{2048, 4096}, {12}); auto cvt_wgt = std::make_shared(wgt, ov::element::f32); - auto zp = ov::opset1::Constant::create(ov::element::u4, ov::Shape{2, 1}, {1}); + auto zp = ov::opset1::Constant::create(ov::element::u4, ov::Shape{2048, 1}, {1}); auto cvt_zp = std::make_shared(zp, ov::element::f32); auto sub = std::make_shared(cvt_wgt, cvt_zp); - auto mul_const = ov::opset1::Constant::create(ov::element::f32, ov::Shape{2, 1}, {0.2}); + auto mul_const = ov::opset1::Constant::create(ov::element::f32, ov::Shape{2048, 1}, {0.2}); auto mul = std::make_shared(sub, mul_const); auto fc = std::make_shared(transpose_src, mul, ov::Rank(3)); @@ -144,15 +144,15 @@ TEST_F(TransformationTestsF, SplitFCTest_int4_weight) { manager.register_pass(1); } { - auto src = std::make_shared(ov::element::f32, ov::Shape{3, 16, 1}); + auto src = std::make_shared(ov::element::f32, ov::Shape{3, 4096, 1}); auto transpose_constant_src = ov::opset1::Constant::create(ov::element::i32, ov::Shape{3}, {0, 2, 1}); auto transpose_src = std::make_shared(src, transpose_constant_src); - auto wgt = ov::opset1::Constant::create(ov::element::u4, ov::Shape{2, 16}, {12}); + auto wgt = ov::opset1::Constant::create(ov::element::u4, ov::Shape{2048, 4096}, {12}); auto cvt_wgt_f32 = std::make_shared(wgt, ov::element::f32); auto split_dim_node = std::make_shared(ov::element::i32, ov::Shape{}, 0); - auto split_length = ov::opset1::Constant::create(ov::element::i32, ov::Shape{2}, {1, 1}); + auto split_length = ov::opset1::Constant::create(ov::element::i32, ov::Shape{2}, {1024, 1024}); auto split_wgts = std::make_shared(cvt_wgt_f32, split_dim_node, split_length); auto cvt_wgt0_u4 = std::make_shared(split_wgts->output(0), ov::element::u4); @@ -160,7 +160,7 @@ TEST_F(TransformationTestsF, SplitFCTest_int4_weight) { auto cvt_wgt0_f32 = std::make_shared(cvt_wgt0_u4, ov::element::f32); auto cvt_wgt1_f32 = std::make_shared(cvt_wgt1_u4, ov::element::f32); - auto zp = ov::opset1::Constant::create(ov::element::u4, ov::Shape{2, 1}, {1}); + auto zp = ov::opset1::Constant::create(ov::element::u4, ov::Shape{2048, 1}, {1}); auto cvt_zp_f32 = std::make_shared(zp, ov::element::f32); auto split_zp = std::make_shared(cvt_zp_f32, split_dim_node, split_length); @@ -172,7 +172,7 @@ TEST_F(TransformationTestsF, SplitFCTest_int4_weight) { auto sub0 = std::make_shared(cvt_wgt0_f32, cvt_zp0_f32); auto sub1 = std::make_shared(cvt_wgt1_f32, cvt_zp1_f32); - auto mul_const = ov::opset1::Constant::create(ov::element::f32, ov::Shape{2, 1}, {0.2}); + auto mul_const = ov::opset1::Constant::create(ov::element::f32, ov::Shape{2048, 1}, {0.2}); auto split_mul_const = std::make_shared(mul_const, split_dim_node, split_length); auto mul0 = std::make_shared(sub0, split_mul_const->output(0)); @@ -192,11 +192,11 @@ TEST_F(TransformationTestsF, SplitFCTest_int4_weight_reshape) { disable_result_friendly_names_check(); disable_rt_info_check(); { - auto src = std::make_shared(ov::element::f32, ov::Shape{ 3, 16, 1 }); + auto src = std::make_shared(ov::element::f32, ov::Shape{ 3, 2048, 1 }); auto transpose_constant_src = ov::opset1::Constant::create(ov::element::i32, ov::Shape{ 3 }, { 0, 2, 1 }); auto transpose_src = std::make_shared(src, transpose_constant_src); - auto wgt = ov::opset1::Constant::create(ov::element::u4, ov::Shape{ 4, 2, 16 }, { 12 }); + auto wgt = ov::opset1::Constant::create(ov::element::u4, ov::Shape{ 4096, 2, 1024}, { 12 }); auto cvt_wgt = std::make_shared(wgt, ov::element::f32); auto zp = ov::opset1::Constant::create(ov::element::u4, ov::Shape{1}, { 1 }); @@ -204,10 +204,10 @@ TEST_F(TransformationTestsF, SplitFCTest_int4_weight_reshape) { auto sub = std::make_shared(cvt_wgt, cvt_zp); - auto mul_const = ov::opset1::Constant::create(ov::element::f32, ov::Shape{4, 2, 1}, {0.2}); + auto mul_const = ov::opset1::Constant::create(ov::element::f32, ov::Shape{4096, 2, 1}, {0.2}); auto mul = std::make_shared(sub, mul_const); - auto res_const = ov::opset1::Constant::create(ov::element::i32, ov::Shape{2}, {4, 32}); + auto res_const = ov::opset1::Constant::create(ov::element::i32, ov::Shape{2}, {4096, 2048}); auto reshape = std::make_shared(mul, res_const, false); auto fc = std::make_shared(transpose_src, reshape, ov::Rank(3)); @@ -215,15 +215,15 @@ TEST_F(TransformationTestsF, SplitFCTest_int4_weight_reshape) { manager.register_pass(1); } { - auto src = std::make_shared(ov::element::f32, ov::Shape{ 3, 16, 1 }); + auto src = std::make_shared(ov::element::f32, ov::Shape{ 3, 2048, 1 }); auto transpose_constant_src = ov::opset1::Constant::create(ov::element::i32, ov::Shape{ 3 }, { 0, 2, 1 }); auto transpose_src = std::make_shared(src, transpose_constant_src); - auto wgt = ov::opset1::Constant::create(ov::element::u4, ov::Shape{ 4, 2, 16 }, { 12 }); + auto wgt = ov::opset1::Constant::create(ov::element::u4, ov::Shape{ 4096, 2, 1024 }, { 12 }); auto cvt_wgt_f32 = std::make_shared(wgt, ov::element::f32); auto split_dim_node = std::make_shared(ov::element::i32, ov::Shape{}, 0); - auto split_length = ov::opset1::Constant::create(ov::element::i32, ov::Shape{2}, {2, 2}); + auto split_length = ov::opset1::Constant::create(ov::element::i32, ov::Shape{2}, {2048, 2048}); auto split_wgts = std::make_shared(cvt_wgt_f32, split_dim_node, split_length); auto cvt_wgt0_u4 = std::make_shared(split_wgts->output(0), ov::element::u4); @@ -241,13 +241,13 @@ TEST_F(TransformationTestsF, SplitFCTest_int4_weight_reshape) { auto sub0 = std::make_shared(cvt_wgt0_f32, cvt_zp0); auto sub1 = std::make_shared(cvt_wgt1_f32, cvt_zp1); - auto mul_const = ov::opset1::Constant::create(ov::element::f32, ov::Shape{4, 2, 1}, {0.2}); + auto mul_const = ov::opset1::Constant::create(ov::element::f32, ov::Shape{4096, 2, 1}, {0.2}); auto split_mul_const = std::make_shared(mul_const, split_dim_node, split_length); auto mul0 = std::make_shared(sub0, split_mul_const->output(0)); auto mul1 = std::make_shared(sub1, split_mul_const->output(1)); - std::vector reshape_pattern_vec = {2, 32}; + std::vector reshape_pattern_vec = {2048, 2048}; auto reshape_pattern = std::make_shared(ov::element::i32, ov::Shape{2}, reshape_pattern_vec); auto reshape0 = std::make_shared(mul0, reshape_pattern, false); auto reshape1 = std::make_shared(mul1, reshape_pattern, false);