Skip to content

Commit

Permalink
feat: implement a basic variant of shift_right for ABY3
Browse files Browse the repository at this point in the history
  • Loading branch information
dkales committed Jun 6, 2024
1 parent 384ba52 commit f0b9527
Showing 1 changed file with 52 additions and 3 deletions.
55 changes: 52 additions & 3 deletions mpc-core/src/protocols/aby3/witness_extension_impl.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
use super::{network::Aby3Network, Aby3PrimeFieldShare, Aby3Protocol};
use crate::{
protocols::plain::PlainDriver,
protocols::{aby3::a2b::Aby3BigUintShare, plain::PlainDriver},
traits::{CircomWitnessExtensionProtocol, PrimeFieldMpcProtocol},
};
use ark_ff::PrimeField;
use ark_ff::{One, PrimeField};
use eyre::{bail, Result};
use num_bigint::BigUint;

#[derive(Clone)]
pub enum Aby3VmType<F: PrimeField> {
Expand Down Expand Up @@ -225,12 +226,60 @@ impl<F: PrimeField> Aby3VmType<F> {
Ok(res)
}

fn shift_r<N: Aby3Network>(_party: &mut Aby3Protocol<F, N>, a: Self, b: Self) -> Result<Self> {
fn shift_r<N: Aby3Network>(party: &mut Aby3Protocol<F, N>, a: Self, b: Self) -> Result<Self> {
// TODO: The circom handling of shifts can handle "negative" inputs, translating them to other type of shift...
let res = match (a, b) {
(Aby3VmType::Public(a), Aby3VmType::Public(b)) => {
let mut plain = PlainDriver::default();
Aby3VmType::Public(plain.vm_shift_r(a, b)?)
}
(Aby3VmType::Public(a), Aby3VmType::Shared(b)) => {
// some special casing
if a == F::zero() {
return Ok(Aby3VmType::Public(F::zero()));
}
// TODO: check for overflows
// This case is equivalent to a*2^b
// Strategy: limit size of b to k bits
// bit-decompose b into bits b_i
let bit_shares = party.a2b(&b)?;
let individual_bit_shares = (0..8)
.map(|i| {
let bit = Aby3BigUintShare {
a: (bit_shares.a.clone() >> i) & BigUint::one(),
b: (bit_shares.b.clone() >> i) & BigUint::one(),
};
let share = party.b2a(bit);
share
})
.collect::<Result<Vec<_>, _>>()?;
// v_i = 2^2^i * <b_i> + 1 - <b_i>
let mut vs: Vec<_> = individual_bit_shares
.into_iter()
.enumerate()
.map(|(i, b_i)| {
let two = F::from(2u64);
let two_to_two_to_i = two.pow(&[2u64.pow(i as u32)]);
let v = party.mul_with_public(&two_to_two_to_i, &b_i);
let v = party.add_with_public(&F::one(), &v);
party.sub(&v, &b_i)
})
.collect();

// v = \prod v_i
// TODO: This should be done in a multiplication tree
let last = vs.pop().unwrap();
let v = vs.into_iter().try_fold(last, |a, b| party.mul(&a, &b))?;
let res = party.mul_with_public(&a, &v);
Aby3VmType::Shared(res)
}
(Aby3VmType::Shared(a), Aby3VmType::Public(b)) => {
// TODO: handle overflows
// This case is equivalent to a*2^b
// TODO: assert b < 256?
let shift = F::from(2u64).pow(&[b.into_bigint().as_mut()[0]]);
Aby3VmType::Shared(party.mul_with_public(&shift, &a))
}
(_, _) => todo!("Shared shift_right not implemented"),
};
Ok(res)
Expand Down

0 comments on commit f0b9527

Please sign in to comment.