Skip to content

Commit

Permalink
Merge branch 'fix/stdlib_fixes' of github.com:TaceoLabs/collaborative…
Browse files Browse the repository at this point in the history
…-circom into fix/stdlib_fixes
  • Loading branch information
0xThemis committed Jun 6, 2024
2 parents c319f39 + f0b9527 commit eda9f74
Show file tree
Hide file tree
Showing 5 changed files with 122 additions and 25 deletions.
9 changes: 7 additions & 2 deletions circom-mpc-vm/src/mpc_vm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,8 @@ impl<P: Pairing, C: CircomWitnessExtensionProtocol<P::ScalarField>> Component<P,
op_codes::MpcOpCode::Assert => {
let assertion = self.pop_field();
if protocol.is_zero(assertion) {
panic!("assertion failed");
// TODO: Handle nicer
panic!("Assertion failed during execution");
}
}
op_codes::MpcOpCode::Add => {
Expand Down Expand Up @@ -343,7 +344,11 @@ impl<P: Pairing, C: CircomWitnessExtensionProtocol<P::ScalarField>> Component<P,
let lhs = self.pop_field();
self.push_field(protocol.vm_shift_l(lhs, rhs)?);
}
op_codes::MpcOpCode::BoolOr => todo!(),
op_codes::MpcOpCode::BoolOr => {
let rhs = self.pop_field();
let lhs = self.pop_field();
self.push_field(protocol.vm_bool_or(lhs, rhs)?);
}
op_codes::MpcOpCode::BoolAnd => {
let rhs = self.pop_field();
let lhs = self.pop_field();
Expand Down
110 changes: 87 additions & 23 deletions mpc-core/src/protocols/aby3/witness_extension_impl.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
use super::{network::Aby3Network, Aby3PrimeFieldShare, Aby3Protocol};
use crate::{
protocols::plain::PlainDriver,
protocols::{aby3::a2b::Aby3BigUintShare, plain::PlainDriver},
traits::{CircomWitnessExtensionProtocol, PrimeFieldMpcProtocol},
};
use ark_ff::PrimeField;
use ark_ff::{One, PrimeField};
use eyre::{bail, Result};
use num_bigint::BigUint;

#[derive(Clone)]
pub enum Aby3VmType<F: PrimeField> {
Expand Down Expand Up @@ -64,7 +65,7 @@ impl<F: PrimeField> Aby3VmType<F> {
Aby3VmType::Shared(party.add_with_public(&b, &a))
}
(Aby3VmType::Shared(a), Aby3VmType::Shared(b)) => Aby3VmType::Shared(party.add(&a, &b)),
(_, _) => todo!("BitShared not yet implemented"),
(_, _) => todo!("BitShared add not yet implemented"),
}
}

Expand All @@ -81,7 +82,7 @@ impl<F: PrimeField> Aby3VmType<F> {
Aby3VmType::Shared(party.add_with_public(&-b, &a))
}
(Aby3VmType::Shared(a), Aby3VmType::Shared(b)) => Aby3VmType::Shared(party.sub(&a, &b)),
(_, _) => todo!("BitShared not yet implemented"),
(_, _) => todo!("BitShared sub not yet implemented"),
}
}

Expand All @@ -100,7 +101,7 @@ impl<F: PrimeField> Aby3VmType<F> {
(Aby3VmType::Shared(a), Aby3VmType::Shared(b)) => {
Aby3VmType::Shared(party.mul(&a, &b)?)
}
(_, _) => todo!("BitShared not yet implemented"),
(_, _) => todo!("BitShared mul not yet implemented"),
};
Ok(res)
}
Expand All @@ -112,7 +113,7 @@ impl<F: PrimeField> Aby3VmType<F> {
Aby3VmType::Public(plain.vm_neg(a))
}
Aby3VmType::Shared(a) => Aby3VmType::Shared(party.neg(&a)),
_ => todo!("BitShared not yet implemented"),
_ => todo!("BitShared neg not yet implemented"),
}
}

Expand All @@ -138,7 +139,7 @@ impl<F: PrimeField> Aby3VmType<F> {
let b_inv = party.inv(&b)?;
Aby3VmType::Shared(party.mul(&a, &b_inv)?)
}
(_, _) => todo!("BitShared not implemented"),
(_, _) => todo!("BitShared div not implemented"),
};
Ok(res)
}
Expand All @@ -149,7 +150,7 @@ impl<F: PrimeField> Aby3VmType<F> {
let mut plain = PlainDriver::default();
Aby3VmType::Public(plain.vm_int_div(a, b)?)
}
(_, _) => todo!("Shared not implemented"),
(_, _) => todo!("Shared int_div not implemented"),
};
Ok(res)
}
Expand All @@ -160,7 +161,7 @@ impl<F: PrimeField> Aby3VmType<F> {
let mut plain = PlainDriver::default();
Aby3VmType::Public(plain.vm_lt(a, b))
}
(_, _) => todo!("Shared not implemented"),
(_, _) => todo!("Shared LT not implemented"),
}
}

Expand All @@ -170,7 +171,7 @@ impl<F: PrimeField> Aby3VmType<F> {
let mut plain = PlainDriver::default();
Aby3VmType::Public(plain.vm_le(a, b))
}
(_, _) => todo!("Shared not implemented"),
(_, _) => todo!("Shared LE not implemented"),
}
}

Expand All @@ -180,7 +181,7 @@ impl<F: PrimeField> Aby3VmType<F> {
let mut plain = PlainDriver::default();
Aby3VmType::Public(plain.vm_gt(a, b))
}
(_, _) => todo!("Shared not implemented"),
(_, _) => todo!("Shared GT not implemented"),
}
}

Expand All @@ -190,7 +191,7 @@ impl<F: PrimeField> Aby3VmType<F> {
let mut plain = PlainDriver::default();
Aby3VmType::Public(plain.vm_ge(a, b))
}
(_, _) => todo!("Shared not implemented"),
(_, _) => todo!("Shared GE not implemented"),
}
}

Expand All @@ -200,7 +201,7 @@ impl<F: PrimeField> Aby3VmType<F> {
let mut plain = PlainDriver::default();
Aby3VmType::Public(plain.vm_eq(a, b))
}
(_, _) => todo!("Shared not implemented"),
(_, _) => todo!("Shared EQ not implemented"),
}
}

Expand All @@ -210,7 +211,7 @@ impl<F: PrimeField> Aby3VmType<F> {
let mut plain = PlainDriver::default();
Aby3VmType::Public(plain.vm_neq(a, b))
}
(_, _) => todo!("Shared not implemented"),
(_, _) => todo!("Shared NEQ not implemented"),
}
}

Expand All @@ -220,18 +221,66 @@ impl<F: PrimeField> Aby3VmType<F> {
let mut plain = PlainDriver::default();
Aby3VmType::Public(plain.vm_shift_l(a, b)?)
}
(_, _) => todo!("Shared not implemented"),
(_, _) => todo!("Shared shift_left not implemented"),
};
Ok(res)
}

fn shift_r<N: Aby3Network>(_party: &mut Aby3Protocol<F, N>, a: Self, b: Self) -> Result<Self> {
fn shift_r<N: Aby3Network>(party: &mut Aby3Protocol<F, N>, a: Self, b: Self) -> Result<Self> {
// TODO: The circom handling of shifts can handle "negative" inputs, translating them to other type of shift...
let res = match (a, b) {
(Aby3VmType::Public(a), Aby3VmType::Public(b)) => {
let mut plain = PlainDriver::default();
Aby3VmType::Public(plain.vm_shift_r(a, b)?)
}
(_, _) => todo!("Shared not implemented"),
(Aby3VmType::Public(a), Aby3VmType::Shared(b)) => {
// some special casing
if a == F::zero() {
return Ok(Aby3VmType::Public(F::zero()));
}
// TODO: check for overflows
// This case is equivalent to a*2^b
// Strategy: limit size of b to k bits
// bit-decompose b into bits b_i
let bit_shares = party.a2b(&b)?;
let individual_bit_shares = (0..8)
.map(|i| {
let bit = Aby3BigUintShare {
a: (bit_shares.a.clone() >> i) & BigUint::one(),
b: (bit_shares.b.clone() >> i) & BigUint::one(),
};
let share = party.b2a(bit);
share
})
.collect::<Result<Vec<_>, _>>()?;
// v_i = 2^2^i * <b_i> + 1 - <b_i>
let mut vs: Vec<_> = individual_bit_shares
.into_iter()
.enumerate()
.map(|(i, b_i)| {
let two = F::from(2u64);
let two_to_two_to_i = two.pow(&[2u64.pow(i as u32)]);
let v = party.mul_with_public(&two_to_two_to_i, &b_i);
let v = party.add_with_public(&F::one(), &v);
party.sub(&v, &b_i)
})
.collect();

// v = \prod v_i
// TODO: This should be done in a multiplication tree
let last = vs.pop().unwrap();
let v = vs.into_iter().try_fold(last, |a, b| party.mul(&a, &b))?;
let res = party.mul_with_public(&a, &v);
Aby3VmType::Shared(res)
}
(Aby3VmType::Shared(a), Aby3VmType::Public(b)) => {
// TODO: handle overflows
// This case is equivalent to a*2^b
// TODO: assert b < 256?
let shift = F::from(2u64).pow(&[b.into_bigint().as_mut()[0]]);
Aby3VmType::Shared(party.mul_with_public(&shift, &a))
}
(_, _) => todo!("Shared shift_right not implemented"),
};
Ok(res)
}
Expand All @@ -242,7 +291,18 @@ impl<F: PrimeField> Aby3VmType<F> {
let mut plain = PlainDriver::default();
Aby3VmType::Public(plain.vm_bool_and(a, b)?)
}
(_, _) => todo!("Shared not implemented"),
(_, _) => todo!("Shared bool_and not implemented"),
};
Ok(res)
}

fn bool_or<N: Aby3Network>(_party: &mut Aby3Protocol<F, N>, a: Self, b: Self) -> Result<Self> {
let res = match (a, b) {
(Aby3VmType::Public(a), Aby3VmType::Public(b)) => {
let mut plain = PlainDriver::default();
Aby3VmType::Public(plain.vm_bool_or(a, b)?)
}
(_, _) => todo!("Shared bool_or not implemented"),
};
Ok(res)
}
Expand All @@ -253,7 +313,7 @@ impl<F: PrimeField> Aby3VmType<F> {
let mut plain = PlainDriver::default();
Aby3VmType::Public(plain.vm_bit_and(a, b)?)
}
(_, _) => todo!("Shared not implemented"),
(_, _) => todo!("Shared bit_and not implemented"),
};
Ok(res)
}
Expand All @@ -264,7 +324,7 @@ impl<F: PrimeField> Aby3VmType<F> {
let mut plain = PlainDriver::default();
Aby3VmType::Public(plain.vm_bit_xor(a, b)?)
}
(_, _) => todo!("Shared not implemented"),
(_, _) => todo!("Shared bit_xor not implemented"),
};
Ok(res)
}
Expand All @@ -275,7 +335,7 @@ impl<F: PrimeField> Aby3VmType<F> {
let mut plain = PlainDriver::default();
Aby3VmType::Public(plain.vm_bit_or(a, b)?)
}
(_, _) => todo!("Shared not implemented"),
(_, _) => todo!("Shared bit_or not implemented"),
};
Ok(res)
}
Expand All @@ -286,7 +346,7 @@ impl<F: PrimeField> Aby3VmType<F> {
let plain = PlainDriver::default();
plain.is_zero(a)
}
_ => todo!("Shared not implemented"),
_ => todo!("Shared is_zero not implemented"),
}
}

Expand All @@ -297,7 +357,7 @@ impl<F: PrimeField> Aby3VmType<F> {
plain.vm_open(a)
}
Aby3VmType::Shared(a) => Ok(party.open(&a)?),
_ => todo!("Shared not implemented"),
_ => todo!("Shared to_index not implemented"),
}
}
}
Expand Down Expand Up @@ -376,6 +436,10 @@ impl<F: PrimeField, N: Aby3Network> CircomWitnessExtensionProtocol<F> for Aby3Pr
Self::VmType::bool_and(self, a, b)
}

fn vm_bool_or(&mut self, a: Self::VmType, b: Self::VmType) -> Result<Self::VmType> {
Self::VmType::bool_or(self, a, b)
}

fn vm_bit_xor(&mut self, a: Self::VmType, b: Self::VmType) -> Result<Self::VmType> {
Self::VmType::bit_xor(self, a, b)
}
Expand Down
15 changes: 15 additions & 0 deletions mpc-core/src/protocols/gsz/witness_extension_impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,17 @@ impl<F: PrimeField> GSZVmType<F> {
Ok(res)
}

fn bool_or<N: GSZNetwork>(_party: &mut GSZProtocol<F, N>, a: Self, b: Self) -> Result<Self> {
let res = match (a, b) {
(GSZVmType::Public(a), GSZVmType::Public(b)) => {
let mut plain = PlainDriver::default();
GSZVmType::Public(plain.vm_bool_or(a, b)?)
}
(_, _) => todo!("Shared not implemented"),
};
Ok(res)
}

fn bit_and<N: GSZNetwork>(_party: &mut GSZProtocol<F, N>, a: Self, b: Self) -> Result<Self> {
let res = match (a, b) {
(GSZVmType::Public(a), GSZVmType::Public(b)) => {
Expand Down Expand Up @@ -361,6 +372,10 @@ impl<F: PrimeField, N: GSZNetwork> CircomWitnessExtensionProtocol<F> for GSZProt
Self::VmType::bool_and(self, a, b)
}

fn vm_bool_or(&mut self, a: Self::VmType, b: Self::VmType) -> Result<Self::VmType> {
Self::VmType::bool_or(self, a, b)
}

fn vm_bit_xor(&mut self, a: Self::VmType, b: Self::VmType) -> Result<Self::VmType> {
Self::VmType::bit_xor(self, a, b)
}
Expand Down
12 changes: 12 additions & 0 deletions mpc-core/src/protocols/plain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,18 @@ impl<F: PrimeField> CircomWitnessExtensionProtocol<F> for PlainDriver {
}
}

fn vm_bool_or(&mut self, a: Self::VmType, b: Self::VmType) -> Result<Self::VmType> {
let lhs = to_usize!(a);
let rhs = to_usize!(b);
debug_assert!(rhs == 0 || rhs == 1);
debug_assert!(lhs == 0 || lhs == 1);
if rhs == 1 || lhs == 1 {
Ok(F::one())
} else {
Ok(F::zero())
}
}

fn vm_bit_xor(&mut self, a: Self::VmType, b: Self::VmType) -> Result<Self::VmType> {
let lhs = to_bigint!(a);
let rhs = to_bigint!(b);
Expand Down
1 change: 1 addition & 0 deletions mpc-core/src/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ pub trait CircomWitnessExtensionProtocol<F: PrimeField>: PrimeFieldMpcProtocol<F
fn vm_shift_l(&mut self, a: Self::VmType, b: Self::VmType) -> Result<Self::VmType>;

fn vm_bool_and(&mut self, a: Self::VmType, b: Self::VmType) -> Result<Self::VmType>;
fn vm_bool_or(&mut self, a: Self::VmType, b: Self::VmType) -> Result<Self::VmType>;

fn vm_bit_xor(&mut self, a: Self::VmType, b: Self::VmType) -> Result<Self::VmType>;
fn vm_bit_or(&mut self, a: Self::VmType, b: Self::VmType) -> Result<Self::VmType>;
Expand Down

0 comments on commit eda9f74

Please sign in to comment.