Skip to content

Commit

Permalink
[Snippets][CPU] Added BrgemmCopyA op
Browse files Browse the repository at this point in the history
  • Loading branch information
a-sidorova committed Oct 2, 2024
1 parent 0b8650a commit b2e39d3
Show file tree
Hide file tree
Showing 29 changed files with 1,023 additions and 368 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "snippets/snippets_isa.hpp"
#include "emitters/snippets/cpu_runtime_configurator.hpp"

#include "emitters/snippets/x64/jit_brgemm_copy_a_emitter.hpp"
#include "emitters/snippets/x64/jit_brgemm_copy_b_emitter.hpp"
#include "emitters/snippets/x64/jit_brgemm_emitter.hpp"
#include "emitters/snippets/x64/jit_memory_emitters.hpp"
Expand All @@ -23,6 +24,7 @@
#include "transformations/snippets/common/op/load_convert.hpp"
#include "transformations/snippets/common/op/store_convert.hpp"
#include "transformations/snippets/common/op/fused_mul_add.hpp"
#include "transformations/snippets/x64/op/brgemm_copy_a.hpp"
#include "transformations/snippets/x64/op/brgemm_copy_b.hpp"
#include "transformations/snippets/x64/op/brgemm_cpu.hpp"
#include "transformations/snippets/x64/op/perf_count_rdtsc.hpp"
Expand Down Expand Up @@ -243,6 +245,9 @@ intel_cpu::CPUTargetMachine::CPUTargetMachine(dnnl::impl::cpu::x64::cpu_isa_t ho
jitters[intel_cpu::BrgemmCPU::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(intel_cpu::jit_brgemm_emitter,
configurator->get_kernel_executor_table(),
compiled_kernel_cache);
jitters[intel_cpu::BrgemmCopyA::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(intel_cpu::jit_brgemm_copy_a_emitter,
configurator->get_kernel_executor_table(),
compiled_kernel_cache);
jitters[intel_cpu::BrgemmCopyB::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(intel_cpu::jit_brgemm_copy_b_emitter,
configurator->get_kernel_executor_table(),
compiled_kernel_cache);
Expand Down Expand Up @@ -356,6 +361,7 @@ ov::snippets::RegType intel_cpu::CPUGenerator::get_specific_op_out_reg_type(cons
std::dynamic_pointer_cast<intel_cpu::tpp::modifier::TensorProcessingPrimitive>(op) ||
std::dynamic_pointer_cast<intel_cpu::tpp::op::Scalar>(op) ||
#endif
std::dynamic_pointer_cast<intel_cpu::BrgemmCopyA>(op)||
std::dynamic_pointer_cast<intel_cpu::BrgemmCopyB>(op))
return ov::snippets::RegType::gpr;
else if (
Expand All @@ -368,7 +374,8 @@ ov::snippets::RegType intel_cpu::CPUGenerator::get_specific_op_out_reg_type(cons

bool intel_cpu::CPUGenerator::uses_precompiled_kernel(const std::shared_ptr<snippets::Emitter>& e) const {
bool need = std::dynamic_pointer_cast<intel_cpu::jit_brgemm_emitter>(e) ||
std::dynamic_pointer_cast<intel_cpu::jit_brgemm_copy_b_emitter>(e);
std::dynamic_pointer_cast<intel_cpu::jit_brgemm_copy_b_emitter>(e) ||
std::dynamic_pointer_cast<intel_cpu::jit_brgemm_copy_a_emitter>(e);
#ifdef SNIPPETS_DEBUG_CAPS
const auto cpu_target_machine = std::dynamic_pointer_cast<intel_cpu::CPUTargetMachine>(target);
need = need || (cpu_target_machine && cpu_target_machine->debug_config.enable_segfault_detector) ||
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "jit_brgemm_copy_a_emitter.hpp"

#include "emitters/plugin/x64/utils.hpp"
#include "emitters/snippets/x64/utils.hpp"
#include "emitters/snippets/jit_snippets_call_args.hpp"

#include "snippets/utils/utils.hpp"

#include "transformations/snippets/x64/op/brgemm_copy_a.hpp"
#include "transformations/snippets/x64/op/brgemm_utils.hpp"


using namespace dnnl::impl::cpu::x64;
using namespace ov::intel_cpu::brgemm_utils;
using namespace ov::snippets::utils;

namespace ov {
namespace intel_cpu {

jit_brgemm_copy_a_emitter::jit_brgemm_copy_a_emitter(jit_generator* h, cpu_isa_t isa, const ov::snippets::lowered::ExpressionPtr& expr,
const snippets::KernelExecutorTablePtr& kernel_table,
const ov::intel_cpu::MultiCacheWeakPtr& compiled_kernel_cache)
: jit_emitter(h, isa) {
in_out_type_ = emitter_in_out_map::gpr_to_gpr;
const auto brgemm_repack = ov::as_type_ptr<ov::intel_cpu::BrgemmCopyA>(expr->get_node());
OV_CPU_JIT_EMITTER_ASSERT(brgemm_repack, "expects BrgemmCopyA node");

// Note: even if the BrgemmCopyA node is dynamic, the first shapeInfer and RuntimeConfigurator::update()
// are performed before the BrgemmCopyAKernelExecutor registration. So we have to trigger update() manually
// for both static and the 1st dynamic shapes.
OV_CPU_JIT_EMITTER_ASSERT(!snippets::utils::is_dynamic_vdims(expr->get_input_port_descriptor(0)->get_shape()),
"Jit emitter is called when the shapes are unknown");

const auto& brgemm_config = brgemm_repack->get_config();
BrgemmCopyAKernelConfig kernel_config(brgemm_repack->get_input_element_type(0), brgemm_config.isa());
m_kernel_executor = kernel_table->register_kernel<BrgemmCopyAKernelExecutor>(expr, compiled_kernel_cache, kernel_config);

m_memory_offsets = {brgemm_repack->get_offset_in(), brgemm_repack->get_offset_out()};
m_buffer_ids = {utils::get_buffer_cluster_id(expr->get_input_port(0)), utils::get_buffer_cluster_id(expr->get_output_port(0))};
}

void jit_brgemm_copy_a_emitter::validate_arguments(const std::vector<size_t> &in, const std::vector<size_t> &out) const {
OV_CPU_JIT_EMITTER_ASSERT(in.size() == 1 && out.size() == 1, "expects 1 input and 1 output");
}

void jit_brgemm_copy_a_emitter::emit_impl(const std::vector<size_t>& in, const std::vector<size_t>& out) const {
validate_arguments(in, out);

std::vector<size_t> mem_ptrs_idxs{in[0], out[0]};

EmitABIRegSpills spill(h);
spill.preamble();

h->mov(h->rbp, reinterpret_cast<uint64_t>(BrgemmCopyAKernelExecutor::execute));
auto reserved_stack_size = sizeof(BrgemmCopyAKernelExecutor::call_args);
// Reserve memory on the stack
h->sub(h->rsp, reserved_stack_size);

const bool is_dynamic_case = std::any_of(m_memory_offsets.cbegin(), m_memory_offsets.cend(), ov::snippets::utils::is_dynamic_value<size_t>);
Xbyak::Reg64 aux_reg = is_dynamic_case ? ov::intel_cpu::utils::get_aux_gpr(mem_ptrs_idxs) : Xbyak::Reg64();

const std::vector<size_t> args_offsets {GET_OFF_BRGEMM_COPY_A_ARGS(src), GET_OFF_BRGEMM_COPY_A_ARGS(tr_src)};
const auto& mem_ptrs = ov::intel_cpu::utils::transform_idxs_to_regs(mem_ptrs_idxs);
for (size_t i = 0; i < mem_ptrs.size(); i++) {
if (ov::snippets::utils::is_dynamic_value(m_memory_offsets[i]))
utils::push_ptr_with_runtime_offset_on_stack(h, args_offsets[i], mem_ptrs[i], aux_reg,
GET_OFF(buffer_offsets) + m_buffer_ids[i] * sizeof(size_t));
else
utils::push_ptr_with_static_offset_on_stack(h, args_offsets[i], mem_ptrs[i], m_memory_offsets[i]);
}

h->mov(abi_param1, reinterpret_cast<uintptr_t>(m_kernel_executor.get()));
h->mov(abi_param2, h->rsp);

spill.rsp_align();
h->call(h->rbp);
spill.rsp_restore();

h->add(h->rsp, reserved_stack_size);

spill.postamble();
}

} // namespace intel_cpu
} // namespace ov
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "emitters/plugin/x64/jit_emitter.hpp"

#include "kernel_executors/brgemm_copy_a.hpp"


namespace ov {
namespace intel_cpu {

class jit_brgemm_copy_a_emitter : public jit_emitter {
public:
jit_brgemm_copy_a_emitter(dnnl::impl::cpu::x64::jit_generator* h, dnnl::impl::cpu::x64::cpu_isa_t isa,
const ov::snippets::lowered::ExpressionPtr& expr,
const snippets::KernelExecutorTablePtr& kernel_table,
const ov::intel_cpu::MultiCacheWeakPtr& compiled_kernel_cache);

size_t get_inputs_num() const override {return 1;}
static std::set<std::vector<element::Type>> get_supported_precisions(const std::shared_ptr<ov::Node>& node = nullptr) {
return {{element::i8}, {element::u8}, {element::bf16}};
}

private:
void validate_arguments(const std::vector<size_t> &in, const std::vector<size_t> &out) const override;
void emit_impl(const std::vector<size_t>& in, const std::vector<size_t>& out) const override;

std::vector<size_t> m_memory_offsets{};
std::vector<size_t> m_buffer_ids{};
std::shared_ptr<BrgemmCopyAKernelExecutor> m_kernel_executor {nullptr};

#ifdef SNIPPETS_DEBUG_CAPS
friend std::string init_info_jit_brgemm_copy_a_emitter(const jit_brgemm_copy_a_emitter *emitter);
#endif
};

} // namespace intel_cpu
} // namespace ov
Original file line number Diff line number Diff line change
Expand Up @@ -48,18 +48,14 @@ jit_brgemm_copy_b_emitter::jit_brgemm_copy_b_emitter(jit_generator* h, cpu_isa_t
OV_CPU_JIT_EMITTER_ASSERT(!snippets::utils::is_dynamic_vdims(expr->get_input_port_descriptor(0)->get_shape()),
"Jit emitter is called when the shapes are unknown");

const auto& in_subtensor = get_projected_subtensor(expr->get_input_port(0));
const auto K_blk = *++in_subtensor.rbegin();

const auto& src_prc = brgemm_repack->get_src_element_type();
const auto& wei_prc = brgemm_repack->get_input_element_type(0);
const auto wei_N_blk = brgemm_utils::repacking::compute_inner_n_block(wei_prc);
const auto is_transposed = get_is_transposed(expr);
const auto brgemm_type = get_brgemm_type(src_prc, K_blk, is_transposed);
const auto primitive_isa = brgemm_utils::get_primitive_isa(src_prc, with_amx(brgemm_type));
m_with_comp = with_compensations(brgemm_type);
const auto& brgemm_config = brgemm_repack->get_config();
m_with_comp = brgemm_config.need_compensations();

BrgemmCopyBKernelConfig kernel_config(src_prc, wei_prc, primitive_isa, m_with_comp, is_transposed, wei_N_blk);
BrgemmCopyBKernelConfig kernel_config(src_prc, wei_prc, brgemm_config.isa(), m_with_comp, is_transposed, wei_N_blk);
m_kernel_executor = kernel_table->register_kernel<BrgemmCopyBKernelExecutor>(expr, compiled_kernel_cache, kernel_config);

m_memory_offsets = {brgemm_repack->get_offset_in(), brgemm_repack->get_offset_out()};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,10 @@ jit_brgemm_emitter::jit_brgemm_emitter(jit_generator* h, cpu_isa_t isa,
const auto& brgemm_node = as_type_ptr<ov::intel_cpu::BrgemmCPU>(expr->get_node());
const auto& brg0Prc = brgemm_node->get_input_element_type(0);
const auto& brg1Prc = brgemm_node->get_input_element_type(1);
const auto brgemm_type = brgemm_node->get_type();
BrgemmKernelConfig kernel_config(brg0Prc, brg1Prc, with_amx(brgemm_type), with_compensations(brgemm_type),
brgemm_utils::get_primitive_isa(brg0Prc, with_amx(brgemm_type)));
m_kernel_executor = kernel_table->register_kernel<BrgemmKernelExecutor>(expr,
compiled_kernel_cache,
kernel_config);
const auto& brgemm_config = brgemm_node->get_config();
BrgemmKernelConfig kernel_config(brg0Prc, brg1Prc, brgemm_config.is_amx(), brgemm_config.need_compensations(), brgemm_config.isa());
m_kernel_executor = kernel_table->register_kernel<BrgemmKernelExecutor>(expr, compiled_kernel_cache, kernel_config);

// Note: even if the Brgemm node is dynamic, the first shapeInfer and RuntimeConfigurator::update()
// are performed before the BrgemmKernelExecutor registration. So we have to trigger update() manually
// for both static and the 1st dynamic shapes.
Expand All @@ -42,7 +40,7 @@ jit_brgemm_emitter::jit_brgemm_emitter(jit_generator* h, cpu_isa_t isa,
m_memory_offsets = {brgemm_node->get_offset_a(), brgemm_node->get_offset_b(), brgemm_node->get_offset_c()};
m_buffer_ids = {utils::get_buffer_cluster_id(expr->get_input_port(0)), utils::get_buffer_cluster_id(expr->get_input_port(1)),
utils::get_buffer_cluster_id(expr->get_output_port(0))};
if (with_scratchpad(brgemm_type)) {
if (brgemm_node->get_input_size() == 3) {
m_memory_offsets.push_back(brgemm_node->get_offset_scratch());
m_buffer_ids.push_back(utils::get_buffer_cluster_id(expr->get_input_port(2)));
}
Expand All @@ -51,29 +49,28 @@ jit_brgemm_emitter::jit_brgemm_emitter(jit_generator* h, cpu_isa_t isa,
std::set<std::vector<element::Type>> jit_brgemm_emitter::get_supported_precisions(const std::shared_ptr<ov::Node>& node) {
const auto brgemm = as_type_ptr<ov::intel_cpu::BrgemmCPU>(node);
OV_CPU_JIT_EMITTER_ASSERT(brgemm, "get_supported_precisions() expects BrgemmCPU node");
using brgemm_utils::BRGEMM_TYPE;
if (brgemm->get_type() == BRGEMM_TYPE::STAND_ALONE) {
return {{element::f32, element::f32}};
} else if (brgemm->get_type() == BRGEMM_TYPE::REPACKING_ONLY) {
const auto& config = brgemm->get_config();
if (config.need_compensations()) {
return {{element::i8, element::i8, element::f32}};
}
if (config.is_amx()) {
return {{element::i8, element::i8, element::u8},
{element::u8, element::i8, element::u8},
{element::bf16, element::bf16, element::u8}};
}
if (config.need_copy_b()) {
std::set<std::vector<element::Type>> supported_types = {{element::u8, element::i8},
{element::bf16, element::bf16},
{element::f32, element::f32}};
if (dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx2_vnni_2))
supported_types.insert({element::i8, element::i8});
return supported_types;
} else if (brgemm->get_type() == BRGEMM_TYPE::WITH_COMPENSATIONS) {
return {{element::i8, element::i8, element::f32}};
} else if (brgemm->get_type() == BRGEMM_TYPE::WITH_AMX) {
return {{element::i8, element::i8, element::u8},
{element::u8, element::i8, element::u8},
{element::bf16, element::bf16, element::u8}};
}
OV_CPU_JIT_EMITTER_THROW("got BrgemmCPU node with unsupported type");
return {{element::f32, element::f32}};
}

void jit_brgemm_emitter::validate_arguments(const std::vector<size_t> &in, const std::vector<size_t> &out) const {
OV_CPU_JIT_EMITTER_ASSERT(m_memory_offsets.size() == in.size() + 1 && (out.size() == 1),
"expects 3 inputs if there are compensations/wsp");
OV_CPU_JIT_EMITTER_ASSERT((m_memory_offsets.size() == in.size() + 1) && (out.size() == 1), "incorrect count of registers");
}

void jit_brgemm_emitter::emit_impl(const std::vector<size_t>& in, const std::vector<size_t>& out) const {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ void BrgemmKernelExecutor::update_config(const ov::snippets::lowered::Expression
OPENVINO_ASSERT(in_ports.size() > 1 && std::all_of(in_ports.cbegin(), in_ports.cend(), check_port) &&
out_ports.size() == 1 && check_port(out_ports.back()),
"Incorrect Loop by Brgemm dimension M");
M = current_expanded_loop_info->get_increment();
M = current_expanded_loop_info->get_work_amount() > 0 ? current_expanded_loop_info->get_increment() : 0;
input_pds[0]->set_subtensor_dim(1, M);
output_pds[0]->set_subtensor_dim(1, M);
}
Expand All @@ -249,7 +249,7 @@ void BrgemmKernelExecutor::update_config(const ov::snippets::lowered::Expression
OPENVINO_ASSERT(in_ports.size() == 2 && !in_ports.front().is_incremented && std::all_of(in_ports.cbegin(), in_ports.cend(), check_port) &&
out_ports.size() == 1 && check_port(out_ports.back()),
"Incorrect Loop by Brgemm dimension N");
N = current_expanded_loop_info->get_increment();
N = current_expanded_loop_info->get_work_amount() > 0 ? current_expanded_loop_info->get_increment() : 0;
input_pds[1]->set_subtensor_dim(0, N);
output_pds[0]->set_subtensor_dim(0, N);
}
Expand All @@ -260,8 +260,9 @@ void BrgemmKernelExecutor::update_config(const ov::snippets::lowered::Expression
// the most first executed Brgemm Block in Loops which iterate through dimension K (work_amount > 0).
// First of them will have `beta = 0`, other - `beta = 1`
float beta = 0;
const auto K_dim = *in0_shape.rbegin();
if (ov::snippets::utils::is_full_dim_value(K)) {
K = *in0_shape.rbegin();
K = K_dim;
} else {
const auto& current_expanded_loop_info = get_loop_info();
const auto& in_ports = current_expanded_loop_info->get_input_ports();
Expand All @@ -272,21 +273,26 @@ void BrgemmKernelExecutor::update_config(const ov::snippets::lowered::Expression
OPENVINO_ASSERT(in_ports.size() == 2 && in_ports.front().dim_idx == 0 && in_ports.back().dim_idx == 1 &&
out_ports.size() == 1 && !out_ports.front().is_incremented,
"Incorrect Loop by Brgemm dimension K");
K = current_expanded_loop_info->get_increment();
K = current_expanded_loop_info->get_work_amount() > 0 ? current_expanded_loop_info->get_increment() : 0;
input_pds[0]->set_subtensor_dim(0, K);
input_pds[1]->set_subtensor_dim(1, K);
if (K > 0)
beta = get_beta(loop_manager, static_cast<int>(loop_ids.back()), current_expanded_loop_info);
}

const auto LDA = DIM_CAST(snippets::utils::get_dim_stride(expr->get_input_port(0)));
const auto LDC = DIM_CAST(snippets::utils::get_dim_stride(expr->get_output_port(0)));
auto LDA = DIM_CAST(snippets::utils::get_dim_stride(expr->get_input_port(0)));
auto LDB = DIM_CAST(snippets::utils::get_dim_stride(expr->get_input_port(1)));
const auto LDC = DIM_CAST(snippets::utils::get_dim_stride(expr->get_output_port(0)));
const auto& brgemm_node = as_type_ptr<ov::intel_cpu::BrgemmCPU>(expr->get_node());
OV_CPU_JIT_EMITTER_ASSERT(brgemm_node, "Got invalid node type in update_config");
// In case of data repacking LDB is chosen in accordance with repacking buffer size
if (with_repacking(brgemm_node->get_type()))
LDB = brgemm_utils::repacking::compute_out_leading_dim(N, brgemm_node->get_input_element_type(1));
if (brgemm_node->get_config().need_copy_a()) {
const auto& src_type = brgemm_node->get_input_element_type(0);
K = rnd_up(K, brgemm_utils::compute_vnni_factor(src_type));
LDA = brgemm_utils::repacking::compute_LDA(K, src_type);
}
if (brgemm_node->get_config().need_copy_b())
LDB = brgemm_utils::repacking::compute_LDB(N, brgemm_node->get_input_element_type(1));

config.update(DIM_CAST(M), DIM_CAST(N), DIM_CAST(K), LDA, LDB, LDC, beta);
}
Expand Down
Loading

0 comments on commit b2e39d3

Please sign in to comment.