Skip to content

Commit

Permalink
Merge pull request #5 from xczhai/fc_parallel_all
Browse files Browse the repository at this point in the history
fix split_fc unittest
  • Loading branch information
xczhai authored Mar 22, 2024
2 parents 5764f29 + f57185e commit df846d3
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -92,10 +92,10 @@ ov::intel_cpu::SplitFC::SplitFC(int sub_stream_num) {
bool need_to_split_convert = ov::shape_size(convert_node1_shape) > 1 &&
split_dim < convert_node1_shape.size() &&
convert_node1_shape[split_dim] == split_dim_range;
auto weights = wgt_item->get_shape();
const auto& wgt_shape = wgt_item->get_shape();

// needn't to split fc when the dim is 0.
if (split_dim_range <= 1 || weights[0] * weights[1] < 6600000) {
if (split_dim_range <= 1 || ov::shape_size(wgt_shape) < 6600000) {
return false;
}

Expand Down
64 changes: 32 additions & 32 deletions src/plugins/intel_cpu/tests/unit/transformations/split_fc_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,25 +26,25 @@ TEST_F(TransformationTestsF, SplitFCTest) {
disable_result_friendly_names_check();
disable_rt_info_check();
{
auto src = std::make_shared<ov::opset1::Parameter>(ov::element::f32, ov::Shape{ 3, 16, 1 });
auto src = std::make_shared<ov::opset1::Parameter>(ov::element::f32, ov::Shape{ 3, 4096, 1 });
auto transpose_constant_src = ov::opset1::Constant::create(ov::element::i32, ov::Shape{ 3 }, { 0, 2, 1 });
auto transpose_src = std::make_shared<ov::opset1::Transpose>(src, transpose_constant_src);

auto wgt = ov::opset1::Constant::create(ov::element::f32, ov::Shape{ 2, 16 }, { 12.34 });
auto wgt = ov::opset1::Constant::create(ov::element::f32, ov::Shape{ 2048, 4096 }, { 12.34 });

auto fc = std::make_shared<FullyConnectedNode>(transpose_src, wgt, ov::Rank(3));
model = std::make_shared<ov::Model>(ov::NodeVector{fc}, ov::ParameterVector{src});
manager.register_pass<SplitFC>(1);
}
{
auto src = std::make_shared<ov::opset1::Parameter>(ov::element::f32, ov::Shape{ 3, 16, 1 });
auto src = std::make_shared<ov::opset1::Parameter>(ov::element::f32, ov::Shape{ 3, 4096, 1 });
auto transpose_constant_src = ov::opset1::Constant::create(ov::element::i32, ov::Shape{ 3 }, { 0, 2, 1 });
auto transpose_src = std::make_shared<ov::opset1::Transpose>(src, transpose_constant_src);

auto wgt = ov::opset1::Constant::create(ov::element::f32, ov::Shape{ 2, 16 }, { 12.34 });
auto wgt = ov::opset1::Constant::create(ov::element::f32, ov::Shape{ 2048, 4096 }, { 12.34 });

auto split_dim_node = std::make_shared<ov::op::v0::Constant>(ov::element::i32, ov::Shape{}, 0);
auto split_length = ov::opset1::Constant::create<int32_t>(ov::element::i32, ov::Shape{2}, {1, 1});
auto split_length = ov::opset1::Constant::create<int32_t>(ov::element::i32, ov::Shape{2}, {1024, 1024});
auto split_wgts = std::make_shared<ov::opset1::VariadicSplit>(wgt, split_dim_node, split_length);

auto fc0 = std::make_shared<FullyConnectedNode>(transpose_src, split_wgts->output(0), ov::Rank(3));
Expand All @@ -61,41 +61,41 @@ TEST_F(TransformationTestsF, SplitFCTest_int8_weight) {
disable_result_friendly_names_check();
disable_rt_info_check();
{
auto src = std::make_shared<ov::opset1::Parameter>(ov::element::f32, ov::Shape{3, 16, 1});
auto src = std::make_shared<ov::opset1::Parameter>(ov::element::f32, ov::Shape{3, 4096, 1});
auto transpose_constant_src = ov::opset1::Constant::create(ov::element::i32, ov::Shape{3}, {0, 2, 1});
auto transpose_src = std::make_shared<ov::opset1::Transpose>(src, transpose_constant_src);

auto wgt = ov::opset1::Constant::create(ov::element::u8, ov::Shape{2, 16}, {123});
auto wgt = ov::opset1::Constant::create(ov::element::u8, ov::Shape{2048, 4096}, {123});
auto cvt_wgt = std::make_shared<ov::opset1::Convert>(wgt, ov::element::f32);

auto zp = ov::opset1::Constant::create(ov::element::u8, ov::Shape{2, 1}, {1});
auto zp = ov::opset1::Constant::create(ov::element::u8, ov::Shape{2048, 1}, {1});
auto cvt_zp = std::make_shared<ov::opset1::Convert>(zp, ov::element::f32);

auto sub = std::make_shared<ov::opset1::Subtract>(cvt_wgt, cvt_zp);

auto mul_const = ov::opset1::Constant::create(ov::element::f32, ov::Shape{2, 1}, {0.2});
auto mul_const = ov::opset1::Constant::create(ov::element::f32, ov::Shape{2048, 1}, {0.2});
auto mul = std::make_shared<ov::opset1::Multiply>(sub, mul_const);

auto fc = std::make_shared<FullyConnectedNode>(transpose_src, mul, ov::Rank(3));
model = std::make_shared<ov::Model>(ov::NodeVector{fc}, ov::ParameterVector{src});
manager.register_pass<SplitFC>(1);
}
{
auto src = std::make_shared<ov::opset1::Parameter>(ov::element::f32, ov::Shape{ 3, 16, 1 });
auto src = std::make_shared<ov::opset1::Parameter>(ov::element::f32, ov::Shape{ 3, 4096, 1 });
auto transpose_constant_src = ov::opset1::Constant::create(ov::element::i32, ov::Shape{ 3 }, { 0, 2, 1 });
auto transpose_src = std::make_shared<ov::opset1::Transpose>(src, transpose_constant_src);

auto wgt = ov::opset1::Constant::create(ov::element::u8, ov::Shape{ 2, 16 }, { 123 });
auto wgt = ov::opset1::Constant::create(ov::element::u8, ov::Shape{ 2048, 4096 }, { 123 });
auto cvt_wgt = std::make_shared<ov::opset1::Convert>(wgt, ov::element::f32);

auto split_dim_node = std::make_shared<ov::op::v0::Constant>(ov::element::i32, ov::Shape{}, 0);
auto split_length = ov::opset1::Constant::create<int32_t>(ov::element::i32, ov::Shape{2}, {1, 1});
auto split_length = ov::opset1::Constant::create<int32_t>(ov::element::i32, ov::Shape{2}, {1024, 1024});

auto split_wgts = std::make_shared<ov::opset1::VariadicSplit>(wgt, split_dim_node, split_length);
auto cvt_wgt0 = std::make_shared<ov::opset1::Convert>(split_wgts->output(0), ov::element::f32);
auto cvt_wgt1 = std::make_shared<ov::opset1::Convert>(split_wgts->output(1), ov::element::f32);

auto zp = ov::opset1::Constant::create(ov::element::u8, ov::Shape{2, 1}, {1});
auto zp = ov::opset1::Constant::create(ov::element::u8, ov::Shape{2048, 1}, {1});
auto split_zp = std::make_shared<ov::opset1::VariadicSplit>(zp, split_dim_node, split_length);

auto cvt_zp0 = std::make_shared<ov::opset1::Convert>(split_zp->output(0), ov::element::f32);
Expand All @@ -104,7 +104,7 @@ TEST_F(TransformationTestsF, SplitFCTest_int8_weight) {
auto sub0 = std::make_shared<ov::opset1::Subtract>(cvt_wgt0, cvt_zp0);
auto sub1 = std::make_shared<ov::opset1::Subtract>(cvt_wgt1, cvt_zp1);

auto mul_const = ov::opset1::Constant::create(ov::element::f32, ov::Shape{2, 1}, {0.2});
auto mul_const = ov::opset1::Constant::create(ov::element::f32, ov::Shape{2048, 1}, {0.2});
auto split_mul_const = std::make_shared<ov::opset1::VariadicSplit>(mul_const, split_dim_node, split_length);

auto mul0 = std::make_shared<ov::opset1::Multiply>(sub0, split_mul_const->output(0));
Expand All @@ -124,43 +124,43 @@ TEST_F(TransformationTestsF, SplitFCTest_int4_weight) {
disable_result_friendly_names_check();
disable_rt_info_check();
{
auto src = std::make_shared<ov::opset1::Parameter>(ov::element::f32, ov::Shape{3, 16, 1});
auto src = std::make_shared<ov::opset1::Parameter>(ov::element::f32, ov::Shape{3, 4096, 1});
auto transpose_constant_src = ov::opset1::Constant::create(ov::element::i32, ov::Shape{3}, {0, 2, 1});
auto transpose_src = std::make_shared<ov::opset1::Transpose>(src, transpose_constant_src);

auto wgt = ov::opset1::Constant::create(ov::element::u4, ov::Shape{2, 16}, {12});
auto wgt = ov::opset1::Constant::create(ov::element::u4, ov::Shape{2048, 4096}, {12});
auto cvt_wgt = std::make_shared<ov::opset1::Convert>(wgt, ov::element::f32);

auto zp = ov::opset1::Constant::create(ov::element::u4, ov::Shape{2, 1}, {1});
auto zp = ov::opset1::Constant::create(ov::element::u4, ov::Shape{2048, 1}, {1});
auto cvt_zp = std::make_shared<ov::opset1::Convert>(zp, ov::element::f32);

auto sub = std::make_shared<ov::opset1::Subtract>(cvt_wgt, cvt_zp);

auto mul_const = ov::opset1::Constant::create(ov::element::f32, ov::Shape{2, 1}, {0.2});
auto mul_const = ov::opset1::Constant::create(ov::element::f32, ov::Shape{2048, 1}, {0.2});
auto mul = std::make_shared<ov::opset1::Multiply>(sub, mul_const);

auto fc = std::make_shared<FullyConnectedNode>(transpose_src, mul, ov::Rank(3));
model = std::make_shared<ov::Model>(ov::NodeVector{fc}, ov::ParameterVector{src});
manager.register_pass<SplitFC>(1);
}
{
auto src = std::make_shared<ov::opset1::Parameter>(ov::element::f32, ov::Shape{3, 16, 1});
auto src = std::make_shared<ov::opset1::Parameter>(ov::element::f32, ov::Shape{3, 4096, 1});
auto transpose_constant_src = ov::opset1::Constant::create(ov::element::i32, ov::Shape{3}, {0, 2, 1});
auto transpose_src = std::make_shared<ov::opset1::Transpose>(src, transpose_constant_src);

auto wgt = ov::opset1::Constant::create(ov::element::u4, ov::Shape{2, 16}, {12});
auto wgt = ov::opset1::Constant::create(ov::element::u4, ov::Shape{2048, 4096}, {12});
auto cvt_wgt_f32 = std::make_shared<ov::opset1::Convert>(wgt, ov::element::f32);

auto split_dim_node = std::make_shared<ov::op::v0::Constant>(ov::element::i32, ov::Shape{}, 0);
auto split_length = ov::opset1::Constant::create<int32_t>(ov::element::i32, ov::Shape{2}, {1, 1});
auto split_length = ov::opset1::Constant::create<int32_t>(ov::element::i32, ov::Shape{2}, {1024, 1024});

auto split_wgts = std::make_shared<ov::opset1::VariadicSplit>(cvt_wgt_f32, split_dim_node, split_length);
auto cvt_wgt0_u4 = std::make_shared<ov::opset1::Convert>(split_wgts->output(0), ov::element::u4);
auto cvt_wgt1_u4 = std::make_shared<ov::opset1::Convert>(split_wgts->output(1), ov::element::u4);
auto cvt_wgt0_f32 = std::make_shared<ov::opset1::Convert>(cvt_wgt0_u4, ov::element::f32);
auto cvt_wgt1_f32 = std::make_shared<ov::opset1::Convert>(cvt_wgt1_u4, ov::element::f32);

auto zp = ov::opset1::Constant::create(ov::element::u4, ov::Shape{2, 1}, {1});
auto zp = ov::opset1::Constant::create(ov::element::u4, ov::Shape{2048, 1}, {1});
auto cvt_zp_f32 = std::make_shared<ov::opset1::Convert>(zp, ov::element::f32);
auto split_zp = std::make_shared<ov::opset1::VariadicSplit>(cvt_zp_f32, split_dim_node, split_length);

Expand All @@ -172,7 +172,7 @@ TEST_F(TransformationTestsF, SplitFCTest_int4_weight) {
auto sub0 = std::make_shared<ov::opset1::Subtract>(cvt_wgt0_f32, cvt_zp0_f32);
auto sub1 = std::make_shared<ov::opset1::Subtract>(cvt_wgt1_f32, cvt_zp1_f32);

auto mul_const = ov::opset1::Constant::create(ov::element::f32, ov::Shape{2, 1}, {0.2});
auto mul_const = ov::opset1::Constant::create(ov::element::f32, ov::Shape{2048, 1}, {0.2});
auto split_mul_const = std::make_shared<ov::opset1::VariadicSplit>(mul_const, split_dim_node, split_length);

auto mul0 = std::make_shared<ov::opset1::Multiply>(sub0, split_mul_const->output(0));
Expand All @@ -192,38 +192,38 @@ TEST_F(TransformationTestsF, SplitFCTest_int4_weight_reshape) {
disable_result_friendly_names_check();
disable_rt_info_check();
{
auto src = std::make_shared<ov::opset1::Parameter>(ov::element::f32, ov::Shape{ 3, 16, 1 });
auto src = std::make_shared<ov::opset1::Parameter>(ov::element::f32, ov::Shape{ 3, 2048, 1 });
auto transpose_constant_src = ov::opset1::Constant::create(ov::element::i32, ov::Shape{ 3 }, { 0, 2, 1 });
auto transpose_src = std::make_shared<ov::opset1::Transpose>(src, transpose_constant_src);

auto wgt = ov::opset1::Constant::create(ov::element::u4, ov::Shape{ 4, 2, 16 }, { 12 });
auto wgt = ov::opset1::Constant::create(ov::element::u4, ov::Shape{ 4096, 2, 1024}, { 12 });
auto cvt_wgt = std::make_shared<ov::opset1::Convert>(wgt, ov::element::f32);

auto zp = ov::opset1::Constant::create(ov::element::u4, ov::Shape{1}, { 1 });
auto cvt_zp = std::make_shared<ov::opset1::Convert>(zp, ov::element::f32);

auto sub = std::make_shared<ov::opset1::Subtract>(cvt_wgt, cvt_zp);

auto mul_const = ov::opset1::Constant::create(ov::element::f32, ov::Shape{4, 2, 1}, {0.2});
auto mul_const = ov::opset1::Constant::create(ov::element::f32, ov::Shape{4096, 2, 1}, {0.2});
auto mul = std::make_shared<ov::opset1::Multiply>(sub, mul_const);

auto res_const = ov::opset1::Constant::create(ov::element::i32, ov::Shape{2}, {4, 32});
auto res_const = ov::opset1::Constant::create(ov::element::i32, ov::Shape{2}, {4096, 2048});
auto reshape = std::make_shared<ov::opset1::Reshape>(mul, res_const, false);

auto fc = std::make_shared<FullyConnectedNode>(transpose_src, reshape, ov::Rank(3));
model = std::make_shared<ov::Model>(ov::NodeVector{fc}, ov::ParameterVector{src});
manager.register_pass<SplitFC>(1);
}
{
auto src = std::make_shared<ov::opset1::Parameter>(ov::element::f32, ov::Shape{ 3, 16, 1 });
auto src = std::make_shared<ov::opset1::Parameter>(ov::element::f32, ov::Shape{ 3, 2048, 1 });
auto transpose_constant_src = ov::opset1::Constant::create(ov::element::i32, ov::Shape{ 3 }, { 0, 2, 1 });
auto transpose_src = std::make_shared<ov::opset1::Transpose>(src, transpose_constant_src);

auto wgt = ov::opset1::Constant::create(ov::element::u4, ov::Shape{ 4, 2, 16 }, { 12 });
auto wgt = ov::opset1::Constant::create(ov::element::u4, ov::Shape{ 4096, 2, 1024 }, { 12 });
auto cvt_wgt_f32 = std::make_shared<ov::opset1::Convert>(wgt, ov::element::f32);

auto split_dim_node = std::make_shared<ov::op::v0::Constant>(ov::element::i32, ov::Shape{}, 0);
auto split_length = ov::opset1::Constant::create<int32_t>(ov::element::i32, ov::Shape{2}, {2, 2});
auto split_length = ov::opset1::Constant::create<int32_t>(ov::element::i32, ov::Shape{2}, {2048, 2048});

auto split_wgts = std::make_shared<ov::opset1::VariadicSplit>(cvt_wgt_f32, split_dim_node, split_length);
auto cvt_wgt0_u4 = std::make_shared<ov::opset1::Convert>(split_wgts->output(0), ov::element::u4);
Expand All @@ -241,13 +241,13 @@ TEST_F(TransformationTestsF, SplitFCTest_int4_weight_reshape) {
auto sub0 = std::make_shared<ov::opset1::Subtract>(cvt_wgt0_f32, cvt_zp0);
auto sub1 = std::make_shared<ov::opset1::Subtract>(cvt_wgt1_f32, cvt_zp1);

auto mul_const = ov::opset1::Constant::create(ov::element::f32, ov::Shape{4, 2, 1}, {0.2});
auto mul_const = ov::opset1::Constant::create(ov::element::f32, ov::Shape{4096, 2, 1}, {0.2});
auto split_mul_const = std::make_shared<ov::opset1::VariadicSplit>(mul_const, split_dim_node, split_length);

auto mul0 = std::make_shared<ov::opset1::Multiply>(sub0, split_mul_const->output(0));
auto mul1 = std::make_shared<ov::opset1::Multiply>(sub1, split_mul_const->output(1));

std::vector<int32_t> reshape_pattern_vec = {2, 32};
std::vector<int32_t> reshape_pattern_vec = {2048, 2048};
auto reshape_pattern = std::make_shared<ov::opset1::Constant>(ov::element::i32, ov::Shape{2}, reshape_pattern_vec);
auto reshape0 = std::make_shared<ov::opset1::Reshape>(mul0, reshape_pattern, false);
auto reshape1 = std::make_shared<ov::opset1::Reshape>(mul1, reshape_pattern, false);
Expand Down

0 comments on commit df846d3

Please sign in to comment.