diff --git a/compiler/noirc_evaluator/src/ssa/function_builder/mod.rs b/compiler/noirc_evaluator/src/ssa/function_builder/mod.rs index 155e5a7baba..058f91adacb 100644 --- a/compiler/noirc_evaluator/src/ssa/function_builder/mod.rs +++ b/compiler/noirc_evaluator/src/ssa/function_builder/mod.rs @@ -1,4 +1,4 @@ -use std::borrow::Cow; +use std::{borrow::Cow, rc::Rc}; use acvm::FieldElement; use noirc_errors::Location; @@ -16,7 +16,8 @@ use super::{ basic_block::BasicBlock, dfg::{CallStack, InsertInstructionResult}, function::RuntimeType, - instruction::{InstructionId, Intrinsic}, + instruction::{Endian, InstructionId, Intrinsic}, + types::NumericType, }, ssa_gen::Ssa, }; @@ -259,9 +260,126 @@ impl FunctionBuilder { arguments: Vec, result_types: Vec, ) -> Cow<[ValueId]> { + if let Value::Intrinsic(intrinsic) = &self.current_function.dfg[func] { + if intrinsic == &Intrinsic::WrappingShiftLeft { + let result_type = self.current_function.dfg.type_of_value(arguments[0]); + let bit_size = match result_type { + Type::Numeric(NumericType::Signed { bit_size }) + | Type::Numeric(NumericType::Unsigned { bit_size }) => bit_size, + _ => { + unreachable!("ICE: Truncation attempted on non-integer"); + } + }; + return self + .insert_wrapping_shift_left(arguments[0], arguments[1], bit_size) + .results(); + } + } + self.insert_instruction(Instruction::Call { func, arguments }, Some(result_types)).results() } + /// Insert ssa instructions which computes lhs << rhs by doing lhs*2^rhs + pub(crate) fn insert_shift_left(&mut self, lhs: ValueId, rhs: ValueId) -> ValueId { + let base = self.field_constant(FieldElement::from(2_u128)); + let pow = self.pow(base, rhs); + let typ = self.current_function.dfg.type_of_value(lhs); + let pow = self.insert_cast(pow, typ); + self.insert_binary(lhs, BinaryOp::Mul, pow) + } + + /// Insert ssa instructions which computes lhs << rhs by doing lhs*2^rhs + /// and truncate the result to bit_size + fn insert_wrapping_shift_left( + &mut self, + lhs: ValueId, + rhs: ValueId, + bit_size: u32, + ) -> InsertInstructionResult { + let base = self.field_constant(FieldElement::from(2_u128)); + let typ = self.current_function.dfg.type_of_value(lhs); + let (max_bit, pow) = if let Some(rhs_constant) = + self.current_function.dfg.get_numeric_constant(rhs) + { + // Happy case is that we know precisely by how many bits the the integer will + // increase: lhs_bit_size + rhs + let (rhs_bit_size_pow_2, overflows) = + 2_u32.overflowing_pow(rhs_constant.to_u128() as u32); + if overflows { + let zero = self.numeric_constant(FieldElement::zero(), typ); + return InsertInstructionResult::SimplifiedTo(zero); + } + let pow = self.numeric_constant(FieldElement::from(rhs_bit_size_pow_2 as u128), typ); + (bit_size + (rhs_constant.to_u128() as u32), pow) + } else { + // we use a predicate to nullify the result in case of overflow + let bit_size_var = + self.numeric_constant(FieldElement::from(bit_size as u128), typ.clone()); + let overflow = self.insert_binary(rhs, BinaryOp::Lt, bit_size_var); + let one = self.numeric_constant(FieldElement::one(), Type::unsigned(1)); + let predicate = self.insert_binary(overflow, BinaryOp::Eq, one); + let predicate = self.insert_cast(predicate, typ.clone()); + + let pow = self.pow(base, rhs); + let pow = self.insert_cast(pow, typ); + (FieldElement::max_num_bits(), self.insert_binary(predicate, BinaryOp::Mul, pow)) + }; + + let instruction = Instruction::Binary(Binary { lhs, rhs: pow, operator: BinaryOp::Mul }); + if max_bit <= bit_size { + self.insert_instruction(instruction, None) + } else { + let result = self.insert_instruction(instruction, None).first(); + self.insert_instruction( + Instruction::Truncate { value: result, bit_size, max_bit_size: max_bit }, + None, + ) + } + } + + /// Insert ssa instructions which computes lhs >> rhs by doing lhs/2^rhs + pub(crate) fn insert_shift_right(&mut self, lhs: ValueId, rhs: ValueId) -> ValueId { + let base = self.field_constant(FieldElement::from(2_u128)); + let pow = self.pow(base, rhs); + self.insert_binary(lhs, BinaryOp::Div, pow) + } + + /// Computes lhs^rhs via square&multiply, using the bits decomposition of rhs + /// Pseudo-code of the computation: + /// let mut r = 1; + /// let rhs_bits = to_bits(rhs); + /// for i in 1 .. bit_size + 1 { + /// let r_squared = r * r; + /// let b = rhs_bits[bit_size - i]; + /// r = (r_squared * lhs * b) + (1 - b) * r_squared; + /// } + pub(crate) fn pow(&mut self, lhs: ValueId, rhs: ValueId) -> ValueId { + let typ = self.current_function.dfg.type_of_value(rhs); + if let Type::Numeric(NumericType::Unsigned { bit_size }) = typ { + let to_bits = self.import_intrinsic_id(Intrinsic::ToBits(Endian::Little)); + let length = self.field_constant(FieldElement::from(bit_size as i128)); + let result_types = + vec![Type::field(), Type::Array(Rc::new(vec![Type::bool()]), bit_size as usize)]; + let rhs_bits = self.insert_call(to_bits, vec![rhs, length], result_types); + let rhs_bits = rhs_bits[1]; + let one = self.field_constant(FieldElement::one()); + let mut r = one; + for i in 1..bit_size + 1 { + let r_squared = self.insert_binary(r, BinaryOp::Mul, r); + let a = self.insert_binary(r_squared, BinaryOp::Mul, lhs); + let idx = self.field_constant(FieldElement::from((bit_size - i) as i128)); + let b = self.insert_array_get(rhs_bits, idx, Type::field()); + let r1 = self.insert_binary(a, BinaryOp::Mul, b); + let c = self.insert_binary(one, BinaryOp::Sub, b); + let r2 = self.insert_binary(c, BinaryOp::Mul, r_squared); + r = self.insert_binary(r1, BinaryOp::Add, r2); + } + r + } else { + unreachable!("Value must be unsigned in power operation"); + } + } + /// Insert an instruction to extract an element from an array pub(crate) fn insert_array_get( &mut self, diff --git a/compiler/noirc_evaluator/src/ssa/ir/instruction.rs b/compiler/noirc_evaluator/src/ssa/ir/instruction.rs index badc1e82d50..a3d6396b668 100644 --- a/compiler/noirc_evaluator/src/ssa/ir/instruction.rs +++ b/compiler/noirc_evaluator/src/ssa/ir/instruction.rs @@ -46,6 +46,7 @@ pub(crate) enum Intrinsic { BlackBox(BlackBoxFunc), FromField, AsField, + WrappingShiftLeft, } impl std::fmt::Display for Intrinsic { @@ -68,6 +69,7 @@ impl std::fmt::Display for Intrinsic { Intrinsic::BlackBox(function) => write!(f, "{function}"), Intrinsic::FromField => write!(f, "from_field"), Intrinsic::AsField => write!(f, "as_field"), + Intrinsic::WrappingShiftLeft => write!(f, "wrapping_shift_left"), } } } @@ -92,7 +94,8 @@ impl Intrinsic { | Intrinsic::ToBits(_) | Intrinsic::ToRadix(_) | Intrinsic::FromField - | Intrinsic::AsField => false, + | Intrinsic::AsField + | Intrinsic::WrappingShiftLeft => false, // Some black box functions have side-effects Intrinsic::BlackBox(func) => matches!(func, BlackBoxFunc::RecursiveAggregation), @@ -119,6 +122,7 @@ impl Intrinsic { "to_be_bits" => Some(Intrinsic::ToBits(Endian::Big)), "from_field" => Some(Intrinsic::FromField), "as_field" => Some(Intrinsic::AsField), + "wrapping_shift_left" => Some(Intrinsic::WrappingShiftLeft), other => BlackBoxFunc::lookup(other).map(Intrinsic::BlackBox), } } diff --git a/compiler/noirc_evaluator/src/ssa/ir/instruction/call.rs b/compiler/noirc_evaluator/src/ssa/ir/instruction/call.rs index b07e2df7bd3..da5544d7dc6 100644 --- a/compiler/noirc_evaluator/src/ssa/ir/instruction/call.rs +++ b/compiler/noirc_evaluator/src/ssa/ir/instruction/call.rs @@ -245,6 +245,9 @@ pub(super) fn simplify_call( let instruction = Instruction::Cast(arguments[0], ctrl_typevars.unwrap().remove(0)); SimplifyResult::SimplifiedToInstruction(instruction) } + Intrinsic::WrappingShiftLeft => { + unreachable!("ICE - wrapping shift left should have been proccessed before") + } } } diff --git a/compiler/noirc_evaluator/src/ssa/ssa_gen/context.rs b/compiler/noirc_evaluator/src/ssa/ssa_gen/context.rs index 26c818250dc..25534c739e2 100644 --- a/compiler/noirc_evaluator/src/ssa/ssa_gen/context.rs +++ b/compiler/noirc_evaluator/src/ssa/ssa_gen/context.rs @@ -13,7 +13,7 @@ use crate::ssa::function_builder::FunctionBuilder; use crate::ssa::ir::dfg::DataFlowGraph; use crate::ssa::ir::function::FunctionId as IrFunctionId; use crate::ssa::ir::function::{Function, RuntimeType}; -use crate::ssa::ir::instruction::{BinaryOp, Endian, Intrinsic}; +use crate::ssa::ir::instruction::BinaryOp; use crate::ssa::ir::map::AtomicCounter; use crate::ssa::ir::types::{NumericType, Type}; use crate::ssa::ir::value::ValueId; @@ -265,50 +265,6 @@ impl<'a> FunctionContext<'a> { Ok(self.builder.numeric_constant(value, typ)) } - /// Insert ssa instructions which computes lhs << rhs by doing lhs*2^rhs - fn insert_shift_left(&mut self, lhs: ValueId, rhs: ValueId) -> ValueId { - let base = self.builder.field_constant(FieldElement::from(2_u128)); - let pow = self.pow(base, rhs); - let typ = self.builder.current_function.dfg.type_of_value(lhs); - let pow = self.builder.insert_cast(pow, typ); - self.builder.insert_binary(lhs, BinaryOp::Mul, pow) - } - - /// Insert ssa instructions which computes lhs >> rhs by doing lhs/2^rhs - fn insert_shift_right(&mut self, lhs: ValueId, rhs: ValueId) -> ValueId { - let base = self.builder.field_constant(FieldElement::from(2_u128)); - let pow = self.pow(base, rhs); - self.builder.insert_binary(lhs, BinaryOp::Div, pow) - } - - /// Computes lhs^rhs via square&multiply, using the bits decomposition of rhs - fn pow(&mut self, lhs: ValueId, rhs: ValueId) -> ValueId { - let typ = self.builder.current_function.dfg.type_of_value(rhs); - if let Type::Numeric(NumericType::Unsigned { bit_size }) = typ { - let to_bits = self.builder.import_intrinsic_id(Intrinsic::ToBits(Endian::Little)); - let length = self.builder.field_constant(FieldElement::from(bit_size as i128)); - let result_types = - vec![Type::field(), Type::Array(Rc::new(vec![Type::bool()]), bit_size as usize)]; - let rhs_bits = self.builder.insert_call(to_bits, vec![rhs, length], result_types); - let rhs_bits = rhs_bits[1]; - let one = self.builder.field_constant(FieldElement::one()); - let mut r = one; - for i in 1..bit_size + 1 { - let r1 = self.builder.insert_binary(r, BinaryOp::Mul, r); - let a = self.builder.insert_binary(r1, BinaryOp::Mul, lhs); - let idx = self.builder.field_constant(FieldElement::from((bit_size - i) as i128)); - let b = self.builder.insert_array_get(rhs_bits, idx, Type::field()); - let r2 = self.builder.insert_binary(a, BinaryOp::Mul, b); - let c = self.builder.insert_binary(one, BinaryOp::Sub, b); - let r3 = self.builder.insert_binary(c, BinaryOp::Mul, r1); - r = self.builder.insert_binary(r2, BinaryOp::Add, r3); - } - r - } else { - unreachable!("Value must be unsigned in power operation"); - } - } - /// Insert a binary instruction at the end of the current block. /// Converts the form of the binary instruction as necessary /// (e.g. swapping arguments, inserting a not) to represent it in the IR. @@ -321,8 +277,8 @@ impl<'a> FunctionContext<'a> { location: Location, ) -> Values { let mut result = match operator { - BinaryOpKind::ShiftLeft => self.insert_shift_left(lhs, rhs), - BinaryOpKind::ShiftRight => self.insert_shift_right(lhs, rhs), + BinaryOpKind::ShiftLeft => self.builder.insert_shift_left(lhs, rhs), + BinaryOpKind::ShiftRight => self.builder.insert_shift_right(lhs, rhs), BinaryOpKind::Equal | BinaryOpKind::NotEqual if matches!(self.builder.type_of_value(lhs), Type::Array(..)) => { diff --git a/noir_stdlib/src/lib.nr b/noir_stdlib/src/lib.nr index 26cf7a225ee..2e34c017db6 100644 --- a/noir_stdlib/src/lib.nr +++ b/noir_stdlib/src/lib.nr @@ -60,6 +60,7 @@ pub fn wrapping_mul(x : T, y: T) -> T { crate::from_field(crate::as_field(x) * crate::as_field(y)) } -pub fn wrapping_shift_left(x : T, y: T) -> T { - crate::from_field(crate::as_field(x) * 2.pow_32(crate::as_field(y))) -} +/// Shift-left x by y bits +/// If the result overflow the bitsize; it does not fail and returns 0 instead +#[builtin(wrapping_shift_left)] +pub fn wrapping_shift_left(x : T, y: T) -> T {} diff --git a/noir_stdlib/src/sha256.nr b/noir_stdlib/src/sha256.nr index d2afd21db8a..358b647a078 100644 --- a/noir_stdlib/src/sha256.nr +++ b/noir_stdlib/src/sha256.nr @@ -7,7 +7,7 @@ fn rotr32(a: u32, b: u32) -> u32 // 32-bit right rotation { // None of the bits overlap between `(a >> b)` and `(a << (32 - b))` // Addition is then equivalent to OR, with fewer constraints. - (a >> b) + (a << (32 - b)) + (a >> b) + (crate::wrapping_shift_left(a, 32 - b)) } fn ch(x: u32, y: u32, z: u32) -> u32 diff --git a/noir_stdlib/src/sha512.nr b/noir_stdlib/src/sha512.nr index c565b16c098..7d3412f517b 100644 --- a/noir_stdlib/src/sha512.nr +++ b/noir_stdlib/src/sha512.nr @@ -7,7 +7,7 @@ fn rotr64(a: u64, b: u64) -> u64 // 64-bit right rotation { // None of the bits overlap between `(a >> b)` and `(a << (64 - b))` // Addition is then equivalent to OR, with fewer constraints. - (a >> b) + (a << (64 - b)) + (a >> b) + (crate::wrapping_shift_left(a, 64 - b)) } fn sha_ch(x: u64, y: u64, z: u64) -> u64