Skip to content

Commit

Permalink
feat(meq): use portable_simd to improve performance of meq
Browse files Browse the repository at this point in the history
  • Loading branch information
rymnc committed Dec 31, 2024
1 parent 9e3a4ea commit f1b162f
Show file tree
Hide file tree
Showing 4 changed files with 371 additions and 2 deletions.
6 changes: 6 additions & 0 deletions fuel-vm/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@ name = "execution"
harness = false
required-features = ["std"]

[[bench]]
name = "meq_performance"
harness = false
required-features = ["std"]

[dependencies]
anyhow = { version = "1.0", optional = true }
async-trait = "0.1"
Expand Down Expand Up @@ -110,6 +115,7 @@ test-helpers = [
"tai64",
"fuel-crypto/test-helpers",
]
experimental = []

[lints.rust]
unexpected_cfgs = { level = "warn", check-cfg = ['cfg(fuzzing)'] }
95 changes: 95 additions & 0 deletions fuel-vm/benches/meq_performance.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
use criterion::{
criterion_group,
criterion_main,
Criterion,
};
use fuel_asm::{
op,
Instruction,
RegId,
};
use fuel_tx::{
Finalizable,
GasCosts,
Script,
TransactionBuilder,
};
use fuel_types::{
Immediate12,
Word,
};
use fuel_vm::{
interpreter::{
Interpreter,
InterpreterParams,
},
prelude::{
IntoChecked,
MemoryInstance,
MemoryStorage,
},
};

/// from; fuel-vm/src/tests/test_helpers.rs
/// Set a register `r` to a Word-sized number value using left-shifts
pub fn set_full_word(r: RegId, v: Word) -> Vec<Instruction> {
let r = r.to_u8();
let mut ops = vec![op::movi(r, 0)];
for byte in v.to_be_bytes() {
ops.push(op::ori(r, r, byte as Immediate12));
ops.push(op::slli(r, r, 8));
}
ops.pop().unwrap(); // Remove last shift
ops
}

fn meq_performance(c: &mut Criterion) {
let benchmark_matrix = [
1, 10, 100, 1000, 10_000, 50_000, 100_000, 500_000, 1_000_000, 2_000_000,
2_500_000, 5_000_000, 10_000_000, 15_000_000, 20_000_000,
// some exact multiples of 8 to verify alignment perf
8, 16, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072,
262144, 524288, 1048576, 2097152, 4194304, 8388608,
];

for size in benchmark_matrix.iter() {
let mut interpreter = Interpreter::<_, _, Script>::with_storage(
MemoryInstance::new(),
MemoryStorage::default(),
InterpreterParams {
gas_costs: GasCosts::free(),
..Default::default()
},
);

let reg_len = RegId::new_checked(0x13).unwrap();

let mut script = set_full_word(reg_len, *size as Word);
script.extend(vec![
op::cfe(0x13),
op::meq(RegId::WRITABLE, RegId::ZERO, RegId::ZERO, reg_len),
op::jmpb(RegId::ZERO, 0),
]);

let tx_builder_script =
TransactionBuilder::script(script.into_iter().collect(), vec![])
.max_fee_limit(0)
.add_fee_input()
.finalize();
let script = tx_builder_script
.into_checked_basic(Default::default(), &Default::default())
.unwrap();
let script = script.test_into_ready();

interpreter.init_script(script).unwrap();

c.bench_function(&format!("meq_performance_{}", size), |b| {
b.iter(|| {
interpreter.execute().unwrap();
});
});
}
}

criterion_group!(benches, meq_performance);
criterion_main!(benches);
270 changes: 269 additions & 1 deletion fuel-vm/src/interpreter/memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1023,6 +1023,272 @@ pub(crate) fn memcopy(
Ok(inc_pc(pc)?)
}

#[cfg(feature = "experimental")]
#[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
fn slices_equal_neon(a: &[u8], b: &[u8]) -> bool {
use std::arch::aarch64::*;

if a.len() != b.len() {
return false;
}

let len = a.len();
let mut i = 0;
const CHUNK: usize = 96;

// if the slices are small, we don't need to
// use SIMD instructions due to overhead
if a.len() < CHUNK {
return slices_equal_fallback(a, b);
}

unsafe {
while i + CHUNK <= len {
let mut cmp =
vceqq_u8(vld1q_u8(a.as_ptr().add(i)), vld1q_u8(b.as_ptr().add(i)));

cmp = vandq_u8(
cmp,
vceqq_u8(
vld1q_u8(a.as_ptr().add(i + 16)),
vld1q_u8(b.as_ptr().add(i + 16)),
),
);
cmp = vandq_u8(
cmp,
vceqq_u8(
vld1q_u8(a.as_ptr().add(i + 32)),
vld1q_u8(b.as_ptr().add(i + 32)),
),
);
cmp = vandq_u8(
cmp,
vceqq_u8(
vld1q_u8(a.as_ptr().add(i + 48)),
vld1q_u8(b.as_ptr().add(i + 48)),
),
);
cmp = vandq_u8(
cmp,
vceqq_u8(
vld1q_u8(a.as_ptr().add(i + 64)),
vld1q_u8(b.as_ptr().add(i + 64)),
),
);
cmp = vandq_u8(
cmp,
vceqq_u8(
vld1q_u8(a.as_ptr().add(i + 80)),
vld1q_u8(b.as_ptr().add(i + 80)),
),
);

if vmaxvq_u8(cmp) != 0xFF {
return false;
}

i += CHUNK;
}

// Scalar comparison for the remainder
a[i..] == b[i..]
}
}

#[cfg(feature = "experimental")]
#[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
fn slices_equal_avx2(a: &[u8], b: &[u8]) -> bool {
use std::arch::x86_64::*;

if a.len() != b.len() {
return false;
}

let len = a.len();
let mut i = 0;
const CHUNK: usize = 256;

// if the slices are small, we don't need to
// use SIMD instructions due to overhead
if a.len() < CHUNK {
return slices_equal_fallback(a, b);
}

unsafe {
let mut aggregate_mask_a = -1i32;
let mut aggregate_mask_b = -1i32;
let mut aggregate_mask_c = -1i32;
let mut aggregate_mask_d = -1i32;
let mut aggregate_mask_a_b = -1i32;
let mut aggregate_mask_c_d = -1i32;

while i + CHUNK <= len {
let simd_a1 = _mm256_loadu_si256(a.as_ptr().add(i) as *const _);
let simd_b1 = _mm256_loadu_si256(b.as_ptr().add(i) as *const _);

let simd_a2 = _mm256_loadu_si256(a.as_ptr().add(i + 32) as *const _);
let simd_b2 = _mm256_loadu_si256(b.as_ptr().add(i + 32) as *const _);

let simd_a3 = _mm256_loadu_si256(a.as_ptr().add(i + 64) as *const _);
let simd_b3 = _mm256_loadu_si256(b.as_ptr().add(i + 64) as *const _);

let simd_a4 = _mm256_loadu_si256(a.as_ptr().add(i + 96) as *const _);
let simd_b4 = _mm256_loadu_si256(b.as_ptr().add(i + 96) as *const _);

let simd_a5 = _mm256_loadu_si256(a.as_ptr().add(i + 128) as *const _);
let simd_b5 = _mm256_loadu_si256(b.as_ptr().add(i + 128) as *const _);

let simd_a6 = _mm256_loadu_si256(a.as_ptr().add(i + 160) as *const _);
let simd_b6 = _mm256_loadu_si256(b.as_ptr().add(i + 160) as *const _);

let simd_a7 = _mm256_loadu_si256(a.as_ptr().add(i + 192) as *const _);
let simd_b7 = _mm256_loadu_si256(b.as_ptr().add(i + 192) as *const _);

let simd_a8 = _mm256_loadu_si256(a.as_ptr().add(i + 224) as *const _);
let simd_b8 = _mm256_loadu_si256(b.as_ptr().add(i + 224) as *const _);

let cmp1 = _mm256_movemask_epi8(_mm256_cmpeq_epi8(simd_a1, simd_b1));
let cmp2 = _mm256_movemask_epi8(_mm256_cmpeq_epi8(simd_a2, simd_b2));
let cmp3 = _mm256_movemask_epi8(_mm256_cmpeq_epi8(simd_a3, simd_b3));
let cmp4 = _mm256_movemask_epi8(_mm256_cmpeq_epi8(simd_a4, simd_b4));
let cmp5 = _mm256_movemask_epi8(_mm256_cmpeq_epi8(simd_a5, simd_b5));
let cmp6 = _mm256_movemask_epi8(_mm256_cmpeq_epi8(simd_a6, simd_b6));
let cmp7 = _mm256_movemask_epi8(_mm256_cmpeq_epi8(simd_a7, simd_b7));
let cmp8 = _mm256_movemask_epi8(_mm256_cmpeq_epi8(simd_a8, simd_b8));

aggregate_mask_a &= cmp1 & cmp2;
aggregate_mask_b &= cmp3 & cmp4;
aggregate_mask_c &= cmp5 & cmp6;
aggregate_mask_d &= cmp7 & cmp8;

aggregate_mask_a_b &= aggregate_mask_a & aggregate_mask_b;
aggregate_mask_c_d &= aggregate_mask_c & aggregate_mask_d;

if aggregate_mask_a_b & aggregate_mask_c_d != -1i32 {
return false;
}

i += CHUNK;
}

a[i..] == b[i..]
}
}

#[cfg(feature = "experimental")]
#[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))]
fn slices_equal_avx512(a: &[u8], b: &[u8]) -> bool {
use std::arch::x86_64::*;

if a.len() != b.len() {
return false;
}

let len = a.len();
let mut i = 0;
const CHUNK: usize = 512;

// if the slices are small, we don't need to
// use SIMD instructions due to overhead
if a.len() < CHUNK {
return slices_equal_fallback(a, b);
}

unsafe {
while i + CHUNK <= len {
let simd_a1 = _mm512_loadu_si512(a.as_ptr().add(i) as *const _);
let simd_b1 = _mm512_loadu_si512(b.as_ptr().add(i) as *const _);

let simd_a2 = _mm512_loadu_si512(a.as_ptr().add(i + 64) as *const _);
let simd_b2 = _mm512_loadu_si512(b.as_ptr().add(i + 64) as *const _);

let simd_a3 = _mm512_loadu_si512(a.as_ptr().add(i + 128) as *const _);
let simd_b3 = _mm512_loadu_si512(b.as_ptr().add(i + 128) as *const _);

let simd_a4 = _mm512_loadu_si512(a.as_ptr().add(i + 192) as *const _);
let simd_b4 = _mm512_loadu_si512(b.as_ptr().add(i + 192) as *const _);

let simd_a5 = _mm512_loadu_si512(a.as_ptr().add(i + 256) as *const _);
let simd_b5 = _mm512_loadu_si512(b.as_ptr().add(i + 256) as *const _);

let simd_a6 = _mm512_loadu_si512(a.as_ptr().add(i + 320) as *const _);
let simd_b6 = _mm512_loadu_si512(b.as_ptr().add(i + 320) as *const _);

let simd_a7 = _mm512_loadu_si512(a.as_ptr().add(i + 384) as *const _);
let simd_b7 = _mm512_loadu_si512(b.as_ptr().add(i + 384) as *const _);

let simd_a8 = _mm512_loadu_si512(a.as_ptr().add(i + 448) as *const _);
let simd_b8 = _mm512_loadu_si512(b.as_ptr().add(i + 448) as *const _);

let cmp1 = _mm512_cmpeq_epi8_mask(simd_a1, simd_b1);
let cmp2 = _mm512_cmpeq_epi8_mask(simd_a2, simd_b2);
let cmp3 = _mm512_cmpeq_epi8_mask(simd_a3, simd_b3);
let cmp4 = _mm512_cmpeq_epi8_mask(simd_a4, simd_b4);
let cmp5 = _mm512_cmpeq_epi8_mask(simd_a5, simd_b5);
let cmp6 = _mm512_cmpeq_epi8_mask(simd_a6, simd_b6);
let cmp7 = _mm512_cmpeq_epi8_mask(simd_a7, simd_b7);
let cmp8 = _mm512_cmpeq_epi8_mask(simd_a8, simd_b8);

let cmp1_2 = cmp1 & cmp2;
let cmp3_4 = cmp3 & cmp4;
let cmp5_6 = cmp5 & cmp6;
let cmp7_8 = cmp7 & cmp8;

let cmp1_4 = cmp1_2 & cmp3_4;
let cmp5_8 = cmp5_6 & cmp7_8;

let full_cmp = cmp1_4 & cmp5_8;

if full_cmp != u64::MAX {
return false;
}

i += CHUNK_SIZE;
}

a[i..] == b[i..]
}
}

#[inline]
fn slices_equal_fallback(a: &[u8], b: &[u8]) -> bool {
a == b
}

#[inline]
fn slice_eq(a: &[u8], b: &[u8]) -> bool {
#[cfg(feature = "experimental")]
{
#[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))]
{
return slices_equal_avx512(a, b);
}
#[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
{
return slices_equal_avx2(a, b);
}
#[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
{
return slices_equal_neon(a, b);
}

#[allow(unreachable_code)]
slices_equal_fallback(a, b)
}
#[cfg(not(feature = "experimental"))]
{
slices_equal_fallback(a, b)
}
}

#[test]
fn slice_eq_test() {
let a = [1u8; 20000];
let b = [1u8; 20000];

assert!(slice_eq(&a, &b));
}

pub(crate) fn memeq(
memory: &mut MemoryInstance,
result: &mut Word,
Expand All @@ -1031,7 +1297,9 @@ pub(crate) fn memeq(
c: Word,
d: Word,
) -> SimpleResult<()> {
*result = (memory.read(b, d)? == memory.read(c, d)?) as Word;
let range_a = memory.read(b, d)?;
let range_b = memory.read(c, d)?;
*result = slice_eq(range_a, range_b) as Word;
Ok(inc_pc(pc)?)
}

Expand Down
Loading

0 comments on commit f1b162f

Please sign in to comment.