Skip to content

Commit

Permalink
Alexandra's comments applied
Browse files Browse the repository at this point in the history
  • Loading branch information
v-Golubev committed Sep 25, 2024
1 parent 00d2b90 commit 44e463d
Show file tree
Hide file tree
Showing 17 changed files with 276 additions and 300 deletions.
2 changes: 0 additions & 2 deletions src/common/snippets/src/pass/fc_tokenization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,8 @@

#include "snippets/pass/fc_tokenization.hpp"

#include "openvino/pass/graph_rewrite.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "snippets/itt.hpp"
#include "snippets/op/subgraph.hpp"
#include "snippets/utils/tokenization_utils.hpp"

ov::snippets::pass::TokenizeFCSnippets::TokenizeFCSnippets(const SnippetsTokenization::Config& config) {
Expand Down
5 changes: 3 additions & 2 deletions src/common/snippets/src/pass/tokenization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,9 @@ bool SnippetsTokenization::run_on_model(const std::shared_ptr<ov::Model>& m) {

manager.register_pass<EnumerateNodes>();
manager.register_pass<ExtractReshapesFromMHA>();
// This pass mustn't be registered in GraphRewrite with other tokenization passes
// since it changes the nodes after the matched root node
// This pass mustn't be registered in GraphRewrite with other tokenization passes because of 2 reasons:
// 1. It has higher priority than other tokenization passes
// 2. It changes the nodes after the matched root node
manager.register_pass<TokenizeMHASnippets>(m_config);

auto tokenization_passes = manager.register_pass<ov::pass::GraphRewrite>();
Expand Down
16 changes: 7 additions & 9 deletions src/common/snippets/src/utils/tokenization_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ bool tokenize_node(const std::shared_ptr<ov::Node>& node, const SnippetsTokeniza
op::update_out_tensor_name(subgraph);
};

auto abort_with_strategy = [&](const std::string& message_reset, const std::string& message_abort = "") {
auto abort = [&](const std::string& message) {
remark(3) << message << std::endl;
create_single_node_subgraph(node);
return true;
};
Expand Down Expand Up @@ -203,7 +204,7 @@ bool tokenize_node(const std::shared_ptr<ov::Node>& node, const SnippetsTokeniza
// todo: In principle, we can still attach the node to the subgraph if cyclic dependency is introduced during ternary merge.
// Need to support.
if (cyclicDependencyIsIntoduced(to_replace_with, currentTopoBounds))
return abort_with_strategy("Attempt to perform recurrent merge for cyclic-dependent subgraphs. Aborting.");
return abort("Attempt to perform recurrent merge for cyclic-dependent subgraphs. Aborting.");
for (const auto& output : internal_consumers) {
for (auto consumer : output.get_target_inputs()) {
auto other_body = clones[subgraph->get_input_node_shared_ptr(i)];
Expand Down Expand Up @@ -260,7 +261,7 @@ bool tokenize_node(const std::shared_ptr<ov::Node>& node, const SnippetsTokeniza
}

if (!ov::is_type<ov::op::v0::Parameter>(grandparent)) {
return abort_with_strategy("Convert supports only as Input and as Result of subgraph. Aborting");
return abort("Convert supports only as Input and as Result of subgraph. Aborting");
}
}
// Result op has a single input
Expand Down Expand Up @@ -288,7 +289,7 @@ bool tokenize_node(const std::shared_ptr<ov::Node>& node, const SnippetsTokeniza
fusedNames += node->get_friendly_name();
num_result_children += get_num_result_children(node);
if (num_result_children > 1)
return abort_with_strategy("New subgraph is created since too many Result children are detected");
return abort("New subgraph is created since too many Result children are detected");

auto body_node = node->copy_with_new_inputs(internal_inputs);
body_node->set_friendly_name(node->get_friendly_name());
Expand Down Expand Up @@ -380,10 +381,7 @@ bool tokenize_node(const std::shared_ptr<ov::Node>& node, const SnippetsTokeniza
const std::string message_reset = "new subgraph is created. Impossible to schedule subgraph with " +
std::to_string(body_parameters.size()) + " inputs, " + std::to_string(body_results.size()) + " outputs and " +
std::to_string(hidden_data_count) + " non-scalar constants and " + std::to_string(unique_buffer_count) + "buffers.";
const std::string message_abort = "failed to continue subgraph. Impossible to schedule subgraph with " +
std::to_string(body_parameters.size()) + " inputs, " + std::to_string(body_results.size()) + " outputs and " +
std::to_string(hidden_data_count) + " non-scalar constants and " + std::to_string(unique_buffer_count) + "buffers.";
return abort_with_strategy(message_reset, message_abort);
return abort(message_reset);
}

auto body = op::create_body(node->get_friendly_name(), body_results, body_parameters);
Expand All @@ -402,7 +400,7 @@ bool tokenize_node(const std::shared_ptr<ov::Node>& node, const SnippetsTokeniza
}

if (outputs_are_not_broadcastable(subgraph))
return abort_with_strategy("New subgraph is created due to outputs of a subgraph not broadcastable.");
return abort("New subgraph is created due to outputs of a subgraph not broadcastable.");

for (size_t i = 0; i < subgraph->get_output_size(); ++i) {
for (auto target_input : subgraph_result_inputs[i]) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "snippets/matmul.hpp"

#include "common_test_utils/test_constants.hpp"
#include "openvino/runtime/system_conf.hpp"

namespace ov {
namespace test {
namespace snippets {
namespace {
static inline std::vector<std::vector<element::Type>> quantized_precisions() {
std::vector<std::vector<element::Type>> prc = {};
// In Snippets MatMul INT8 is supported only on VNNI/AMX platforms
if (ov::with_cpu_x86_avx512_core_vnni() || ov::with_cpu_x86_avx512_core_amx_int8()) {
prc.emplace_back(std::vector<element::Type>{element::i8, element::i8});
prc.emplace_back(std::vector<element::Type>{element::u8, element::i8});
}
return prc;
}

static inline std::vector<std::vector<element::Type>> precisions(bool only_fp32 = true) {
std::vector<std::vector<element::Type>> prc = {
{element::f32, element::f32},
};
// Note: TPP doesn't support low precisions yet
#ifndef SNIPPETS_LIBXSMM_TPP
if (!only_fp32) {
auto quant = quantized_precisions();
std::copy(quant.begin(), quant.end(), std::back_inserter(prc));
// In Snippets MatMul BF16 is supported only on bf16/AMX platforms
if (ov::with_cpu_x86_bfloat16() || ov::with_cpu_x86_avx512_core_amx_bf16()) {
prc.emplace_back(std::vector<element::Type>{element::bf16, element::bf16});
}
}
#endif
return prc;
}

std::vector<std::vector<ov::test::InputShape>> fc_input_shapes{
{
{PartialShape{-1, -1, -1, 2500}, {{2, 1, 32, 2500}, {1, 3, 80, 2500}}},
{{}, {{2500, 256}}}
},
};

INSTANTIATE_TEST_SUITE_P(smoke_Snippets_FullyConnected, MatMul,
::testing::Combine(
::testing::ValuesIn(fc_input_shapes),
::testing::ValuesIn(precisions(false)),
::testing::Values(MatMulType::FullyConnected),
::testing::Values(1), // MatMul
::testing::Values(1), // Tokenized MatMul
::testing::Values(ov::test::utils::DEVICE_CPU)),
MatMul::getTestCaseName);

INSTANTIATE_TEST_SUITE_P(smoke_Snippets_FullyConnectedFQ, MatMulFQ,
::testing::Combine(
::testing::ValuesIn(fc_input_shapes),
::testing::ValuesIn(precisions()),
::testing::Values(MatMulType::FullyConnected),
::testing::Values(1), // MatMul;
::testing::Values(1), // Tokenized MatMul
::testing::Values(ov::test::utils::DEVICE_CPU)),
MatMul::getTestCaseName);

INSTANTIATE_TEST_SUITE_P(smoke_Snippets_FullyConnectedEltwiseChain, MatMulEltwiseChain,
::testing::Combine(
::testing::ValuesIn(fc_input_shapes),
::testing::ValuesIn(precisions()),
::testing::Values(MatMulType::FullyConnected),
::testing::Values(1), // MatMul
::testing::Values(1), // Tokenized MatMul
::testing::Values(ov::test::utils::DEVICE_CPU)),
MatMul::getTestCaseName);

std::vector<std::vector<ov::test::InputShape>> fc_cascade_shapes{
{
{PartialShape{-1, -1, -1, 2500}, {{2, 1, 32, 2500}, {1, 3, 80, 2500}, {2, 1, 32, 2500}}},
{PartialShape{}, {{2500, 128}}},
{PartialShape{}, {{128, 64}}},
},
};

INSTANTIATE_TEST_SUITE_P(smoke_Snippets_FullyConnectedEltwiseChainCascade, MatMulEltwiseChainCascade,
::testing::Combine(
::testing::ValuesIn(fc_cascade_shapes),
::testing::ValuesIn(precisions()),
::testing::Values(MatMulType::FullyConnected),
::testing::Values(1),
::testing::Values(1),
::testing::Values(ov::test::utils::DEVICE_CPU)),
MatMul::getTestCaseName);

std::vector<std::vector<ov::test::InputShape>> fc_transpose_b_shapes{
{
{PartialShape{-1, -1, -1, 2500}, {{2, 1, 32, 2500}, {1, 3, 80, 2500}}},
{{}, {{256, 2500}}}
},
};

INSTANTIATE_TEST_SUITE_P(smoke_Snippets_FullyConnectedTransposeB, MatMulTransposeB,
::testing::Combine(
::testing::ValuesIn(fc_transpose_b_shapes),
::testing::ValuesIn(precisions(false)),
::testing::Values(MatMulType::FullyConnected),
::testing::Values(1), // MatMul
::testing::Values(1), // Tokenized MatMul
::testing::Values(ov::test::utils::DEVICE_CPU)),
MatMul::getTestCaseName);


std::vector<std::vector<ov::test::InputShape>> fc_bias_shapes{
{
{PartialShape{-1, -1, -1, 2500}, {{2, 1, 32, 2500}, {1, 3, 80, 2500}}},
{{}, {{2500, 256}}},
{PartialShape{-1, -1, -1, 256}, {{1, 1, 32, 256}, {1, 1, 80, 256}}}
},
};

INSTANTIATE_TEST_SUITE_P(smoke_Snippets_FullyConnectedBias, MatMulBias,
::testing::Combine(
::testing::ValuesIn(fc_bias_shapes),
::testing::ValuesIn(precisions(false)),
::testing::Values(MatMulType::FullyConnected),
::testing::Values(1), // Subgraph;
::testing::Values(1), // Tokenized MatMul+Bias
::testing::Values(ov::test::utils::DEVICE_CPU)),
MatMul::getTestCaseName);

INSTANTIATE_TEST_SUITE_P(smoke_Snippets_FullyConnectedBiasQuantized, MatMulBiasQuantized,
::testing::Combine(
::testing::ValuesIn(fc_bias_shapes),
::testing::ValuesIn(quantized_precisions()),
::testing::Values(MatMulType::FullyConnected),
::testing::Values(1), // Subgraph
::testing::Values(1), // Tokenized MatMul+Bias
::testing::Values(ov::test::utils::DEVICE_CPU)),
MatMul::getTestCaseName);

std::vector<std::vector<ov::test::InputShape>> fc_quantized_shapes{
{
{PartialShape{-1, -1, -1, 2500}, {{2, 1, 32, 2500}, {1, 3, 80, 2500}}},
{{}, {{2500, 256}}},
{{}, {{256, 64}}}
},
};

INSTANTIATE_TEST_SUITE_P(smoke_Snippets_FullyConnectedsQuantized, MatMulsQuantized,
::testing::Combine(
::testing::ValuesIn(fc_quantized_shapes),
::testing::ValuesIn(quantized_precisions()),
::testing::Values(MatMulType::FullyConnected),
::testing::Values(1), // Reshape on weights is folded => only 1 Subgraph remains
::testing::Values(1), // Tokenized [MatMul+FQ+Matmul]
::testing::Values(ov::test::utils::DEVICE_CPU)),
MatMul::getTestCaseName);

INSTANTIATE_TEST_SUITE_P(smoke_Snippets_FullyConnectedsQuantizedSoftmax, MatMulsQuantizedSoftmax,
::testing::Combine(
::testing::ValuesIn(fc_quantized_shapes),
::testing::ValuesIn(quantized_precisions()),
::testing::Values(MatMulType::FullyConnected),
::testing::Values(1), // Reshape on weights is folded => only 1 Subgraph remains
::testing::Values(1), // Tokenized [MatMul+FQ+Matmul]
::testing::Values(ov::test::utils::DEVICE_CPU)),
MatMul::getTestCaseName);
} // namespace
} // namespace snippets
} // namespace test
} // namespace ov
Loading

0 comments on commit 44e463d

Please sign in to comment.