diff --git a/mpc-core/src/protocols/aby3/witness_extension_impl.rs b/mpc-core/src/protocols/aby3/witness_extension_impl.rs index 33aa0dd6..f51c8221 100644 --- a/mpc-core/src/protocols/aby3/witness_extension_impl.rs +++ b/mpc-core/src/protocols/aby3/witness_extension_impl.rs @@ -1,10 +1,11 @@ use super::{network::Aby3Network, Aby3PrimeFieldShare, Aby3Protocol}; use crate::{ - protocols::plain::PlainDriver, + protocols::{aby3::a2b::Aby3BigUintShare, plain::PlainDriver}, traits::{CircomWitnessExtensionProtocol, PrimeFieldMpcProtocol}, }; -use ark_ff::PrimeField; +use ark_ff::{One, PrimeField}; use eyre::{bail, Result}; +use num_bigint::BigUint; #[derive(Clone)] pub enum Aby3VmType { @@ -225,12 +226,60 @@ impl Aby3VmType { Ok(res) } - fn shift_r(_party: &mut Aby3Protocol, a: Self, b: Self) -> Result { + fn shift_r(party: &mut Aby3Protocol, a: Self, b: Self) -> Result { + // TODO: The circom handling of shifts can handle "negative" inputs, translating them to other type of shift... let res = match (a, b) { (Aby3VmType::Public(a), Aby3VmType::Public(b)) => { let mut plain = PlainDriver::default(); Aby3VmType::Public(plain.vm_shift_r(a, b)?) } + (Aby3VmType::Public(a), Aby3VmType::Shared(b)) => { + // some special casing + if a == F::zero() { + return Ok(Aby3VmType::Public(F::zero())); + } + // TODO: check for overflows + // This case is equivalent to a*2^b + // Strategy: limit size of b to k bits + // bit-decompose b into bits b_i + let bit_shares = party.a2b(&b)?; + let individual_bit_shares = (0..8) + .map(|i| { + let bit = Aby3BigUintShare { + a: (bit_shares.a.clone() >> i) & BigUint::one(), + b: (bit_shares.b.clone() >> i) & BigUint::one(), + }; + let share = party.b2a(bit); + share + }) + .collect::, _>>()?; + // v_i = 2^2^i * + 1 - + let mut vs: Vec<_> = individual_bit_shares + .into_iter() + .enumerate() + .map(|(i, b_i)| { + let two = F::from(2u64); + let two_to_two_to_i = two.pow(&[2u64.pow(i as u32)]); + let v = party.mul_with_public(&two_to_two_to_i, &b_i); + let v = party.add_with_public(&F::one(), &v); + party.sub(&v, &b_i) + }) + .collect(); + + // v = \prod v_i + // TODO: This should be done in a multiplication tree + let last = vs.pop().unwrap(); + let v = vs.into_iter().try_fold(last, |a, b| party.mul(&a, &b))?; + let res = party.mul_with_public(&a, &v); + Aby3VmType::Shared(res) + } + (Aby3VmType::Shared(a), Aby3VmType::Public(b)) => { + // TODO: handle overflows + // This case is equivalent to a*2^b + // TODO: assert b < 256? + let shift = F::from(2u64).pow(&[b.into_bigint().as_mut()[0]]); + Aby3VmType::Shared(party.mul_with_public(&shift, &a)) + } (_, _) => todo!("Shared shift_right not implemented"), }; Ok(res)