Skip to content

Commit

Permalink
[Snippets][CPU] Added BrgemmCopyA op
Browse files Browse the repository at this point in the history
  • Loading branch information
a-sidorova committed Oct 3, 2024
1 parent 0b8650a commit a07a398
Show file tree
Hide file tree
Showing 32 changed files with 1,275 additions and 463 deletions.
102 changes: 55 additions & 47 deletions src/plugins/intel_cpu/src/emitters/plugin/x64/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,46 +5,52 @@
#include "utils.hpp"

#include "emitters/utils.hpp"
#include "utils/general_utils.h"

namespace ov {
namespace intel_cpu {

using namespace Xbyak;
using namespace dnnl::impl::cpu::x64;

EmitABIRegSpills::EmitABIRegSpills(jit_generator* h) : h(h), isa(get_isa()) {}
EmitABIRegSpills::EmitABIRegSpills(jit_generator* h, Type type) : h(h), isa(get_isa()), type(type) {
OPENVINO_ASSERT(one_of(type, Type::GPRS, Type::VECS, Type::ALL), "Incorrect type");
}

EmitABIRegSpills::~EmitABIRegSpills() {
OPENVINO_ASSERT(spill_status, "postamble or preamble is missed");
OPENVINO_ASSERT(rsp_status, "rsp_align or rsp_restore is missed");
}

void EmitABIRegSpills::preamble() {
// gprs
Xbyak::Operand gprs_to_save[] = {h->r8, h->r9, h->r10, h->r11, h->r12, h->r13, h->r14, h->r15,
h->rax, h->rbx, h->rcx, h->rdx, h->rdi, h->rsi, h->rbp};
size_t n_gprs_to_save = sizeof(gprs_to_save) / sizeof(gprs_to_save[0]);

h->sub(h->rsp, n_gprs_to_save * gpr_size);
for (size_t i = 0; i < n_gprs_to_save; ++i)
h->mov(h->ptr[h->rsp + i * gpr_size], gprs_to_save[i]);

if (isa == avx512_core) {
h->sub(h->rsp, k_mask_num * k_mask_size);
for (size_t i = 0; i < k_mask_num; ++i) {
h->kmovq(h->ptr[h->rsp + i * k_mask_size], Xbyak::Opmask(static_cast<int>(i)));
}
if (type & Type::GPRS) {
Xbyak::Operand gprs_to_save[] = {h->r8, h->r9, h->r10, h->r11, h->r12, h->r13, h->r14, h->r15,
h->rax, h->rbx, h->rcx, h->rdx, h->rdi, h->rsi, h->rbp};
size_t n_gprs_to_save = sizeof(gprs_to_save) / sizeof(gprs_to_save[0]);

h->sub(h->rsp, n_gprs_to_save * gpr_size);
for (size_t i = 0; i < n_gprs_to_save; ++i)
h->mov(h->ptr[h->rsp + i * gpr_size], gprs_to_save[i]);
}

h->sub(h->rsp, get_max_vecs_count() * get_vec_length());
for (size_t i = 0; i < get_max_vecs_count(); ++i) {
const auto addr = h->ptr[h->rsp + i * get_vec_length()];
if (isa == sse41) {
h->uni_vmovups(addr, Xmm(i));
} else if (isa == avx2) {
h->uni_vmovups(addr, Ymm(i));
} else {
h->uni_vmovups(addr, Zmm(i));
if (type & Type::VECS) {
if (isa == avx512_core) {
h->sub(h->rsp, k_mask_num * k_mask_size);
for (size_t i = 0; i < k_mask_num; ++i) {
h->kmovq(h->ptr[h->rsp + i * k_mask_size], Xbyak::Opmask(static_cast<int>(i)));
}
}

h->sub(h->rsp, get_max_vecs_count() * get_vec_length());
for (size_t i = 0; i < get_max_vecs_count(); ++i) {
const auto addr = h->ptr[h->rsp + i * get_vec_length()];
if (isa == sse41) {
h->uni_vmovups(addr, Xmm(i));
} else if (isa == avx2) {
h->uni_vmovups(addr, Ymm(i));
} else {
h->uni_vmovups(addr, Zmm(i));
}
}
}

Expand All @@ -53,34 +59,36 @@ void EmitABIRegSpills::preamble() {
}

void EmitABIRegSpills::postamble() {
// restore vector registers
for (int i = static_cast<int>(get_max_vecs_count()) - 1; i >= 0; --i) {
const auto addr = h->ptr[h->rsp + i * get_vec_length()];
if (isa == sse41) {
h->uni_vmovups(Xmm(i), addr);
} else if (isa == avx2) {
h->uni_vmovups(Ymm(i), addr);
} else {
h->uni_vmovups(Zmm(i), addr);
if (type & Type::VECS) {
for (int i = static_cast<int>(get_max_vecs_count()) - 1; i >= 0; --i) {
const auto addr = h->ptr[h->rsp + i * get_vec_length()];
if (isa == sse41) {
h->uni_vmovups(Xmm(i), addr);
} else if (isa == avx2) {
h->uni_vmovups(Ymm(i), addr);
} else {
h->uni_vmovups(Zmm(i), addr);
}
}
}
h->add(h->rsp, (get_max_vecs_count()) * get_vec_length());
h->add(h->rsp, (get_max_vecs_count()) * get_vec_length());

// restore k reg
if (isa == avx512_core) {
for (int i = k_mask_num - 1; i >= 0; --i) {
h->kmovq(Xbyak::Opmask(i), h->ptr[h->rsp + i * k_mask_size]);
if (isa == avx512_core) {
for (int i = k_mask_num - 1; i >= 0; --i) {
h->kmovq(Xbyak::Opmask(i), h->ptr[h->rsp + i * k_mask_size]);
}
h->add(h->rsp, k_mask_num * k_mask_size);
}
h->add(h->rsp, k_mask_num * k_mask_size);
}

// restore gpr registers
Xbyak::Operand gprs_to_save[] = {h->r8, h->r9, h->r10, h->r11, h->r12, h->r13, h->r14, h->r15,
h->rax, h->rbx, h->rcx, h->rdx, h->rdi, h->rsi, h->rbp};
size_t n_gprs_to_save = sizeof(gprs_to_save) / sizeof(gprs_to_save[0]);
for (int i = n_gprs_to_save - 1; i >= 0; --i)
h->mov(gprs_to_save[i], h->ptr[h->rsp + i * gpr_size]);
h->add(h->rsp, n_gprs_to_save * gpr_size);
if (type & Type::GPRS) {
// restore gpr registers
Xbyak::Operand gprs_to_save[] = {h->r8, h->r9, h->r10, h->r11, h->r12, h->r13, h->r14, h->r15,
h->rax, h->rbx, h->rcx, h->rdx, h->rdi, h->rsi, h->rbp};
size_t n_gprs_to_save = sizeof(gprs_to_save) / sizeof(gprs_to_save[0]);
for (int i = n_gprs_to_save - 1; i >= 0; --i)
h->mov(gprs_to_save[i], h->ptr[h->rsp + i * gpr_size]);
h->add(h->rsp, n_gprs_to_save * gpr_size);
}

// Update the status
spill_status = true;
Expand Down
9 changes: 8 additions & 1 deletion src/plugins/intel_cpu/src/emitters/plugin/x64/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,13 @@ namespace intel_cpu {
// The class emit register spills for the possible call of external binary code
class EmitABIRegSpills {
public:
EmitABIRegSpills(dnnl::impl::cpu::x64::jit_generator* h);
enum Type {
GPRS = 1 << 1, // spill only general-purpose regisers
VECS = 1 << 2, // spill only vector regisers
ALL = GPRS | VECS, // default, spill vector and general-purpose registers
};

EmitABIRegSpills(dnnl::impl::cpu::x64::jit_generator* h, Type type = Type::ALL);
~EmitABIRegSpills();

// push (save) all registers on the stack
Expand All @@ -35,6 +41,7 @@ class EmitABIRegSpills {

dnnl::impl::cpu::x64::jit_generator* h {nullptr};
const dnnl::impl::cpu::x64::cpu_isa_t isa {dnnl::impl::cpu::x64::cpu_isa_t::isa_undef};
const Type type {Type::ALL};

static constexpr int k_mask_size = 8;
static constexpr int k_mask_num = 8;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ namespace intel_cpu {
#define SNIPPETS_MAX_DATA_PTR_COUNT 11
#endif

// Maximum count of Buffer offsets (clusters)
#define SNIPPETS_MAX_BUFFER_COUNT 16

#define GET_OFF(field) offsetof(jit_snippets_call_args, field)
#define GET_OFF_LOOP_ARGS(field) offsetof(jit_snippets_call_args::loop_args_t, field)

Expand All @@ -46,7 +49,7 @@ struct jit_snippets_call_args {
// for all non-static data members. So we can keep them public or friend all control-flow emitters
loop_args_t* loop_args = nullptr;
amx_tile_config_t amx_tile_config;
size_t buffer_offsets[SNIPPETS_MAX_DATA_PTR_COUNT] = {};
size_t buffer_offsets[SNIPPETS_MAX_BUFFER_COUNT] = {};
};

struct jit_snippets_call_args::loop_args_t {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "snippets/snippets_isa.hpp"
#include "emitters/snippets/cpu_runtime_configurator.hpp"

#include "emitters/snippets/x64/jit_brgemm_copy_a_emitter.hpp"
#include "emitters/snippets/x64/jit_brgemm_copy_b_emitter.hpp"
#include "emitters/snippets/x64/jit_brgemm_emitter.hpp"
#include "emitters/snippets/x64/jit_memory_emitters.hpp"
Expand All @@ -23,6 +24,7 @@
#include "transformations/snippets/common/op/load_convert.hpp"
#include "transformations/snippets/common/op/store_convert.hpp"
#include "transformations/snippets/common/op/fused_mul_add.hpp"
#include "transformations/snippets/x64/op/brgemm_copy_a.hpp"
#include "transformations/snippets/x64/op/brgemm_copy_b.hpp"
#include "transformations/snippets/x64/op/brgemm_cpu.hpp"
#include "transformations/snippets/x64/op/perf_count_rdtsc.hpp"
Expand Down Expand Up @@ -243,6 +245,9 @@ intel_cpu::CPUTargetMachine::CPUTargetMachine(dnnl::impl::cpu::x64::cpu_isa_t ho
jitters[intel_cpu::BrgemmCPU::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(intel_cpu::jit_brgemm_emitter,
configurator->get_kernel_executor_table(),
compiled_kernel_cache);
jitters[intel_cpu::BrgemmCopyA::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(intel_cpu::jit_brgemm_copy_a_emitter,
configurator->get_kernel_executor_table(),
compiled_kernel_cache);
jitters[intel_cpu::BrgemmCopyB::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(intel_cpu::jit_brgemm_copy_b_emitter,
configurator->get_kernel_executor_table(),
compiled_kernel_cache);
Expand Down Expand Up @@ -356,6 +361,7 @@ ov::snippets::RegType intel_cpu::CPUGenerator::get_specific_op_out_reg_type(cons
std::dynamic_pointer_cast<intel_cpu::tpp::modifier::TensorProcessingPrimitive>(op) ||
std::dynamic_pointer_cast<intel_cpu::tpp::op::Scalar>(op) ||
#endif
std::dynamic_pointer_cast<intel_cpu::BrgemmCopyA>(op)||
std::dynamic_pointer_cast<intel_cpu::BrgemmCopyB>(op))
return ov::snippets::RegType::gpr;
else if (
Expand All @@ -368,7 +374,8 @@ ov::snippets::RegType intel_cpu::CPUGenerator::get_specific_op_out_reg_type(cons

bool intel_cpu::CPUGenerator::uses_precompiled_kernel(const std::shared_ptr<snippets::Emitter>& e) const {
bool need = std::dynamic_pointer_cast<intel_cpu::jit_brgemm_emitter>(e) ||
std::dynamic_pointer_cast<intel_cpu::jit_brgemm_copy_b_emitter>(e);
std::dynamic_pointer_cast<intel_cpu::jit_brgemm_copy_b_emitter>(e) ||
std::dynamic_pointer_cast<intel_cpu::jit_brgemm_copy_a_emitter>(e);
#ifdef SNIPPETS_DEBUG_CAPS
const auto cpu_target_machine = std::dynamic_pointer_cast<intel_cpu::CPUTargetMachine>(target);
need = need || (cpu_target_machine && cpu_target_machine->debug_config.enable_segfault_detector) ||
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "jit_brgemm_copy_a_emitter.hpp"

#include "emitters/plugin/x64/utils.hpp"
#include "emitters/snippets/x64/utils.hpp"
#include "emitters/snippets/jit_snippets_call_args.hpp"

#include "snippets/utils/utils.hpp"

#include "transformations/snippets/x64/op/brgemm_copy_a.hpp"
#include "transformations/snippets/x64/op/brgemm_utils.hpp"


using namespace dnnl::impl::cpu::x64;
using namespace ov::intel_cpu::brgemm_utils;
using namespace ov::snippets::utils;

namespace ov {
namespace intel_cpu {

jit_brgemm_copy_a_emitter::jit_brgemm_copy_a_emitter(jit_generator* h, cpu_isa_t isa, const ov::snippets::lowered::ExpressionPtr& expr,
const snippets::KernelExecutorTablePtr& kernel_table,
const ov::intel_cpu::MultiCacheWeakPtr& compiled_kernel_cache)
: jit_emitter(h, isa) {
in_out_type_ = emitter_in_out_map::gpr_to_gpr;
const auto brgemm_repack = ov::as_type_ptr<ov::intel_cpu::BrgemmCopyA>(expr->get_node());
OV_CPU_JIT_EMITTER_ASSERT(brgemm_repack, "expects BrgemmCopyA node");

// Note: even if the BrgemmCopyA node is dynamic, the first shapeInfer and RuntimeConfigurator::update()
// are performed before the BrgemmCopyAKernelExecutor registration. So we have to trigger update() manually
// for both static and the 1st dynamic shapes.
OV_CPU_JIT_EMITTER_ASSERT(!snippets::utils::is_dynamic_vdims(expr->get_input_port_descriptor(0)->get_shape()),
"Jit emitter is called when the shapes are unknown");

const auto& brgemm_config = brgemm_repack->get_config();
BrgemmCopyAKernelConfig kernel_config(brgemm_repack->get_input_element_type(0), brgemm_config.isa());
m_kernel_executor = kernel_table->register_kernel<BrgemmCopyAKernelExecutor>(expr, compiled_kernel_cache, kernel_config);

m_memory_offsets = {brgemm_repack->get_offset_in(), brgemm_repack->get_offset_out()};
m_buffer_ids = {utils::get_buffer_cluster_id(expr->get_input_port(0)), utils::get_buffer_cluster_id(expr->get_output_port(0))};
}

void jit_brgemm_copy_a_emitter::validate_arguments(const std::vector<size_t> &in, const std::vector<size_t> &out) const {
OV_CPU_JIT_EMITTER_ASSERT(in.size() == 1 && out.size() == 1, "expects 1 input and 1 output");
}

void jit_brgemm_copy_a_emitter::emit_impl(const std::vector<size_t>& in, const std::vector<size_t>& out) const {
validate_arguments(in, out);

std::vector<size_t> mem_ptrs_idxs{in[0], out[0]};

EmitABIRegSpills spill(h);
spill.preamble();

h->mov(h->rbp, reinterpret_cast<uint64_t>(BrgemmCopyAKernelExecutor::execute));
auto reserved_stack_size = sizeof(BrgemmCopyAKernel::call_args);
// Reserve memory on the stack
h->sub(h->rsp, reserved_stack_size);

const bool is_dynamic_case = std::any_of(m_memory_offsets.cbegin(), m_memory_offsets.cend(), ov::snippets::utils::is_dynamic_value<size_t>);
Xbyak::Reg64 aux_reg = is_dynamic_case ? ov::intel_cpu::utils::get_aux_gpr(mem_ptrs_idxs) : Xbyak::Reg64();

const std::vector<size_t> args_offsets {GET_OFF_BRGEMM_COPY_A_ARGS(src), GET_OFF_BRGEMM_COPY_A_ARGS(tr_src)};
const auto& mem_ptrs = ov::intel_cpu::utils::transform_idxs_to_regs(mem_ptrs_idxs);
for (size_t i = 0; i < mem_ptrs.size(); i++) {
if (ov::snippets::utils::is_dynamic_value(m_memory_offsets[i]))
utils::push_ptr_with_runtime_offset_on_stack(h, args_offsets[i], mem_ptrs[i], aux_reg,
GET_OFF(buffer_offsets) + m_buffer_ids[i] * sizeof(size_t));
else
utils::push_ptr_with_static_offset_on_stack(h, args_offsets[i], mem_ptrs[i], m_memory_offsets[i]);
}

h->mov(abi_param1, reinterpret_cast<uintptr_t>(m_kernel_executor.get()));
h->mov(abi_param2, h->rsp);

spill.rsp_align();
h->call(h->rbp);
spill.rsp_restore();

h->add(h->rsp, reserved_stack_size);

spill.postamble();
}

} // namespace intel_cpu
} // namespace ov
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "emitters/plugin/x64/jit_emitter.hpp"

#include "kernel_executors/brgemm_copy_a.hpp"


namespace ov {
namespace intel_cpu {

class jit_brgemm_copy_a_emitter : public jit_emitter {
public:
jit_brgemm_copy_a_emitter(dnnl::impl::cpu::x64::jit_generator* h, dnnl::impl::cpu::x64::cpu_isa_t isa,
const ov::snippets::lowered::ExpressionPtr& expr,
const snippets::KernelExecutorTablePtr& kernel_table,
const ov::intel_cpu::MultiCacheWeakPtr& compiled_kernel_cache);

size_t get_inputs_num() const override {return 1;}
static std::set<std::vector<element::Type>> get_supported_precisions(const std::shared_ptr<ov::Node>& node = nullptr) {
return {{element::i8}, {element::u8}, {element::bf16}};
}

private:
void validate_arguments(const std::vector<size_t> &in, const std::vector<size_t> &out) const override;
void emit_impl(const std::vector<size_t>& in, const std::vector<size_t>& out) const override;

std::vector<size_t> m_memory_offsets{};
std::vector<size_t> m_buffer_ids{};
std::shared_ptr<BrgemmCopyAKernelExecutor> m_kernel_executor {nullptr};

#ifdef SNIPPETS_DEBUG_CAPS
friend std::string init_info_jit_brgemm_copy_a_emitter(const jit_brgemm_copy_a_emitter *emitter);
#endif
};

} // namespace intel_cpu
} // namespace ov
Original file line number Diff line number Diff line change
Expand Up @@ -48,18 +48,14 @@ jit_brgemm_copy_b_emitter::jit_brgemm_copy_b_emitter(jit_generator* h, cpu_isa_t
OV_CPU_JIT_EMITTER_ASSERT(!snippets::utils::is_dynamic_vdims(expr->get_input_port_descriptor(0)->get_shape()),
"Jit emitter is called when the shapes are unknown");

const auto& in_subtensor = get_projected_subtensor(expr->get_input_port(0));
const auto K_blk = *++in_subtensor.rbegin();

const auto& src_prc = brgemm_repack->get_src_element_type();
const auto& wei_prc = brgemm_repack->get_input_element_type(0);
const auto wei_N_blk = brgemm_utils::repacking::compute_inner_n_block(wei_prc);
const auto is_transposed = get_is_transposed(expr);
const auto brgemm_type = get_brgemm_type(src_prc, K_blk, is_transposed);
const auto primitive_isa = brgemm_utils::get_primitive_isa(src_prc, with_amx(brgemm_type));
m_with_comp = with_compensations(brgemm_type);
const auto& brgemm_config = brgemm_repack->get_config();
m_with_comp = brgemm_config.need_compensations();

BrgemmCopyBKernelConfig kernel_config(src_prc, wei_prc, primitive_isa, m_with_comp, is_transposed, wei_N_blk);
BrgemmCopyBKernelConfig kernel_config(src_prc, wei_prc, brgemm_config.isa(), m_with_comp, is_transposed, wei_N_blk);
m_kernel_executor = kernel_table->register_kernel<BrgemmCopyBKernelExecutor>(expr, compiled_kernel_cache, kernel_config);

m_memory_offsets = {brgemm_repack->get_offset_in(), brgemm_repack->get_offset_out()};
Expand Down
Loading

0 comments on commit a07a398

Please sign in to comment.