From 0500fda4834ef2d771ee9ccd7995037aeb0bdb97 Mon Sep 17 00:00:00 2001 From: Robin Freyler Date: Wed, 6 Nov 2024 09:59:37 +0100 Subject: [PATCH] Fix double negation translation bug (#1293) * add tests for invalid double negation of compare instrs * fix invalid double negation bug --- .../wasmi/src/engine/translator/comparator.rs | 22 +-- .../src/engine/translator/instr_encoder.rs | 4 +- .../src/engine/translator/tests/op/i32_eqz.rs | 134 ++++++++++++++++-- 3 files changed, 141 insertions(+), 19 deletions(-) diff --git a/crates/wasmi/src/engine/translator/comparator.rs b/crates/wasmi/src/engine/translator/comparator.rs index d9b4d588c6..4fb289333c 100644 --- a/crates/wasmi/src/engine/translator/comparator.rs +++ b/crates/wasmi/src/engine/translator/comparator.rs @@ -5,11 +5,17 @@ use crate::{ }; pub trait NegateCmpInstr: Sized { - fn negate_cmp_instr(&self) -> Option; + /// Negates the compare (`cmp`) [`Instruction`]. + /// + /// If the user of the fused comparison [`Instruction`] is going to be a + /// conditional branch [`Instruction`] the `is_branch` parameter is set to + /// `true`. This allows for more optimizations since the result does not need + /// to be bit-accurate. + fn negate_cmp_instr(&self, is_branch: bool) -> Option; } impl NegateCmpInstr for Instruction { - fn negate_cmp_instr(&self) -> Option { + fn negate_cmp_instr(&self, is_branch: bool) -> Option { use Instruction as I; #[rustfmt::skip] let negated = match *self { @@ -34,15 +40,15 @@ impl NegateCmpInstr for Instruction { I::I32And { result, lhs, rhs } => I::i32_and_eqz(result, lhs, rhs), I::I32Or { result, lhs, rhs } => I::i32_or_eqz(result, lhs, rhs), I::I32Xor { result, lhs, rhs } => I::i32_xor_eqz(result, lhs, rhs), - I::I32AndEqz { result, lhs, rhs } => I::i32_and(result, lhs, rhs), - I::I32OrEqz { result, lhs, rhs } => I::i32_or(result, lhs, rhs), - I::I32XorEqz { result, lhs, rhs } => I::i32_xor(result, lhs, rhs), + I::I32AndEqz { result, lhs, rhs } if is_branch => I::i32_and(result, lhs, rhs), + I::I32OrEqz { result, lhs, rhs } if is_branch => I::i32_or(result, lhs, rhs), + I::I32XorEqz { result, lhs, rhs } if is_branch => I::i32_xor(result, lhs, rhs), I::I32AndImm16 { result, lhs, rhs } => I::i32_and_eqz_imm16(result, lhs, rhs), I::I32OrImm16 { result, lhs, rhs } => I::i32_or_eqz_imm16(result, lhs, rhs), I::I32XorImm16 { result, lhs, rhs } => I::i32_xor_eqz_imm16(result, lhs, rhs), - I::I32AndEqzImm16 { result, lhs, rhs } => I::i32_and_imm16(result, lhs, rhs), - I::I32OrEqzImm16 { result, lhs, rhs } => I::i32_or_imm16(result, lhs, rhs), - I::I32XorEqzImm16 { result, lhs, rhs } => I::i32_xor_imm16(result, lhs, rhs), + I::I32AndEqzImm16 { result, lhs, rhs } if is_branch => I::i32_and_imm16(result, lhs, rhs), + I::I32OrEqzImm16 { result, lhs, rhs } if is_branch => I::i32_or_imm16(result, lhs, rhs), + I::I32XorEqzImm16 { result, lhs, rhs } if is_branch => I::i32_xor_imm16(result, lhs, rhs), // i64 I::I64Eq { result, lhs, rhs } => I::i64_ne(result, lhs, rhs), I::I64Ne { result, lhs, rhs } => I::i64_eq(result, lhs, rhs), diff --git a/crates/wasmi/src/engine/translator/instr_encoder.rs b/crates/wasmi/src/engine/translator/instr_encoder.rs index 6aea272193..313574024b 100644 --- a/crates/wasmi/src/engine/translator/instr_encoder.rs +++ b/crates/wasmi/src/engine/translator/instr_encoder.rs @@ -945,7 +945,7 @@ impl InstrEncoder { // thus indicating that we cannot fuse the instructions. return false; } - let Some(negated) = last_instruction.negate_cmp_instr() else { + let Some(negated) = last_instruction.negate_cmp_instr(false) else { // Last instruction is unable to be negated. return false; }; @@ -1085,7 +1085,7 @@ impl InstrEncoder { return Ok(None); } let last_instruction = match negate { - true => match last_instruction.negate_cmp_instr() { + true => match last_instruction.negate_cmp_instr(true) { Some(negated) => negated, None => return Ok(None), }, diff --git a/crates/wasmi/src/engine/translator/tests/op/i32_eqz.rs b/crates/wasmi/src/engine/translator/tests/op/i32_eqz.rs index be3877d7fe..e1afa7a9cd 100644 --- a/crates/wasmi/src/engine/translator/tests/op/i32_eqz.rs +++ b/crates/wasmi/src/engine/translator/tests/op/i32_eqz.rs @@ -229,9 +229,6 @@ fn binop_i32_eqz_double() { test_for!( ("i32", "eq", Instruction::i32_eq), ("i32", "ne", Instruction::i32_ne), - ("i32", "and", Instruction::i32_and), - ("i32", "or", Instruction::i32_or), - ("i32", "xor", Instruction::i32_xor), ("i32", "lt_s", Instruction::i32_lt_s), ("i32", "lt_u", Instruction::i32_lt_u), ("i32", "le_s", Instruction::i32_le_s), @@ -297,9 +294,6 @@ fn binop_imm_i32_eqz_rhs_double() { test_for_imm!( (i32, "eq", Instruction::i32_eq_imm16), (i32, "ne", Instruction::i32_ne_imm16), - (i32, "and", Instruction::i32_and_imm16), - (i32, "or", Instruction::i32_or_imm16), - (i32, "xor", Instruction::i32_xor_imm16), (i32, "lt_s", Instruction::i32_lt_s_imm16_rhs), (u32, "lt_u", Instruction::i32_lt_u_imm16_rhs), (i32, "le_s", Instruction::i32_le_s_imm16_rhs), @@ -361,9 +355,6 @@ fn binop_imm_i32_eqz_lhs_double() { test_for_imm!( (i32, "eq", Instruction::i32_eq_imm16), (i32, "ne", Instruction::i32_ne_imm16), - (i32, "and", Instruction::i32_and_imm16), - (i32, "or", Instruction::i32_or_imm16), - (i32, "xor", Instruction::i32_xor_imm16), (i32, "lt_s", swap_ops!(Instruction::i32_lt_s_imm16_lhs)), (u32, "lt_u", swap_ops!(Instruction::i32_lt_u_imm16_lhs)), (i32, "le_s", swap_ops!(Instruction::i32_le_s_imm16_lhs)), @@ -382,3 +373,128 @@ fn binop_imm_i32_eqz_lhs_double() { (u64, "ge_u", Instruction::i64_le_u_imm16_rhs), ); } + +#[test] +#[cfg_attr(miri, ignore)] +fn binop_i32_eqz_double_invalid() { + fn test_for( + input_ty: &str, + op: &str, + expect_instr: fn(result: Reg, lhs: Reg, rhs: Reg) -> Instruction, + ) { + let wasm = &format!( + r" + (module + (func (param {input_ty} {input_ty}) (result i32) + (local.get 0) + (local.get 1) + ({input_ty}.{op}) + (i32.eqz) + (i32.eqz) + ) + )", + ); + TranslationTest::from_wat(wasm) + .expect_func_instrs([ + expect_instr(Reg::from(2), Reg::from(0), Reg::from(1)), + Instruction::i32_eq_imm16(Reg::from(2), Reg::from(2), 0), + Instruction::return_reg(2), + ]) + .run() + } + test_for!( + ("i32", "and", Instruction::i32_and_eqz), + ("i32", "or", Instruction::i32_or_eqz), + ("i32", "xor", Instruction::i32_xor_eqz), + ); +} + +#[test] +#[cfg_attr(miri, ignore)] +fn binop_imm_i32_eqz_rhs_double_invalid() { + fn test_for( + op: &str, + value: T, + expect_instr: fn(result: Reg, lhs: Reg, rhs: Const16) -> Instruction, + ) where + T: Display + WasmTy, + Const16: TryFrom, + DisplayWasm: Display, + { + let input_ty = T::NAME; + let display_value = DisplayWasm::from(value); + let wasm = &format!( + r" + (module + (func (param {input_ty} {input_ty}) (result i32) + (local.get 0) + ({input_ty}.const {display_value}) + ({input_ty}.{op}) + (i32.eqz) + (i32.eqz) + ) + )", + ); + TranslationTest::from_wat(wasm) + .expect_func_instrs([ + expect_instr( + Reg::from(2), + Reg::from(0), + Const16::try_from(value).ok().unwrap(), + ), + Instruction::i32_eq_imm16(Reg::from(2), Reg::from(2), 0), + Instruction::return_reg(2), + ]) + .run() + } + test_for_imm!( + (i32, "and", Instruction::i32_and_eqz_imm16), + (i32, "or", Instruction::i32_or_eqz_imm16), + (i32, "xor", Instruction::i32_xor_eqz_imm16), + ); +} + +#[test] +#[cfg_attr(miri, ignore)] +fn binop_imm_i32_eqz_lhs_double_invalid() { + fn test_for( + op: &str, + value: T, + expect_instr: fn(result: Reg, lhs: Reg, rhs: Const16) -> Instruction, + ) where + T: Display + WasmTy, + Const16: TryFrom, + DisplayWasm: Display, + { + let input_ty = T::NAME; + let display_value = DisplayWasm::from(value); + let wasm = &format!( + r" + (module + (func (param {input_ty} {input_ty}) (result i32) + ({input_ty}.const {display_value}) + (local.get 0) + ({input_ty}.{op}) + (i32.eqz) + (i32.eqz) + ) + )", + ); + TranslationTest::from_wat(wasm) + .expect_func_instrs([ + expect_instr( + Reg::from(2), + Reg::from(0), + Const16::try_from(value).ok().unwrap(), + ), + Instruction::i32_eq_imm16(Reg::from(2), Reg::from(2), 0), + Instruction::return_reg(2), + ]) + .run() + } + test_for_imm!( + (i32, "and", Instruction::i32_and_eqz_imm16), + (i32, "or", Instruction::i32_or_eqz_imm16), + (i32, "xor", Instruction::i32_xor_eqz_imm16), + ); +}