Skip to content

Commit

Permalink
Fix double negation translation bug (#1293)
Browse files Browse the repository at this point in the history
* add tests for invalid double negation of compare instrs

* fix invalid double negation bug
  • Loading branch information
Robbepop authored Nov 6, 2024
1 parent 1f8ff6d commit 0500fda
Show file tree
Hide file tree
Showing 3 changed files with 141 additions and 19 deletions.
22 changes: 14 additions & 8 deletions crates/wasmi/src/engine/translator/comparator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,17 @@ use crate::{
};

pub trait NegateCmpInstr: Sized {
fn negate_cmp_instr(&self) -> Option<Self>;
/// Negates the compare (`cmp`) [`Instruction`].
///
/// If the user of the fused comparison [`Instruction`] is going to be a
/// conditional branch [`Instruction`] the `is_branch` parameter is set to
/// `true`. This allows for more optimizations since the result does not need
/// to be bit-accurate.
fn negate_cmp_instr(&self, is_branch: bool) -> Option<Self>;
}

impl NegateCmpInstr for Instruction {
fn negate_cmp_instr(&self) -> Option<Self> {
fn negate_cmp_instr(&self, is_branch: bool) -> Option<Self> {
use Instruction as I;
#[rustfmt::skip]
let negated = match *self {
Expand All @@ -34,15 +40,15 @@ impl NegateCmpInstr for Instruction {
I::I32And { result, lhs, rhs } => I::i32_and_eqz(result, lhs, rhs),
I::I32Or { result, lhs, rhs } => I::i32_or_eqz(result, lhs, rhs),
I::I32Xor { result, lhs, rhs } => I::i32_xor_eqz(result, lhs, rhs),
I::I32AndEqz { result, lhs, rhs } => I::i32_and(result, lhs, rhs),
I::I32OrEqz { result, lhs, rhs } => I::i32_or(result, lhs, rhs),
I::I32XorEqz { result, lhs, rhs } => I::i32_xor(result, lhs, rhs),
I::I32AndEqz { result, lhs, rhs } if is_branch => I::i32_and(result, lhs, rhs),
I::I32OrEqz { result, lhs, rhs } if is_branch => I::i32_or(result, lhs, rhs),
I::I32XorEqz { result, lhs, rhs } if is_branch => I::i32_xor(result, lhs, rhs),
I::I32AndImm16 { result, lhs, rhs } => I::i32_and_eqz_imm16(result, lhs, rhs),
I::I32OrImm16 { result, lhs, rhs } => I::i32_or_eqz_imm16(result, lhs, rhs),
I::I32XorImm16 { result, lhs, rhs } => I::i32_xor_eqz_imm16(result, lhs, rhs),
I::I32AndEqzImm16 { result, lhs, rhs } => I::i32_and_imm16(result, lhs, rhs),
I::I32OrEqzImm16 { result, lhs, rhs } => I::i32_or_imm16(result, lhs, rhs),
I::I32XorEqzImm16 { result, lhs, rhs } => I::i32_xor_imm16(result, lhs, rhs),
I::I32AndEqzImm16 { result, lhs, rhs } if is_branch => I::i32_and_imm16(result, lhs, rhs),
I::I32OrEqzImm16 { result, lhs, rhs } if is_branch => I::i32_or_imm16(result, lhs, rhs),
I::I32XorEqzImm16 { result, lhs, rhs } if is_branch => I::i32_xor_imm16(result, lhs, rhs),
// i64
I::I64Eq { result, lhs, rhs } => I::i64_ne(result, lhs, rhs),
I::I64Ne { result, lhs, rhs } => I::i64_eq(result, lhs, rhs),
Expand Down
4 changes: 2 additions & 2 deletions crates/wasmi/src/engine/translator/instr_encoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -945,7 +945,7 @@ impl InstrEncoder {
// thus indicating that we cannot fuse the instructions.
return false;
}
let Some(negated) = last_instruction.negate_cmp_instr() else {
let Some(negated) = last_instruction.negate_cmp_instr(false) else {
// Last instruction is unable to be negated.
return false;
};
Expand Down Expand Up @@ -1085,7 +1085,7 @@ impl InstrEncoder {
return Ok(None);
}
let last_instruction = match negate {
true => match last_instruction.negate_cmp_instr() {
true => match last_instruction.negate_cmp_instr(true) {
Some(negated) => negated,
None => return Ok(None),
},
Expand Down
134 changes: 125 additions & 9 deletions crates/wasmi/src/engine/translator/tests/op/i32_eqz.rs
Original file line number Diff line number Diff line change
Expand Up @@ -229,9 +229,6 @@ fn binop_i32_eqz_double() {
test_for!(
("i32", "eq", Instruction::i32_eq),
("i32", "ne", Instruction::i32_ne),
("i32", "and", Instruction::i32_and),
("i32", "or", Instruction::i32_or),
("i32", "xor", Instruction::i32_xor),
("i32", "lt_s", Instruction::i32_lt_s),
("i32", "lt_u", Instruction::i32_lt_u),
("i32", "le_s", Instruction::i32_le_s),
Expand Down Expand Up @@ -297,9 +294,6 @@ fn binop_imm_i32_eqz_rhs_double() {
test_for_imm!(
(i32, "eq", Instruction::i32_eq_imm16),
(i32, "ne", Instruction::i32_ne_imm16),
(i32, "and", Instruction::i32_and_imm16),
(i32, "or", Instruction::i32_or_imm16),
(i32, "xor", Instruction::i32_xor_imm16),
(i32, "lt_s", Instruction::i32_lt_s_imm16_rhs),
(u32, "lt_u", Instruction::i32_lt_u_imm16_rhs),
(i32, "le_s", Instruction::i32_le_s_imm16_rhs),
Expand Down Expand Up @@ -361,9 +355,6 @@ fn binop_imm_i32_eqz_lhs_double() {
test_for_imm!(
(i32, "eq", Instruction::i32_eq_imm16),
(i32, "ne", Instruction::i32_ne_imm16),
(i32, "and", Instruction::i32_and_imm16),
(i32, "or", Instruction::i32_or_imm16),
(i32, "xor", Instruction::i32_xor_imm16),
(i32, "lt_s", swap_ops!(Instruction::i32_lt_s_imm16_lhs)),
(u32, "lt_u", swap_ops!(Instruction::i32_lt_u_imm16_lhs)),
(i32, "le_s", swap_ops!(Instruction::i32_le_s_imm16_lhs)),
Expand All @@ -382,3 +373,128 @@ fn binop_imm_i32_eqz_lhs_double() {
(u64, "ge_u", Instruction::i64_le_u_imm16_rhs),
);
}

#[test]
#[cfg_attr(miri, ignore)]
fn binop_i32_eqz_double_invalid() {
fn test_for(
input_ty: &str,
op: &str,
expect_instr: fn(result: Reg, lhs: Reg, rhs: Reg) -> Instruction,
) {
let wasm = &format!(
r"
(module
(func (param {input_ty} {input_ty}) (result i32)
(local.get 0)
(local.get 1)
({input_ty}.{op})
(i32.eqz)
(i32.eqz)
)
)",
);
TranslationTest::from_wat(wasm)
.expect_func_instrs([
expect_instr(Reg::from(2), Reg::from(0), Reg::from(1)),
Instruction::i32_eq_imm16(Reg::from(2), Reg::from(2), 0),
Instruction::return_reg(2),
])
.run()
}
test_for!(
("i32", "and", Instruction::i32_and_eqz),
("i32", "or", Instruction::i32_or_eqz),
("i32", "xor", Instruction::i32_xor_eqz),
);
}

#[test]
#[cfg_attr(miri, ignore)]
fn binop_imm_i32_eqz_rhs_double_invalid() {
fn test_for<T>(
op: &str,
value: T,
expect_instr: fn(result: Reg, lhs: Reg, rhs: Const16<T>) -> Instruction,
) where
T: Display + WasmTy,
Const16<T>: TryFrom<T>,
DisplayWasm<T>: Display,
{
let input_ty = T::NAME;
let display_value = DisplayWasm::from(value);
let wasm = &format!(
r"
(module
(func (param {input_ty} {input_ty}) (result i32)
(local.get 0)
({input_ty}.const {display_value})
({input_ty}.{op})
(i32.eqz)
(i32.eqz)
)
)",
);
TranslationTest::from_wat(wasm)
.expect_func_instrs([
expect_instr(
Reg::from(2),
Reg::from(0),
Const16::try_from(value).ok().unwrap(),
),
Instruction::i32_eq_imm16(Reg::from(2), Reg::from(2), 0),
Instruction::return_reg(2),
])
.run()
}
test_for_imm!(
(i32, "and", Instruction::i32_and_eqz_imm16),
(i32, "or", Instruction::i32_or_eqz_imm16),
(i32, "xor", Instruction::i32_xor_eqz_imm16),
);
}

#[test]
#[cfg_attr(miri, ignore)]
fn binop_imm_i32_eqz_lhs_double_invalid() {
fn test_for<T>(
op: &str,
value: T,
expect_instr: fn(result: Reg, lhs: Reg, rhs: Const16<T>) -> Instruction,
) where
T: Display + WasmTy,
Const16<T>: TryFrom<T>,
DisplayWasm<T>: Display,
{
let input_ty = T::NAME;
let display_value = DisplayWasm::from(value);
let wasm = &format!(
r"
(module
(func (param {input_ty} {input_ty}) (result i32)
({input_ty}.const {display_value})
(local.get 0)
({input_ty}.{op})
(i32.eqz)
(i32.eqz)
)
)",
);
TranslationTest::from_wat(wasm)
.expect_func_instrs([
expect_instr(
Reg::from(2),
Reg::from(0),
Const16::try_from(value).ok().unwrap(),
),
Instruction::i32_eq_imm16(Reg::from(2), Reg::from(2), 0),
Instruction::return_reg(2),
])
.run()
}
test_for_imm!(
(i32, "and", Instruction::i32_and_eqz_imm16),
(i32, "or", Instruction::i32_or_eqz_imm16),
(i32, "xor", Instruction::i32_xor_eqz_imm16),
);
}

0 comments on commit 0500fda

Please sign in to comment.