Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Snippets][CPU] Enabled dynamic INT8|BF16 MHA tokenization on non AMX-platforms #26547

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -978,9 +978,6 @@ void Transformations::MainSnippets(void) {
return false;
if (is_fp32)
return true;
// Only FP32 dynamic MHA is supported
if (matmul->is_dynamic())
return false;
// [114487] brgemm kernel in oneDNN requires brgemm_copy_b kernel if MatMul node has transposed_b=True
// The current solution with ExtractExplicitMatMulTranspose pass is slower for non-f32 cases than using of brgemm_copy_b kernel
if (matmul->get_transpose_a() || matmul->get_transpose_b())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,6 @@ std::vector<std::string> disabledTestPatterns() {
R"(.*smoke_FakeQuantize.*/FakeQuantizeLayerTest.Inference.*TS=.*3.4.2.5.*LEVELS=255.*)",
R"(.*smoke_FakeQuantizePerChannel.*/FakeQuantizeLayerTest.Inference.*TS=.*11.10.22.19.*LEVELS=(255|256).*netPRC=f32.*)",
R"(.*smoke_MVN_5D/Mvn6LayerTest.Inference.*TS=.*3.4.2.5.*LEVELS=255.*netPRC=f16.*)",
R"(.*smoke_Snippets_MHAINT8MatMul/MHAINT8MatMul.*)",
R"(.*smoke_static/ConvertFqRnnToQuantizedRnn.*2.1.5.*2.1.1.*2.1.1.*)",
R"(.*smoke_InterpolateBicubicPillow_Layout_Test/InterpolateLayerCPUTest.CompareWithRefs/ShapeCalcMode=sizes_IS=\[?.2..20.?.?\]_TS.*1.17.4.4.*2.3.10.12.*1.17.4.4.*Sizes.*4.4.*10.20.*10.4.*PARAMETER.*0.0.0.0.*0.0.1.1.*2.3.*)",
R"(.*smoke_LoopForCommon/LoopLayerCPUTest.CompareWithRefs/.*_netType=bf16.*)",
Expand Down Expand Up @@ -560,14 +559,18 @@ std::vector<std::string> disabledTestPatterns() {
// ignored for not supported bf16 platforms
retVector.emplace_back(R"(.*smoke_Snippets_EnforcePrecision_bf16.*)");
retVector.emplace_back(R"(.*smoke_Snippets_MHAWOTransposeEnforceBF16.*)");
retVector.emplace_back(R"(.*smoke_Snippets_MHAEnforceBF16.*)");
retVector.emplace_back(R"(.*smoke_Snippets_MHA.*EnforceBF16.*)");
}
// [150842] Need to support dynamic K dimension of BF16|INT8 MatMul on AMX systems
if (ov::with_cpu_x86_avx512_core_amx()) {
retVector.emplace_back(R"(.*smoke_Snippets_MatMul/MatMul.CompareWithRefImpl/.*IS\[0\]=\[2.2.70.\?\].*T\[0\]=(u8|i8|bf16)_T\[1\]=(i8|bf16).*)");
retVector.emplace_back(R"(.*smoke_Snippets_MatMul/MatMul.CompareWithRefImpl/.*IS\[0\]=\[\?.\?.\?.\?\].*T\[0\]=(u8|i8|bf16)_T\[1\]=(i8|bf16).*)");
retVector.emplace_back(R"(.*smoke_Snippets_MatMulTransposeB.*IS\[0\]=\[\?.\?.\?.\?\].*T\[0\]=(u8|i8|bf16)_T\[1\]=(i8|bf16).*)");
retVector.emplace_back(R"(.*smoke_Snippets_MatMulBias.*IS\[0\]=\[\?.\?.\?.\?\].*T\[0\]=(u8|i8|bf16)_T\[1\]=(i8|bf16).*)");

retVector.emplace_back(R"(.*smoke_Snippets_MHA.*BF16.*/MHA.*IS\[0\]=\[(\?|1).(\?|4).(\?|12).(\?|64)\].*)");
retVector.emplace_back(R"(.*smoke_Snippets_MHA.*BF16.*/MHA.*IS\[0\]=\[\?.\?.\?\].*)");
retVector.emplace_back(R"(.*smoke_Snippets_(MHAINT8MatMul|MHAQuantMatMul0|MHAFQAfterMatMul_4D|smoke_Snippets_MHAFQ).*IS\[0\]=\[\?.\?.\?\.\?].*)");
}
#ifdef SNIPPETS_LIBXSMM_TPP
// GN in TPP requires exposing tmp Buffer results outside the loop (ticket: 151234)
Expand Down

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions src/tests/functional/plugin/shared/include/snippets/mha.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ class MHABase : virtual public SnippetsTestsCommon {
void generate_inputs(const std::vector<ov::Shape>& targetInputStaticShapes) override;
virtual std::shared_ptr<SnippetsFunctionBase> get_subgraph() const = 0;
virtual void init_params(std::vector<InputShape>& input_shapes, ov::element::Type& prc, ov::AnyMap& additional_config) = 0;
virtual void init_thresholds();

size_t m_thread_count;
std::vector<ov::element::Type> m_input_types;
Expand Down Expand Up @@ -88,6 +89,7 @@ class MHATransposedB : public MHA {
class MHAINT8MatMul : public MHA {
protected:
std::shared_ptr<SnippetsFunctionBase> get_subgraph() const override;
void init_thresholds() override;
};

class MHAQuantMatMul0 : public MHA {
Expand All @@ -103,6 +105,7 @@ class MHAFQAfterMatMul : public MHA {
class MHAFQ : public MHA {
protected:
std::shared_ptr<SnippetsFunctionBase> get_subgraph() const override;
void init_thresholds() override;
};

class MHAWithExtractedReshape : public MHA {
Expand Down
20 changes: 17 additions & 3 deletions src/tests/functional/plugin/shared/src/snippets/mha.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,15 +53,19 @@ void MHABase::SetUp() {
configuration.insert({"SNIPPETS_MODE", "IGNORE_CALLBACK"});
}

setInferenceType(prc);
inType = outType = prc;
setInferenceType(prc);
init_thresholds();
}

void MHABase::init_thresholds() {
// Note: Libxsmm calculates Exp in a slightly different way, so the abs values might differ a bit. Ticket: 130699
#ifdef SNIPPETS_LIBXSMM_TPP
abs_threshold = 1e-6;
#endif
if (prc == ov::element::bf16)
if (inType == ov::element::bf16)
rel_threshold = 0.05f;
}
}

std::string MHA::getTestCaseName(testing::TestParamInfo<ov::test::snippets::MHAParams> obj) {
std::vector<InputShape> input_shapes;
Expand Down Expand Up @@ -194,6 +198,11 @@ std::shared_ptr<SnippetsFunctionBase> MHAINT8MatMul::get_subgraph() const {
return std::make_shared<ov::test::snippets::MHAINT8MatMulFunction>(inputDynamicShapes);
}

void MHAINT8MatMul::init_thresholds() {
MHABase::init_thresholds();
abs_threshold = 4e-6;
}

std::shared_ptr<SnippetsFunctionBase> MHAQuantMatMul0::get_subgraph() const {
return std::make_shared<ov::test::snippets::MHAQuantMatMul0Function>(inputDynamicShapes);
}
Expand All @@ -206,6 +215,11 @@ std::shared_ptr<SnippetsFunctionBase> MHAFQ::get_subgraph() const {
return std::make_shared<ov::test::snippets::MHAFQFunction>(inputDynamicShapes);
}

void MHAFQ::init_thresholds() {
MHABase::init_thresholds();
abs_threshold = 0.016;
}

std::shared_ptr<SnippetsFunctionBase> MHAMulAdd::get_subgraph() const {
return std::make_shared<ov::test::snippets::MHAMulAddFunction>(inputDynamicShapes);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -235,9 +235,7 @@ class MHAWOTransposeSplitMFunction : public MHAWOTransposeFunction {
* FakeQuantize i8
* \ /
* Add
* Reshape0
* Softmax
* Reshape1 Transpose2[0,2,1,3]
* Softmax Transpose2[0,2,1,3]
* \ /
* MatMul1
* FakeQuantize i8
Expand All @@ -261,9 +259,7 @@ class MHAFQAfterMatMulFunction : public SnippetsFunctionBase {
* FakeQuantize i8
* \ /
* Add
* Reshape0
* Softmax
* Reshape1 FakeQuantize i8
* Softmax FakeQuantize i8
* FakeQuantize u8 Transpose2[0,2,1,3]
* \ /
* MatMul1
Expand All @@ -281,20 +277,17 @@ class MHAINT8MatMulFunction : public SnippetsFunctionBase {
};

/* Graph:
* FakeQuantize i8 Reshape1
* Reshape0 Transpose1[0,2,3,1]
* FakeQuantize i8 Transpose1[0,2,3,1]
* Transpose0[0,2,1,3] FakeQuantize i8
* \ /
* MatMul0
* \ /
* Add Reshape2
* Add
* Softmax Transpose2[0,2,1,3]
* \ /
* MatMul1
* FakeQuantize i8
* Transpose3[0,2,1,3]
* Reshape3
* Note: Reshapes are tosplit Tokenization between FQs and deq Mul and MHA since Snippets::Ignore_Callback may be enabled
*/
class MHAQuantMatMul0Function : public SnippetsFunctionBase {
public:
Expand Down
Loading
Loading