Skip to content

Commit

Permalink
[CPU] [Snippets] [ARM] Support precision f16 for Snippets on ARM and …
Browse files Browse the repository at this point in the history
…separate Conversion from Load/Store emitters (#26483)

### Details:
- *Separate Conversion functionality from Load/Store emitters on ARM.
The idea is to make each emitter as simple as possible, so Load/Store
will focus on memory instructions and will not include any conversion
instructions on ARM. Necessary Convert emitter will be automatically
inserted to the subgraph to perform precision conversion.*
 - *Support inference precision f16 for Snippets on ARM.*

### Tickets:
 - *[150430](https://jira.devtools.intel.com/browse/CVS-150430)*
 - *[141292](https://jira.devtools.intel.com/browse/CVS-141292)*
  • Loading branch information
xuchen-intel authored Oct 3, 2024
1 parent 7bf0444 commit a679279
Show file tree
Hide file tree
Showing 21 changed files with 45 additions and 151 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -191,16 +191,9 @@ void jit_convert_emitter::jit_convert_process(const TReg &src, const TReg &dst,
}

jit_convert_emitter::jit_convert_emitter(jit_generator *host, cpu_isa_t host_isa, const std::shared_ptr<ov::Node>& node, ov::element::Type exec_prc)
: jit_convert_emitter(host, host_isa, node->get_input_element_type(0), node->get_output_element_type(0), exec_prc) {
}

jit_convert_emitter::jit_convert_emitter(jit_generator *host, cpu_isa_t host_isa,
ov::element::Type input_prc,
ov::element::Type output_prc,
ov::element::Type exec_prc)
: jit_emitter(host, host_isa, exec_prc) {
input_type = input_prc;
output_type = output_prc;
input_type = node->get_input_element_type(0);
output_type = node->get_output_element_type(0);
}

void jit_convert_emitter::validate_types() const {
Expand All @@ -221,13 +214,6 @@ jit_convert_truncation_emitter::jit_convert_truncation_emitter(jit_generator *ho
: jit_convert_emitter(host, host_isa, node, exec_prc) {
}

jit_convert_truncation_emitter::jit_convert_truncation_emitter(jit_generator *host, cpu_isa_t host_isa,
ov::element::Type input_prc,
ov::element::Type output_prc,
ov::element::Type exec_prc)
: jit_convert_emitter(host, host_isa, input_prc, output_prc, exec_prc) {
}

void jit_convert_truncation_emitter::emit_impl(const std::vector<size_t> &in_idxs, const std::vector<size_t> &out_idxs) const {
validate_types();
if (host_isa_ == dnnl::impl::cpu::aarch64::asimd) {
Expand All @@ -250,13 +236,6 @@ jit_convert_saturation_emitter::jit_convert_saturation_emitter(jit_generator *ho
: jit_convert_emitter(host, host_isa, node, exec_prc) {
}

jit_convert_saturation_emitter::jit_convert_saturation_emitter(jit_generator *host, cpu_isa_t host_isa,
ov::element::Type input_prc,
ov::element::Type output_prc,
ov::element::Type exec_prc)
: jit_convert_emitter(host, host_isa, input_prc, output_prc, exec_prc) {
}

void jit_convert_saturation_emitter::emit_impl(const std::vector<size_t> &in_idxs, const std::vector<size_t> &out_idxs) const {
validate_types();
if (host_isa_ == dnnl::impl::cpu::aarch64::asimd) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@ class jit_convert_emitter : public jit_emitter {
public:
jit_convert_emitter(dnnl::impl::cpu::aarch64::jit_generator *host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa,
const std::shared_ptr<ov::Node>& n, ov::element::Type exec_prc = ov::element::f32);
jit_convert_emitter(dnnl::impl::cpu::aarch64::jit_generator *host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa,
ov::element::Type input_prc, ov::element::Type output_prc, ov::element::Type exec_prc = ov::element::f32);

size_t get_inputs_count() const override;

Expand Down Expand Up @@ -60,8 +58,6 @@ class jit_convert_truncation_emitter : public jit_convert_emitter {
public:
jit_convert_truncation_emitter(dnnl::impl::cpu::aarch64::jit_generator *host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa,
const std::shared_ptr<ov::Node>& n, ov::element::Type exec_prc = ov::element::f32);
jit_convert_truncation_emitter(dnnl::impl::cpu::aarch64::jit_generator *host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa,
ov::element::Type input_prc, ov::element::Type output_prc, ov::element::Type exec_prc = ov::element::f32);

private:
void emit_impl(const std::vector<size_t>& in_idxs, const std::vector<size_t>& out_idxs) const override;
Expand All @@ -77,8 +73,6 @@ class jit_convert_saturation_emitter : public jit_convert_emitter {
public:
jit_convert_saturation_emitter(dnnl::impl::cpu::aarch64::jit_generator *host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa,
const std::shared_ptr<ov::Node>& n, ov::element::Type exec_prc = ov::element::f32);
jit_convert_saturation_emitter(dnnl::impl::cpu::aarch64::jit_generator *host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa,
ov::element::Type input_prc, ov::element::Type output_prc, ov::element::Type exec_prc = ov::element::f32);

private:
void emit_impl(const std::vector<size_t>& in_idxs, const std::vector<size_t>& out_idxs) const override;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,8 @@ using cpu_isa_t = dnnl::impl::cpu::aarch64::cpu_isa_t;
jit_load_emitter::jit_load_emitter(dnnl::impl::cpu::aarch64::jit_generator *host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa,
ov::element::Type src_prc, ov::element::Type dst_prc, int load_num, int byte_offset,
ov::element::Type exec_prc, emitter_in_out_map in_out_type)
: jit_emitter(host, host_isa, exec_prc, in_out_type), name_("unknown"), load_num_(load_num), byte_offset_(byte_offset),
src_prc_(src_prc), dst_prc_(dst_prc) {
if (src_prc_ != dst_prc_) {
convert_truncation_emitter.reset(new jit_convert_truncation_emitter(host, host_isa, src_prc, dst_prc, exec_prc));
}
: jit_emitter(host, host_isa, exec_prc, in_out_type), name_("unknown"), load_num_(load_num), byte_offset_(byte_offset), prc_(src_prc) {
OV_CPU_JIT_EMITTER_ASSERT(src_prc == dst_prc, "Unsupported precision pair.");
}

void jit_load_emitter::emit_impl(const std::vector<size_t> &in_idxs, const std::vector<size_t> &out_idxs) const {
Expand Down Expand Up @@ -133,30 +130,24 @@ void jit_load_emitter::load_byte(const std::vector<size_t> &in_idxs, const std::

template <cpu_isa_t isa>
void jit_load_emitter::emit_isa(const std::vector<size_t> &in_idxs, const std::vector<size_t> &out_idxs) const {
bool is_supported_precision = one_of(src_prc_, ov::element::f32, ov::element::i32, ov::element::f16, ov::element::i8, ov::element::u8) &&
(src_prc_ == dst_prc_ || one_of(dst_prc_, ov::element::f32, ov::element::i32));
OV_CPU_JIT_EMITTER_ASSERT(is_supported_precision, "Unsupported precision pair.");
OV_CPU_JIT_EMITTER_ASSERT(one_of(prc_, ov::element::f32, ov::element::i32, ov::element::f16, ov::element::i8, ov::element::u8),
"Unsupported precision.");
OV_CPU_JIT_EMITTER_ASSERT(load_num_ <= 4, "Unexpected number of elements to load.");

switch (src_prc_) {
switch (prc_) {
case ov::element::f32:
case ov::element::i32:
load_qbyte<isa>(in_idxs, src_prc_ == dst_prc_ ? out_idxs : aux_vec_idxs);
load_qbyte<isa>(in_idxs, out_idxs);
break;
case ov::element::f16:
load_dbyte<isa>(in_idxs, src_prc_ == dst_prc_ ? out_idxs : aux_vec_idxs);
load_dbyte<isa>(in_idxs, out_idxs);
break;
case ov::element::i8:
case ov::element::u8:
load_byte<isa>(in_idxs, src_prc_ == dst_prc_ ? out_idxs : aux_vec_idxs);
load_byte<isa>(in_idxs, out_idxs);
break;
default:
OV_CPU_JIT_EMITTER_THROW("Unsupported input type: ", src_prc_.get_type_name());
}

if (src_prc_ != dst_prc_) {
OPENVINO_ASSERT(convert_truncation_emitter, "Invalid convert_truncation_emitter.");
convert_truncation_emitter->emit_code(aux_vec_idxs, out_idxs);
OV_CPU_JIT_EMITTER_THROW("Unsupported precision: ", prc_.get_type_name());
}
}

Expand All @@ -167,27 +158,11 @@ size_t jit_load_emitter::get_aux_gprs_count() const {
return 0;
}

size_t jit_load_emitter::get_aux_vecs_count() const {
if (src_prc_ != dst_prc_)
return 1;

return 0;
}

jit_store_emitter::jit_store_emitter(dnnl::impl::cpu::aarch64::jit_generator *host, dnnl::impl::cpu::aarch64::cpu_isa_t host_isa,
ov::element::Type src_prc, ov::element::Type dst_prc, int store_num, int byte_offset,
arithmetic_mode mode, ov::element::Type exec_prc, emitter_in_out_map in_out_type)
: jit_emitter(host, host_isa, exec_prc, in_out_type), name_("unknown"), store_num_(store_num), byte_offset_(byte_offset),
src_prc_(src_prc), dst_prc_(dst_prc) {
if (src_prc_ != dst_prc_) {
if (mode == arithmetic_mode::truncation) {
convert_emitter.reset(new jit_convert_truncation_emitter(host, host_isa, src_prc, dst_prc, exec_prc));
} else if (mode == arithmetic_mode::saturation) {
convert_emitter.reset(new jit_convert_saturation_emitter(host, host_isa, src_prc, dst_prc, exec_prc));
} else {
OV_CPU_JIT_EMITTER_THROW("Unsupported Convert emitter.");
}
}
: jit_emitter(host, host_isa, exec_prc, in_out_type), name_("unknown"), store_num_(store_num), byte_offset_(byte_offset), prc_(dst_prc) {
OV_CPU_JIT_EMITTER_ASSERT(src_prc == dst_prc, "Unsupported precision pair.");
}

void jit_store_emitter::emit_impl(const std::vector<size_t> &in_idxs, const std::vector<size_t> &out_idxs) const {
Expand Down Expand Up @@ -299,30 +274,24 @@ void jit_store_emitter::store_byte(const std::vector<size_t> &in_idxs, const std

template <cpu_isa_t isa>
void jit_store_emitter::emit_isa(const std::vector<size_t> &in_idxs, const std::vector<size_t> &out_idxs) const {
bool is_supported_precision = one_of(dst_prc_, ov::element::f32, ov::element::i32, ov::element::f16, ov::element::i8, ov::element::u8) &&
(src_prc_ == dst_prc_ || one_of(src_prc_, ov::element::f32, ov::element::i32));
OV_CPU_JIT_EMITTER_ASSERT(is_supported_precision, "Unsupported precision pair.");
OV_CPU_JIT_EMITTER_ASSERT(one_of(prc_, ov::element::f32, ov::element::i32, ov::element::f16, ov::element::i8, ov::element::u8),
"Unsupported precision.");
OV_CPU_JIT_EMITTER_ASSERT(store_num_ <= 4, "Unexpected number of elements to store.");

if (src_prc_ != dst_prc_) {
OPENVINO_ASSERT(convert_emitter, "Invalid convert_emitter.");
convert_emitter->emit_code(in_idxs, aux_vec_idxs);
}

switch (dst_prc_) {
switch (prc_) {
case ov::element::f32:
case ov::element::i32:
store_qbyte<isa>(src_prc_ == dst_prc_ ? in_idxs : aux_vec_idxs, out_idxs);
store_qbyte<isa>(in_idxs, out_idxs);
break;
case ov::element::f16:
store_dbyte<isa>(src_prc_ == dst_prc_ ? in_idxs : aux_vec_idxs, out_idxs);
store_dbyte<isa>(in_idxs, out_idxs);
break;
case ov::element::i8:
case ov::element::u8:
store_byte<isa>(src_prc_ == dst_prc_ ? in_idxs : aux_vec_idxs, out_idxs);
store_byte<isa>(in_idxs, out_idxs);
break;
default:
OV_CPU_JIT_EMITTER_THROW("Unsupported output type: ", dst_prc_.get_type_name());
OV_CPU_JIT_EMITTER_THROW("Unsupported precision: ", prc_.get_type_name());
}
}

Expand All @@ -333,13 +302,6 @@ size_t jit_store_emitter::get_aux_gprs_count() const {
return 0;
}

size_t jit_store_emitter::get_aux_vecs_count() const {
if (src_prc_ != dst_prc_)
return 1;

return 0;
}

} // namespace aarch64
} // namespace intel_cpu
} // namespace ov
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

#include "jit_emitter.hpp"
#include "cpu/aarch64/jit_generator.hpp"
#include "emitters/plugin/aarch64/jit_conversion_emitters.hpp"

namespace ov {
namespace intel_cpu {
Expand Down Expand Up @@ -38,15 +37,11 @@ class jit_load_emitter : public jit_emitter {
template <dnnl::impl::cpu::aarch64::cpu_isa_t isa>
void load_byte(const std::vector<size_t> &in_idxs, const std::vector<size_t> &out_idxs) const;
size_t get_aux_gprs_count() const override;
size_t get_aux_vecs_count() const override;

std::unique_ptr<jit_convert_truncation_emitter> convert_truncation_emitter = nullptr;

std::string name_;
int load_num_; // the element number to load
int byte_offset_;
ov::element::Type src_prc_;
ov::element::Type dst_prc_;
ov::element::Type prc_;
};

class jit_store_emitter : public jit_emitter {
Expand All @@ -69,15 +64,11 @@ class jit_store_emitter : public jit_emitter {
template <dnnl::impl::cpu::aarch64::cpu_isa_t isa>
void store_byte(const std::vector<size_t> &in_idxs, const std::vector<size_t> &out_idxs) const;
size_t get_aux_gprs_count() const override;
size_t get_aux_vecs_count() const override;

std::unique_ptr<jit_convert_emitter> convert_emitter = nullptr;

std::string name_;
int store_num_; // the element number to store
int byte_offset_;
ov::element::Type src_prc_;
ov::element::Type dst_prc_;
ov::element::Type prc_;
};

} // namespace aarch64
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@
#include "emitters/snippets/aarch64/jit_memory_emitters.hpp"
#include "emitters/snippets/aarch64/jit_fill_emitter.hpp"

#include "transformations/snippets/common/op/load_convert.hpp"
#include "transformations/snippets/common/op/store_convert.hpp"
#include "transformations/snippets/common/op/fused_mul_add.hpp"
#include "transformations/cpu_opset/common/op/swish_cpu.hpp"

Expand Down Expand Up @@ -117,11 +115,7 @@ CPUTargetMachine::CPUTargetMachine(dnnl::impl::cpu::aarch64::cpu_isa_t host_isa)
// memory access
jitters[snippets::op::Load::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(jit_load_memory_emitter);
jitters[snippets::op::BroadcastLoad::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(jit_load_broadcast_emitter);
jitters[intel_cpu::LoadConvertSaturation::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(jit_load_memory_emitter);
jitters[intel_cpu::LoadConvertTruncation::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(jit_load_memory_emitter);
jitters[snippets::op::Store::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(jit_store_memory_emitter);
jitters[intel_cpu::StoreConvertSaturation::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(jit_store_memory_emitter);
jitters[intel_cpu::StoreConvertTruncation::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(jit_store_memory_emitter);

// ternary
jitters[intel_cpu::FusedMulAdd::get_type_info_static()] = CREATE_CPU_EMITTER(jit_mul_add_emitter);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
//

#include "jit_memory_emitters.hpp"
#include "transformations/snippets/common/op/load_convert.hpp"
#include "transformations/snippets/common/op/store_convert.hpp"
#include "emitters/utils.hpp"

using namespace Xbyak_aarch64;
Expand All @@ -25,7 +23,7 @@ jit_memory_emitter::jit_memory_emitter(jit_generator* h, cpu_isa_t isa, const Ex

jit_load_memory_emitter::jit_load_memory_emitter(jit_generator* h, cpu_isa_t isa, const ExpressionPtr& expr) : jit_memory_emitter(h, isa, expr) {
bool is_supported_precision = one_of(src_prc, ov::element::f32, ov::element::i32, ov::element::f16, ov::element::i8, ov::element::u8) &&
(src_prc == dst_prc || one_of(dst_prc, ov::element::f32, ov::element::i32));
src_prc == dst_prc;
OV_CPU_JIT_EMITTER_ASSERT(is_supported_precision, "Unsupported precision pair.");

const auto load = std::dynamic_pointer_cast<snippets::op::Load>(expr->get_node());
Expand Down Expand Up @@ -90,23 +88,15 @@ void jit_load_broadcast_emitter::emit_isa(const std::vector<size_t> &in, const s

jit_store_memory_emitter::jit_store_memory_emitter(jit_generator* h, cpu_isa_t isa, const ExpressionPtr& expr) : jit_memory_emitter(h, isa, expr) {
bool is_supported_precision = one_of(dst_prc, ov::element::f32, ov::element::i32, ov::element::f16, ov::element::i8, ov::element::u8) &&
(src_prc == dst_prc || one_of(src_prc, ov::element::f32, ov::element::i32));
src_prc == dst_prc;
OV_CPU_JIT_EMITTER_ASSERT(is_supported_precision, "Unsupported precision pair.");

const auto store = ov::as_type_ptr<snippets::op::Store>(expr->get_node());
OV_CPU_JIT_EMITTER_ASSERT(store != nullptr, "Expects Store expression");
count = store->get_count();
byte_offset = store->get_offset();
in_out_type_ = emitter_in_out_map::vec_to_gpr;
if (ov::is_type<ov::intel_cpu::StoreConvertTruncation>(expr->get_node())) {
store_emitter.reset(new jit_store_emitter(h, isa, src_prc, dst_prc, count, byte_offset, arithmetic_mode::truncation));
} else if (ov::is_type<ov::intel_cpu::StoreConvertSaturation>(expr->get_node())) {
store_emitter.reset(new jit_store_emitter(h, isa, src_prc, dst_prc, count, byte_offset, arithmetic_mode::saturation));
} else if (ov::is_type<ov::snippets::op::Store>(expr->get_node())) {
store_emitter.reset(new jit_store_emitter(h, isa, src_prc, dst_prc, count, byte_offset));
} else {
OV_CPU_JIT_EMITTER_THROW("Expects Store node");
}
store_emitter.reset(new jit_store_emitter(h, isa, src_prc, dst_prc, count, byte_offset));
}

void jit_store_memory_emitter::emit_impl(const std::vector<size_t>& in,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,14 @@
#include "emitters/plugin/x64/jit_dnnl_ext_emitters.hpp"
#include "emitters/plugin/x64/jit_conversion_emitters.hpp"

#include "transformations/snippets/common/op/load_convert.hpp"
#include "transformations/snippets/common/op/store_convert.hpp"
#include "transformations/snippets/x64/op/load_convert.hpp"
#include "transformations/snippets/x64/op/store_convert.hpp"
#include "transformations/snippets/common/op/fused_mul_add.hpp"
#include "transformations/snippets/x64/op/brgemm_copy_b.hpp"
#include "transformations/snippets/x64/op/brgemm_cpu.hpp"
#include "transformations/snippets/x64/op/perf_count_rdtsc.hpp"
#include "transformations/cpu_opset/common/op/swish_cpu.hpp"
#include "transformations/snippets/common/pass/lowered/fuse_load_store_and_convert.hpp"
#include "transformations/snippets/x64/pass/lowered/fuse_load_store_and_convert.hpp"

#include <openvino/opsets/opset5.hpp>
#include "emitters/snippets/cpu_kernel_executor_table.hpp"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
#include "jit_memory_emitters.hpp"

#include "emitters/snippets/jit_snippets_call_args.hpp"
#include "transformations/snippets/common/op/load_convert.hpp"
#include "transformations/snippets/common/op/store_convert.hpp"
#include "transformations/snippets/x64/op/load_convert.hpp"
#include "transformations/snippets/x64/op/store_convert.hpp"
#include "snippets/op/buffer.hpp"


Expand Down
4 changes: 2 additions & 2 deletions src/plugins/intel_cpu/src/extension.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@
#include "transformations/cpu_opset/x64/op/qkv_proj.hpp"
#include "transformations/snippets/x64/op/brgemm_copy_b.hpp"
#include "transformations/snippets/x64/op/brgemm_cpu.hpp"
#include "transformations/snippets/common/op/load_convert.hpp"
#include "transformations/snippets/x64/op/load_convert.hpp"
#include "transformations/snippets/x64/op/perf_count_rdtsc.hpp"
#include "transformations/snippets/common/op/store_convert.hpp"
#include "transformations/snippets/x64/op/store_convert.hpp"

namespace {

Expand Down
Loading

0 comments on commit a679279

Please sign in to comment.