diff --git a/src/bindings/js/node/lib/addon.ts b/src/bindings/js/node/lib/addon.ts index 060af2cfec92e8..24c9d780aa9f7e 100644 --- a/src/bindings/js/node/lib/addon.ts +++ b/src/bindings/js/node/lib/addon.ts @@ -21,6 +21,8 @@ type elementTypeString = | 'f32' | 'string'; +type OVAny = string | number | boolean; + /** * Core represents an OpenVINO runtime Core entity. * @@ -48,7 +50,7 @@ interface Core { compileModel( model: Model, deviceName: string, - config?: { [propertyName: string]: string }, + config?: Record, ): Promise; /** * Asynchronously reads a model and creates a compiled model @@ -67,7 +69,7 @@ interface Core { compileModel( modelPath: string, deviceName: string, - config?: { [propertyName: string]: string }, + config?: Record, ): Promise; /** * A synchronous version of {@link Core.compileModel}. @@ -76,7 +78,7 @@ interface Core { compileModelSync( model: Model, deviceName: string, - config?: { [propertyName: string]: string }, + config?: Record, ): CompiledModel; /** * A synchronous version of {@link Core.compileModel}. @@ -85,7 +87,7 @@ interface Core { compileModelSync( modelPath: string, deviceName: string, - config?: { [propertyName: string]: string }, + config?: Record, ): CompiledModel; /** * It returns a list of available inference devices. @@ -101,7 +103,7 @@ interface Core { * It gets the properties dedicated to device behaviour. * @param propertyName A property name. */ - getProperty(propertyName: string): string | number | boolean; + getProperty(propertyName: string): OVAny; /** * It gets the properties dedicated to device behaviour. @@ -111,7 +113,7 @@ interface Core { getProperty( deviceName: string, propertyName: string, - ): string | number | boolean; + ): OVAny; /** * It returns information on the version of device plugins. * @param deviceName A device name to identify a plugin. @@ -135,7 +137,7 @@ interface Core { importModel( modelStream: Buffer, device: string, - config?: { [key: string]: string | number | boolean }, + config?: Record, ): Promise; /** * A synchronous version of {@link Core.importModel}. @@ -144,7 +146,7 @@ interface Core { importModelSync( modelStream: Buffer, device: string, - config?: { [key: string]: string | number | boolean }, + config?: Record, ): CompiledModel; /** * It reads models from the IR / ONNX / PDPD / TF and TFLite formats. @@ -197,16 +199,13 @@ interface Core { * It sets the properties. * @param properties An object with the property name - property value pairs. */ - setProperty(properties: { [key: string]: string | number | boolean }): void; + setProperty(properties: Record): void; /** * It sets the properties for a device. * @param deviceName The name of a device. * @param properties An object with the property name - property value pairs. */ - setProperty( - deviceName: string, - properties: { [key: string]: string | number | boolean }, - ): void; + setProperty(deviceName: string, properties: Record): void; /** * It queries the device if it supports specified model with the specified * properties. @@ -218,8 +217,8 @@ interface Core { queryModel( model: Model, deviceName: string, - properties?: {[key: string]: string | number | boolean}, - ): {[key: string]: string | number | boolean}; + properties?: Record, + ): { [key: string]: string }; } interface CoreConstructor { new (): Core; @@ -325,7 +324,7 @@ interface CompiledModel { * @param propertyName A string to get the property value. * @returns The property value. */ - getProperty(propertyName: string): string | number | boolean; + getProperty(propertyName: string): OVAny; /** * It creates an inference request object used to infer the compiled model. * @return {InferRequest} @@ -380,9 +379,7 @@ interface CompiledModel { * @param property An object with the key-value pairs. * (property name, property value) */ - setProperty(properties: { - [propertyName: string]: string | number | boolean; - }): void; + setProperty(properties: Record): void; } /** diff --git a/src/bindings/js/node/tests/unit/core.test.js b/src/bindings/js/node/tests/unit/core.test.js index 6cf431a38b5030..f62adda9f90f9c 100644 --- a/src/bindings/js/node/tests/unit/core.test.js +++ b/src/bindings/js/node/tests/unit/core.test.js @@ -12,11 +12,11 @@ describe('ov.Core tests', () => { before(async () => { await isModelAvailable(testModels.testModelFP32); }); - + beforeEach(() => { core = new ov.Core(); }); - + it('Core.setProperty()', () => { const tmpDir = '/tmp'; @@ -83,29 +83,29 @@ describe('ov.Core tests', () => { it('Core.queryModel() with empty parameters should throw an error', () => { assert.throws( () => core.queryModel().then(), - /'queryModel' method called with incorrect parameters./ - ) + /'queryModel' method called with incorrect parameters./, + ); }); it('Core.queryModel() with less arguments should throw an error', () => { assert.throws( - () => core.queryModel("Unexpected Argument").then(), - /'queryModel' method called with incorrect parameters./ - ) + () => core.queryModel('Unexpected Argument').then(), + /'queryModel' method called with incorrect parameters./, + ); }); it('Core.queryModel() with incorrect arguments should throw an error', () => { const model = core.readModelSync(getModelPath().xml); assert.throws( - () => core.queryModel(model, "arg1", "arg2").then(), - /'queryModel' method called with incorrect parameters./ - ) + () => core.queryModel(model, 'arg1', 'arg2').then(), + /'queryModel' method called with incorrect parameters./, + ); }); it('Core.queryModel() should have device in the result values', () => { const model = core.readModelSync(getModelPath().xml); const device = 'CPU'; - const query_model = core.queryModel(model, device); - assert(Object.values(query_model).includes(device)); + const queryModel = core.queryModel(model, device); + assert(Object.values(queryModel).includes(device)); }); }); diff --git a/src/common/snippets/include/snippets/utils/utils.hpp b/src/common/snippets/include/snippets/utils/utils.hpp index 7bad968866e3c2..f7e584d48a905c 100644 --- a/src/common/snippets/include/snippets/utils/utils.hpp +++ b/src/common/snippets/include/snippets/utils/utils.hpp @@ -74,9 +74,21 @@ constexpr inline bool implication(bool cause, bool cond) { } template -inline T div_up(const T a, const U b) { - OPENVINO_ASSERT(b != 0, "Divider must not be zero"); - return static_cast((a + b - 1) / b); +static inline auto div_up(const T lhs, const U rhs) -> decltype((lhs + rhs - 1) / rhs) { + OPENVINO_ASSERT(rhs != 0, "Divider must not be zero"); + if (((std::is_same::value || std::is_same::value) && utils::is_dynamic_value(lhs)) || + ((std::is_same::value || std::is_same::value) && utils::is_dynamic_value(rhs))) + return utils::get_dynamic_value(); + return (lhs + rhs - 1) / rhs; +} + +template +static inline auto rnd_up(const T lhs, const U rhs) -> decltype(div_up(lhs, rhs) * rhs) { + const T div_up_res = div_up(lhs, rhs); + if (((std::is_same::value || std::is_same::value) && utils::is_dynamic_value(div_up_res)) || + ((std::is_same::value || std::is_same::value) && utils::is_dynamic_value(rhs))) + return utils::get_dynamic_value(); + return div_up_res * rhs; } inline bool is_dynamic_vdims(const VectorDims& shape) { diff --git a/src/core/src/bound_evaluate.cpp b/src/core/src/bound_evaluate.cpp index 22b91a15e3dcee..f1c6a0601eea90 100644 --- a/src/core/src/bound_evaluate.cpp +++ b/src/core/src/bound_evaluate.cpp @@ -494,14 +494,12 @@ bool ov::interval_bound_evaluator(const Node* node, vector_of_output_variants.emplace_back(output.get_element_type(), output.get_shape()); } - node->evaluate(vector_of_output_variants, input_variant); + if (!node->evaluate(vector_of_output_variants, input_variant)) { + return false; + }; TensorVector vector_of_unsqueezed_output_variants; for (const auto& output : vector_of_output_variants) { - if (!output) { - return false; - } - auto unsqueezed_shape = output.get_shape(); unsqueezed_shape.insert(unsqueezed_shape.begin(), 1); diff --git a/src/plugins/intel_cpu/src/emitters/plugin/x64/debug_capabilities.cpp b/src/plugins/intel_cpu/src/emitters/plugin/x64/debug_capabilities.cpp new file mode 100644 index 00000000000000..01af9dbde7fe01 --- /dev/null +++ b/src/plugins/intel_cpu/src/emitters/plugin/x64/debug_capabilities.cpp @@ -0,0 +1,220 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#ifdef CPU_DEBUG_CAPS + +#include "debug_capabilities.hpp" +#include +#include + +namespace ov { +namespace intel_cpu { + +using namespace Xbyak; +using namespace dnnl::impl::cpu::x64; + +template void RegPrinter::print(jit_generator &h, Xmm reg, const char *name); +template void RegPrinter::print(jit_generator &h, Xmm reg, const char *name); +template void RegPrinter::print(jit_generator &h, Ymm reg, const char *name); +template void RegPrinter::print(jit_generator &h, Ymm reg, const char *name); +template void RegPrinter::print(jit_generator &h, Zmm reg, const char *name); +template void RegPrinter::print(jit_generator &h, Zmm reg, const char *name); +template void RegPrinter::print(jit_generator &h, Reg64 reg, const char *name); +template void RegPrinter::print(jit_generator &h, Reg64 reg, const char *name); +template void RegPrinter::print(jit_generator &h, Reg32 reg, const char *name); +template void RegPrinter::print(jit_generator &h, Reg32 reg, const char *name); +template void RegPrinter::print(jit_generator &h, Reg16 reg, const char *name); +template void RegPrinter::print(jit_generator &h, Reg16 reg, const char *name); +template void RegPrinter::print(jit_generator &h, Reg8 reg, const char *name); +template void RegPrinter::print(jit_generator &h, Reg8 reg, const char *name); + +template +void RegPrinter::print_reg_prc(const char *name, const char *ori_name, T *ptr) { + std::stringstream ss; + if (name) ss << name << " | "; + ss << ori_name << ": "; + if (std::is_floating_point::value) { + ss << *ptr; + } else { + if (std::is_signed::value) { + ss << static_cast(*ptr); + } else { + ss << static_cast(*ptr); + } + } + ss << std::endl; + std::cout << ss.str(); +} + +template +void RegPrinter::print_vmm_prc(const char *name, const char *ori_name, PRC_T *ptr) { + std::stringstream ss; + if (name) ss << name << " | "; + ss << ori_name << ": {" << ptr[0]; + for (size_t i = 1; i < vlen / sizeof(float); i++) { + ss << ", " << ptr[i]; + } + ss << "}" << std::endl; + std::cout << ss.str(); +} +template void RegPrinter::print_vmm_prc(const char *name, const char *ori_name, float *ptr); +template void RegPrinter::print_vmm_prc(const char *name, const char *ori_name, float *ptr); +template void RegPrinter::print_vmm_prc(const char *name, const char *ori_name, float *ptr); +template void RegPrinter::print_vmm_prc(const char *name, const char *ori_name, int *ptr); +template void RegPrinter::print_vmm_prc(const char *name, const char *ori_name, int *ptr); +template void RegPrinter::print_vmm_prc(const char *name, const char *ori_name, int *ptr); + +template +struct vmm_traits{}; + +template <> +struct vmm_traits { + static constexpr size_t vmm_len = 16; + static constexpr size_t vmm_cnt = 16; +}; + +template <> +struct vmm_traits { + static constexpr size_t vmm_len = 32; + static constexpr size_t vmm_cnt = 16; +}; + +template <> +struct vmm_traits { + static constexpr size_t vmm_len = 64; + static constexpr size_t vmm_cnt = 32; +}; + +template +void RegPrinter::save_vmm(jit_generator &h) { + h.sub(h.rsp, vmm_traits::vmm_len * vmm_traits::vmm_cnt); + for (size_t i = 0; i < vmm_traits::vmm_cnt; i++) { + h.uni_vmovups(h.ptr[h.rsp + i * vmm_traits::vmm_len], T(i)); + } +} + +template +void RegPrinter::restore_vmm(jit_generator &h) { + for (size_t i = 0; i < vmm_traits::vmm_cnt; i++) { + h.uni_vmovups(T(i), h.ptr[h.rsp + i * vmm_traits::vmm_len]); + } + h.add(h.rsp, vmm_traits::vmm_len * vmm_traits::vmm_cnt); +} + +void RegPrinter::save_reg(jit_generator &h) { + h.sub(h.rsp, reg_len * reg_cnt); + for (size_t i = 0; i < reg_cnt; i++) { + h.mov(h.ptr[h.rsp + i * reg_len], Reg64(i)); + } +} + +void RegPrinter::restore_reg(jit_generator &h) { + for (size_t i = 0; i < reg_cnt; i++) { + h.mov(Reg64(i), h.ptr[h.rsp + i * reg_len]); + } + h.add(h.rsp, reg_len * reg_cnt); +} + +void RegPrinter::preamble(jit_generator &h) { + save_reg(h); + mayiuse(cpu_isa_t::avx512_core) ? save_vmm(h) : (mayiuse(cpu_isa_t::avx2) ? + save_vmm(h) : save_vmm(h)); +} + +void RegPrinter::postamble(jit_generator &h) { + mayiuse(cpu_isa_t::avx512_core) ? restore_vmm(h) : (mayiuse(cpu_isa_t::avx2) ? + restore_vmm(h) : restore_vmm(h)); + restore_reg(h); +} + +// ABI requires 16-bype stack alignment before a call +void RegPrinter::align_rsp(jit_generator &h) { + constexpr int alignment = 16; + h.mov(h.r15, h.rsp); + h.and_(h.rsp, ~(alignment - 1)); +} + +void RegPrinter::restore_rsp(jit_generator &h) { + h.mov(h.rsp, h.r15); +} + +template +void RegPrinter::print_vmm(jit_generator &h, REG_T vmm, const char *name) { + preamble(h); + + h.push(h.rax); + h.push(abi_param1); + h.push(abi_param2); + h.push(abi_param3); + { + const int vlen = vmm.isZMM() ? 64 : (vmm.isYMM() ? 32 : 16); + h.sub(h.rsp, vlen); + h.uni_vmovups(h.ptr[h.rsp], vmm); + + h.mov(abi_param3, h.rsp); + h.mov(abi_param2, reinterpret_cast(vmm.toString())); + h.mov(abi_param1, reinterpret_cast(name)); + if (vmm.isZMM()) { + auto p = &print_vmm_prc; + h.mov(h.rax, reinterpret_cast(p)); + } else if (vmm.isYMM()) { + auto p = &print_vmm_prc; + h.mov(h.rax, reinterpret_cast(p)); + } else { + auto p = &print_vmm_prc; + h.mov(h.rax, reinterpret_cast(p)); + } + align_rsp(h); + h.call(h.rax); + restore_rsp(h); + + h.add(h.rsp, vlen); + } + + h.pop(abi_param3); + h.pop(abi_param2); + h.pop(abi_param1); + h.pop(h.rax); + + postamble(h); +} + +template +void RegPrinter::print_reg(jit_generator &h, REG_T reg, const char *name) { + preamble(h); + + h.push(h.rax); + h.push(abi_param1); + h.push(abi_param2); + h.push(abi_param3); + { + const int rlen = reg.getBit() / 8; + h.sub(h.rsp, rlen); + h.mov(h.ptr[h.rsp], reg); + + h.mov(abi_param3, h.rsp); + h.mov(abi_param2, reinterpret_cast(reg.toString())); + h.mov(abi_param1, reinterpret_cast(name)); + auto p = &print_reg_prc; + h.mov(h.rax, reinterpret_cast(p)); + align_rsp(h); + h.call(h.rax); + restore_rsp(h); + + h.add(h.rsp, rlen); + } + + h.pop(abi_param3); + h.pop(abi_param2); + h.pop(abi_param1); + h.pop(h.rax); + + postamble(h); +} + +} // namespace intel_cpu +} // namespace ov + + +#endif // CPU_DEBUG_CAPS diff --git a/src/plugins/intel_cpu/src/emitters/plugin/x64/debug_capabilities.hpp b/src/plugins/intel_cpu/src/emitters/plugin/x64/debug_capabilities.hpp new file mode 100644 index 00000000000000..fd7135b17bf5b9 --- /dev/null +++ b/src/plugins/intel_cpu/src/emitters/plugin/x64/debug_capabilities.hpp @@ -0,0 +1,97 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#ifdef CPU_DEBUG_CAPS + +#include "cpu/x64/jit_generator.hpp" + +namespace ov { +namespace intel_cpu { + +// Usage +// 1. Include this headfile where JIT kennels of CPU plugin are implemented for Register printing +// 2. Invoke RegPrinter::print method. Here are some examples. Note that user friendly register name +// will be printed, if it has been set. Current implementation doesn't buffer the name. So if you +// choose to set a name for the register, do not use local variable to pass the name, just pass a +// direct string to the interface like examples. While Original Xbyak register name will always be +// printed. +// Example 1: +// Invocation: RegPrinter::print(*this, vmm_val, "vmm_val"); +// Console: vmm_val | ymm0: {30, 20, 25, 29, 24, 31, 27, 23} +// +// Example 2: +// Invocation: RegPrinter::print(*this, vmm_val); +// Console: ymm0: {30, 20, 25, 29, 24, 31, 27, 23} +// +// Example 3: +// Invocation: RegPrinter::print(*this, vmm_idx, "vmm_idx"); +// Console: vmm_idx | ymm1: {5, 6, 0, 2, 0, 6, 6, 6} +// +// Example 4: +// Invocation: RegPrinter::print(*this, reg_work_amount, "reg_work_amount"); +// Console: reg_work_amount | r13: 8 +// +// Example 5: +// Invocation: RegPrinter::print(*this, reg_work_amount); +// Console: r13: 8 +// +// Example 6: +// Invocation: RegPrinter::print(*this, reg_tmp_64, "reg_tmp_64"); +// Console: reg_tmp_64 | r15: 1 +// +// Parameter +// The following combinations of Register types and precisions are supported. +// fp32 int32 int8 u8 +// Xmm Yes Yes No No +// Ymm Yes Yes No No +// Zmm Yes Yes No No +// Reg64 Yes Yes No No +// Reg32 Yes Yes No No +// Reg16 No No Yes Yes +// Reg8 No No Yes Yes + +class RegPrinter { +public: + using jit_generator = dnnl::impl::cpu::x64::jit_generator; + template ::value, int>::type = 0> + static void print(jit_generator &h, REG_T reg, const char *name = nullptr) { + print_vmm(h, reg, name); + } + template ::value, int>::type = 0> + static void print(jit_generator &h, REG_T reg, const char *name = nullptr) { + print_reg(h, reg, name); + } + +private: + RegPrinter() {} + template + static void print_vmm(jit_generator &h, REG_T vmm, const char *name); + template + static void print_reg(jit_generator &h, REG_T reg, const char *name); + template + static void print_vmm_prc(const char *name, const char *ori_name, PRC_T *ptr); + template + static void print_reg_prc(const char *name, const char *ori_name, T *val); + static void preamble(jit_generator &h); + static void postamble(jit_generator &h); + template + static void save_vmm(jit_generator &h); + template + static void restore_vmm(jit_generator &h); + static void save_reg(jit_generator &h); + static void restore_reg(jit_generator &h); + static void align_rsp(jit_generator &h); + static void restore_rsp(jit_generator &h); + static constexpr size_t reg_len = 8; + static constexpr size_t reg_cnt = 16; +}; + +} // namespace intel_cpu +} // namespace ov + +#endif // CPU_DEBUG_CAPS diff --git a/src/plugins/intel_cpu/src/emitters/plugin/x64/jit_emitter.cpp b/src/plugins/intel_cpu/src/emitters/plugin/x64/jit_emitter.cpp index bc1bdfadae8808..acbb04ea01af80 100644 --- a/src/plugins/intel_cpu/src/emitters/plugin/x64/jit_emitter.cpp +++ b/src/plugins/intel_cpu/src/emitters/plugin/x64/jit_emitter.cpp @@ -214,81 +214,5 @@ void jit_emitter::emit_code(const std::vector &in_idxs, const std::vecto emitter_postamble(); } -void jit_emitter::internal_call_preamble() const { - // gprs - Xbyak::Operand gprs_to_save[] = {h->r8, h->r9, h->r10, h->r11, h->r12, h->r13, h->r14, h->r15, - h->rax, h->rbx, h->rcx, h->rdx, h->rdi, h->rsi, h->rbp}; - size_t n_gprs_to_save = sizeof(gprs_to_save) / sizeof(gprs_to_save[0]); - - h->sub(h->rsp, n_gprs_to_save * gpr_size); - for (size_t i = 0; i < n_gprs_to_save; ++i) - h->mov(h->ptr[h->rsp + i * gpr_size], gprs_to_save[i]); - - // mask regs - // need preserve based on cpu capability, instead of host isa. - // in case there are possibilty that different isa emitters exist in one subgraph KernelEmitter from perf standpoint in the future. - // e.g. other emitters isa is avx512, while this emitter isa is avx2, and internal call is used. Internal call may use avx512 and spoil k-reg. - // do not care about platform w/ avx512_common but w/o avx512_core(knight landing), which is obsoleted. - if (dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core)) { - h->sub(h->rsp, k_mask_num * k_mask_size); - for (size_t i = 0; i < k_mask_num; ++i) { - h->kmovq(h->ptr[h->rsp + i * k_mask_size], Xbyak::Opmask(static_cast(i))); - } - } - - // vector regs - // 1. Caller obligation to save vector registers as callee may use them. - // 2. There is an implicit assumption that the host code uses the same - // `isa` as the injector. Once the assumption is wrong, `vecs_count` and - // `vlen` should be replaced with `host_isa::vlen` and - // `host_isa::vecs_count`. - h->sub(h->rsp, get_max_vecs_count() * get_vec_length()); - for (size_t i = 0; i < get_max_vecs_count(); ++i) { - push_vec(h->ptr[h->rsp + i * get_vec_length()], i); - } -} - -void jit_emitter::internal_call_postamble() const { - // restore vector registers - for (int i = static_cast(get_max_vecs_count()) - 1; i >= 0; --i) { - pop_vec(static_cast(i), h->ptr[h->rsp + i * get_vec_length()]); - } - h->add(h->rsp, (get_max_vecs_count()) * get_vec_length()); - - // restore k reg - if (dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core)) { - for (int i = k_mask_num - 1; i >= 0; --i) { - h->kmovq(Xbyak::Opmask(i), h->ptr[h->rsp + i * k_mask_size]); - } - h->add(h->rsp, k_mask_num * k_mask_size); - } - - // restore gpr registers - Xbyak::Operand gprs_to_save[] = {h->r8, h->r9, h->r10, h->r11, h->r12, h->r13, h->r14, h->r15, - h->rax, h->rbx, h->rcx, h->rdx, h->rdi, h->rsi, h->rbp}; - size_t n_gprs_to_save = sizeof(gprs_to_save) / sizeof(gprs_to_save[0]); - for (int i = n_gprs_to_save - 1; i >= 0; --i) - h->mov(gprs_to_save[i], h->ptr[h->rsp + i * gpr_size]); - h->add(h->rsp, n_gprs_to_save * gpr_size); -} - -void jit_emitter::internal_call_rsp_align() const { - h->mov(h->rbx, h->rsp); - h->and_(h->rbx, 0xf); - h->sub(h->rsp, h->rbx); -#ifdef _WIN32 - // Allocate shadow space (home space) according to ABI - h->sub(h->rsp, 32); -#endif -} - -void jit_emitter::internal_call_rsp_restore() const { -#ifdef _WIN32 - // Release shadow space (home space) - h->add(h->rsp, 32); -#endif - h->add(h->rsp, h->rbx); -} - } // namespace intel_cpu } // namespace ov diff --git a/src/plugins/intel_cpu/src/emitters/plugin/x64/jit_emitter.hpp b/src/plugins/intel_cpu/src/emitters/plugin/x64/jit_emitter.hpp index ea12593ece1ab6..c5729613f1bfe5 100644 --- a/src/plugins/intel_cpu/src/emitters/plugin/x64/jit_emitter.hpp +++ b/src/plugins/intel_cpu/src/emitters/plugin/x64/jit_emitter.hpp @@ -144,13 +144,6 @@ class jit_emitter : public ov::snippets::Emitter { } } - void internal_call_preamble() const; - void internal_call_postamble() const; - // align stack on 16-byte and allocate shadow space as ABI reqiures - // callee is responsible to save and restore rbx. rbx must not be changed after call callee. - void internal_call_rsp_align() const; - void internal_call_rsp_restore() const; - virtual void validate_arguments(const std::vector&, const std::vector&) const {} #ifdef SNIPPETS_DEBUG_CAPS diff --git a/src/plugins/intel_cpu/src/emitters/plugin/x64/utils.cpp b/src/plugins/intel_cpu/src/emitters/plugin/x64/utils.cpp index 44130010b551cd..ea16122f2f9793 100644 --- a/src/plugins/intel_cpu/src/emitters/plugin/x64/utils.cpp +++ b/src/plugins/intel_cpu/src/emitters/plugin/x64/utils.cpp @@ -3,8 +3,8 @@ // #include "utils.hpp" -#include -#include + +#include "emitters/utils.hpp" namespace ov { namespace intel_cpu { @@ -12,203 +12,113 @@ namespace intel_cpu { using namespace Xbyak; using namespace dnnl::impl::cpu::x64; -template void RegPrinter::print(jit_generator &h, Xmm reg, const char *name); -template void RegPrinter::print(jit_generator &h, Xmm reg, const char *name); -template void RegPrinter::print(jit_generator &h, Ymm reg, const char *name); -template void RegPrinter::print(jit_generator &h, Ymm reg, const char *name); -template void RegPrinter::print(jit_generator &h, Zmm reg, const char *name); -template void RegPrinter::print(jit_generator &h, Zmm reg, const char *name); -template void RegPrinter::print(jit_generator &h, Reg64 reg, const char *name); -template void RegPrinter::print(jit_generator &h, Reg64 reg, const char *name); -template void RegPrinter::print(jit_generator &h, Reg32 reg, const char *name); -template void RegPrinter::print(jit_generator &h, Reg32 reg, const char *name); -template void RegPrinter::print(jit_generator &h, Reg16 reg, const char *name); -template void RegPrinter::print(jit_generator &h, Reg16 reg, const char *name); -template void RegPrinter::print(jit_generator &h, Reg8 reg, const char *name); -template void RegPrinter::print(jit_generator &h, Reg8 reg, const char *name); - -template -void RegPrinter::print_reg_prc(const char *name, const char *ori_name, T *ptr) { - std::stringstream ss; - if (name) ss << name << " | "; - ss << ori_name << ": "; - if (std::is_floating_point::value) { - ss << *ptr; - } else { - if (std::is_signed::value) { - ss << static_cast(*ptr); - } else { - ss << static_cast(*ptr); - } - } - ss << std::endl; - std::cout << ss.str(); -} +EmitABIRegSpills::EmitABIRegSpills(jit_generator* h) : h(h), isa(get_isa()) {} -template -void RegPrinter::print_vmm_prc(const char *name, const char *ori_name, PRC_T *ptr) { - std::stringstream ss; - if (name) ss << name << " | "; - ss << ori_name << ": {" << ptr[0]; - for (size_t i = 1; i < vlen / sizeof(float); i++) { - ss << ", " << ptr[i]; - } - ss << "}" << std::endl; - std::cout << ss.str(); -} -template void RegPrinter::print_vmm_prc(const char *name, const char *ori_name, float *ptr); -template void RegPrinter::print_vmm_prc(const char *name, const char *ori_name, float *ptr); -template void RegPrinter::print_vmm_prc(const char *name, const char *ori_name, float *ptr); -template void RegPrinter::print_vmm_prc(const char *name, const char *ori_name, int *ptr); -template void RegPrinter::print_vmm_prc(const char *name, const char *ori_name, int *ptr); -template void RegPrinter::print_vmm_prc(const char *name, const char *ori_name, int *ptr); - -template -struct vmm_traits{}; - -template <> -struct vmm_traits { - static constexpr size_t vmm_len = 16; - static constexpr size_t vmm_cnt = 16; -}; - -template <> -struct vmm_traits { - static constexpr size_t vmm_len = 32; - static constexpr size_t vmm_cnt = 16; -}; - -template <> -struct vmm_traits { - static constexpr size_t vmm_len = 64; - static constexpr size_t vmm_cnt = 32; -}; - -template -void RegPrinter::save_vmm(jit_generator &h) { - h.sub(h.rsp, vmm_traits::vmm_len * vmm_traits::vmm_cnt); - for (size_t i = 0; i < vmm_traits::vmm_cnt; i++) { - h.uni_vmovups(h.ptr[h.rsp + i * vmm_traits::vmm_len], T(i)); - } +EmitABIRegSpills::~EmitABIRegSpills() { + OPENVINO_ASSERT(spill_status, "postamble or preamble is missed"); + OPENVINO_ASSERT(rsp_status, "rsp_align or rsp_restore is missed"); } -template -void RegPrinter::restore_vmm(jit_generator &h) { - for (size_t i = 0; i < vmm_traits::vmm_cnt; i++) { - h.uni_vmovups(T(i), h.ptr[h.rsp + i * vmm_traits::vmm_len]); - } - h.add(h.rsp, vmm_traits::vmm_len * vmm_traits::vmm_cnt); -} +void EmitABIRegSpills::preamble() { + // gprs + Xbyak::Operand gprs_to_save[] = {h->r8, h->r9, h->r10, h->r11, h->r12, h->r13, h->r14, h->r15, + h->rax, h->rbx, h->rcx, h->rdx, h->rdi, h->rsi, h->rbp}; + size_t n_gprs_to_save = sizeof(gprs_to_save) / sizeof(gprs_to_save[0]); -void RegPrinter::save_reg(jit_generator &h) { - h.sub(h.rsp, reg_len * reg_cnt); - for (size_t i = 0; i < reg_cnt; i++) { - h.mov(h.ptr[h.rsp + i * reg_len], Reg64(i)); - } -} + h->sub(h->rsp, n_gprs_to_save * gpr_size); + for (size_t i = 0; i < n_gprs_to_save; ++i) + h->mov(h->ptr[h->rsp + i * gpr_size], gprs_to_save[i]); -void RegPrinter::restore_reg(jit_generator &h) { - for (size_t i = 0; i < reg_cnt; i++) { - h.mov(Reg64(i), h.ptr[h.rsp + i * reg_len]); + if (isa == avx512_core) { + h->sub(h->rsp, k_mask_num * k_mask_size); + for (size_t i = 0; i < k_mask_num; ++i) { + h->kmovq(h->ptr[h->rsp + i * k_mask_size], Xbyak::Opmask(static_cast(i))); + } } - h.add(h.rsp, reg_len * reg_cnt); -} -void RegPrinter::preamble(jit_generator &h) { - save_reg(h); - mayiuse(cpu_isa_t::avx512_core) ? save_vmm(h) : (mayiuse(cpu_isa_t::avx2) ? - save_vmm(h) : save_vmm(h)); -} - -void RegPrinter::postamble(jit_generator &h) { - mayiuse(cpu_isa_t::avx512_core) ? restore_vmm(h) : (mayiuse(cpu_isa_t::avx2) ? - restore_vmm(h) : restore_vmm(h)); - restore_reg(h); -} - -// ABI requires 16-bype stack alignment before a call -void RegPrinter::align_rsp(jit_generator &h) { - constexpr int alignment = 16; - h.mov(h.r15, h.rsp); - h.and_(h.rsp, ~(alignment - 1)); -} + h->sub(h->rsp, get_max_vecs_count() * get_vec_length()); + for (size_t i = 0; i < get_max_vecs_count(); ++i) { + const auto addr = h->ptr[h->rsp + i * get_vec_length()]; + if (isa == sse41) { + h->uni_vmovups(addr, Xmm(i)); + } else if (isa == avx2) { + h->uni_vmovups(addr, Ymm(i)); + } else { + h->uni_vmovups(addr, Zmm(i)); + } + } -void RegPrinter::restore_rsp(jit_generator &h) { - h.mov(h.rsp, h.r15); + // Update the status + spill_status = false; } -template -void RegPrinter::print_vmm(jit_generator &h, REG_T vmm, const char *name) { - preamble(h); - - h.push(h.rax); - h.push(abi_param1); - h.push(abi_param2); - h.push(abi_param3); - { - const int vlen = vmm.isZMM() ? 64 : (vmm.isYMM() ? 32 : 16); - h.sub(h.rsp, vlen); - h.uni_vmovups(h.ptr[h.rsp], vmm); - - h.mov(abi_param3, h.rsp); - h.mov(abi_param2, reinterpret_cast(vmm.toString())); - h.mov(abi_param1, reinterpret_cast(name)); - if (vmm.isZMM()) { - auto p = &print_vmm_prc; - h.mov(h.rax, reinterpret_cast(p)); - } else if (vmm.isYMM()) { - auto p = &print_vmm_prc; - h.mov(h.rax, reinterpret_cast(p)); +void EmitABIRegSpills::postamble() { + // restore vector registers + for (int i = static_cast(get_max_vecs_count()) - 1; i >= 0; --i) { + const auto addr = h->ptr[h->rsp + i * get_vec_length()]; + if (isa == sse41) { + h->uni_vmovups(Xmm(i), addr); + } else if (isa == avx2) { + h->uni_vmovups(Ymm(i), addr); } else { - auto p = &print_vmm_prc; - h.mov(h.rax, reinterpret_cast(p)); + h->uni_vmovups(Zmm(i), addr); } - align_rsp(h); - h.call(h.rax); - restore_rsp(h); + } + h->add(h->rsp, (get_max_vecs_count()) * get_vec_length()); - h.add(h.rsp, vlen); + // restore k reg + if (isa == avx512_core) { + for (int i = k_mask_num - 1; i >= 0; --i) { + h->kmovq(Xbyak::Opmask(i), h->ptr[h->rsp + i * k_mask_size]); + } + h->add(h->rsp, k_mask_num * k_mask_size); } - h.pop(abi_param3); - h.pop(abi_param2); - h.pop(abi_param1); - h.pop(h.rax); + // restore gpr registers + Xbyak::Operand gprs_to_save[] = {h->r8, h->r9, h->r10, h->r11, h->r12, h->r13, h->r14, h->r15, + h->rax, h->rbx, h->rcx, h->rdx, h->rdi, h->rsi, h->rbp}; + size_t n_gprs_to_save = sizeof(gprs_to_save) / sizeof(gprs_to_save[0]); + for (int i = n_gprs_to_save - 1; i >= 0; --i) + h->mov(gprs_to_save[i], h->ptr[h->rsp + i * gpr_size]); + h->add(h->rsp, n_gprs_to_save * gpr_size); - postamble(h); + // Update the status + spill_status = true; } -template -void RegPrinter::print_reg(jit_generator &h, REG_T reg, const char *name) { - preamble(h); - - h.push(h.rax); - h.push(abi_param1); - h.push(abi_param2); - h.push(abi_param3); - { - const int rlen = reg.getBit() / 8; - h.sub(h.rsp, rlen); - h.mov(h.ptr[h.rsp], reg); - - h.mov(abi_param3, h.rsp); - h.mov(abi_param2, reinterpret_cast(reg.toString())); - h.mov(abi_param1, reinterpret_cast(name)); - auto p = &print_reg_prc; - h.mov(h.rax, reinterpret_cast(p)); - align_rsp(h); - h.call(h.rax); - restore_rsp(h); - - h.add(h.rsp, rlen); - } +void EmitABIRegSpills::rsp_align() { + h->mov(h->rbx, h->rsp); + h->and_(h->rbx, 0xf); + h->sub(h->rsp, h->rbx); +#ifdef _WIN32 + // Allocate shadow space (home space) according to ABI + h->sub(h->rsp, 32); +#endif + + // Update the status + rsp_status = false; +} + +void EmitABIRegSpills::rsp_restore() { +#ifdef _WIN32 + // Release shadow space (home space) + h->add(h->rsp, 32); +#endif + h->add(h->rsp, h->rbx); - h.pop(abi_param3); - h.pop(abi_param2); - h.pop(abi_param1); - h.pop(h.rax); + // Update the status + rsp_status = true; +} - postamble(h); +cpu_isa_t EmitABIRegSpills::get_isa() { + // need preserve based on cpu capability, instead of host isa. + // in case there are possibilty that different isa emitters exist in one kernel from perf standpoint in the future. + // e.g. other emitters isa is avx512, while this emitter isa is avx2, and internal call is used. Internal call may use avx512 and spoil k-reg, ZMM. + // do not care about platform w/ avx512_common but w/o avx512_core(knight landing), which is obsoleted. + if (mayiuse(avx512_core)) return avx512_core; + if (mayiuse(avx2)) return avx2; + if (mayiuse(sse41)) return sse41; + OV_CPU_JIT_EMITTER_THROW("unsupported isa"); } } // namespace intel_cpu diff --git a/src/plugins/intel_cpu/src/emitters/plugin/x64/utils.hpp b/src/plugins/intel_cpu/src/emitters/plugin/x64/utils.hpp index a74ad6af3df5c2..16a66beba7a536 100644 --- a/src/plugins/intel_cpu/src/emitters/plugin/x64/utils.hpp +++ b/src/plugins/intel_cpu/src/emitters/plugin/x64/utils.hpp @@ -9,84 +9,39 @@ namespace ov { namespace intel_cpu { -// Usage -// 1. Include this headfile where JIT kennels of CPU plugin are implemented for Register printing -// 2. Invoke RegPrinter::print method. Here are some examples. Note that user friendly register name -// will be printed, if it has been set. Current implementation doesn't buffer the name. So if you -// choose to set a name for the register, do not use local variable to pass the name, just pass a -// direct string to the interface like examples. While Original Xbyak register name will always be -// printed. -// Example 1: -// Invocation: RegPrinter::print(*this, vmm_val, "vmm_val"); -// Console: vmm_val | ymm0: {30, 20, 25, 29, 24, 31, 27, 23} -// -// Example 2: -// Invocation: RegPrinter::print(*this, vmm_val); -// Console: ymm0: {30, 20, 25, 29, 24, 31, 27, 23} -// -// Example 3: -// Invocation: RegPrinter::print(*this, vmm_idx, "vmm_idx"); -// Console: vmm_idx | ymm1: {5, 6, 0, 2, 0, 6, 6, 6} -// -// Example 4: -// Invocation: RegPrinter::print(*this, reg_work_amount, "reg_work_amount"); -// Console: reg_work_amount | r13: 8 -// -// Example 5: -// Invocation: RegPrinter::print(*this, reg_work_amount); -// Console: r13: 8 -// -// Example 6: -// Invocation: RegPrinter::print(*this, reg_tmp_64, "reg_tmp_64"); -// Console: reg_tmp_64 | r15: 1 -// -// Parameter -// The following combinations of Register types and precisions are supported. -// fp32 int32 int8 u8 -// Xmm Yes Yes No No -// Ymm Yes Yes No No -// Zmm Yes Yes No No -// Reg64 Yes Yes No No -// Reg32 Yes Yes No No -// Reg16 No No Yes Yes -// Reg8 No No Yes Yes - -class RegPrinter { +// The class emit register spills for the possible call of external binary code +class EmitABIRegSpills { public: - using jit_generator = dnnl::impl::cpu::x64::jit_generator; - template ::value, int>::type = 0> - static void print(jit_generator &h, REG_T reg, const char *name = nullptr) { - print_vmm(h, reg, name); - } - template ::value, int>::type = 0> - static void print(jit_generator &h, REG_T reg, const char *name = nullptr) { - print_reg(h, reg, name); - } + EmitABIRegSpills(dnnl::impl::cpu::x64::jit_generator* h); + ~EmitABIRegSpills(); + + // push (save) all registers on the stack + void preamble(); + // pop (take) all registers from the stack + void postamble(); + + // align stack on 16-byte and allocate shadow space as ABI reqiures + // callee is responsible to save and restore `rbx`. `rbx` must not be changed after call callee. + void rsp_align(); + void rsp_restore(); private: - RegPrinter() {} - template - static void print_vmm(jit_generator &h, REG_T vmm, const char *name); - template - static void print_reg(jit_generator &h, REG_T reg, const char *name); - template - static void print_vmm_prc(const char *name, const char *ori_name, PRC_T *ptr); - template - static void print_reg_prc(const char *name, const char *ori_name, T *val); - static void preamble(jit_generator &h); - static void postamble(jit_generator &h); - template - static void save_vmm(jit_generator &h); - template - static void restore_vmm(jit_generator &h); - static void save_reg(jit_generator &h); - static void restore_reg(jit_generator &h); - static void align_rsp(jit_generator &h); - static void restore_rsp(jit_generator &h); - static constexpr size_t reg_len = 8; - static constexpr size_t reg_cnt = 16; + EmitABIRegSpills() = default; + + static dnnl::impl::cpu::x64::cpu_isa_t get_isa(); + + inline size_t get_max_vecs_count() const { return dnnl::impl::cpu::x64::isa_num_vregs(isa); } + inline size_t get_vec_length() const { return dnnl::impl::cpu::x64::isa_max_vlen(isa); } + + dnnl::impl::cpu::x64::jit_generator* h {nullptr}; + const dnnl::impl::cpu::x64::cpu_isa_t isa {dnnl::impl::cpu::x64::cpu_isa_t::isa_undef}; + + static constexpr int k_mask_size = 8; + static constexpr int k_mask_num = 8; + static constexpr int gpr_size = 8; + + bool spill_status = true; + bool rsp_status = true; }; } // namespace intel_cpu diff --git a/src/plugins/intel_cpu/src/emitters/snippets/x64/cpu_generator.cpp b/src/plugins/intel_cpu/src/emitters/snippets/x64/cpu_generator.cpp index cc640f89554ef2..2cfd6e714e1dd8 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/x64/cpu_generator.cpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/x64/cpu_generator.cpp @@ -239,11 +239,13 @@ intel_cpu::CPUTargetMachine::CPUTargetMachine(dnnl::impl::cpu::x64::cpu_isa_t ho jitters[snippets::op::KernelDynamic::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(intel_cpu::jit_kernel_dynamic_emitter); jitters[snippets::op::LoopBegin::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(intel_cpu::jit_loop_begin_emitter); jitters[snippets::op::LoopEnd::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(intel_cpu::jit_loop_end_emitter); - // Note: jit_brgemm_emitter supports runtime recompilation, so its constructor takes additional arguments + // Note: jit_brgemm_emitter and jit_brgemm_copy_b_emitter support runtime recompilation, so their constructor takes additional arguments jitters[intel_cpu::BrgemmCPU::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(intel_cpu::jit_brgemm_emitter, configurator->get_kernel_executor_table(), compiled_kernel_cache); - jitters[intel_cpu::BrgemmCopyB::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(intel_cpu::jit_brgemm_copy_b_emitter); + jitters[intel_cpu::BrgemmCopyB::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(intel_cpu::jit_brgemm_copy_b_emitter, + configurator->get_kernel_executor_table(), + compiled_kernel_cache); jitters[snippets::op::ReduceMax::get_type_info_static()] = CREATE_UNDEFINED_EMITTER({{ov::element::f32}}); jitters[snippets::op::ReduceSum::get_type_info_static()] = CREATE_UNDEFINED_EMITTER({{ov::element::f32}}); diff --git a/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_brgemm_copy_b_emitter.cpp b/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_brgemm_copy_b_emitter.cpp index 8d2f33b18513f6..e68ab224407c7b 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_brgemm_copy_b_emitter.cpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_brgemm_copy_b_emitter.cpp @@ -4,7 +4,8 @@ #include "jit_brgemm_copy_b_emitter.hpp" -#include "jit_brgemm_emitter.hpp" +#include "emitters/plugin/x64/utils.hpp" +#include "emitters/snippets/x64/utils.hpp" #include "snippets/utils/utils.hpp" #include "snippets/lowered/expression.hpp" @@ -23,103 +24,50 @@ using namespace ov::snippets::utils; namespace ov { namespace intel_cpu { -jit_brgemm_copy_b_emitter::jit_brgemm_copy_b_emitter(jit_generator* h, cpu_isa_t isa, const ov::snippets::lowered::ExpressionPtr& expr) +namespace { +bool get_is_transposed(const ov::snippets::lowered::ExpressionPtr& expr) { + const auto& layout = expr->get_input_port_descriptor(0)->get_layout(); + const auto is_transposed = !layout.empty() && layout.back() != layout.size() - 1; + OV_CPU_JIT_EMITTER_ASSERT(IMPLICATION(is_transposed, (layout[layout.size() - 2] == layout.size() - 1)), + "supports only N dim placed as last or pre last dimension"); + return is_transposed; +} +} // namespace + +jit_brgemm_copy_b_emitter::jit_brgemm_copy_b_emitter(jit_generator* h, cpu_isa_t isa, const ov::snippets::lowered::ExpressionPtr& expr, + const snippets::KernelExecutorTablePtr& kernel_table, + const ov::intel_cpu::MultiCacheWeakPtr& compiled_kernel_cache) : jit_emitter(h, isa) { in_out_type_ = emitter_in_out_map::gpr_to_gpr; const auto brgemm_repack = ov::as_type_ptr(expr->get_node()); - if (!brgemm_repack) - OV_CPU_JIT_EMITTER_THROW("expects BrgemmCopyB node"); - OV_CPU_JIT_EMITTER_ASSERT(is_superset(host_isa_, cpu::x64::avx2), "host_isa must be at least avx2"); - m_with_comp = with_compensations(brgemm_repack->get_type()); - m_in_offset = brgemm_repack->get_offset_in(); - m_out_offset = brgemm_repack->get_offset_out(); - if (m_with_comp) - m_comp_offset = brgemm_repack->get_offset_compensations(); - - const auto& in_desc = expr->get_input_port_descriptor(0); - const auto& original_shape = in_desc->get_shape(); - const auto& layout = in_desc->get_layout(); - m_transpose = !layout.empty() && layout.back() != layout.size() - 1; - if (m_transpose) - OPENVINO_ASSERT(layout[layout.size() - 2] == layout.size() - 1, "supports only N dim placed as last or pre last dimension"); - - const auto planar_shape = get_planar_vdims(original_shape, layout); - const size_t N = *planar_shape.rbegin(); - m_K = *++planar_shape.rbegin(); - OV_CPU_JIT_EMITTER_ASSERT(!is_dynamic_value(N) && !is_dynamic_value(m_K), "K and N dims must be static"); + OV_CPU_JIT_EMITTER_ASSERT(brgemm_repack, "expects BrgemmCopyB node"); - const auto& in_subtensor = get_projected_subtensor(expr->get_input_port(0)); - m_N_blk = *in_subtensor.rbegin(); - m_K_blk = *++in_subtensor.rbegin(); - OV_CPU_JIT_EMITTER_ASSERT(m_N_blk <= N && m_K_blk <= m_K, "BrgemmCopyB has incompatible subtensor dimensions"); - m_brg_weight_etype = brgemm_repack->get_input_element_type(0); - m_inner_N_block = repacking::compute_inner_n_block(m_brg_weight_etype); - m_inner_N_tail = m_N_blk % m_inner_N_block; - m_brgemmVNNIFactor = compute_vnni_factor(m_brg_weight_etype); - - OV_CPU_JIT_EMITTER_ASSERT(m_brgemmVNNIFactor > 0, "brgemmVNNIFactor value must be positive."); - OV_CPU_JIT_EMITTER_ASSERT(m_K_blk == m_K || m_K_blk % m_brgemmVNNIFactor == 0, - "K Block size (", m_K_blk, "), which is not divisible by brgemmVNNIFactor (", - m_brgemmVNNIFactor, ") and not equal to K dimension (", m_K, - "), is not supported for brgemm data repacking."); - - OV_CPU_JIT_EMITTER_ASSERT(get_projected_subtensor(expr->get_output_port(0)) == in_subtensor, - "output and input subtensors must be equal"); - if (m_with_comp) { - const auto& compensations_subtensor = get_projected_subtensor(expr->get_output_port(1)); - const auto& compensations_n = *compensations_subtensor.rbegin(); - const auto& compensations_k = *++compensations_subtensor.rbegin(); - OV_CPU_JIT_EMITTER_ASSERT(compensations_n == m_N_blk && compensations_k == 1, - "compensations subtensor must be {1, m_N_blk}"); - } + // Note: even if the BrgemmCopyB node is dynamic, the first shapeInfer and RuntimeConfigurator::update() + // are performed before the BrgemmCopyBKernelExecutor registration. So we have to trigger update() manually + // for both static and the 1st dynamic shapes. + OV_CPU_JIT_EMITTER_ASSERT(!snippets::utils::is_dynamic_vdims(expr->get_input_port_descriptor(0)->get_shape()), + "Jit emitter is called when the shapes are unknown"); - const auto& brg_src_etype = brgemm_repack->get_src_element_type(); - OV_CPU_JIT_EMITTER_ASSERT(one_of(m_brg_weight_etype, element::f32, element::bf16, element::i8), "doesn't support precision ", m_brg_weight_etype); + const auto& in_subtensor = get_projected_subtensor(expr->get_input_port(0)); + const auto K_blk = *++in_subtensor.rbegin(); - const auto brgemm_type = get_brgemm_type(brg_src_etype, m_K_blk, m_N_blk, m_transpose); + const auto& src_prc = brgemm_repack->get_src_element_type(); + const auto& wei_prc = brgemm_repack->get_input_element_type(0); + const auto wei_N_blk = brgemm_utils::repacking::compute_inner_n_block(wei_prc); + const auto is_transposed = get_is_transposed(expr); + const auto brgemm_type = get_brgemm_type(src_prc, K_blk, is_transposed); + const auto primitive_isa = brgemm_utils::get_primitive_isa(src_prc, with_amx(brgemm_type)); + m_with_comp = with_compensations(brgemm_type); - const auto ldb = repacking::compute_out_leading_dim(m_N_blk, m_brg_weight_etype); - const auto wei_stride = get_dim_stride(expr->get_input_port(0), m_transpose ? 0 : 1) * m_brg_weight_etype.size(); - // Note: 2D format tags are used just to force the needed OneDNN primitive creation. - // However, the generated primitive can be also applied to tensors with other ranks - const auto format = m_transpose ? dnnl_ba : dnnl_ab; - init_brgemm_copy(m_kernel, N, m_inner_N_block, m_inner_N_tail, ldb, m_K_blk, brgemm_type, brg_src_etype, m_brg_weight_etype, wei_stride, format); -} + BrgemmCopyBKernelConfig kernel_config(src_prc, wei_prc, primitive_isa, m_with_comp, is_transposed, wei_N_blk); + m_kernel_executor = kernel_table->register_kernel(expr, compiled_kernel_cache, kernel_config); -void jit_brgemm_copy_b_emitter::init_brgemm_copy(std::unique_ptr& kernel, - size_t N, size_t N_blk, size_t N_tail, size_t out_leading_dim, size_t K_blk, BRGEMM_TYPE brgemm_type, - const ov::element::Type& src_dt, const ov::element::Type& wei_dt, size_t wei_stride, - dnnl_format_tag_t format) const { - matmul::brgemm_matmul_conf_t brgCopyKernelConf; - brgCopyKernelConf.src_dt = static_cast(DnnlExtensionUtils::ElementTypeToDataType(src_dt)); - brgCopyKernelConf.wei_dt = static_cast(DnnlExtensionUtils::ElementTypeToDataType(wei_dt)); - brgCopyKernelConf.orig_wei_dt = brgCopyKernelConf.wei_dt; - brgCopyKernelConf.wei_n_blk = static_cast(N_blk); - brgCopyKernelConf.wei_tag = format; - brgCopyKernelConf.transposed_B = m_transpose; - brgCopyKernelConf.copy_B_wei_stride = wei_stride; - brgCopyKernelConf.LDB = static_cast(out_leading_dim); - brgCopyKernelConf.N = static_cast(N); - brgCopyKernelConf.N_tail = static_cast(N_tail); - brgCopyKernelConf.N_blk = static_cast(N_blk); - brgCopyKernelConf.K = static_cast(K_blk); - brgCopyKernelConf.K_blk = static_cast(K_blk); - brgCopyKernelConf.N_chunk_elems = brgCopyKernelConf.N_blk; - brgCopyKernelConf.b_dt_sz = DnnlExtensionUtils::sizeOfDataType(static_cast(brgCopyKernelConf.src_dt)); - brgCopyKernelConf.tr_b_dt_sz = DnnlExtensionUtils::sizeOfDataType(static_cast(brgCopyKernelConf.src_dt)); - - brgCopyKernelConf.req_wei_vnni_downconvert = false; - - - brgCopyKernelConf.isa = get_primitive_isa(src_dt, with_amx(brgemm_type)); - brgCopyKernelConf.s8s8_compensation_required = with_compensations(brgemm_type); - - brgCopyKernelConf.has_zero_point_a = false; - brgCopyKernelConf.has_zero_point_b = false; - brgCopyKernelConf.src_zp_type = dnnl::impl::cpu::x64::none; - - auto status = matmul::create_brgemm_matmul_copy_b(kernel, &brgCopyKernelConf); - OV_CPU_JIT_EMITTER_ASSERT(status == dnnl_success, "cannot create kernel due to invalid params"); + m_memory_offsets = {brgemm_repack->get_offset_in(), brgemm_repack->get_offset_out()}; + m_buffer_ids = {utils::get_buffer_cluster_id(expr->get_input_port(0)), utils::get_buffer_cluster_id(expr->get_output_port(0))}; + if (m_with_comp) { + m_memory_offsets.push_back(brgemm_repack->get_offset_compensations()); + m_buffer_ids.push_back(utils::get_buffer_cluster_id(expr->get_output_port(1))); + } } void jit_brgemm_copy_b_emitter::validate_arguments(const std::vector &in, const std::vector &out) const { @@ -130,107 +78,45 @@ void jit_brgemm_copy_b_emitter::validate_arguments(const std::vector &in void jit_brgemm_copy_b_emitter::emit_impl(const std::vector& in, const std::vector& out) const { validate_arguments(in, out); - - Xbyak::Reg64 src(static_cast(in[0])); - Xbyak::Reg64 dst(static_cast(out[0])); - Xbyak::Reg64 comp(static_cast(m_with_comp ? out[1] : 0)); - - const size_t data_size = m_brg_weight_etype.size(); - size_t start_in = m_in_offset; - size_t start_out = m_out_offset; - size_t start_comp = m_comp_offset; - - // OneDNN requires tail handling before main iterations - if (m_inner_N_tail != 0) { - emit_kernel_call(m_kernel.get(), src, dst, comp, m_inner_N_tail, m_K_blk, start_in, start_out, start_comp); - start_in += m_transpose ? m_K * m_inner_N_tail * data_size : m_inner_N_tail * data_size; - start_out += m_inner_N_tail * m_brgemmVNNIFactor * data_size; - start_comp += m_inner_N_tail * sizeof(int32_t); + std::vector mem_ptrs_idxs{in[0], out[0]}; + if (out.size() > 1) + mem_ptrs_idxs.emplace_back(out[1]); + + EmitABIRegSpills spill(h); + spill.preamble(); + + h->mov(h->rbp, reinterpret_cast(BrgemmCopyBKernelExecutor::execute)); + auto reserved_stack_size = sizeof(BrgemmCopyBKernel::call_args); + // Reserve memory on the stack + h->sub(h->rsp, reserved_stack_size); + + const bool is_dynamic_case = std::any_of(m_memory_offsets.cbegin(), m_memory_offsets.cend(), ov::snippets::utils::is_dynamic_value); + Xbyak::Reg64 aux_reg = is_dynamic_case ? ov::intel_cpu::utils::get_aux_gpr(mem_ptrs_idxs) : Xbyak::Reg64(); + + const std::vector args_offsets {GET_OFF_BRGEMM_COPY_B_ARGS(src), GET_OFF_BRGEMM_COPY_B_ARGS(tr_src), GET_OFF_BRGEMM_COPY_B_ARGS(compensation_ptr)}; + const auto& mem_ptrs = ov::intel_cpu::utils::transform_idxs_to_regs(mem_ptrs_idxs); + for (size_t i = 0; i < mem_ptrs.size(); i++) { + if (ov::snippets::utils::is_dynamic_value(m_memory_offsets[i])) + utils::push_ptr_with_runtime_offset_on_stack(h, args_offsets[i], mem_ptrs[i], aux_reg, + GET_OFF(buffer_offsets) + m_buffer_ids[i] * sizeof(size_t)); + else + utils::push_ptr_with_static_offset_on_stack(h, args_offsets[i], mem_ptrs[i], m_memory_offsets[i]); } - const size_t in_ld = m_transpose ? m_K * m_inner_N_block * data_size : m_inner_N_block * data_size; - const size_t out_ld = m_inner_N_block * m_brgemmVNNIFactor * data_size; - const size_t comp_ld = m_inner_N_block * sizeof(int32_t); - for (size_t nb = 0; nb < m_N_blk / m_inner_N_block; nb++) { - const size_t offset_in = start_in + nb * in_ld; - const size_t offset_out = start_out + nb * out_ld; - const size_t offset_comp = m_with_comp ? start_comp + nb * comp_ld : 0; - emit_kernel_call(m_kernel.get(), src, dst, comp, m_inner_N_block, m_K_blk, offset_in, offset_out, offset_comp); - } -} + // No scratchpad => need to write nullptr manually + if (!m_with_comp) + h->mov(h->qword[h->rsp + args_offsets.back()], reinterpret_cast(nullptr)); -void jit_brgemm_copy_b_emitter::emit_kernel_call(const matmul::jit_brgemm_matmul_copy_b_t* kernel, Reg64 src, Reg64 dst, Reg64 comp, - size_t N, size_t K, size_t offset_in, size_t offset_out, size_t offset_comp) const { - const auto data_ptr = [&](Xmm xmm, Xbyak::Reg64 reg, size_t bytes_offset) { - h->uni_vmovq(reg, xmm); - if (bytes_offset) h->add(reg, bytes_offset); - }; - - internal_call_preamble(); - // save function address in gpr to pass in call instruction - const auto &kernel_overload = static_cast(execute); - h->mov(h->rbp, reinterpret_cast(kernel_overload)); - // todo: several of addr_{A, B, C} could be also abi_paramX, so one of them could be corrupted - // if moving directly h->uni_vmovq(abi_paramX, adr_X). Save them to vector regs to avoid corruption. - // It's likely that a more efficient solution exists. - h->uni_vmovq(Xmm(0), src); - h->uni_vmovq(Xmm(1), dst); - if (m_with_comp) - h->uni_vmovq(Xmm(2), comp); - // todo: Windows ABI : requires different num of arguments passed in regs and on the stack. Need to align. - h->mov(abi_param1, reinterpret_cast(kernel)); - - data_ptr(Xmm(0), abi_param2, offset_in); - data_ptr(Xmm(1), abi_param3, offset_out); - if (m_with_comp) { - data_ptr(Xmm(2), abi_param4, offset_comp); - } else { - h->mov(abi_param4, reinterpret_cast(nullptr)); - } + h->mov(abi_param1, reinterpret_cast(m_kernel_executor.get())); + h->mov(abi_param2, h->rsp); -#ifdef _WIN32 - // Note: ABI requires that the remaining parameters (except the first for) are pushed to the stack in right-to-left order - // Shadow space will be allocated inside internal_call_rsp_align() - h->push(K); - h->push(N); -#else - h->mov(abi_param5, N); - h->mov(abi_param6, K); -#endif - - internal_call_rsp_align(); + spill.rsp_align(); h->call(h->rbp); - internal_call_rsp_restore(); + spill.rsp_restore(); -#ifdef _WIN32 - h->add(h->rsp, gpr_size * 2); -#endif - internal_call_postamble(); -} + h->add(h->rsp, reserved_stack_size); -void jit_brgemm_copy_b_emitter::execute(matmul::jit_brgemm_matmul_copy_b_t* kernel, - const void* src, - const void* dst, - const void* comp, - size_t N, - size_t K) { - auto ctx = dnnl::impl::cpu::x64::matmul::jit_brgemm_matmul_copy_b_t::ctx_t(); - ctx.current_N_blk = N; - ctx.src = src; - ctx.tr_src = dst; - ctx.compensation_ptr = comp; - ctx.zp_a_compensation_ptr = nullptr; - ctx.zp_a_neg_value_ptr = nullptr; - ctx.current_K_start = 0; - ctx.current_K_iters = K; - - OV_CPU_JIT_EMITTER_ASSERT(kernel, "Kernel hasn't been created"); - (*kernel)(&ctx); + spill.postamble(); } } // namespace intel_cpu diff --git a/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_brgemm_copy_b_emitter.hpp b/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_brgemm_copy_b_emitter.hpp index a9743012da5b7f..ef53efe6081217 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_brgemm_copy_b_emitter.hpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_brgemm_copy_b_emitter.hpp @@ -6,10 +6,8 @@ #include "emitters/plugin/x64/jit_emitter.hpp" -#include +#include "kernel_executors/brgemm_copy_b.hpp" -#include "emitters/plugin/x64/jit_emitter.hpp" -#include "transformations/snippets/x64/op/brgemm_copy_b.hpp" namespace ov { namespace intel_cpu { @@ -17,7 +15,9 @@ namespace intel_cpu { class jit_brgemm_copy_b_emitter : public jit_emitter { public: jit_brgemm_copy_b_emitter(dnnl::impl::cpu::x64::jit_generator* h, dnnl::impl::cpu::x64::cpu_isa_t isa, - const ov::snippets::lowered::ExpressionPtr& expr); + const ov::snippets::lowered::ExpressionPtr& expr, + const snippets::KernelExecutorTablePtr& kernel_table, + const ov::intel_cpu::MultiCacheWeakPtr& compiled_kernel_cache); size_t get_inputs_num() const override {return 1;} static std::set> get_supported_precisions(const std::shared_ptr& node = nullptr) { @@ -28,36 +28,10 @@ class jit_brgemm_copy_b_emitter : public jit_emitter { void validate_arguments(const std::vector &in, const std::vector &out) const override; void emit_impl(const std::vector& in, const std::vector& out) const override; - void init_brgemm_copy(std::unique_ptr& kernel, - size_t N, size_t N_blk, size_t N_tail, size_t out_leading_dim, size_t K_blk, brgemm_utils::BRGEMM_TYPE brgemm_type, - const ov::element::Type& dt_in0, const ov::element::Type& dt_in1, size_t wei_stride, dnnl_format_tag_t format) const; - void emit_kernel_call(const dnnl::impl::cpu::x64::matmul::jit_brgemm_matmul_copy_b_t* kernel, - Xbyak::Reg64 src, Xbyak::Reg64 dst, Xbyak::Reg64 comp, size_t N, size_t K, - size_t offset_in, size_t offset_out, size_t offset_comp) const; - - static void execute(dnnl::impl::cpu::x64::matmul::jit_brgemm_matmul_copy_b_t* kernel, - const void* src, const void* dst, const void* comp, size_t N, size_t K); - - std::unique_ptr m_kernel; - ov::element::Type m_brg_weight_etype; - - // Block size which is set by snippets: it is usually shared between brgemm and brgemm_copy_b nodes - size_t m_N_blk = 0lu; - // Block size which is used by the internal OneDNN implementation. - // It is used in snippets emitter to iterate through input/output data and call OneDNN kernel - size_t m_inner_N_block = 0lu; - size_t m_inner_N_tail = 0lu; - - size_t m_K = 0lu; - size_t m_K_blk = 0lu; - size_t m_brgemmVNNIFactor = 0lu; - - size_t m_in_offset = 0lu; - size_t m_out_offset = 0lu; - size_t m_comp_offset = 0lu; - - bool m_with_comp = false; - bool m_transpose = false; + std::vector m_memory_offsets{}; + std::vector m_buffer_ids{}; + std::shared_ptr m_kernel_executor {nullptr}; + bool m_with_comp {false}; #ifdef SNIPPETS_DEBUG_CAPS friend std::string init_info_jit_brgemm_copy_b_emitter(const jit_brgemm_copy_b_emitter *emitter); diff --git a/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_brgemm_emitter.cpp b/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_brgemm_emitter.cpp index 4c36aa3b21ab35..057a3687ab8d16 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_brgemm_emitter.cpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_brgemm_emitter.cpp @@ -5,9 +5,8 @@ #include "jit_brgemm_emitter.hpp" #include "transformations/snippets/x64/op/brgemm_cpu.hpp" -#include -#include #include "snippets/utils/utils.hpp" +#include "emitters/plugin/x64/utils.hpp" #include "utils.hpp" #include "transformations/snippets/x64/op/brgemm_utils.hpp" @@ -39,32 +38,13 @@ jit_brgemm_emitter::jit_brgemm_emitter(jit_generator* h, cpu_isa_t isa, OV_CPU_JIT_EMITTER_ASSERT(!snippets::utils::is_dynamic_vdims(expr->get_input_port_descriptor(0)->get_shape()) && !snippets::utils::is_dynamic_vdims(expr->get_input_port_descriptor(1)->get_shape()), "Jit emitter is called when the shapes are unknown"); - auto get_cluster_id = [](const snippets::lowered::ExpressionPort& p) { - // Note: NewMemoryBuffer is used as a scratchpad and can't be dynamic, so we don't need to account for them here - if (const auto buffer = ov::as_type_ptr(p.get_expr())) - return buffer->get_cluster_id(); - else - return SIZE_MAX; - }; + m_memory_offsets = {brgemm_node->get_offset_a(), brgemm_node->get_offset_b(), brgemm_node->get_offset_c()}; - if (with_scratchpad(brgemm_type)) + m_buffer_ids = {utils::get_buffer_cluster_id(expr->get_input_port(0)), utils::get_buffer_cluster_id(expr->get_input_port(1)), + utils::get_buffer_cluster_id(expr->get_output_port(0))}; + if (with_scratchpad(brgemm_type)) { m_memory_offsets.push_back(brgemm_node->get_offset_scratch()); - - m_buffer_ids.assign(m_memory_offsets.size(), SIZE_MAX); - for (size_t i = 0; i < m_memory_offsets.size(); i++) { - if (snippets::utils::is_dynamic_value(m_memory_offsets[i])) { - switch (i) { - case 0: - case 1: - m_buffer_ids[i] = get_cluster_id(expr->get_input_port_connector(i)->get_source()); - break; - case 2: - for (const auto& child : expr->get_output_port_connector(0)->get_consumers()) - if (!ov::is_type(child.get_expr()->get_node())) - m_buffer_ids[i] = get_cluster_id(child); - } - OV_CPU_JIT_EMITTER_ASSERT(m_buffer_ids[i] != SIZE_MAX, "Dynamic offset requires a valid buffer ID"); - } + m_buffer_ids.push_back(utils::get_buffer_cluster_id(expr->get_input_port(2))); } } @@ -101,39 +81,27 @@ void jit_brgemm_emitter::emit_impl(const std::vector& in, const std::vec std::vector mem_ptrs_idxs{in[0], in[1], out[0]}; if (in.size() > 2) mem_ptrs_idxs.emplace_back(in[2]); - emit_brgemm_kernel_call(mem_ptrs_idxs, m_memory_offsets); -} -void jit_brgemm_emitter::emit_brgemm_kernel_call(const std::vector& mem_ptrs_idxs, const std::vector& mem_offsets) const { - internal_call_preamble(); + + EmitABIRegSpills spill(h); + spill.preamble(); + h->mov(h->rbp, reinterpret_cast(BrgemmKernelExecutor::execute)); auto reserved_stack_size = sizeof(BrgemmKernelExecutor::call_args); // Reserve memory on the stack h->sub(h->rsp, reserved_stack_size); - Xbyak::Reg64 aux_reg = [this, &mem_ptrs_idxs]() { - std::set used(mem_ptrs_idxs.begin(), mem_ptrs_idxs.end()); - std::vector spilled_gprs {h->r8, h->r9, h->r10, h->r11, h->r12, h->r13, h->r14, h->r15, - h->rax, h->rbx, h->rcx, h->rdx, h->rdi, h->rsi, h->rbp}; - for (const auto& reg : spilled_gprs) - if (used.count(reg.getIdx()) == 0) - return reg; - OV_CPU_JIT_EMITTER_THROW("Failed to allocate aux register"); - }(); - - auto write_addr_on_stack = [&](size_t arg_offset, Reg64 addr, size_t addr_offset, size_t buffer_id) { - const auto stack_frame = h->qword[h->rsp + arg_offset]; - h->mov(aux_reg, addr); - if (snippets::utils::is_dynamic_value(addr_offset)) - h->add(aux_reg, h->ptr[abi_param1 + GET_OFF(buffer_offsets) + buffer_id * sizeof(size_t)]); - else if (addr_offset != 0) - h->add(aux_reg, addr_offset); - h->mov(stack_frame, aux_reg); - }; - const std::vector brgemm_args_offsets {GET_OFF_BRGEMM_ARGS(A), GET_OFF_BRGEMM_ARGS(B), GET_OFF_BRGEMM_ARGS(C), - GET_OFF_BRGEMM_ARGS(scratch)}; + const bool is_dynamic_case = std::any_of(m_memory_offsets.cbegin(), m_memory_offsets.cend(), ov::snippets::utils::is_dynamic_value); + Xbyak::Reg64 aux_reg = is_dynamic_case ? ov::intel_cpu::utils::get_aux_gpr(mem_ptrs_idxs) : Xbyak::Reg64(); + + const std::vector brgemm_args_offsets {GET_OFF_BRGEMM_ARGS(A), GET_OFF_BRGEMM_ARGS(B), GET_OFF_BRGEMM_ARGS(C), GET_OFF_BRGEMM_ARGS(scratch)}; const auto& mem_ptrs = utils::transform_idxs_to_regs(mem_ptrs_idxs); - for (size_t i = 0; i < mem_ptrs.size(); i++) - write_addr_on_stack(brgemm_args_offsets[i], mem_ptrs[i], mem_offsets[i], m_buffer_ids[i]); + for (size_t i = 0; i < mem_ptrs.size(); i++) { + if (ov::snippets::utils::is_dynamic_value(m_memory_offsets[i])) + utils::push_ptr_with_runtime_offset_on_stack(h, brgemm_args_offsets[i], mem_ptrs[i], aux_reg, + GET_OFF(buffer_offsets) + m_buffer_ids[i] * sizeof(size_t)); + else + utils::push_ptr_with_static_offset_on_stack(h, brgemm_args_offsets[i], mem_ptrs[i], m_memory_offsets[i]); + } // No scratchpad => need to write nullptr manually if (mem_ptrs.size() < 4) @@ -146,12 +114,13 @@ void jit_brgemm_emitter::emit_brgemm_kernel_call(const std::vector& mem_ h->mov(abi_param1, reinterpret_cast(m_kernel_executor.get())); h->mov(abi_param2, h->rsp); - internal_call_rsp_align(); + spill.rsp_align(); h->call(h->rbp); - internal_call_rsp_restore(); + spill.rsp_restore(); h->add(h->rsp, reserved_stack_size); - internal_call_postamble(); + + spill.postamble(); } } // namespace intel_cpu diff --git a/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_brgemm_emitter.hpp b/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_brgemm_emitter.hpp index 2387aca72c4479..baa6ed95473034 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_brgemm_emitter.hpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_brgemm_emitter.hpp @@ -24,9 +24,6 @@ class jit_brgemm_emitter : public jit_emitter { void validate_arguments(const std::vector &in, const std::vector &out) const override; void emit_impl(const std::vector& in, const std::vector& out) const override; - // Note: expected arguments order: A, B, C (+ scratchpad, if needed) - void emit_brgemm_kernel_call(const std::vector& mem_ptrs_idxs, const std::vector& mem_offsets) const; - // Note: offsets order: A, B, C (+ scratchpad, if needed). Values can be dynamic_value if offset is calculated in runtime std::vector m_memory_offsets{}; // Note: cluster ids order: A, B, C (+ scratchpad, if needed). Values can be dynamic_value if there is no buffer diff --git a/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_loop_emitters.cpp b/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_loop_emitters.cpp index 74d0106696141b..f3151d0df4ccb1 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_loop_emitters.cpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_loop_emitters.cpp @@ -5,6 +5,7 @@ #include "jit_loop_emitters.hpp" #include "emitters/snippets/jit_snippets_call_args.hpp" +#include "emitters/snippets/x64/utils.hpp" #include "snippets/utils/utils.hpp" using namespace Xbyak; @@ -21,7 +22,7 @@ class jit_aux_gpr_holder { : m_h(host), m_pool_gpr_idxs(pool_gpr_idxs) { // If the pool is empty, let's manualy allocate the gpr and push original vlaue on stack if (m_pool_gpr_idxs.empty()) { - m_aux_gpr_idx = Reg64(static_cast(allocate_aux_gpr(used_gpr_idxs))); + m_aux_gpr_idx = ov::intel_cpu::utils::get_aux_gpr(used_gpr_idxs); m_is_preserved = true; m_h->push(m_aux_gpr_idx); } else { @@ -41,18 +42,6 @@ class jit_aux_gpr_holder { const Reg64& get_reg() const { return m_aux_gpr_idx; } private: - size_t allocate_aux_gpr(const std::vector& used_gpr_idxs) const { - // RSP, RBP - stack-related registers, abi_param1 - runtime parameter register in the kernel - static std::set blakclist_gpr_idxs = { Operand::RSP, Operand::RBP, static_cast(abi_param1.getIdx()) }; - for (size_t gpr_idx = 0; gpr_idx <= Operand::R15; ++gpr_idx) { - size_t _idx = Operand::R15 - gpr_idx; // we allocate from the end - if (std::find(used_gpr_idxs.cbegin(), used_gpr_idxs.cend(), _idx) != used_gpr_idxs.cend()) continue; - if (std::find(blakclist_gpr_idxs.cbegin(), blakclist_gpr_idxs.cend(), _idx) != blakclist_gpr_idxs.cend()) continue; - return _idx; - } - OV_CPU_JIT_EMITTER_THROW("Failed to allocate aux GPR"); - } - dnnl::impl::cpu::x64::jit_generator* m_h; std::vector& m_pool_gpr_idxs; Reg64 m_aux_gpr_idx {}; diff --git a/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_perf_count_chrono_emitters.cpp b/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_perf_count_chrono_emitters.cpp index 653412f5e68a67..f89e906ce57593 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_perf_count_chrono_emitters.cpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_perf_count_chrono_emitters.cpp @@ -5,6 +5,8 @@ #include "jit_perf_count_chrono_emitters.hpp" +#include "emitters/plugin/x64/utils.hpp" + using namespace dnnl::impl; using namespace dnnl::impl::utils; using namespace dnnl::impl::cpu; @@ -29,16 +31,18 @@ void jit_perf_count_chrono_start_emitter::set_start_time(snippets::op::PerfCount } void jit_perf_count_chrono_start_emitter::emit_impl(const std::vector &in_idxs, const std::vector &out_idxs) const { - internal_call_preamble(); + EmitABIRegSpills spill(h); + spill.preamble(); const auto &set_start_time_overload = static_cast(set_start_time); h->mov(h->rax, reinterpret_cast(set_start_time_overload)); h->mov(abi_param1, reinterpret_cast(m_start_node.get())); - internal_call_rsp_align(); + + spill.rsp_align(); h->call(h->rax); - internal_call_rsp_restore(); + spill.rsp_restore(); - internal_call_postamble(); + spill.postamble(); } ///////////////////jit_perf_count_chrono_end_emitter//////////////////////////////////// @@ -56,16 +60,18 @@ void jit_perf_count_chrono_end_emitter::set_accumulated_time(snippets::op::PerfC } void jit_perf_count_chrono_end_emitter::emit_impl(const std::vector &in_idxs, const std::vector &out_idxs) const { - internal_call_preamble(); + EmitABIRegSpills spill(h); + spill.preamble(); const auto &set_accumulated_time_overload = static_cast(set_accumulated_time); h->mov(h->rax, reinterpret_cast(set_accumulated_time_overload)); h->mov(abi_param1, reinterpret_cast(m_end_node.get())); - internal_call_rsp_align(); + + spill.rsp_align(); h->call(h->rax); - internal_call_rsp_restore(); + spill.rsp_restore(); - internal_call_postamble(); + spill.postamble(); } } // namespace intel_cpu diff --git a/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_segfault_detector_emitter.cpp b/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_segfault_detector_emitter.cpp index 109950dd3a668e..f88c345ff055b5 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_segfault_detector_emitter.cpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_segfault_detector_emitter.cpp @@ -5,6 +5,7 @@ #ifdef SNIPPETS_DEBUG_CAPS #include "jit_segfault_detector_emitter.hpp" +#include "emitters/plugin/x64/utils.hpp" using namespace dnnl::impl::utils; using namespace dnnl::impl; @@ -43,16 +44,18 @@ void jit_uni_segfault_detector_emitter::emit_impl(const std::vector& in_ void jit_uni_segfault_detector_emitter::save_target_emitter() const { // use internal call as "->local" shoule be the execution thread. Otherwise always compilation thread. - internal_call_preamble(); + EmitABIRegSpills spill(h); + spill.preamble(); const auto &set_local_handler_overload = static_cast(set_local_handler); h->mov(h->rax, reinterpret_cast(set_local_handler_overload)); h->mov(abi_param1, reinterpret_cast(this)); - internal_call_rsp_align(); + + spill.rsp_align(); h->call(h->rax); - internal_call_rsp_restore(); + spill.rsp_restore(); - internal_call_postamble(); + spill.postamble(); } void jit_uni_segfault_detector_emitter::set_local_handler(jit_uni_segfault_detector_emitter* emitter_address) { diff --git a/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm.cpp b/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm.cpp index e46de866990005..6f1f4ab93aeda9 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm.cpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm.cpp @@ -210,7 +210,7 @@ void BrgemmKernelExecutor::update_config(const ov::snippets::lowered::Expression const auto& loop_ids = expr->get_loop_ids(); const auto& loop_manager = linear_ir->get_loop_manager(); auto get_loop_info = [&](){ - OPENVINO_ASSERT(loop_idx < loop_ids.size(), "Loop by dimension M is missed"); + OPENVINO_ASSERT(loop_idx < loop_ids.size(), "Loop is missed"); return loop_manager->get_loop_info(loop_ids[loop_idx++]); }; diff --git a/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm_copy_b.cpp b/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm_copy_b.cpp new file mode 100644 index 00000000000000..17f8923ae9867b --- /dev/null +++ b/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm_copy_b.cpp @@ -0,0 +1,322 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "brgemm_copy_b.hpp" + +#include "snippets/lowered/loop_manager.hpp" +#include "emitters/plugin/x64/utils.hpp" +#include "transformations/snippets/x64/op/brgemm_utils.hpp" + +#define DTYPE_CAST(X) static_cast(DnnlExtensionUtils::ElementTypeToDataType(X)) + +using namespace dnnl::impl; +using namespace dnnl::impl::cpu::x64; + +namespace ov { +namespace intel_cpu { + +BrgemmCopyBKernelConfig::BrgemmCopyBKernelConfig(const element::Type& src_dt, const element::Type& wei_dt, cpu_isa_t isa, + bool is_with_comp, bool is_transposed_B, dnnl_dim_t wei_N_blk) + : m_static_params(std::make_shared(src_dt, wei_dt, isa, is_with_comp, is_transposed_B, wei_N_blk)) { + m_hash = compute_hash(); +} + +bool BrgemmCopyBKernelConfig::is_completed() const { + return !utils::one_of(0, m_N, m_K, m_copy_B_wei_stride, m_LDB) || is_empty(); +} + +bool BrgemmCopyBKernelConfig::is_empty() const { + return everyone_is(0, m_N, m_N_blk, m_K, m_K_blk, m_copy_B_wei_stride, m_LDB); +} + +bool BrgemmCopyBKernelConfig::operator==(const BrgemmCopyBKernelConfig& rhs) const { +#define EQ(X) X == rhs.X + return EQ(m_hash) && EQ(m_N) && EQ(m_N_blk) && EQ(m_K) && EQ(m_K_blk) && EQ(m_LDB) && EQ(m_copy_B_wei_stride) && + (EQ(m_static_params.get()) || *m_static_params == *(rhs.m_static_params)); +#undef EQ +} + +void BrgemmCopyBKernelConfig::update(dnnl_dim_t N, dnnl_dim_t N_blk, dnnl_dim_t K, dnnl_dim_t K_blk, dnnl_dim_t copy_B_wei_stride, dnnl_dim_t LDB) { + // If one of the dims is zero, it means that BrgemmCopyB won't be executed (in Loop with work_amount = 0, for example) + // To process this case, we have to make this Config as empty (nullify runtime parameters) + if (utils::one_of(0, N, K)) { + m_N = 0; m_N_blk = 0; + m_K = 0; m_K_blk = 0; + m_copy_B_wei_stride = 0; m_LDB = 0; + } else { + m_N = N; m_N_blk = N_blk; + m_K = K; m_K_blk = K_blk; + m_copy_B_wei_stride = copy_B_wei_stride; m_LDB = LDB; + } + m_hash = compute_hash(); +} + +size_t BrgemmCopyBKernelConfig::compute_hash() const { + size_t seed = m_static_params->hash; +#define HASH(X) seed = hash_combine(seed, X) + HASH(m_N); HASH(m_N_blk); + HASH(m_K); HASH(m_K_blk); + HASH(m_copy_B_wei_stride); HASH(m_LDB); +#undef HASH + return seed; +} + +BrgemmCopyBKernelConfig::StaticParams::StaticParams(const element::Type& src_type, const element::Type& wei_type, cpu_isa_t isa, + bool is_with_comp, bool is_transposed_B, dnnl_dim_t wei_n_blk) + : src_dt(DTYPE_CAST(src_type)), wei_dt(DTYPE_CAST(wei_type)), isa(isa), + is_with_comp(is_with_comp), is_transposed_B(is_transposed_B), wei_N_blk(wei_n_blk), + hash(init_hash(src_dt, wei_dt, isa, is_with_comp, is_transposed_B, wei_N_blk)) {} + +bool BrgemmCopyBKernelConfig::StaticParams::operator==(const StaticParams& rhs) const { +#define EQ(X) X == rhs.X + return EQ(hash) && EQ(src_dt) && EQ(wei_dt)&& EQ(isa) && EQ(is_with_comp) && EQ(is_transposed_B) && EQ(wei_N_blk); +#undef EQ +} + +size_t BrgemmCopyBKernelConfig::StaticParams::init_hash(const dnnl_data_type_t& src_dt, const dnnl_data_type_t& wei_dt, cpu_isa_t isa, + bool is_with_comp, bool is_transposed_B, dnnl_dim_t wei_N_blk) { + size_t seed = 0; +#define HASH(X) seed = hash_combine(seed, X) + HASH(src_dt); HASH(wei_dt); HASH(isa); + HASH(is_with_comp); HASH(is_transposed_B); HASH(wei_N_blk); +#undef HASH + return seed; +} + +#ifdef SNIPPETS_DEBUG_CAPS +#define PRINT(X) ss << #X << " = " << X << "\n" +std::string BrgemmCopyBKernelConfig::to_string() const { + std::stringstream ss; + ss << m_static_params->to_string() << "\n"; + PRINT(m_hash); PRINT(m_N); PRINT(m_N_blk); + PRINT(m_K); PRINT(m_K_blk); PRINT(m_LDB); PRINT(m_copy_B_wei_stride); + return ss.str(); +} +std::string BrgemmCopyBKernelConfig::StaticParams::to_string() const { + std::stringstream ss; + PRINT(src_dt); PRINT(wei_dt); PRINT(isa); + PRINT(is_with_comp); PRINT(is_transposed_B); PRINT(wei_N_blk); + return ss.str(); +} +#undef PRINT +#endif + +BrgemmCopyBKernel::BrgemmCopyBKernel() : jit_generator(jit_name()), ker_(nullptr) {} + +BrgemmCopyBKernel::BrgemmCopyBKernel(const BrgemmCopyBKernelConfig& conf) + : jit_generator(jit_name()), is_with_comp(conf.is_with_comp()), is_transpose(conf.is_transposed_B()), + wei_data_size(dnnl_data_type_size(conf.get_wei_dt())), vnni_factor(data_type_vnni_granularity(conf.get_wei_dt())), + K(conf.get_K()), N_blk(conf.get_N_blk()), wei_N_blk(conf.get_wei_N_blk()), wei_N_tail(conf.get_wei_N_tail()), ker_(nullptr) { + init_brgemm_copy_b_kernel(dnnl_brgemm_copy_b_kernel, conf); + OV_CPU_JIT_EMITTER_ASSERT(dnnl_brgemm_copy_b_kernel, "Kernel is missed!"); +} + +status_t BrgemmCopyBKernel::create_kernel() { + const auto code = jit_generator::create_kernel(); + OV_CPU_JIT_EMITTER_ASSERT(code == status::success, "Failed to create kernel"); + ker_ = (decltype(ker_))jit_ker(); + return code; +} + +void BrgemmCopyBKernel::operator()(const call_args* args) const { + OV_CPU_JIT_EMITTER_ASSERT(ker_, "Kernel is nullptr"); + ker_(args); +} + +void BrgemmCopyBKernel::init_brgemm_copy_b_kernel(std::unique_ptr& kernel, + const BrgemmCopyBKernelConfig& conf) const { + matmul::brgemm_matmul_conf_t brgCopyKernelConf; + brgCopyKernelConf.src_dt = conf.get_src_dt(); + brgCopyKernelConf.wei_dt = conf.get_wei_dt(); + brgCopyKernelConf.orig_wei_dt = brgCopyKernelConf.wei_dt; + brgCopyKernelConf.wei_n_blk = static_cast(conf.get_wei_N_blk()); + // Note: 2D format tags are used just to force the needed OneDNN primitive creation. + // However, the generated primitive can be also applied to tensors with other ranks + brgCopyKernelConf.wei_tag = conf.is_transposed_B() ? dnnl_ba : dnnl_ab; + brgCopyKernelConf.transposed_B = conf.is_transposed_B(); + brgCopyKernelConf.copy_B_wei_stride = conf.get_copy_B_wei_stride(); + brgCopyKernelConf.LDB = conf.get_LDB(); + brgCopyKernelConf.N = conf.get_N(); + brgCopyKernelConf.N_tail = conf.get_wei_N_tail(); + brgCopyKernelConf.N_blk = conf.get_wei_N_blk(); + brgCopyKernelConf.K = conf.get_K_blk(); + brgCopyKernelConf.K_blk = conf.get_K_blk(); + brgCopyKernelConf.N_chunk_elems = brgCopyKernelConf.N_blk; + brgCopyKernelConf.b_dt_sz = DnnlExtensionUtils::sizeOfDataType(static_cast(brgCopyKernelConf.wei_dt)); + brgCopyKernelConf.tr_b_dt_sz = DnnlExtensionUtils::sizeOfDataType(static_cast(brgCopyKernelConf.wei_dt)); + + brgCopyKernelConf.req_wei_vnni_downconvert = false; + + brgCopyKernelConf.isa = conf.get_isa(); + brgCopyKernelConf.s8s8_compensation_required = conf.is_with_comp(); + + brgCopyKernelConf.has_zero_point_a = false; + brgCopyKernelConf.has_zero_point_b = false; + brgCopyKernelConf.src_zp_type = dnnl::impl::cpu::x64::none; + + OV_CPU_JIT_EMITTER_ASSERT(matmul::create_brgemm_matmul_copy_b(kernel, &brgCopyKernelConf) == dnnl_success, + "cannot create kernel due to invalid params"); +} + +void BrgemmCopyBKernel::generate() { + preamble(); + + mov(src_reg, ptr[abi_param1 + GET_OFF_BRGEMM_COPY_B_ARGS(src)]); + mov(tr_src_reg, ptr[abi_param1 + GET_OFF_BRGEMM_COPY_B_ARGS(tr_src)]); + if (is_with_comp) + mov(comp_reg, ptr[abi_param1 + GET_OFF_BRGEMM_COPY_B_ARGS(compensation_ptr)]); + + size_t start_in = 0; + size_t start_out = 0; + size_t start_comp = 0; + + auto add_ptr_increments = [&](size_t current_N) { + start_in += is_transpose ? K * current_N * wei_data_size : current_N * wei_data_size; + start_out += current_N * vnni_factor * wei_data_size; + start_comp += is_with_comp ? current_N * sizeof(int32_t) : 0; + }; + + // OneDNN requires tail handling before main iterations + if (wei_N_tail != 0) { + emit_brgemm_copy_b_kernel_call(wei_N_tail, K, start_in, start_out, start_comp); + add_ptr_increments(wei_N_tail); + } + + for (auto nb = wei_N_tail; nb < N_blk; nb += wei_N_blk) { + emit_brgemm_copy_b_kernel_call(wei_N_blk, K, start_in, start_out, start_comp); + add_ptr_increments(wei_N_blk); + } + + postamble(); +} + +void BrgemmCopyBKernel::emit_brgemm_copy_b_kernel_call(size_t N, size_t K, size_t offset_in, size_t offset_out, size_t offset_comp) { + EmitABIRegSpills spill(this); + spill.preamble(); + + const auto add_offset = [&](Xbyak::Reg64 reg, size_t bytes_offset) { + if (bytes_offset) add(reg, bytes_offset); + }; + + // save function address in gpr to pass in call instruction + const auto& kernel_overload = static_cast(execute); + mov(rbp, reinterpret_cast(kernel_overload)); + mov(abi_param1, reinterpret_cast(dnnl_brgemm_copy_b_kernel.get())); + + add_offset(src_reg, offset_in); // abi_param2 + add_offset(tr_src_reg, offset_out); // abi_param3 + if (is_with_comp) // abi_param4 + add_offset(comp_reg, offset_comp); + else + mov(comp_reg, reinterpret_cast(nullptr)); + +#ifdef _WIN32 + // Note: ABI requires that the remaining parameters (except the first for) are pushed to the stack in right-to-left order + // Shadow space will be allocated inside internal_call_rsp_align() + push(K); + push(N); +#else + mov(abi_param5, N); + mov(abi_param6, K); +#endif + + spill.rsp_align(); + call(rbp); + spill.rsp_restore(); + +#ifdef _WIN32 + static constexpr int gpr_size = 8; + add(rsp, gpr_size * 2); +#endif + + spill.postamble(); +} + +void BrgemmCopyBKernel::execute(matmul::jit_brgemm_matmul_copy_b_t* kernel, const void* src, const void* dst, const void* comp, size_t N, size_t K) { + auto ctx = matmul::jit_brgemm_matmul_copy_b_t::ctx_t(); + ctx.current_N_blk = N; + ctx.src = src; + ctx.tr_src = dst; + ctx.compensation_ptr = comp; + ctx.zp_a_compensation_ptr = nullptr; + ctx.zp_a_neg_value_ptr = nullptr; + ctx.current_K_start = 0; + ctx.current_K_iters = K; + + OV_CPU_JIT_EMITTER_ASSERT(kernel, "Kernel hasn't been created"); + (*kernel)(&ctx); +} + +BrgemmCopyBKernelExecutor::BrgemmCopyBKernelExecutor(ov::intel_cpu::MultiCacheWeakPtr kernel_cache, BrgemmCopyBKernelConfig config) + : CPUKernelExecutor(std::move(kernel_cache), std::move(config)) { } + +std::shared_ptr BrgemmCopyBKernelExecutor::compile_kernel(const BrgemmCopyBKernelConfig& config) const { + std::shared_ptr compiled_kernel = std::make_shared(); + // BrgemmCopyB is not executable - nothing to compile + if (!config.is_empty()) { + compiled_kernel = std::make_shared(config); + OV_CPU_JIT_EMITTER_ASSERT(compiled_kernel, "compiled kernel is nullptr"); + compiled_kernel->create_kernel(); + } + + return compiled_kernel; +} + +void BrgemmCopyBKernelExecutor::update_config(const ov::snippets::lowered::ExpressionPtr& expr, + const ov::snippets::lowered::LinearIRCPtr& linear_ir, + BrgemmCopyBKernelConfig& config) const { + const auto& input_desc = expr->get_input_port_descriptor(0); + const auto& output_desc = expr->get_output_port_descriptor(0); + + // Need to update K, N + // 1. If the original value in subtensor is `FULL_DIM`, it means that + // BrgemmCopyB block should process full tensor by this dim -> take dimension from shape + // 2. Otherwise, BrgemmCopyB block processes part of the tensor by this dim + // (there is blocking by this dimension) -> take from Loop increment + + const auto planar_shape = ov::snippets::utils::get_planar_vdims(expr->get_input_port(0)); + const auto& in_subtensor = input_desc->get_subtensor(); + + size_t loop_idx = 0; + const auto& loop_ids = expr->get_loop_ids(); + const auto& loop_manager = linear_ir->get_loop_manager(); + + auto init = [&](size_t& dim, size_t& blk, size_t idx) { + OPENVINO_ASSERT(idx < planar_shape.size() && idx < in_subtensor.size(), "Index must be less than shape/subtensor rank!"); + dim = *(planar_shape.rbegin() + idx); + blk = *(in_subtensor.rbegin() + idx); + if (ov::snippets::utils::is_full_dim_value(blk)) { + blk = dim; + } else { + OPENVINO_ASSERT(loop_idx < loop_ids.size(), "Loop is missed"); + const auto& current_expanded_loop_info = loop_manager->get_loop_info(loop_ids[loop_idx++]); + blk = current_expanded_loop_info->get_increment(); + input_desc->set_subtensor_dim(idx, blk); + output_desc->set_subtensor_dim(idx, blk); + OV_CPU_JIT_EMITTER_ASSERT(blk <= dim, "BrgemmCopyB has incompatible subtensor dimensions"); + } + }; + + size_t K_dim, K_blk, N_dim, N_blk; + // Dimension K + init(K_dim, K_blk, 1); + // Dimension N + init(N_dim, N_blk, 0); + + const auto& brg_weight_etype = expr->get_node()->get_input_element_type(0); + const auto LDB = brgemm_utils::repacking::compute_out_leading_dim(N_dim, brg_weight_etype); + const auto copy_B_wei_stride = ov::snippets::utils::get_dim_stride(expr->get_input_port(0), config.is_transposed_B() ? 0 : 1) * brg_weight_etype.size(); + + config.update(N_dim, N_blk, K_dim, K_blk, copy_B_wei_stride, LDB); +} + +void BrgemmCopyBKernelExecutor::execute(const BrgemmCopyBKernelExecutor* executor, BrgemmCopyBKernel::call_args* args) { + auto kernel = executor->get_kernel(); + OV_CPU_JIT_EMITTER_ASSERT(kernel, "has nullptr kernel"); + OV_CPU_JIT_EMITTER_ASSERT(args, "has nullptr call args"); + (*kernel)(args); +} + +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm_copy_b.hpp b/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm_copy_b.hpp new file mode 100644 index 00000000000000..c4e3f3622ad88f --- /dev/null +++ b/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm_copy_b.hpp @@ -0,0 +1,155 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "emitters/plugin/x64/jit_emitter.hpp" +#include "emitters/snippets/jit_snippets_call_args.hpp" +#include "emitters/snippets/cpu_kernel_executor_table.hpp" + +#include +#include + + +namespace ov { +namespace intel_cpu { + +struct BrgemmCopyBKernelConfig : public snippets::KernelExecutorBase::GenericConfig { +public: + BrgemmCopyBKernelConfig() = default; + BrgemmCopyBKernelConfig(const element::Type& src_dt, const element::Type& wei_dt, dnnl::impl::cpu::x64::cpu_isa_t isa, + bool is_with_comp, bool is_transposed_B, dnnl_dim_t wei_N_blk); + + bool operator==(const BrgemmCopyBKernelConfig& rhs) const; + bool operator!=(const BrgemmCopyBKernelConfig& rhs) const {return !(*this == rhs);} + + std::unique_ptr get_clone_ptr() const override { + return std::unique_ptr(new BrgemmCopyBKernelConfig(*this)); + } + + bool is_empty() const; + bool is_completed() const override; + + void update(dnnl_dim_t N, dnnl_dim_t N_blk, dnnl_dim_t K, dnnl_dim_t K_blk, dnnl_dim_t copy_B_wei_stride, dnnl_dim_t LDB); + + size_t hash() const override { return m_hash; } + + dnnl_data_type_t get_src_dt() const { return m_static_params->src_dt; } + dnnl_data_type_t get_wei_dt() const { return m_static_params->wei_dt; } + + dnnl::impl::cpu::x64::cpu_isa_t get_isa() const { return m_static_params->isa; } + bool is_with_comp() const { return m_static_params->is_with_comp; } + bool is_transposed_B() const { return m_static_params->is_transposed_B; } + + dnnl_dim_t get_N() const { return m_N; } + dnnl_dim_t get_N_blk() const { return m_N_blk; } + dnnl_dim_t get_N_tail() const { return m_N % m_N_blk; } + dnnl_dim_t get_wei_N_blk() const { return m_static_params->wei_N_blk; } + dnnl_dim_t get_wei_N_tail() const { return m_N_blk % m_static_params->wei_N_blk; } + dnnl_dim_t get_K() const { return m_K; } + dnnl_dim_t get_K_blk() const { return m_K_blk; } + dnnl_dim_t get_copy_B_wei_stride() const { return m_copy_B_wei_stride; } + dnnl_dim_t get_LDB() const { return m_LDB; } + +#ifdef SNIPPETS_DEBUG_CAPS + std::string to_string() const override; +#endif + +private: + struct StaticParams { + StaticParams(const element::Type& src_dt, const element::Type& wei_dt, dnnl::impl::cpu::x64::cpu_isa_t isa, + bool is_with_comp, bool is_transposed_B, dnnl_dim_t wei_N_blk); + + const dnnl_data_type_t src_dt {dnnl_data_type_undef}, wei_dt {dnnl_data_type_undef}; + const dnnl::impl::cpu::x64::cpu_isa_t isa {dnnl::impl::cpu::x64::isa_undef}; + const bool is_with_comp {false}; + const bool is_transposed_B {false}; + const dnnl_dim_t wei_N_blk {0}; + const size_t hash {0}; + + bool operator==(const StaticParams& rhs) const; + bool operator!=(const StaticParams& rhs) const { return !(*this == rhs); } + +#ifdef SNIPPETS_DEBUG_CAPS + std::string to_string() const; +#endif + + private: + static size_t init_hash(const dnnl_data_type_t& src_dt, const dnnl_data_type_t& wei_dt, dnnl::impl::cpu::x64::cpu_isa_t primitive_isa, + bool is_with_comp, bool is_transposed_B, dnnl_dim_t wei_N_blk); + }; + + size_t compute_hash() const; + + std::shared_ptr m_static_params; + dnnl_dim_t m_N {0}, m_N_blk {0}; + dnnl_dim_t m_K {0}, m_K_blk {0}; + dnnl_dim_t m_copy_B_wei_stride {0}, m_LDB {0}; + size_t m_hash {SIZE_MAX}; +}; + +struct BrgemmCopyBKernel : public dnnl::impl::cpu::x64::jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(BrgemmCopyBKernel) + struct call_args { + const void* src = nullptr; + void* tr_src = nullptr; + void* compensation_ptr = nullptr; + }; + + BrgemmCopyBKernel(); + BrgemmCopyBKernel(const BrgemmCopyBKernelConfig& conf); + + dnnl::impl::status_t create_kernel() override; + + void operator()(const call_args* args) const; + +private: + void generate() override; + + void emit_brgemm_copy_b_kernel_call(size_t N, size_t K, size_t offset_in, size_t offset_out, size_t offset_comp); + + static void execute(dnnl::impl::cpu::x64::matmul::jit_brgemm_matmul_copy_b_t* kernel, const void* src, const void* dst, const void* comp, + size_t N, size_t K); + + void init_brgemm_copy_b_kernel(std::unique_ptr& kernel, + const BrgemmCopyBKernelConfig& conf) const; + + static constexpr auto abi_param_regs = dnnl::impl::cpu::x64::abi_param_regs; + const Xbyak::Reg64 src_reg = abi_param2; + const Xbyak::Reg64 tr_src_reg = abi_param3; + const Xbyak::Reg64 comp_reg = abi_param4; + + const bool is_with_comp = false; + const bool is_transpose = false; + const size_t wei_data_size = 1u; + const size_t vnni_factor = 1u; + const size_t K = 0; + const size_t N_blk = 0; + const size_t wei_N_blk = 0; + const size_t wei_N_tail = 0; + + // JIT kernel code of the current BrgemmCopyBKernel + void (*ker_)(const call_args*); + + // JIT kernel dnnl Brgemm copy b which is called in the current snippets BrgemmCopyBKernel + std::unique_ptr dnnl_brgemm_copy_b_kernel = nullptr; +}; + +class BrgemmCopyBKernelExecutor : public CPUKernelExecutor { +public: + BrgemmCopyBKernelExecutor(ov::intel_cpu::MultiCacheWeakPtr kernel_cache, BrgemmCopyBKernelConfig config); + + static void execute(const BrgemmCopyBKernelExecutor* executor, BrgemmCopyBKernel::call_args* args); + +protected: + std::shared_ptr compile_kernel(const BrgemmCopyBKernelConfig& c) const override; + + void update_config(const ov::snippets::lowered::ExpressionPtr& expr, + const ov::snippets::lowered::LinearIRCPtr& linear_ir, + BrgemmCopyBKernelConfig& config) const override; +}; +#define GET_OFF_BRGEMM_COPY_B_ARGS(field) offsetof(BrgemmCopyBKernel::call_args, field) + +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/emitters/snippets/x64/utils.cpp b/src/plugins/intel_cpu/src/emitters/snippets/x64/utils.cpp new file mode 100644 index 00000000000000..687a397addc208 --- /dev/null +++ b/src/plugins/intel_cpu/src/emitters/snippets/x64/utils.cpp @@ -0,0 +1,76 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "utils.hpp" + +#include "snippets/lowered/expressions/buffer_expression.hpp" +#include "snippets/op/memory_access.hpp" +#include "snippets/op/loop.hpp" + +#include "emitters/utils.hpp" + +using namespace dnnl::impl::cpu::x64; + +namespace ov { +namespace intel_cpu { +namespace utils { + +size_t get_buffer_cluster_id(const ov::snippets::lowered::ExpressionPort& port) { + auto get_cluster_id = [](const snippets::lowered::ExpressionPort& p) { + const auto buffer = ov::as_type_ptr(p.get_expr()); + return buffer ? buffer->get_cluster_id() : SIZE_MAX; + }; + const auto& ma_op = std::dynamic_pointer_cast(port.get_expr()->get_node()); + OPENVINO_ASSERT(ma_op, "Expected MemoryAccess op!"); + size_t offset = ov::snippets::utils::get_dynamic_value(); + size_t id = SIZE_MAX; + switch (port.get_type()) { + case ov::snippets::lowered::ExpressionPort::Type::Input: + offset = ma_op->get_input_offset(port.get_index()); + id = get_cluster_id(port.get_port_connector_ptr()->get_source()); + break; + case ov::snippets::lowered::ExpressionPort::Type::Output: + offset = ma_op->get_output_offset(port.get_index()); + for (const auto& child : port.get_connected_ports()) + if (!ov::is_type(child.get_expr()->get_node())) + id = get_cluster_id(child); + break; + default: + OV_CPU_JIT_EMITTER_THROW("Uknown type of expression port!"); + } + OV_CPU_JIT_EMITTER_ASSERT(IMPLICATION(ov::snippets::utils::is_dynamic_value(offset), id != SIZE_MAX), + "In dynamic case Buffer Cluster ID must be known!"); + return id; +} + +Xbyak::Reg64 get_aux_gpr(const std::vector& used_gpr_idxs) { + // RSP, RBP - stack-related registers, abi_param1 - runtime parameter register in the kernel + static std::unordered_set blacklist_gpr_idxs = { Xbyak::Operand::RSP, Xbyak::Operand::RBP, static_cast(abi_param1.getIdx()) }; + for (size_t gpr_idx = 0; gpr_idx <= Xbyak::Operand::R15; ++gpr_idx) { + size_t _idx = Xbyak::Operand::R15 - gpr_idx; // we allocate from the end + if (std::find(used_gpr_idxs.cbegin(), used_gpr_idxs.cend(), _idx) != used_gpr_idxs.cend()) continue; + if (blacklist_gpr_idxs.count(_idx) > 0) continue; + return Xbyak::Reg64(_idx); + } + OV_CPU_JIT_EMITTER_THROW("Failed to allocate aux GPR"); +} + +void push_ptr_with_runtime_offset_on_stack(dnnl::impl::cpu::x64::jit_generator* h, size_t stack_offset, + Xbyak::Reg64 ptr_reg, Xbyak::Reg64 aux_reg, size_t runtime_offset) { + const auto stack_frame = h->qword[h->rsp + stack_offset]; + h->mov(aux_reg, ptr_reg); + h->add(aux_reg, h->ptr[abi_param1 + runtime_offset]); + h->mov(stack_frame, aux_reg); +} + +void push_ptr_with_static_offset_on_stack(dnnl::impl::cpu::x64::jit_generator* h, size_t stack_offset, + Xbyak::Reg64 ptr_reg, size_t ptr_offset) { + const auto stack_frame = h->qword[h->rsp + stack_offset]; + h->mov(stack_frame, ptr_reg); + if (ptr_offset != 0) h->add(stack_frame, ptr_offset); +} + +} // namespace utils +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/emitters/snippets/x64/utils.hpp b/src/plugins/intel_cpu/src/emitters/snippets/x64/utils.hpp index bd1ddb5ad7957f..97ea86f404fd67 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/x64/utils.hpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/x64/utils.hpp @@ -1,8 +1,11 @@ // Copyright (C) 2024 Intel Corporation // SPDX-License-Identifier: Apache-2.0 // + +#pragma once + #include "cpu/x64/jit_generator.hpp" -#include "snippets/emitter.hpp" +#include "snippets/lowered/expression_port.hpp" namespace ov { namespace intel_cpu { @@ -20,6 +23,42 @@ inline static std::vector transform_snippets_regs_to_idxs(const std::vec return idxs; } +/** + * @brief If the passed `port` is connected to a Buffer, return its cluster ID. + * Otherwise returns SIZE_MAX + * @param port expression port of memory access op + * @return cluster ID of the connected Buffer or SIZE_MAX + */ +size_t get_buffer_cluster_id(const ov::snippets::lowered::ExpressionPort& port); + +/** + * @brief Find the available register from the pool excepting: abi_param1, RSP, RBP and `used_gpr_idxs` + * @param used_gpr_idxs current used gpr register indexes + * @return register + */ +Xbyak::Reg64 get_aux_gpr(const std::vector& used_gpr_idxs); + +/** + * @brief Push data pointer on stack adding offset. The offset is taken from runtime params `abi_param1` + * @param h generator + * @param stack_offset stack offset + * @param ptr_reg register contains data pointer + * @param aux_reg aux register + * @param runtime_offset offset in runtime params `abi_param1` + */ +void push_ptr_with_runtime_offset_on_stack(dnnl::impl::cpu::x64::jit_generator* h, size_t stack_offset, + Xbyak::Reg64 ptr_reg, Xbyak::Reg64 aux_reg, size_t runtime_offset); + +/** + * @brief Push data pointer on stack adding static offset `ptr_offset` + * @param h generator + * @param stack_offset stack offset + * @param ptr_reg register contains data pointer + * @param ptr_offset offset which will be added to data pointer + */ +void push_ptr_with_static_offset_on_stack(dnnl::impl::cpu::x64::jit_generator* h, size_t stack_offset, + Xbyak::Reg64 ptr_reg, size_t ptr_offset); + } // namespace utils } // namespace intel_cpu } // namespace ov \ No newline at end of file diff --git a/src/plugins/intel_cpu/src/emitters/snippets/x64/verbose.cpp b/src/plugins/intel_cpu/src/emitters/snippets/x64/verbose.cpp index e95d4447303185..78563bc00aa228 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/x64/verbose.cpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/x64/verbose.cpp @@ -97,16 +97,9 @@ std::string init_info_jit_brgemm_emitter(const jit_brgemm_emitter *emitter) { std::string init_info_jit_brgemm_copy_b_emitter(const jit_brgemm_copy_b_emitter *emitter) { std::stringstream ss; ss << "Emitter_type_name:jit_brgemm_copy_b_emitter" - << " m_brg_weight_etype:" << emitter->m_brg_weight_etype - << " m_N_blk:" << emitter->m_N_blk - << " m_inner_N_block:" << emitter->m_inner_N_block - << " m_inner_N_tail:" << emitter->m_inner_N_tail - << " m_K_blk:" << emitter->m_K_blk - << " m_brgemmVNNIFactor:" << emitter->m_brgemmVNNIFactor - << " m_in_offset:" << emitter->m_in_offset - << " m_out_offset:" << emitter->m_out_offset - << " m_comp_offset:" << emitter->m_comp_offset - << " m_with_comp:" << emitter->m_with_comp; + << emitter->m_kernel_executor->to_string() + << " m_memory_offset:" << vector_to_string(emitter->m_memory_offsets) + << " m_buffer_ids:" << vector_to_string(emitter->m_buffer_ids); return ss.str(); } diff --git a/src/plugins/intel_cpu/src/emitters/tpp/x64/jit_equation_emitter.cpp b/src/plugins/intel_cpu/src/emitters/tpp/x64/jit_equation_emitter.cpp index a1c43efa7d1849..1efa9d850e31de 100644 --- a/src/plugins/intel_cpu/src/emitters/tpp/x64/jit_equation_emitter.cpp +++ b/src/plugins/intel_cpu/src/emitters/tpp/x64/jit_equation_emitter.cpp @@ -4,6 +4,7 @@ #include "jit_equation_emitter.hpp" #include "transformations/tpp/x64/op/equation.hpp" +#include "emitters/plugin/x64/utils.hpp" using namespace Xbyak; using namespace dnnl::impl; @@ -90,7 +91,8 @@ void EquationTppEmitter::validate_arguments(const std::vector &in, const } void EquationTppEmitter::emit_impl(const std::vector& in, const std::vector& out) const { - internal_call_preamble(); + EmitABIRegSpills spill(h); + spill.preamble(); // save function address in gpr to pass in call instruction h->mov(h->rbp, get_execute_function_ptr()); @@ -115,13 +117,13 @@ void EquationTppEmitter::emit_impl(const std::vector& in, const std::vec h->mov(abi_param2, num_kernel_args); h->mov(abi_param3, h->rsp); - internal_call_rsp_align(); + spill.rsp_align(); h->call(h->rbp); - internal_call_rsp_restore(); + spill.rsp_restore(); // Free allocated memory on the stack h->add(h->rsp, num_kernel_args * sizeof(void*)); - internal_call_postamble(); + spill.postamble(); } void EquationTppEmitter::execute_kernel(libxsmm_meqn_function equation_kernel, int argc, void **argv) { diff --git a/src/plugins/intel_cpu/src/emitters/tpp/x64/jit_tpp_emitter.cpp b/src/plugins/intel_cpu/src/emitters/tpp/x64/jit_tpp_emitter.cpp index 70ddbb3d79ee21..cb18f69082e1b2 100644 --- a/src/plugins/intel_cpu/src/emitters/tpp/x64/jit_tpp_emitter.cpp +++ b/src/plugins/intel_cpu/src/emitters/tpp/x64/jit_tpp_emitter.cpp @@ -5,6 +5,7 @@ #include "jit_tpp_emitter.hpp" #include "snippets/lowered/port_descriptor.hpp" #include "transformations/tpp/x64/op/eltwise.hpp" +#include "emitters/plugin/x64/utils.hpp" using namespace Xbyak; using namespace dnnl::impl; @@ -77,7 +78,9 @@ void TppEmitter::emit_code(const std::vector &in, const std::vector& in, const std::vector& out) const { - internal_call_preamble(); + EmitABIRegSpills spill(h); + spill.preamble(); + // Note: 4 args is currently enough for unary and binary ops. // To enable ternary ops, we will have to pass extra regs on stack for Windows, std::array abi_params{abi_param1, abi_param2, abi_param3, abi_param4}; @@ -104,11 +107,11 @@ void TppEmitter::emit_impl(const std::vector& in, const std::vectorcall(h->rbp); - internal_call_rsp_restore(); + spill.rsp_restore(); - internal_call_postamble(); + spill.postamble(); } libxsmm_datatype TppEmitter::ov_to_xsmm_dtype(ov::element::Type_t elemet_type) { diff --git a/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_cpu.cpp b/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_cpu.cpp index dfe4441de90699..b40bd88f31726b 100644 --- a/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_cpu.cpp +++ b/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_cpu.cpp @@ -96,12 +96,11 @@ void BrgemmCPU::validate_and_infer_types() { void BrgemmCPU::validate_with_scratchpad() const { // Additional check for 3rd input - if (one_of(m_type, BRGEMM_TYPE::WITH_COMPENSATIONS, BRGEMM_TYPE::WITH_AMX)) { - const auto& pshape = get_input_partial_shape(2); - OPENVINO_ASSERT(pshape.is_static(), "BRGEMM Scratch must have static shape"); - if (with_compensations(m_type)) { - OPENVINO_ASSERT(get_input_element_type(2) == ov::element::f32, "BRGEMM Scratch with compensations must have FP32 element type"); - } + if (with_compensations(m_type)) { + OPENVINO_ASSERT(get_input_element_type(2) == ov::element::f32, "BRGEMM Scratch with compensations must have FP32 element type"); + } else if (with_amx(m_type)) { + OPENVINO_ASSERT(get_input_partial_shape(2).is_static(), "BRGEMM Scratch must have static shape"); + OPENVINO_ASSERT(get_input_element_type(2) == ov::element::u8, "BRGEMM Scratch must have U8 element type"); } } diff --git a/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_utils.cpp b/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_utils.cpp index af70218ce0635f..844ec338b8a83b 100644 --- a/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_utils.cpp +++ b/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_utils.cpp @@ -42,7 +42,7 @@ cpu_isa_t get_primitive_isa(const ov::element::Type& dt_in0, bool is_with_amx) { #undef SUPPORT } -BRGEMM_TYPE get_brgemm_type(const ov::element::Type& element_type_a, const Dimension& K_dim, const Dimension& N_dim, bool transpose_b) { +BRGEMM_TYPE get_brgemm_type(const ov::element::Type& element_type_a, const Dimension& K_dim, bool transpose_b) { if (element_type_a == element::f32) return transpose_b ? BRGEMM_TYPE::REPACKING_ONLY : BRGEMM_TYPE::STAND_ALONE; @@ -52,8 +52,7 @@ BRGEMM_TYPE get_brgemm_type(const ov::element::Type& element_type_a, const Dimen const auto brgemmVNNIFactor = 4 / element_type_a.size(); if (one_of(element_type_a, element::u8, element::i8, element::bf16) && dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_amx) && - K_dim.is_static() && K_dim.get_length() % brgemmVNNIFactor == 0 && - N_dim.is_static() && N_dim.get_length() % brgemmVNNIFactor == 0) + K_dim.is_static() && K_dim.get_length() % brgemmVNNIFactor == 0) return BRGEMM_TYPE::WITH_AMX; // Note: this condition reproduces logic from the OneDNN Brgemm implementation. This is needed to align with the // backend requirements. More details in onednn/src/cpu/x64/brgemm/brgemm_utils.cpp diff --git a/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_utils.hpp b/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_utils.hpp index d0360e45a62e18..bc627c59920c4b 100644 --- a/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_utils.hpp +++ b/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_utils.hpp @@ -22,7 +22,7 @@ enum class BRGEMM_TYPE { dnnl::impl::cpu::x64::cpu_isa_t get_primitive_isa(const ov::element::Type& dt_in0, bool is_with_amx); -BRGEMM_TYPE get_brgemm_type(const element::Type& element_type_a, const Dimension& K_dim, const Dimension& N_dim, bool transpose_b); +BRGEMM_TYPE get_brgemm_type(const element::Type& element_type_a, const Dimension& K_dim, bool transpose_b); inline bool stand_alone(BRGEMM_TYPE type) { return type == BRGEMM_TYPE::STAND_ALONE; } diff --git a/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/brgemm_to_brgemm_cpu.cpp b/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/brgemm_to_brgemm_cpu.cpp index 6dda47e47326aa..abb6147bac3588 100644 --- a/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/brgemm_to_brgemm_cpu.cpp +++ b/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/brgemm_to_brgemm_cpu.cpp @@ -32,9 +32,8 @@ using namespace snippets::lowered; namespace { template void set_full_port_desc(const T& port) { - const auto& shape = port.get_shape(); - static const std::vector full_dim_subtensor(std::min(shape.size(), size_t(2)), - ov::snippets::utils::get_full_dim_value()); + const auto& shape_rank = port.get_partial_shape().size(); + static const std::vector full_dim_subtensor(std::min(shape_rank, size_t(2)), ov::snippets::utils::get_full_dim_value()); PortDescriptorUtils::set_port_descriptor(port, full_dim_subtensor); } } // namespace @@ -62,7 +61,6 @@ pass::BrgemmToBrgemmCPU::BrgemmToBrgemmCPU() { const auto dimsMatMulIn1 = snippets::utils::get_planar_pshape(brgemm->input(1)); const auto K = *dimsMatMulIn0.rbegin(); - const auto N = *dimsMatMulIn1.rbegin(); const auto& layout_a = brgemm_in0_desc->get_layout(); const auto& layout_b = brgemm_in1_desc->get_layout(); @@ -70,7 +68,7 @@ pass::BrgemmToBrgemmCPU::BrgemmToBrgemmCPU() { const auto element_type_a = brgemm->get_input_element_type(0); const bool transpose_b = !layout_b.empty() && layout_b.back() != layout_b.size() - 1; - const auto brgemm_type = brgemm_utils::get_brgemm_type(element_type_a, K, N, transpose_b); + const auto brgemm_type = brgemm_utils::get_brgemm_type(element_type_a, K, transpose_b); const auto offset_a = brgemm->get_offset_a(); const auto offset_b = brgemm->get_offset_b(); const auto offset_c = brgemm->get_offset_c(); diff --git a/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/expressions/brgemm_copy_b_buffer_expressions.cpp b/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/expressions/brgemm_copy_b_buffer_expressions.cpp index 9d7adab2fdc09b..638cd8a1005e12 100644 --- a/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/expressions/brgemm_copy_b_buffer_expressions.cpp +++ b/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/expressions/brgemm_copy_b_buffer_expressions.cpp @@ -39,8 +39,6 @@ void RepackedWeightsBufferExpression::init_allocation_size(const std::shared_ptr const size_t n_blk = *in_subtensor.rbegin(); const size_t k_blk = *++in_subtensor.rbegin(); - OPENVINO_ASSERT(!ov::snippets::utils::is_dynamic_value(n_blk) && !ov::snippets::utils::is_dynamic_value(k_blk), - "RepackedWeightsBufferExpression supports only static subtensor values"); const auto& precision = get_node()->get_input_element_type(0); // Repacking buffer shape is set in accordance to OneDNN requirements @@ -49,14 +47,14 @@ void RepackedWeightsBufferExpression::init_allocation_size(const std::shared_ptr // In case of transpose, K dimension must be rounded-up to number of elems in vector register // For the details, please see 'transpose16x8' and 'fixup16x16' implementations and usage in onednn/src/cpu/x64/matmul/brgemm_matmul_copy_utils.cpp const auto elems_in_vec = brgemm_utils::get_elems_in_vec(precision); - m_allocation_size = N_dim * rnd_up(k_blk, elems_in_vec); + m_allocation_size = snippets::utils::dynamic_safe_mul(N_dim, snippets::utils::rnd_up(k_blk, elems_in_vec)); } else { // Low precision repacking writes the result by m_brgemmVNNIFactor * m_inner_n_block blocks // despite the actual size of the input data. Because of that we have to round-up the allocation shape to always have enough memory allocated. // For the details, please see 'copy_4x64' and 'copy_2x32' implementations and usage in onednn/src/cpu/x64/matmul/brgemm_matmul_copy_utils.cpp const auto brgemmVNNIFactor = brgemm_utils::compute_vnni_factor(precision); OPENVINO_ASSERT(brgemmVNNIFactor > 0, "brgemmVNNIFactor value must be positive."); - m_allocation_size = N_dim * rnd_up(k_blk, brgemmVNNIFactor); + m_allocation_size = snippets::utils::dynamic_safe_mul(N_dim, snippets::utils::rnd_up(k_blk, brgemmVNNIFactor)); } } @@ -77,14 +75,16 @@ void CompensationsBufferExpression::validate() const { void CompensationsBufferExpression::init_allocation_size(const std::shared_ptr& loop_manager, size_t allocation_rank) { const auto& parent_expr = get_input_port_connector(0)->get_source().get_expr(); - const auto& in_subtensor = ov::snippets::utils::get_projected_subtensor(parent_expr->get_input_port(0)); - const size_t n_blk = *in_subtensor.rbegin(); - OPENVINO_ASSERT(!ov::snippets::utils::is_dynamic_value(n_blk), "CompensationsBufferExpression supports only static subtensor values"); - const auto& precision = parent_expr->get_node()->get_input_element_type(0); // Compensations are computed during repacking, so we need to round-up allocation shape according to m_inner_n_block // because of OneDNN implementation nuances (as in get_repacking_buffer_size). // However, the compensations are computed by N dimension, so K dimension doesn't affect the compensations buffer - m_allocation_size = std::max(n_blk, compute_inner_n_block(precision)); + const size_t n_blk = *ov::snippets::utils::get_projected_subtensor(parent_expr->get_input_port(0)).rbegin(); + if (snippets::utils::is_dynamic_value(n_blk)) { + m_allocation_size = snippets::utils::get_dynamic_value(); + } else { + const auto& precision = parent_expr->get_node()->get_input_element_type(0); + m_allocation_size = std::max(n_blk, compute_inner_n_block(precision)); + } } } // namespace intel_cpu diff --git a/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp b/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp index 6be213f9d066da..ff496aeffeefa9 100644 --- a/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp +++ b/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp @@ -968,9 +968,15 @@ void Transformations::MainSnippets(void) { return false; const auto in_type0 = matmul->get_input_element_type(0); const auto in_type1 = matmul->get_input_element_type(1); - if (in_type0 == ov::element::f16 || in_type1 == ov::element::f16) + const auto is_fp32 = (in_type0 == ov::element::f32 && in_type1 == ov::element::f32 && + one_of(config.inferencePrecision, element::f32, element::undefined)); + const auto is_fp16 = (in_type0 == ov::element::f16 || in_type1 == ov::element::f16); + const auto is_bf16 = (in_type0 == ov::element::bf16 && in_type1 == ov::element::bf16) || + ((in_type0 == element::f32 && in_type1 == ov::element::f32 && config.inferencePrecision == ov::element::bf16)); + const auto is_int8 = in_type0 == ov::element::i8; + if (is_fp16) return false; - if (in_type0 == ov::element::f32 && in_type1 == ov::element::f32 && one_of(config.inferencePrecision, element::f32, element::undefined)) + if (is_fp32) return true; // Only FP32 dynamic MHA is supported if (matmul->is_dynamic()) @@ -979,23 +985,19 @@ void Transformations::MainSnippets(void) { // The current solution with ExtractExplicitMatMulTranspose pass is slower for non-f32 cases than using of brgemm_copy_b kernel if (matmul->get_transpose_a() || matmul->get_transpose_b()) return false; - if (in_type0 == ov::element::i8) + // [150842] The execution of Brgemm INT8/BF16 on AMX platforms depends on the value of "K % VNNIFactor". + // For more details, please teake a look at the ticket 150842 + if (dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_amx)) { + const auto& b_shape = matmul->get_input_partial_shape(1); + const auto K = matmul->get_transpose_b() ? *b_shape.rbegin() : *++b_shape.rbegin(); + if (is_bf16) return K.is_static() && (K.get_length() % 2 == 0); + if (is_int8) return K.is_static(); + } + if (is_int8) return dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_vnni) || dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx2_vnni); - if ((in_type0 == ov::element::bf16 && in_type1 == ov::element::bf16) || - ((in_type0 == element::f32 && in_type1 == ov::element::f32 && config.inferencePrecision == ov::element::bf16))) { - // Implementation calls AMX BF16 brgemm only for tensors with K and N aligned on 2, otherwise fallbacks on vector impl - // Vector madd BF16 instruction on SPR has reduced performance on HW level, which results in overall perf degradation - size_t bf16Factor = 2; - if (dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_amx)) { - const auto& b_shape = matmul->get_input_partial_shape(1); - const auto K = matmul->get_transpose_b() ? *b_shape.rbegin() : *++b_shape.rbegin(); - const auto N = matmul->get_transpose_b() ? *++b_shape.rbegin() : *b_shape.rbegin(); - return K.is_static() && (K.get_length() % bf16Factor == 0) && - N.is_static() && (N.get_length() % bf16Factor == 0); - } + if (is_bf16) return dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_bf16); - } return true; }; auto is_unsupported_parallel_work_amount = [&](const std::shared_ptr& n, const ov::PartialShape& shape) { @@ -1135,17 +1137,19 @@ void Transformations::MainSnippets(void) { auto mm_supports_transpose_b = [this, ignoreCallback](const std::shared_ptr& n) { MAYBE_UNUSED(config.inferencePrecision); - const auto& b_shape = n->get_input_partial_shape(1); - if (!ignoreCallback || b_shape.is_dynamic()) + if (!ignoreCallback) return false; // Note: BrgemmTPP doesn't support transposed KN natively // so we should extract transposes for the corresponding matmul nodes #if defined(SNIPPETS_LIBXSMM_TPP) + // TPP doesn't support dynamic shapes -> there will be BrgemmCPU node + if (n->is_dynamic()) + return true; std::vector> layouts(3); const auto matmul = ov::as_type_ptr(n); OPENVINO_ASSERT(matmul, "ExplicitTransposeMatMulInputs callback must be called for matmul node"); if (matmul->get_transpose_b()) { - std::vector transposed_layout(b_shape.size()); + std::vector transposed_layout(n->get_input_partial_shape(1).size()); std::iota(transposed_layout.begin(), transposed_layout.end(), 0); std::swap(*transposed_layout.rbegin(), *(transposed_layout.rbegin() + 1)); layouts[1] = std::move(transposed_layout); diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/skip_tests_config.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/skip_tests_config.cpp index 1c7fd22e018eb6..c4e5af875323ae 100644 --- a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/skip_tests_config.cpp +++ b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/skip_tests_config.cpp @@ -562,6 +562,13 @@ std::vector disabledTestPatterns() { retVector.emplace_back(R"(.*smoke_Snippets_MHAWOTransposeEnforceBF16.*)"); retVector.emplace_back(R"(.*smoke_Snippets_MHAEnforceBF16.*)"); } + // [150842] Need to support dynamic K dimension of BF16|INT8 MatMul on AMX systems + if (ov::with_cpu_x86_avx512_core_amx()) { + retVector.emplace_back(R"(.*smoke_Snippets_MatMul/MatMul.CompareWithRefImpl/.*IS\[0\]=\[2.2.70.\?\].*T\[0\]=(u8|i8|bf16)_T\[1\]=(i8|bf16).*)"); + retVector.emplace_back(R"(.*smoke_Snippets_MatMul/MatMul.CompareWithRefImpl/.*IS\[0\]=\[\?.\?.\?.\?\].*T\[0\]=(u8|i8|bf16)_T\[1\]=(i8|bf16).*)"); + retVector.emplace_back(R"(.*smoke_Snippets_MatMulTransposeB.*IS\[0\]=\[\?.\?.\?.\?\].*T\[0\]=(u8|i8|bf16)_T\[1\]=(i8|bf16).*)"); + retVector.emplace_back(R"(.*smoke_Snippets_MatMulBias.*IS\[0\]=\[\?.\?.\?.\?\].*T\[0\]=(u8|i8|bf16)_T\[1\]=(i8|bf16).*)"); + } #ifdef SNIPPETS_LIBXSMM_TPP // GN in TPP requires exposing tmp Buffer results outside the loop (ticket: 151234) retVector.emplace_back(R"(.*smoke_Snippets_GroupNormalization.*)"); diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/matmul.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/matmul.cpp index b0e8d58da2f0b2..f5057137f9b65c 100644 --- a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/matmul.cpp +++ b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/matmul.cpp @@ -14,18 +14,6 @@ namespace snippets { #define STATIC_SHAPES(...) static_shapes_to_test_representation(std::vector>{__VA_ARGS__}) namespace { -const auto& input_shapes = STATIC_SHAPES( - {{2, 1, 3, 5}, {1, 3, 5, 3}}, - {{3, 1, 32, 14}, {1, 2, 14, 32}}, - {{1, 2, 37, 23}, {2, 1, 23, 37}}, - {{1, 1, 37, 23}, {1, 2, 23, 33}}, - {{1, 1, 32, 23}, {1, 1, 23, 68}}, - {{1, 16, 384, 64}, {1, 16, 64, 384}}, - {{1, 1, 100, 700}, {1, 1, 700, 100}}, - {{1, 1, 100, 1024}, {1, 1, 1024, 100}}, - {{1, 1, 100, 2500}, {1, 1, 2500, 100}}, - {{1, 1, 100, 4500}, {1, 1, 4500, 100}}, -); static inline std::vector> quantized_precisions() { std::vector> prc = {}; @@ -37,164 +25,130 @@ static inline std::vector> quantized_precisions() { return prc; } -static inline std::vector> precisions(bool only_fp32 = true) { +static inline std::vector> precisions() { std::vector> prc = { - {element::f32, element::f32}, + {element::f32, element::f32}, }; // Note: TPP doesn't support low precisions yet #ifndef SNIPPETS_LIBXSMM_TPP - if (!only_fp32) { - auto quant = quantized_precisions(); - std::copy(quant.begin(), quant.end(), std::back_inserter(prc)); - // In Snippets MatMul BF16 is supported only on bf16/AMX platforms - if (ov::with_cpu_x86_bfloat16() || ov::with_cpu_x86_avx512_core_amx_bf16()) { - prc.emplace_back(std::vector{element::bf16, element::bf16}); - } + auto quant = quantized_precisions(); + std::copy(quant.begin(), quant.end(), std::back_inserter(prc)); + // In Snippets MatMul BF16 is supported only on bf16/AMX platforms + if (ov::with_cpu_x86_bfloat16() || ov::with_cpu_x86_avx512_core_amx_bf16()) { + prc.emplace_back(std::vector{element::bf16, element::bf16}); } #endif return prc; } -INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MatMult, MatMul, - ::testing::Combine( - ::testing::ValuesIn(input_shapes), - ::testing::ValuesIn(precisions(false)), - ::testing::Values(MatMulType::MatMul), - ::testing::Values(1), // MatMul - ::testing::Values(1), // Tokenized MatMul - ::testing::Values(ov::test::utils::DEVICE_CPU)), - MatMul::getTestCaseName); - -std::vector> input_shapes_dynamic{ - // All dimensions are dynamic - { - {PartialShape{-1, -1, -1, -1}, {{2, 1, 32, 64}, {2, 2, 10, 20}, {2, 2, 100, 80}, - {2, 2, 10, 20}, {2, 1, 32, 64}, {2, 3, 64, 55}}}, - {PartialShape{-1, -1, -1, -1}, {{1, 3, 64, 128}, {2, 2, 20, 30}, {2, 2, 80, 120}, - {2, 2, 20, 30}, {1, 3, 64, 128}, {2, 3, 55, 128}}} - }, - // Only M dimension is dynamic + one one loop by M - { - {PartialShape{-1, 2, -1, 64}, {{2, 2, 64, 64}, {2, 2, 64, 64}, {2, 2, 35, 64}, - {2, 2, 120, 64}, {2, 2, 15, 64}, {2, 2, 35, 64}}}, - {PartialShape{-1, 2, 64, 32}, {{2, 2, 64, 32}, {2, 2, 64, 32}, {1, 2, 64, 32}, - {1, 2, 64, 32}, {2, 2, 64, 32}, {1, 2, 64, 32}}} - }, - // Only M dimension is dynamic + all Loops (by M, N, K) - { - {PartialShape{2, 2, -1, 550}, {{2, 2, 64, 550}, {2, 2, 16, 550}, {2, 2, 35, 550}, - {2, 2, 16, 550}, {2, 2, 70, 550}, {2, 2, 64, 550}}}, - {PartialShape{2, 1, 550, 70}, {{2, 1, 550, 70}, {2, 1, 550, 70}, {2, 1, 550, 70}, - {2, 1, 550, 70}, {2, 1, 550, 70}, {2, 1, 550, 70}}} - }, - // Only K dimension is dynamic - { - {PartialShape{2, 2, 70, -1}, {{2, 2, 70, 512}, {2, 2, 70, 10}, {2, 2, 70, 33}, {2, 2, 70, 2000}, {2, 2, 70, 35}, {2, 2, 70, 600}}}, - {PartialShape{2, 2, -1, 70}, {{2, 2, 512, 70}, {2, 2, 10, 70}, {2, 2, 33, 70}, {2, 2, 2000, 70}, {2, 2, 35, 70}, {2, 2, 600, 70}}} - }, - // Only N dimension is dynamic - { - {PartialShape{}, {{2, 2, 65, 550}}}, - {PartialShape{2, 2, 550, -1}, {{2, 2, 550, 70}, {2, 2, 550, 12}, {2, 2, 550, 70}, - {2, 2, 550, 12}, {2, 2, 550, 10}, {2, 2, 550, 64} }} - }, +std::vector> input_shapes{ + { {{}, {{2, 1, 3, 5}}}, {{}, {{1, 3, 5, 3}}} }, + { {{}, {{3, 1, 32, 14}}}, {{}, {{1, 3, 14, 37}}} }, + { {{}, {{1, 2, 37, 23}}}, {{}, {{2, 1, 23, 37}}} }, + { {{}, {{1, 1, 32, 23}}}, {{}, {{1, 1, 23, 68}}} }, + { {{}, {{1, 16, 384, 64}}}, {{}, {{1, 16, 64, 384}}} }, + { {{}, {{1, 1, 100, 700}}}, {{}, {{1, 1, 700, 100}}} }, + { {{}, {{1, 1, 100, 1024}}}, {{}, {{1, 1, 1024, 100}}} }, + { {{}, {{1, 1, 100, 2500}}}, {{}, {{1, 1, 2500, 100}}} }, + { {{}, {{1, 1, 100, 4500}}}, {{}, {{1, 1, 4500, 100}}} }, + // Only M dimension is dynamic + one one loop by M + { + {PartialShape{-1, 2, -1, 64}, {{2, 2, 64, 64}, {2, 2, 64, 64}, {2, 2, 35, 64}, + {2, 2, 120, 64}, {2, 2, 15, 64}, {2, 2, 35, 64}}}, + {PartialShape{-1, 2, 64, 32}, {{2, 2, 64, 32}, {2, 2, 64, 32}, {1, 2, 64, 32}, + {1, 2, 64, 32}, {2, 2, 64, 32}, {1, 2, 64, 32}}} + }, + // Only M dimension is dynamic + all Loops (by M, N, K) + { + {PartialShape{2, 2, -1, 550}, {{2, 2, 64, 550}, {2, 2, 16, 550}, {2, 2, 35, 550}, + {2, 2, 16, 550}, {2, 2, 70, 550}, {2, 2, 64, 550}}}, + {PartialShape{2, 1, 550, 70}, {{2, 1, 550, 70}, {2, 1, 550, 70}, {2, 1, 550, 70}, + {2, 1, 550, 70}, {2, 1, 550, 70}, {2, 1, 550, 70}}} + }, + // All dimensions are dynamic + { + {PartialShape{-1, -1, -1, -1}, {{2, 1, 32, 64}, {2, 2, 10, 20}, {2, 2, 100, 80}, + {2, 2, 10, 20}, {2, 1, 32, 64}, {2, 3, 64, 55}}}, + {PartialShape{-1, -1, -1, -1}, {{1, 3, 64, 128}, {2, 2, 20, 30}, {2, 2, 80, 120}, + {2, 2, 20, 30}, {1, 3, 64, 128}, {2, 3, 55, 128}}} + }, + // Only K dimension is dynamic + { + {PartialShape{2, 2, 70, -1}, {{2, 2, 70, 512}, {2, 2, 70, 10}, {2, 2, 70, 33}, {2, 2, 70, 2000}, {2, 2, 70, 35}, {2, 2, 70, 600}}}, + {PartialShape{2, 2, -1, 70}, {{2, 2, 512, 70}, {2, 2, 10, 70}, {2, 2, 33, 70}, {2, 2, 2000, 70}, {2, 2, 35, 70}, {2, 2, 600, 70}}} + }, + // Only N dimension is dynamic + { + {PartialShape{}, {{2, 2, 65, 550}}}, + {PartialShape{2, 2, 550, -1}, {{2, 2, 550, 70}, {2, 2, 550, 12}, {2, 2, 550, 70}, + {2, 2, 550, 12}, {2, 2, 550, 10}, {2, 2, 550, 64} }} + } }; -INSTANTIATE_TEST_SUITE_P(smoke_Snippets_DynMatMul, MatMul, - ::testing::Combine( - ::testing::ValuesIn(input_shapes_dynamic), - ::testing::ValuesIn(precisions(true)), - ::testing::Values(MatMulType::MatMul), - ::testing::Values(1), // MatMul - ::testing::Values(1), // Tokenized MatMul - ::testing::Values(ov::test::utils::DEVICE_CPU)), - MatMul::getTestCaseName); - -INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MatMulFQ, MatMulFQ, +INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MatMul, MatMul, ::testing::Combine( ::testing::ValuesIn(input_shapes), ::testing::ValuesIn(precisions()), ::testing::Values(MatMulType::MatMul), - ::testing::Values(1), // MatMul; + ::testing::Values(1), // MatMul ::testing::Values(1), // Tokenized MatMul ::testing::Values(ov::test::utils::DEVICE_CPU)), MatMul::getTestCaseName); -INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MatMulEltwiseChain, MatMulEltwiseChain, - ::testing::Combine( - ::testing::ValuesIn(input_shapes), - ::testing::ValuesIn(precisions()), - ::testing::Values(MatMulType::MatMul), - ::testing::Values(1), // MatMul - ::testing::Values(1), // Tokenized MatMul - ::testing::Values(ov::test::utils::DEVICE_CPU)), - MatMul::getTestCaseName); - -std::vector> matmul_cascade_shapes{ +std::vector> transpose_b_shapes{ + { {{}, {{3, 3, 64, 64}}}, {{}, {{3, 3, 64, 64}}} }, + { {{}, {{1, 1, 32, 128}}}, {{}, {{1, 1, 64, 128}}} }, + { {{}, {{1, 1, 32, 128}}}, {{}, {{1, 1, 384, 128}}} }, + { {{}, {{1, 1, 64, 1500}}}, {{}, {{1, 1, 420, 1500}}} }, + { {{}, {{1, 1, 64, 1024}}}, {{}, {{1, 1, 420, 1024}}} }, + { {{}, {{4, 8, 32, 1024}}}, {{}, {{4, 8, 420, 1024}}} }, + // All dimensions are dynamic + { + {PartialShape{-1, -1, -1, -1}, {{2, 1, 32, 64}, {2, 2, 10, 20}, {2, 2, 100, 600}, {2, 1, 32, 64}}}, + {PartialShape{-1, -1, -1, -1}, {{1, 3, 128, 64}, {2, 2, 30, 20}, {2, 2, 120, 600}, {1, 3, 128, 64}}} + }, + // Only M is dynamic { - {PartialShape{-1, -1, -1, -1}, {{2, 1, 32, 2500}, {1, 3, 80, 700}, {2, 1, 32, 2500}}}, - {PartialShape{-1, -1, -1, -1}, {{1, 2, 2500, 128}, {1, 3, 700, 150}, {1, 2, 2500, 128}}}, - {PartialShape{-1, -1, -1, -1}, {{1, 1, 128, 64}, {1, 3, 150, 128}, {1, 1, 128, 64}}}, + {PartialShape{2, 2, -1, 64}, {{2, 2, 40, 64}, {2, 2, 16, 64}}}, + {PartialShape{2, 2, 100, 64}, {{2, 2, 100, 64}, {2, 2, 100, 64}}} }, + // Only N is static + { + {PartialShape{2, 2, -1, 100}, {{2, 2, 32, 100}, {2, 2, 10, 100}, {2, 2, 10, 100}}}, + {PartialShape{2, 2, -1, 100}, {{2, 2, 100, 100}, {2, 2, 64, 100}, {2, 2, 100, 100}}} + } }; -INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MatMulEltwiseChainCascade, MatMulEltwiseChainCascade, - ::testing::Combine( - ::testing::ValuesIn(matmul_cascade_shapes), - ::testing::ValuesIn(precisions()), - ::testing::Values(MatMulType::MatMul), - ::testing::Values(1), // MatMul - ::testing::Values(1), // Tokenized MatMul - ::testing::Values(ov::test::utils::DEVICE_CPU)), - MatMul::getTestCaseName); - -const auto& transpose_b_shapes = STATIC_SHAPES( - {{3, 3, 64, 64}, {3, 3, 64, 64}}, - {{1, 1, 32, 128}, {1, 1, 64, 128}}, - {{1, 1, 32, 128}, {1, 1, 384, 128}}, - {{1, 1, 64, 1500}, {1, 1, 420, 1500}}, - {{1, 1, 64, 1024}, {1, 1, 420, 1024}}, - {{4, 8, 32, 1024}, {4, 8, 420, 1024}}, -); - INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MatMulTransposeB, MatMulTransposeB, ::testing::Combine( ::testing::ValuesIn(transpose_b_shapes), - ::testing::ValuesIn(precisions(false)), + ::testing::ValuesIn(precisions()), ::testing::Values(MatMulType::MatMul), ::testing::Values(1), // MatMul ::testing::Values(1), // Tokenized MatMul ::testing::Values(ov::test::utils::DEVICE_CPU)), MatMul::getTestCaseName); -INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MatMulBias, MatMulBias, - ::testing::Combine( - ::testing::ValuesIn(STATIC_SHAPES({{1, 2, 69, 43}, {2, 1, 43, 49}, {1, 1, 69, 49}}, - {{1, 2, 95, 1023}, {1, 2, 1023, 255}, {1, 2, 95, 255}})), - ::testing::ValuesIn(precisions(false)), - ::testing::Values(MatMulType::MatMul), - ::testing::Values(1), // Subgraph; - ::testing::Values(1), // Tokenized MatMul+Bias - ::testing::Values(ov::test::utils::DEVICE_CPU)), - MatMul::getTestCaseName); - -std::vector> input_shapes_dynamic_bias{ - { - {PartialShape{-1, -1, -1, -1}, {{1, 2, 69, 43}, {1, 2, 95, 1023}, {1, 2, 69, 43}}}, - {PartialShape{-1, -1, -1, -1}, {{2, 1, 43, 49}, {1, 2, 1023, 255}, {2, 1, 43, 49}}}, - {PartialShape{-1, -1, -1, -1}, {{1, 1, 69, 49}, {1, 2, 95, 255}, {1, 1, 69, 49}}} - }, - { - {PartialShape{-1, -1, -1, -1}, {{2, 2, 16, 32}, {2, 2, 16, 32}, {2, 2, 16, 32}, {2, 2, 16, 32}}}, - {PartialShape{-1, -1, -1, -1}, {{2, 2, 32, 18}, {2, 2, 32, 18}, {2, 2, 32, 1}, {2, 2, 32, 1}}}, - {PartialShape{-1, -1, -1, -1}, {{1, 1, 16, 18}, {1, 1, 16, 1}, {1, 1, 16, 18}, {1, 1, 16, 1}}} - }, +std::vector> input_shapes_bias{ + { {{}, {{1, 2, 69, 43}}}, {{}, {{2, 1, 43, 49}}}, {{}, {{1, 1, 69, 49}}} }, + { {{}, {{1, 2, 95, 1023}}}, {{}, {{1, 2, 1023, 255}}}, {{}, {{1, 2, 95, 255}}} }, + { + {PartialShape{-1, -1, -1, -1}, {{1, 2, 69, 43}, {1, 2, 95, 1023}, {1, 2, 69, 43}}}, + {PartialShape{-1, -1, -1, -1}, {{2, 1, 43, 49}, {1, 2, 1023, 255}, {2, 1, 43, 49}}}, + {PartialShape{-1, -1, -1, -1}, {{1, 1, 69, 49}, {1, 2, 95, 255}, {1, 1, 69, 49}}} + }, + { + {PartialShape{-1, -1, -1, -1}, {{2, 2, 16, 32}, {2, 2, 16, 32}, {2, 2, 16, 32}, {2, 2, 16, 32}}}, + {PartialShape{-1, -1, -1, -1}, {{2, 2, 32, 18}, {2, 2, 32, 18}, {2, 2, 32, 1}, {2, 2, 32, 1}}}, + {PartialShape{-1, -1, -1, -1}, {{1, 1, 16, 18}, {1, 1, 16, 1}, {1, 1, 16, 18}, {1, 1, 16, 1}}} + } }; -INSTANTIATE_TEST_SUITE_P(smoke_Snippets_DynMatMulBias, MatMulBias, + +INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MatMulBias, MatMulBias, ::testing::Combine( - ::testing::ValuesIn(input_shapes_dynamic_bias), - ::testing::ValuesIn(precisions(true)), + ::testing::ValuesIn(input_shapes_bias), + ::testing::ValuesIn(precisions()), ::testing::Values(MatMulType::MatMul), ::testing::Values(1), // Subgraph; ::testing::Values(1), // Tokenized MatMul+Bias @@ -203,8 +157,7 @@ INSTANTIATE_TEST_SUITE_P(smoke_Snippets_DynMatMulBias, MatMulBias, INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MatMulBiasQuantized, MatMulBiasQuantized, ::testing::Combine( - ::testing::ValuesIn(STATIC_SHAPES({{1, 2, 69, 43}, {2, 1, 43, 49}, {1, 2, 1, 1}}, - {{1, 2, 69, 43}, {2, 1, 43, 49}, {1, 2, 69, 49}})), + ::testing::ValuesIn(input_shapes_bias), ::testing::ValuesIn(quantized_precisions()), ::testing::Values(MatMulType::MatMul), ::testing::Values(1), // Subgraph @@ -231,7 +184,8 @@ INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MatMulsQuantizedSoftmax, MatMulsQuantize ::testing::Values(2), // Tokenized [MatMul+FQ+Matmul] and [FQ] ::testing::Values(ov::test::utils::DEVICE_CPU)), MatMul::getTestCaseName); -} // namespace + +} // namespace } // namespace snippets } // namespace test } // namespace ov \ No newline at end of file diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/mha.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/mha.cpp index 62edcba0de74e3..79db0b1546b2a8 100644 --- a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/mha.cpp +++ b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/mha.cpp @@ -463,7 +463,23 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(CPUTestUtils::empty_plugin_config)), MHA::getTestCaseName); -const auto& inputShapesTransposedB = STATIC_SHAPES({{1, 12, 12, 64}, {1, 12, 48, 64}, {1, 12, 48, 64}}); +std::vector> inputShapesTransposedB { + { + {{}, {{1, 12, 12, 64}}}, + {{}, {{1, 12, 48, 64}}}, + {{}, {{1, 12, 48, 64}}} + }, + { + {PartialShape{-1, 3, -1, 64}, {{1, 3, 12, 64}, {2, 3, 36, 64}}}, + {PartialShape{-1, 3, -1, 64}, {{1, 3, 14, 64}, {2, 3, 42, 64}}}, + {PartialShape{-1, 3, -1, -1}, {{1, 3, 14, 36}, {2, 3, 42, 36}}}, + }, + { + {PartialShape{2, -1, 32, -1}, {{2, 1, 32, 70}, {2, 2, 32, 96}}}, + {PartialShape{2, -1, 49, -1}, {{2, 3, 49, 70}, {2, 1, 49, 96}}}, + {PartialShape{2, -1, 49, -1}, {{2, 1, 49, 17}, {2, 2, 49, 81}}}, + }, +}; INSTANTIATE_TEST_SUITE_P( smoke_Snippets_MHATransposedB, @@ -479,32 +495,6 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(CPUTestUtils::empty_plugin_config)), MHA::getTestCaseName); -std::vector> inputShapesTransposedB_dynamic{ - { - {PartialShape{-1, 3, -1, 64}, {{1, 3, 12, 64}, {2, 3, 36, 64}}}, - {PartialShape{-1, 3, -1, 64}, {{1, 3, 14, 64}, {2, 3, 42, 64}}}, - {PartialShape{-1, 3, -1, -1}, {{1, 3, 14, 36}, {2, 3, 42, 36}}}, - }, - { - {PartialShape{2, -1, 32, -1}, {{2, 1, 32, 70}, {2, 2, 32, 96}}}, - {PartialShape{2, -1, 49, -1}, {{2, 3, 49, 70}, {2, 1, 49, 96}}}, - {PartialShape{2, -1, 49, -1}, {{2, 1, 49, 17}, {2, 2, 49, 81}}}, - }, -}; -INSTANTIATE_TEST_SUITE_P( - smoke_Snippets_DynMHATransposedB, - MHATransposedB, - ::testing::Combine(::testing::ValuesIn(inputShapesTransposedB_dynamic), - ::testing::Values(std::vector{}), - ::testing::Values(ov::element::f32), - ::testing::ValuesIn({true}), // Need to support False for graph builder in tests - ::testing::Values(MHA::default_thread_count), - ::testing::Values(2), - ::testing::Values(1), - ::testing::Values(ov::test::utils::DEVICE_CPU), - ::testing::Values(CPUTestUtils::empty_plugin_config)), - MHA::getTestCaseName); - const auto& inputShapesExtractedReshape = STATIC_SHAPES( {{2, 196, 64}, {2, 64, 196}, {2, 14, 14, 14, 1}, {2, 14, 14, 1, 14}, {2, 196, 64}}, {{1, 16, 10}, {1, 10, 16}, {1, 4, 4, 4, 1}, {1, 4, 4, 1, 4}, {1, 16, 10}},