Skip to content

Commit

Permalink
[Snippets][CPU] Added BrgemmCopyA op
Browse files Browse the repository at this point in the history
  • Loading branch information
a-sidorova authored and chenhu-wang committed Oct 1, 2024
1 parent c7ab3a0 commit 305ff09
Show file tree
Hide file tree
Showing 30 changed files with 1,020 additions and 322 deletions.
1 change: 1 addition & 0 deletions src/common/snippets/src/op/subgraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,7 @@ void Subgraph::data_flow_transformations(const BlockedShapeVector& blocked_input
manager.register_pass<snippets::pass::ConvertConstantsToScalars>();

manager.register_positioned_passes(backend_passes);
manager.register_pass<ov::pass::Serialize>("body.xml", "body.bin");
manager.run_passes(body_ptr());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "snippets/snippets_isa.hpp"
#include "emitters/snippets/cpu_runtime_configurator.hpp"

#include "emitters/snippets/x64/jit_brgemm_copy_a_emitter.hpp"
#include "emitters/snippets/x64/jit_brgemm_copy_b_emitter.hpp"
#include "emitters/snippets/x64/jit_brgemm_emitter.hpp"
#include "emitters/snippets/x64/jit_memory_emitters.hpp"
Expand All @@ -23,6 +24,7 @@
#include "transformations/snippets/common/op/load_convert.hpp"
#include "transformations/snippets/common/op/store_convert.hpp"
#include "transformations/snippets/common/op/fused_mul_add.hpp"
#include "transformations/snippets/x64/op/brgemm_copy_a.hpp"
#include "transformations/snippets/x64/op/brgemm_copy_b.hpp"
#include "transformations/snippets/x64/op/brgemm_cpu.hpp"
#include "transformations/snippets/x64/op/perf_count_rdtsc.hpp"
Expand Down Expand Up @@ -243,6 +245,9 @@ intel_cpu::CPUTargetMachine::CPUTargetMachine(dnnl::impl::cpu::x64::cpu_isa_t ho
jitters[intel_cpu::BrgemmCPU::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(intel_cpu::jit_brgemm_emitter,
configurator->get_kernel_executor_table(),
compiled_kernel_cache);
jitters[intel_cpu::BrgemmCopyA::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(intel_cpu::jit_brgemm_copy_a_emitter,
configurator->get_kernel_executor_table(),
compiled_kernel_cache);
jitters[intel_cpu::BrgemmCopyB::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(intel_cpu::jit_brgemm_copy_b_emitter,
configurator->get_kernel_executor_table(),
compiled_kernel_cache);
Expand Down Expand Up @@ -356,6 +361,7 @@ ov::snippets::RegType intel_cpu::CPUGenerator::get_specific_op_out_reg_type(cons
std::dynamic_pointer_cast<intel_cpu::tpp::modifier::TensorProcessingPrimitive>(op) ||
std::dynamic_pointer_cast<intel_cpu::tpp::op::Scalar>(op) ||
#endif
std::dynamic_pointer_cast<intel_cpu::BrgemmCopyA>(op)||
std::dynamic_pointer_cast<intel_cpu::BrgemmCopyB>(op))
return ov::snippets::RegType::gpr;
else if (
Expand All @@ -368,7 +374,8 @@ ov::snippets::RegType intel_cpu::CPUGenerator::get_specific_op_out_reg_type(cons

bool intel_cpu::CPUGenerator::uses_precompiled_kernel(const std::shared_ptr<snippets::Emitter>& e) const {
bool need = std::dynamic_pointer_cast<intel_cpu::jit_brgemm_emitter>(e) ||
std::dynamic_pointer_cast<intel_cpu::jit_brgemm_copy_b_emitter>(e);
std::dynamic_pointer_cast<intel_cpu::jit_brgemm_copy_b_emitter>(e) ||
std::dynamic_pointer_cast<intel_cpu::jit_brgemm_copy_a_emitter>(e);
#ifdef SNIPPETS_DEBUG_CAPS
const auto cpu_target_machine = std::dynamic_pointer_cast<intel_cpu::CPUTargetMachine>(target);
need = need || (cpu_target_machine && cpu_target_machine->debug_config.enable_segfault_detector) ||
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "jit_brgemm_copy_a_emitter.hpp"

#include "emitters/plugin/x64/utils.hpp"
#include "emitters/snippets/x64/utils.hpp"
#include "emitters/snippets/jit_snippets_call_args.hpp"

#include "snippets/utils/utils.hpp"

#include "transformations/snippets/x64/op/brgemm_copy_a.hpp"
#include "transformations/snippets/x64/op/brgemm_utils.hpp"


using namespace Xbyak;
using namespace dnnl::impl;
using namespace dnnl::impl::cpu::x64;
using namespace ov::intel_cpu::brgemm_utils;
using namespace ov::snippets::utils;

namespace ov {
namespace intel_cpu {

jit_brgemm_copy_a_emitter::jit_brgemm_copy_a_emitter(jit_generator* h, cpu_isa_t isa, const ov::snippets::lowered::ExpressionPtr& expr,
const snippets::KernelExecutorTablePtr& kernel_table,
const ov::intel_cpu::MultiCacheWeakPtr& compiled_kernel_cache)
: jit_emitter(h, isa) {
in_out_type_ = emitter_in_out_map::gpr_to_gpr;
const auto brgemm_repack = ov::as_type_ptr<ov::intel_cpu::BrgemmCopyA>(expr->get_node());
OV_CPU_JIT_EMITTER_ASSERT(brgemm_repack, "expects BrgemmCopyA node");

// Note: even if the BrgemmCopyA node is dynamic, the first shapeInfer and RuntimeConfigurator::update()
// are performed before the BrgemmCopyAKernelExecutor registration. So we have to trigger update() manually
// for both static and the 1st dynamic shapes.
OV_CPU_JIT_EMITTER_ASSERT(!snippets::utils::is_dynamic_vdims(expr->get_input_port_descriptor(0)->get_shape()),
"Jit emitter is called when the shapes are unknown");

const auto& brgemm_config = brgemm_repack->get_config();
BrgemmCopyAKernelConfig kernel_config(brgemm_repack->get_input_element_type(0), brgemm_config.isa());
m_kernel_executor = kernel_table->register_kernel<BrgemmCopyAKernelExecutor>(expr, compiled_kernel_cache, kernel_config);

m_memory_offsets = {brgemm_repack->get_offset_in(), brgemm_repack->get_offset_out()};
m_buffer_ids = {utils::get_buffer_cluster_id(expr->get_input_port(0)), utils::get_buffer_cluster_id(expr->get_output_port(0))};
}

void jit_brgemm_copy_a_emitter::validate_arguments(const std::vector<size_t> &in, const std::vector<size_t> &out) const {
OV_CPU_JIT_EMITTER_ASSERT(in.size() == 1 && out.size() == 1, "expects 1 input and 1 output");
}

void jit_brgemm_copy_a_emitter::emit_impl(const std::vector<size_t>& in, const std::vector<size_t>& out) const {
validate_arguments(in, out);

std::vector<size_t> mem_ptrs_idxs{in[0], out[0]};

EmitABIRegSpills spill(h);
spill.preamble();

h->mov(h->rbp, reinterpret_cast<uint64_t>(BrgemmCopyAKernelExecutor::execute));
auto reserved_stack_size = sizeof(BrgemmCopyAKernelExecutor::call_args);
// Reserve memory on the stack
h->sub(h->rsp, reserved_stack_size);

Xbyak::Reg64 aux_reg = ov::intel_cpu::utils::get_aux_gpr(mem_ptrs_idxs);

const std::vector<size_t> args_offsets {GET_OFF_BRGEMM_COPY_A_ARGS(src), GET_OFF_BRGEMM_COPY_A_ARGS(tr_src)};
const auto& mem_ptrs = ov::intel_cpu::utils::transform_idxs_to_regs(mem_ptrs_idxs);
for (size_t i = 0; i < mem_ptrs.size(); i++) {
if (ov::snippets::utils::is_dynamic_value(m_memory_offsets[i]))
utils::push_ptr_with_runtime_offset_on_stack(h, args_offsets[i], mem_ptrs[i], aux_reg,
GET_OFF(buffer_offsets) + m_buffer_ids[i] * sizeof(size_t));
else
utils::push_ptr_with_static_offset_on_stack(h, args_offsets[i], mem_ptrs[i], aux_reg, m_memory_offsets[i]);
}

h->mov(abi_param1, reinterpret_cast<uintptr_t>(m_kernel_executor.get()));
h->mov(abi_param2, h->rsp);

spill.rsp_align();
h->call(h->rbp);
spill.rsp_restore();

h->add(h->rsp, reserved_stack_size);

spill.postamble();
}

} // namespace intel_cpu
} // namespace ov
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "emitters/plugin/x64/jit_emitter.hpp"

#include "kernel_executors/brgemm_copy_a.hpp"


namespace ov {
namespace intel_cpu {

class jit_brgemm_copy_a_emitter : public jit_emitter {
public:
jit_brgemm_copy_a_emitter(dnnl::impl::cpu::x64::jit_generator* h, dnnl::impl::cpu::x64::cpu_isa_t isa,
const ov::snippets::lowered::ExpressionPtr& expr,
const snippets::KernelExecutorTablePtr& kernel_table,
const ov::intel_cpu::MultiCacheWeakPtr& compiled_kernel_cache);

size_t get_inputs_num() const override {return 1;}
static std::set<std::vector<element::Type>> get_supported_precisions(const std::shared_ptr<ov::Node>& node = nullptr) {
return {{element::i8}, {element::u8}, {element::bf16}};
}

private:
void validate_arguments(const std::vector<size_t> &in, const std::vector<size_t> &out) const override;
void emit_impl(const std::vector<size_t>& in, const std::vector<size_t>& out) const override;

std::vector<size_t> m_memory_offsets{};
std::vector<size_t> m_buffer_ids{};
std::shared_ptr<BrgemmCopyAKernelExecutor> m_kernel_executor {nullptr};

#ifdef SNIPPETS_DEBUG_CAPS
friend std::string init_info_jit_brgemm_copy_a_emitter(const jit_brgemm_copy_a_emitter *emitter);
#endif
};

} // namespace intel_cpu
} // namespace ov
Original file line number Diff line number Diff line change
Expand Up @@ -48,18 +48,14 @@ jit_brgemm_copy_b_emitter::jit_brgemm_copy_b_emitter(jit_generator* h, cpu_isa_t
OV_CPU_JIT_EMITTER_ASSERT(!snippets::utils::is_dynamic_vdims(expr->get_input_port_descriptor(0)->get_shape()),
"Jit emitter is called when the shapes are unknown");

const auto& in_subtensor = get_projected_subtensor(expr->get_input_port(0));
const auto K_blk = *++in_subtensor.rbegin();

const auto& src_prc = brgemm_repack->get_src_element_type();
const auto& wei_prc = brgemm_repack->get_input_element_type(0);
const auto wei_N_blk = brgemm_utils::repacking::compute_inner_n_block(wei_prc);
const auto is_transposed = get_is_transposed(expr);
const auto brgemm_type = get_brgemm_type(src_prc, K_blk, is_transposed);
const auto primitive_isa = brgemm_utils::get_primitive_isa(src_prc, with_amx(brgemm_type));
m_with_comp = with_compensations(brgemm_type);
const auto& brgemm_config = brgemm_repack->get_config();
m_with_comp = brgemm_config.need_compensations();

BrgemmCopyBKernelConfig kernel_config(src_prc, wei_prc, primitive_isa, m_with_comp, is_transposed, wei_N_blk);
BrgemmCopyBKernelConfig kernel_config(src_prc, wei_prc, brgemm_config.isa(), m_with_comp, is_transposed, wei_N_blk);
m_kernel_executor = kernel_table->register_kernel<BrgemmCopyBKernelExecutor>(expr, compiled_kernel_cache, kernel_config);

m_memory_offsets = {brgemm_repack->get_offset_in(), brgemm_repack->get_offset_out()};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,10 @@ jit_brgemm_emitter::jit_brgemm_emitter(jit_generator* h, cpu_isa_t isa,
const auto& brgemm_node = as_type_ptr<ov::intel_cpu::BrgemmCPU>(expr->get_node());
const auto& brg0Prc = brgemm_node->get_input_element_type(0);
const auto& brg1Prc = brgemm_node->get_input_element_type(1);
const auto brgemm_type = brgemm_node->get_type();
BrgemmKernelConfig kernel_config(brg0Prc, brg1Prc, with_amx(brgemm_type), with_compensations(brgemm_type),
brgemm_utils::get_primitive_isa(brg0Prc, with_amx(brgemm_type)));
m_kernel_executor = kernel_table->register_kernel<BrgemmKernelExecutor>(expr,
compiled_kernel_cache,
kernel_config);
const auto& brgemm_config = brgemm_node->get_config();
BrgemmKernelConfig kernel_config(brg0Prc, brg1Prc, brgemm_config.is_amx(), brgemm_config.need_compensations(), brgemm_config.isa());
m_kernel_executor = kernel_table->register_kernel<BrgemmKernelExecutor>(expr, compiled_kernel_cache, kernel_config);

// Note: even if the Brgemm node is dynamic, the first shapeInfer and RuntimeConfigurator::update()
// are performed before the BrgemmKernelExecutor registration. So we have to trigger update() manually
// for both static and the 1st dynamic shapes.
Expand All @@ -42,7 +40,7 @@ jit_brgemm_emitter::jit_brgemm_emitter(jit_generator* h, cpu_isa_t isa,
m_memory_offsets = {brgemm_node->get_offset_a(), brgemm_node->get_offset_b(), brgemm_node->get_offset_c()};
m_buffer_ids = {utils::get_buffer_cluster_id(expr->get_input_port(0)), utils::get_buffer_cluster_id(expr->get_input_port(1)),
utils::get_buffer_cluster_id(expr->get_output_port(0))};
if (with_scratchpad(brgemm_type)) {
if (brgemm_node->get_input_size() == 3) {
m_memory_offsets.push_back(brgemm_node->get_offset_scratch());
m_buffer_ids.push_back(utils::get_buffer_cluster_id(expr->get_input_port(2)));
}
Expand All @@ -51,29 +49,28 @@ jit_brgemm_emitter::jit_brgemm_emitter(jit_generator* h, cpu_isa_t isa,
std::set<std::vector<element::Type>> jit_brgemm_emitter::get_supported_precisions(const std::shared_ptr<ov::Node>& node) {
const auto brgemm = as_type_ptr<ov::intel_cpu::BrgemmCPU>(node);
OV_CPU_JIT_EMITTER_ASSERT(brgemm, "get_supported_precisions() expects BrgemmCPU node");
using brgemm_utils::BRGEMM_TYPE;
if (brgemm->get_type() == BRGEMM_TYPE::STAND_ALONE) {
return {{element::f32, element::f32}};
} else if (brgemm->get_type() == BRGEMM_TYPE::REPACKING_ONLY) {
const auto& config = brgemm->get_config();
if (config.need_compensations()) {
return {{element::i8, element::i8, element::f32}};
}
if (config.is_amx()) {
return {{element::i8, element::i8, element::u8},
{element::u8, element::i8, element::u8},
{element::bf16, element::bf16, element::u8}};
}
if (config.need_copy_b()) {
std::set<std::vector<element::Type>> supported_types = {{element::u8, element::i8},
{element::bf16, element::bf16},
{element::f32, element::f32}};
if (dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx2_vnni_2))
supported_types.insert({element::i8, element::i8});
return supported_types;
} else if (brgemm->get_type() == BRGEMM_TYPE::WITH_COMPENSATIONS) {
return {{element::i8, element::i8, element::f32}};
} else if (brgemm->get_type() == BRGEMM_TYPE::WITH_AMX) {
return {{element::i8, element::i8, element::u8},
{element::u8, element::i8, element::u8},
{element::bf16, element::bf16, element::u8}};
}
OV_CPU_JIT_EMITTER_THROW("got BrgemmCPU node with unsupported type");
return {{element::f32, element::f32}};
}

void jit_brgemm_emitter::validate_arguments(const std::vector<size_t> &in, const std::vector<size_t> &out) const {
OV_CPU_JIT_EMITTER_ASSERT(m_memory_offsets.size() == in.size() + 1 && (out.size() == 1),
"expects 3 inputs if there are compensations/wsp");
OV_CPU_JIT_EMITTER_ASSERT((m_memory_offsets.size() == in.size() + 1) && (out.size() == 1), "incorrect count of registers");
}

void jit_brgemm_emitter::emit_impl(const std::vector<size_t>& in, const std::vector<size_t>& out) const {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ void BrgemmKernelExecutor::update_config(const ov::snippets::lowered::Expression
OPENVINO_ASSERT(in_ports.size() > 1 && std::all_of(in_ports.cbegin(), in_ports.cend(), check_port) &&
out_ports.size() == 1 && check_port(out_ports.back()),
"Incorrect Loop by Brgemm dimension M");
M = current_expanded_loop_info->get_increment();
M = current_expanded_loop_info->get_work_amount() > 0 ? current_expanded_loop_info->get_increment() : 0;
input_pds[0]->set_subtensor_dim(1, M);
output_pds[0]->set_subtensor_dim(1, M);
}
Expand All @@ -249,7 +249,7 @@ void BrgemmKernelExecutor::update_config(const ov::snippets::lowered::Expression
OPENVINO_ASSERT(in_ports.size() == 2 && !in_ports.front().is_incremented && std::all_of(in_ports.cbegin(), in_ports.cend(), check_port) &&
out_ports.size() == 1 && check_port(out_ports.back()),
"Incorrect Loop by Brgemm dimension N");
N = current_expanded_loop_info->get_increment();
N = current_expanded_loop_info->get_work_amount() > 0 ? current_expanded_loop_info->get_increment() : 0;
input_pds[1]->set_subtensor_dim(0, N);
output_pds[0]->set_subtensor_dim(0, N);
}
Expand All @@ -260,8 +260,9 @@ void BrgemmKernelExecutor::update_config(const ov::snippets::lowered::Expression
// the most first executed Brgemm Block in Loops which iterate through dimension K (work_amount > 0).
// First of them will have `beta = 0`, other - `beta = 1`
float beta = 0;
const auto K_dim = *in0_shape.rbegin();
if (ov::snippets::utils::is_full_dim_value(K)) {
K = *in0_shape.rbegin();
K = K_dim;
} else {
const auto& current_expanded_loop_info = get_loop_info();
const auto& in_ports = current_expanded_loop_info->get_input_ports();
Expand All @@ -272,21 +273,26 @@ void BrgemmKernelExecutor::update_config(const ov::snippets::lowered::Expression
OPENVINO_ASSERT(in_ports.size() == 2 && in_ports.front().dim_idx == 0 && in_ports.back().dim_idx == 1 &&
out_ports.size() == 1 && !out_ports.front().is_incremented,
"Incorrect Loop by Brgemm dimension K");
K = current_expanded_loop_info->get_increment();
K = current_expanded_loop_info->get_work_amount() > 0 ? current_expanded_loop_info->get_increment() : 0;
input_pds[0]->set_subtensor_dim(0, K);
input_pds[1]->set_subtensor_dim(1, K);
if (K > 0)
beta = get_beta(loop_manager, static_cast<int>(loop_ids.back()), current_expanded_loop_info);
}

const auto LDA = DIM_CAST(snippets::utils::get_dim_stride(expr->get_input_port(0)));
const auto LDC = DIM_CAST(snippets::utils::get_dim_stride(expr->get_output_port(0)));
auto LDA = DIM_CAST(snippets::utils::get_dim_stride(expr->get_input_port(0)));
auto LDB = DIM_CAST(snippets::utils::get_dim_stride(expr->get_input_port(1)));
const auto LDC = DIM_CAST(snippets::utils::get_dim_stride(expr->get_output_port(0)));
const auto& brgemm_node = as_type_ptr<ov::intel_cpu::BrgemmCPU>(expr->get_node());
OV_CPU_JIT_EMITTER_ASSERT(brgemm_node, "Got invalid node type in update_config");
// In case of data repacking LDB is chosen in accordance with repacking buffer size
if (with_repacking(brgemm_node->get_type()))
LDB = brgemm_utils::repacking::compute_out_leading_dim(N, brgemm_node->get_input_element_type(1));
if (brgemm_node->get_config().need_copy_a()) {
const auto& src_type = brgemm_node->get_input_element_type(0);
K = rnd_up(K, brgemm_utils::compute_vnni_factor(src_type));
LDA = brgemm_utils::repacking::compute_LDA(K, src_type);
}
if (brgemm_node->get_config().need_copy_b())
LDB = brgemm_utils::repacking::compute_LDB(N, brgemm_node->get_input_element_type(1));

config.update(DIM_CAST(M), DIM_CAST(N), DIM_CAST(K), LDA, LDB, LDC, beta);
}
Expand Down
Loading

0 comments on commit 305ff09

Please sign in to comment.