From 7113108f3395c266320f51c089aba3c195855268 Mon Sep 17 00:00:00 2001 From: Franco Nieddu Date: Thu, 12 Sep 2024 11:40:39 +0200 Subject: [PATCH] plonk plain working again --- co-circom/circom-mpc-vm/src/mpc/rep3.rs | 8 +- co-circom/co-plonk/src/mpc.rs | 30 ++-- co-circom/co-plonk/src/mpc/plain.rs | 28 +++- co-circom/co-plonk/src/mpc/rep3.rs | 10 +- co-circom/co-plonk/src/round2.rs | 69 +++++---- co-circom/co-plonk/src/round3.rs | 170 ++++++++++++---------- co-circom/co-plonk/src/round4.rs | 16 +- co-circom/co-plonk/src/round5.rs | 124 +++++++++------- mpc-core/src/protocols/rep3/arithmetic.rs | 4 +- mpc-core/src/protocols/rep3/detail.rs | 1 - tests/src/rep3_network.rs | 13 +- 11 files changed, 281 insertions(+), 192 deletions(-) diff --git a/co-circom/circom-mpc-vm/src/mpc/rep3.rs b/co-circom/circom-mpc-vm/src/mpc/rep3.rs index f8e7dba8..ec75abc3 100644 --- a/co-circom/circom-mpc-vm/src/mpc/rep3.rs +++ b/co-circom/circom-mpc-vm/src/mpc/rep3.rs @@ -103,12 +103,16 @@ impl VmCircomWitnessExtension } (Rep3VmType::Public(b), Rep3VmType::Binary(a)) | (Rep3VmType::Binary(a), Rep3VmType::Public(b)) => { - let a = futures::executor::block_on(conversion::b2a(&a, &mut self.io_context))?; + let a = self + .runtime + .block_on(conversion::b2a(&a, &mut self.io_context))?; Ok(arithmetic::add_public(a, b, self.io_context.id).into()) } (Rep3VmType::Arithmetic(a), Rep3VmType::Binary(b)) | (Rep3VmType::Binary(b), Rep3VmType::Arithmetic(a)) => { - let b = futures::executor::block_on(conversion::b2a(&b, &mut self.io_context))?; + let b = self + .runtime + .block_on(conversion::b2a(&b, &mut self.io_context))?; Ok(arithmetic::add(a, b).into()) } (Rep3VmType::Binary(a), Rep3VmType::Binary(b)) => { diff --git a/co-circom/co-plonk/src/mpc.rs b/co-circom/co-plonk/src/mpc.rs index 535550b9..0eaddd97 100644 --- a/co-circom/co-plonk/src/mpc.rs +++ b/co-circom/co-plonk/src/mpc.rs @@ -1,3 +1,5 @@ +use std::{future::Future, process::Output}; + use ark_ec::{pairing::Pairing, CurveGroup}; use ark_poly::EvaluationDomain; use ark_serialize::{CanonicalDeserialize, CanonicalSerialize}; @@ -18,12 +20,11 @@ pub trait CircomPlonkProver { type PartyID: Send + Sync + Copy; - fn rand(&mut self) -> Self::ArithmeticShare; + fn debug_print(a: Self::ArithmeticShare); - fn get_party_id(&self) -> Self::PartyID { - self.io_context.id - } + fn rand(&mut self) -> Self::ArithmeticShare; + fn get_party_id(&self) -> Self::PartyID; /// Subtract the share b from the share a: \[c\] = \[a\] - \[b\] fn add(a: Self::ArithmeticShare, b: Self::ArithmeticShare) -> Self::ArithmeticShare; @@ -46,29 +47,28 @@ pub trait CircomPlonkProver { public: P::ScalarField, ) -> Self::ArithmeticShare; - async fn mul_vec( + fn mul_vec( &mut self, a: &[Self::ArithmeticShare], b: &[Self::ArithmeticShare], - ) -> IoResult>; + ) -> impl Future>>; /// Convenience method for \[a\] + \[b\] * \[c\] - async fn add_mul_vec( + fn add_mul_vec( &mut self, a: &[Self::ArithmeticShare], b: &[Self::ArithmeticShare], c: &[Self::ArithmeticShare], - ) -> IoResult>; + ) -> impl Future>>; /// Convenience method for \[a\] + \[b\] * c fn add_mul_public( &mut self, - a: &Self::ArithmeticShare, - b: &Self::ArithmeticShare, - c: &P::ScalarField, + a: Self::ArithmeticShare, + b: Self::ArithmeticShare, + c: P::ScalarField, ) -> Self::ArithmeticShare { - let tmp = self.mul_with_public(c, b); - self.add(a, &tmp) + Self::add(a, Self::mul_with_public(b, c)) } /// This function performs a multiplication directly followed by an opening. This safes one round of communication in some MPC protocols compared to calling `mul` and `open` separately. @@ -79,7 +79,7 @@ pub trait CircomPlonkProver { ) -> IoResult>; /// Reconstructs many shared values: a = Open(\[a\]). - async fn open_vec(&mut self, a: Vec) -> IoResult>; + async fn open_vec(&mut self, a: &[Self::ArithmeticShare]) -> IoResult>; /// Computes the inverse of many shared values: \[b\] = \[a\] ^ -1. Requires network communication. async fn inv_vec( @@ -106,7 +106,7 @@ pub trait CircomPlonkProver { ) -> Vec; /// Reconstructs many shared points: A = Open(\[A\]). - async fn open_point_g1(&mut self, a: &Self::PointShareG1) -> IoResult; + async fn open_point_g1(&mut self, a: Self::PointShareG1) -> IoResult; async fn open_point_vec_g1(&mut self, a: &[Self::PointShareG1]) -> IoResult>; // WE NEED THIS ALSO FOR GROTH16 diff --git a/co-circom/co-plonk/src/mpc/plain.rs b/co-circom/co-plonk/src/mpc/plain.rs index d3fdb17f..b80aae6a 100644 --- a/co-circom/co-plonk/src/mpc/plain.rs +++ b/co-circom/co-plonk/src/mpc/plain.rs @@ -1,10 +1,12 @@ use super::IoResult; use ark_ec::pairing::Pairing; use ark_ec::scalar_mul::variable_base::VariableBaseMSM; +use ark_ff::Field; use ark_ff::UniformRand; use ark_poly::univariate::DensePolynomial; use ark_poly::Polynomial; use itertools::izip; +use num_traits::Zero; use super::CircomPlonkProver; use rand::thread_rng; @@ -21,6 +23,10 @@ impl CircomPlonkProver

for PlainPlonkDriver { //doesn't matter type PartyID = usize; + fn debug_print(a: Self::ArithmeticShare) { + println!("{a}") + } + fn rand(&mut self) -> Self::ArithmeticShare { let mut rng = thread_rng(); Self::ArithmeticShare::rand(&mut rng) @@ -66,7 +72,7 @@ impl CircomPlonkProver

for PlainPlonkDriver { a: &[Self::ArithmeticShare], b: &[Self::ArithmeticShare], ) -> IoResult> { - Ok(izip!(a, b).map(|(a, b)| *a + *b).collect()) + Ok(izip!(a, b).map(|(a, b)| *a * *b).collect()) } async fn add_mul_vec( @@ -86,15 +92,25 @@ impl CircomPlonkProver

for PlainPlonkDriver { Ok(izip!(a, b).map(|(a, b)| *a * *b).collect()) } - async fn open_vec(&mut self, a: Vec) -> IoResult> { - Ok(a) + async fn open_vec(&mut self, a: &[Self::ArithmeticShare]) -> IoResult> { + Ok(a.to_vec()) } async fn inv_vec( &mut self, a: &[Self::ArithmeticShare], ) -> IoResult> { - Ok(a.iter().map(|a| -*a).collect()) + let mut res = Vec::with_capacity(a.len()); + for a in a { + if a.is_zero() { + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidInput, + "Cannot invert zero", + )); + } + res.push(a.inverse().unwrap()); + } + Ok(res) } fn promote_to_trivial_share( @@ -118,8 +134,8 @@ impl CircomPlonkProver

for PlainPlonkDriver { domain.ifft(data) } - async fn open_point_g1(&mut self, a: &Self::PointShareG1) -> IoResult { - Ok(*a) + async fn open_point_g1(&mut self, a: Self::PointShareG1) -> IoResult { + Ok(a) } async fn open_point_vec_g1(&mut self, a: &[Self::PointShareG1]) -> IoResult> { diff --git a/co-circom/co-plonk/src/mpc/rep3.rs b/co-circom/co-plonk/src/mpc/rep3.rs index eec88a69..0232c620 100644 --- a/co-circom/co-plonk/src/mpc/rep3.rs +++ b/co-circom/co-plonk/src/mpc/rep3.rs @@ -26,6 +26,10 @@ impl CircomPlonkProver

for Rep3PlonkDriver { type PartyID = PartyID; + fn debug_print(a: Self::ArithmeticShare) { + todo!() + } + fn rand(&mut self) -> Self::ArithmeticShare { Self::ArithmeticShare::rand(&mut self.io_context) } @@ -91,7 +95,7 @@ impl CircomPlonkProver

for Rep3PlonkDriver { rep3::arithmetic::mul_open_vec(a, b, &mut self.io_context).await } - async fn open_vec(&mut self, a: Vec) -> IoResult> { + async fn open_vec(&mut self, a: &[Self::ArithmeticShare]) -> IoResult> { rep3::arithmetic::open_vec(a, &mut self.io_context).await } @@ -123,8 +127,8 @@ impl CircomPlonkProver

for Rep3PlonkDriver { domain.ifft(&data) } - async fn open_point_g1(&mut self, a: &Self::PointShareG1) -> IoResult { - rep3::pointshare::open_point(a, &mut self.io_context).await + async fn open_point_g1(&mut self, a: Self::PointShareG1) -> IoResult { + rep3::pointshare::open_point(&a, &mut self.io_context).await } async fn open_point_vec_g1(&mut self, a: &[Self::PointShareG1]) -> IoResult> { diff --git a/co-circom/co-plonk/src/round2.rs b/co-circom/co-plonk/src/round2.rs index 5fe36c12..7e7ee3ea 100644 --- a/co-circom/co-plonk/src/round2.rs +++ b/co-circom/co-plonk/src/round2.rs @@ -19,19 +19,19 @@ macro_rules! array_prod_mul { // Do the multiplications of inp[i] * inp[i-1] in constant rounds let len = $inp.len(); let r = (0..=len).map(|_| $driver.rand()).collect::>(); - let r_inv = futures::executor::block_on($driver.inv_many(&r))?; + let r_inv = futures::executor::block_on($driver.inv_vec(&r))?; let r_inv0 = vec![r_inv[0].clone(); len]; let mut unblind = futures::executor::block_on($driver.mul_vec(&r_inv0, &r[1..]))?; let mul = futures::executor::block_on($driver.mul_vec(&r[..len], &$inp))?; - let mut open = $driver.mul_open_many(&mul, &r_inv[1..])?; + let mut open = futures::executor::block_on($driver.mul_open_vec(&mul, &r_inv[1..]))?; for i in 1..open.len() { open[i] = open[i] * open[i - 1]; } - for (unblind, open) in unblind.iter_mut().zip(open.iter()) { - *unblind = $driver.mul_with_public(open, unblind); + for (unblind, open) in unblind.iter_mut().zip(open.into_iter()) { + *unblind = T::mul_with_public(*unblind, open); } unblind }}; @@ -117,6 +117,7 @@ impl<'a, P: Pairing, T: CircomPlonkProver

> Round2<'a, P, T> { // To reduce the number of communication rounds, we implement the array_prod_mul macro according to https://www.usenix.org/system/files/sec22-ozdemir.pdf, p11 first paragraph. fn compute_z( driver: &mut T, + runtime: &mut Runtime, zkey: &ZKey

, domains: &Domains, challenges: &Round2Challenges, @@ -157,17 +158,26 @@ impl<'a, P: Pairing, T: CircomPlonkProver

> Round2<'a, P, T> { n3.push(n3_); // denArr := (a + beta·sigma1 + gamma)(b + beta·sigma2 + gamma)(c + beta·sigma3 + gamma) - let d1_ = - driver.add_with_public(&(challenges.beta * zkey.s1_poly.evaluations[i * 4]), a); - let d1_ = driver.add_with_public(&challenges.gamma, &d1_); - - let d2_ = - driver.add_with_public(&(challenges.beta * zkey.s2_poly.evaluations[i * 4]), b); - let d2_ = driver.add_with_public(&challenges.gamma, &d2_); - - let d3_ = - driver.add_with_public(&(challenges.beta * zkey.s3_poly.evaluations[i * 4]), c); - let d3_ = driver.add_with_public(&challenges.gamma, &d3_); + let d1_ = T::add_with_public( + party_id, + *a, + challenges.beta * zkey.s1_poly.evaluations[i * 4], + ); + let d1_ = T::add_with_public(party_id, d1_, challenges.gamma); + + let d2_ = T::add_with_public( + party_id, + *b, + challenges.beta * zkey.s2_poly.evaluations[i * 4], + ); + let d2_ = T::add_with_public(party_id, d2_, challenges.gamma); + + let d3_ = T::add_with_public( + party_id, + *c, + challenges.beta * zkey.s3_poly.evaluations[i * 4], + ); + let d3_ = T::add_with_public(party_id, d3_, challenges.gamma); d1.push(d1_); d2.push(d2_); @@ -177,10 +187,10 @@ impl<'a, P: Pairing, T: CircomPlonkProver

> Round2<'a, P, T> { } // TODO parallelize these? With a different network structure this might not be needed though - let num = futures::executor::block_on(driver.mul_vec(&n1, &n2))?; - let num = futures::executor::block_on(driver.mul_vec(&num, &n3))?; - let den = futures::executor::block_on(driver.mul_vec(&d1, &d2))?; - let den = futures::executor::block_on(driver.mul_vec(&den, &d3))?; + let num = runtime.block_on(driver.mul_vec(&n1, &n2))?; + let num = runtime.block_on(driver.mul_vec(&num, &n3))?; + let den = runtime.block_on(driver.mul_vec(&d1, &d2))?; + let den = runtime.block_on(driver.mul_vec(&den, &d3))?; // TODO parallelize these? With a different network structure this might not be needed though // Do the multiplications of num[i] * num[i-1] and den[i] * den[i-1] in constant rounds @@ -195,12 +205,12 @@ impl<'a, P: Pairing, T: CircomPlonkProver

> Round2<'a, P, T> { buffer_z.rotate_right(1); // Required by SNARKJs/Plonk // Compute polynomial coefficients z(X) from buffer_z - let poly_z = driver.ifft(&buffer_z, &domains.domain); + let poly_z = T::ifft(&buffer_z, &domains.domain); // Compute extended evaluations of z(X) polynomial - let eval_z = driver.fft(&poly_z, &domains.extended_domain); + let eval_z = T::fft(&poly_z, &domains.extended_domain); - let poly_z = plonk_utils::blind_coefficients::(driver, &poly_z, &challenges.b[6..9]); + let poly_z = plonk_utils::blind_coefficients::(&poly_z, &challenges.b[6..9]); if poly_z.len() > zkey.domain_size + 3 { Err(PlonkProofError::PolynomialDegreeTooLarge) @@ -217,7 +227,7 @@ impl<'a, P: Pairing, T: CircomPlonkProver

> Round2<'a, P, T> { pub(super) fn round2(self) -> PlonkProofResult> { let Self { mut driver, - runtime, + mut runtime, data, proof, challenges, @@ -250,7 +260,14 @@ impl<'a, P: Pairing, T: CircomPlonkProver

> Round2<'a, P, T> { let gamma = transcript.get_challenge(); tracing::debug!("beta: {beta}, gamma: {gamma}"); let challenges = Round2Challenges::new(challenges, beta, gamma); - let z = Self::compute_z(&mut driver, zkey, &domains, &challenges, &polys)?; + let z = Self::compute_z( + &mut driver, + &mut runtime, + zkey, + &domains, + &challenges, + &polys, + )?; // STEP 2.3 - Compute permutation [z]_1 tracing::debug!("committing to poly z (MSMs)"); @@ -280,6 +297,7 @@ pub mod tests { use circom_types::plonk::ZKey; use circom_types::Witness; use co_circom_snarks::SharedWitness; + use tokio::runtime; use crate::mpc::plain::PlainPlonkDriver; use crate::round1::Round1; @@ -313,7 +331,8 @@ pub mod tests { }; let challenges = Round1Challenges::deterministic(&mut driver); - let mut round1 = Round1::init_round(driver, &zkey, witness).unwrap(); + let runtime = runtime::Builder::new_current_thread().build().unwrap(); + let mut round1 = Round1::init_round(driver, runtime, &zkey, witness).unwrap(); round1.challenges = challenges; let round2 = round1.round1().unwrap(); let round3 = round2.round2().unwrap(); diff --git a/co-circom/co-plonk/src/round3.rs b/co-circom/co-plonk/src/round3.rs index 6354b16e..e2e024ea 100644 --- a/co-circom/co-plonk/src/round3.rs +++ b/co-circom/co-plonk/src/round3.rs @@ -52,19 +52,19 @@ macro_rules! mul4vec { } macro_rules! mul4vec_post { - ($driver: expr, $a: expr,$b: expr,$c: expr,$d: expr,$i: expr, $z1: expr, $z2: expr, $z3: expr) => {{ + ($party_id: expr, $a: expr,$b: expr,$c: expr,$d: expr,$i: expr, $z1: expr, $z2: expr, $z3: expr) => {{ let mod_i = $i % 4; let mut rz = $a[$i].clone(); if mod_i != 0 { let b = &$b[$i]; let c = &$c[$i]; let d = &$d[$i]; - let tmp = $driver.mul_with_public(&$z1[mod_i], &b); - rz = $driver.add(&rz, &tmp); - let tmp = $driver.mul_with_public(&$z2[mod_i], &c); - rz = $driver.add(&rz, &tmp); - let tmp = $driver.mul_with_public(&$z3[mod_i], &d); - rz = $driver.add(&rz, &tmp); + let tmp = T::mul_with_public(*b, $z1[mod_i]); + rz = T::add(tmp, rz); + let tmp = T::mul_with_public(*c, $z2[mod_i]); + rz = T::add(rz, tmp); + let tmp = T::mul_with_public(*d, $z3[mod_i]); + rz = T::add(rz, tmp); } rz }}; @@ -224,14 +224,15 @@ impl<'a, P: Pairing, T: CircomPlonkProver

> Round3<'a, P, T> { let mut ap = Vec::with_capacity(zkey.domain_size * 4); let mut bp = Vec::with_capacity(zkey.domain_size * 4); let mut cp = Vec::with_capacity(zkey.domain_size * 4); + let party_id = driver.get_party_id(); let pow_root_of_unity = domains.root_of_unity_pow; let pow_plus2_root_of_unity = domains.root_of_unity_pow_2; // We do not want to have any network operation in here to reduce MPC rounds. To enforce this, we have a for_each loop here (Network operations require a result) (0..zkey.domain_size * 4).for_each(|_| { - ap.push(driver.add_mul_public(&challenges.b[1], &challenges.b[0], &w)); - bp.push(driver.add_mul_public(&challenges.b[3], &challenges.b[2], &w)); - cp.push(driver.add_mul_public(&challenges.b[5], &challenges.b[4], &w)); + ap.push(driver.add_mul_public(challenges.b[1], challenges.b[0], w)); + bp.push(driver.add_mul_public(challenges.b[3], challenges.b[2], w)); + cp.push(driver.add_mul_public(challenges.b[5], challenges.b[4], w)); w *= &pow_plus2_root_of_unity; }); @@ -280,69 +281,83 @@ impl<'a, P: Pairing, T: CircomPlonkProver

> Round3<'a, P, T> { let bp = &bp[i]; let w2 = w.square(); - let zp_lhs = driver.mul_with_public(&w2, &challenges.b[6]); - let zp_rhs = driver.mul_with_public(&w, &challenges.b[7]); - let zp_ = driver.add(&zp_lhs, &zp_rhs); - let zp_ = driver.add(&challenges.b[8], &zp_); + let zp_lhs = T::mul_with_public(challenges.b[6], w2); + let zp_rhs = T::mul_with_public(challenges.b[7], w); + let zp_ = T::add(zp_lhs, zp_rhs); + let zp_ = T::add(challenges.b[8], zp_); zp.push(zp_); let w_w = w * pow_root_of_unity; let w_w2 = w_w.square(); let zw = polys.z.eval[(zkey.domain_size * 4 + 4 + i) % (zkey.domain_size * 4)].clone(); - let zwp_lhs = driver.mul_with_public(&w_w2, &challenges.b[6]); - let zwp_rhs = driver.mul_with_public(&w_w, &challenges.b[7]); - let zwp_ = driver.add(&zwp_lhs, &zwp_rhs); - let zwp_ = driver.add(&challenges.b[8], &zwp_); + let zwp_lhs = T::mul_with_public(challenges.b[6], w_w2); + let zwp_rhs = T::mul_with_public(challenges.b[7], w_w); + let zwp_ = T::add(zwp_lhs, zwp_rhs); + let zwp_ = T::add(challenges.b[8], zwp_); zwp.push(zwp_); - let mut a0 = driver.add(&a_bp, &ap_b); + let mut a0 = T::add(*a_bp, *ap_b); let mod_i = i % 4; if mod_i != 0 { let z1 = z1[mod_i]; let ap_bp = ap_bp[i].clone(); - let tmp = driver.mul_with_public(&z1, &ap_bp); - a0 = driver.add(&a0, &tmp); + let tmp = T::mul_with_public(ap_bp, z1); + a0 = T::add(a0, tmp); } let (mut e1_, mut e1z_) = (a_b.to_owned(), a0.to_owned()); - e1_ = driver.mul_with_public(&qm, &e1_); - e1z_ = driver.mul_with_public(&qm, &e1z_); + e1_ = T::mul_with_public(e1_, qm); + e1z_ = T::mul_with_public(e1z_, qm); - e1_ = driver.add_mul_public(&e1_, &a, &ql); - e1z_ = driver.add_mul_public(&e1z_, &ap, &ql); + e1_ = driver.add_mul_public(e1_, *a, ql); + e1z_ = driver.add_mul_public(e1z_, *ap, ql); - e1_ = driver.add_mul_public(&e1_, &b, &qr); - e1z_ = driver.add_mul_public(&e1z_, &bp, &qr); + e1_ = driver.add_mul_public(e1_, *b, qr); + e1z_ = driver.add_mul_public(e1z_, *bp, qr); - e1_ = driver.add_mul_public(&e1_, &c, &qo); - e1z_ = driver.add_mul_public(&e1z_, &cp[i], &qo); + e1_ = driver.add_mul_public(e1_, *c, qo); + e1z_ = driver.add_mul_public(e1z_, cp[i], qo); let mut pi = T::ArithmeticShare::default(); for (j, lagrange) in zkey.lagrange.iter().enumerate() { - let l_eval = lagrange.evaluations[i]; - let a_val = polys.buffer_a[j].clone(); - let tmp = driver.mul_with_public(&l_eval, &a_val); - pi = driver.sub(&pi, &tmp); + let tmp = T::mul_with_public(polys.buffer_a[j], lagrange.evaluations[i]); + pi = T::sub(pi, tmp); } - e1_ = driver.add(&e1_, &pi); - e1_ = driver.add_with_public(&qc, &e1_); + e1_ = T::add(e1_, pi); + e1_ = T::add_with_public(party_id, e1_, qc); e1.push(e1_); e1z.push(e1z_); let betaw = challenges.beta * w; - e2a.push(driver.add_with_public(&(betaw + challenges.gamma), &a)); - e2b.push( - driver.add_with_public(&(betaw * zkey.verifying_key.k1 + challenges.gamma), &b), - ); - e2c.push( - driver.add_with_public(&(betaw * zkey.verifying_key.k2 + challenges.gamma), &c), - ); - - e2d.push(z.clone()); - e3a.push(driver.add_with_public(&(s1 * challenges.beta + challenges.gamma), &a)); - e3b.push(driver.add_with_public(&(s2 * challenges.beta + challenges.gamma), &b)); - e3c.push(driver.add_with_public(&(s3 * challenges.beta + challenges.gamma), &c)); + e2a.push(T::add_with_public(party_id, *a, betaw + challenges.gamma)); + e2b.push(T::add_with_public( + party_id, + *b, + betaw * zkey.verifying_key.k1 + challenges.gamma, + )); + e2c.push(T::add_with_public( + party_id, + *c, + betaw * zkey.verifying_key.k2 + challenges.gamma, + )); + + e2d.push(*z); + e3a.push(T::add_with_public( + party_id, + *a, + s1 * challenges.beta + challenges.gamma, + )); + e3b.push(T::add_with_public( + party_id, + *b, + s2 * challenges.beta + challenges.gamma, + )); + e3c.push(T::add_with_public( + party_id, + *c, + s3 * challenges.beta + challenges.gamma, + )); e3d.push(zw); w *= pow_plus2_root_of_unity; }); @@ -358,53 +373,53 @@ impl<'a, P: Pairing, T: CircomPlonkProver

> Round3<'a, P, T> { // We do not want to have any network operation in here to reduce MPC rounds. To enforce this, we have a for_each loop here (Network operations require a result) (0..zkey.domain_size * 4).for_each(|i| { let mut e2 = e2[i].clone(); - let mut e2z = mul4vec_post!(driver, e2z_0, e2z_1, e2z_2, e2z_3, i, z1, z2, z3); + let mut e2z = mul4vec_post!(party_id, e2z_0, e2z_1, e2z_2, e2z_3, i, z1, z2, z3); let mut e3 = e3[i].clone(); - let mut e3z = mul4vec_post!(driver, e3z_0, e3z_1, e3z_2, e3z_3, i, z1, z2, z3); + let mut e3z = mul4vec_post!(party_id, e3z_0, e3z_1, e3z_2, e3z_3, i, z1, z2, z3); let z = polys.z.eval[i].clone(); let zp = zp[i].clone(); - e2 = driver.mul_with_public(&challenges.alpha, &e2); - e2z = driver.mul_with_public(&challenges.alpha, &e2z); + e2 = T::mul_with_public(e2, challenges.alpha); + e2z = T::mul_with_public(e2z, challenges.alpha); - e3 = driver.mul_with_public(&challenges.alpha, &e3); - e3z = driver.mul_with_public(&challenges.alpha, &e3z); + e3 = T::mul_with_public(e3, challenges.alpha); + e3z = T::mul_with_public(e3z, challenges.alpha); - let mut e4 = driver.add_with_public(&-P::ScalarField::one(), &z); - e4 = driver.mul_with_public(&zkey.lagrange[0].evaluations[i], &e4); - e4 = driver.mul_with_public(&challenges.alpha2, &e4); + let mut e4 = T::add_with_public(party_id, z, -P::ScalarField::one()); + e4 = T::mul_with_public(e4, zkey.lagrange[0].evaluations[i]); + e4 = T::mul_with_public(e4, challenges.alpha2); - let mut e4z = driver.mul_with_public(&zkey.lagrange[0].evaluations[i], &zp); - e4z = driver.mul_with_public(&challenges.alpha2, &e4z); + let mut e4z = T::mul_with_public(zp, zkey.lagrange[0].evaluations[i]); + e4z = T::mul_with_public(e4z, challenges.alpha2); - let mut t = driver.add(&e1[i], &e2); - t = driver.sub(&t, &e3); - t = driver.add(&t, &e4); + let mut t = T::add(e1[i], e2); + t = T::sub(t, e3); + t = T::add(t, e4); - let mut tz = driver.add(&e1z[i], &e2z); - tz = driver.sub(&tz, &e3z); - tz = driver.add(&tz, &e4z); + let mut tz = T::add(e1z[i], e2z); + tz = T::sub(tz, e3z); + tz = T::add(tz, e4z); t_vec.push(t); tz_vec.push(tz); }); - let mut coefficients_t = driver.ifft(&t_vec, &domains.extended_domain); + let mut coefficients_t = T::ifft(&t_vec, &domains.extended_domain); driver.neg_vec_in_place(&mut coefficients_t[..zkey.domain_size]); // We do not want to have any network operation in here to reduce MPC rounds. To enforce this, we have a for_each loop here (Network operations require a result) (zkey.domain_size..zkey.domain_size * 4).for_each(|i| { let a_lhs = &coefficients_t[i - zkey.domain_size]; let a_rhs = &coefficients_t[i]; - let a = driver.sub(&a_lhs, &a_rhs); + let a = T::sub(*a_lhs, *a_rhs); coefficients_t[i] = a; // Snarkjs is checking whether the poly was divisble by Zh, but we cannot do this here }); - let coefficients_tz = driver.ifft(&tz_vec, &domains.extended_domain); + let coefficients_tz = T::ifft(&tz_vec, &domains.extended_domain); let mut t_final = izip!(coefficients_t.iter(), coefficients_tz.iter()) - .map(|(lhs, rhs)| driver.add(lhs, rhs)); + .map(|(lhs, rhs)| T::add(*lhs, *rhs)); let mut t1 = Vec::with_capacity(zkey.domain_size + 1); let mut t2 = Vec::with_capacity(zkey.domain_size + 1); for _ in 0..zkey.domain_size { @@ -416,10 +431,10 @@ impl<'a, P: Pairing, T: CircomPlonkProver

> Round3<'a, P, T> { let mut t3 = t_final.take(zkey.domain_size + 6).collect::>(); t1.push(challenges.b[9].to_owned()); - t2[0] = driver.sub(&t2[0], &challenges.b[9]); + t2[0] = T::sub(t2[0], challenges.b[9]); t2.push(challenges.b[10].to_owned()); - t3[0] = driver.sub(&t3[0], &challenges.b[10]); + t3[0] = T::sub(t3[0], challenges.b[10]); tracing::debug!("computing t polynomial done!"); Ok([t1.into(), t2.into(), t3.into()]) } @@ -428,6 +443,7 @@ impl<'a, P: Pairing, T: CircomPlonkProver

> Round3<'a, P, T> { pub(super) fn round3(self) -> PlonkProofResult> { let Self { mut driver, + runtime, domains, challenges, proof, @@ -450,17 +466,19 @@ impl<'a, P: Pairing, T: CircomPlonkProver

> Round3<'a, P, T> { tracing::debug!("committing to poly t (MSMs)"); // Compute [T1]_1, [T2]_1, [T3]_1 - let commit_t1 = driver.msm_public_points(&data.zkey.p_tau[..t1.len()], &t1); - let commit_t2 = driver.msm_public_points(&data.zkey.p_tau[..t2.len()], &t2); - let commit_t3 = driver.msm_public_points(&data.zkey.p_tau[..t3.len()], &t3); + let commit_t1 = T::msm_public_points_g1(&data.zkey.p_tau[..t1.len()], &t1); + let commit_t2 = T::msm_public_points_g1(&data.zkey.p_tau[..t2.len()], &t2); + let commit_t3 = T::msm_public_points_g1(&data.zkey.p_tau[..t3.len()], &t3); - let opened = driver.open_point_many(&[commit_t1, commit_t2, commit_t3])?; + let opened = + runtime.block_on(driver.open_point_vec_g1(&[commit_t1, commit_t2, commit_t3]))?; let polys = FinalPolys::new(polys, t1, t2, t3); let proof = Round3Proof::new(proof, opened[0], opened[1], opened[2]); tracing::debug!("round3 result: {proof}"); Ok(Round4 { driver, + runtime, domains, challenges, proof, @@ -479,6 +497,7 @@ pub mod tests { use circom_types::plonk::ZKey; use circom_types::Witness; use co_circom_snarks::SharedWitness; + use tokio::runtime; use crate::{ mpc::plain::PlainPlonkDriver, @@ -513,7 +532,8 @@ pub mod tests { }; let challenges = Round1Challenges::deterministic(&mut driver); - let mut round1 = Round1::init_round(driver, &zkey, witness).unwrap(); + let runtime = runtime::Builder::new_current_thread().build().unwrap(); + let mut round1 = Round1::init_round(driver, runtime, &zkey, witness).unwrap(); round1.challenges = challenges; let round2 = round1.round1().unwrap(); let round3 = round2.round2().unwrap(); diff --git a/co-circom/co-plonk/src/round4.rs b/co-circom/co-plonk/src/round4.rs index a076b266..42366608 100644 --- a/co-circom/co-plonk/src/round4.rs +++ b/co-circom/co-plonk/src/round4.rs @@ -103,6 +103,7 @@ impl<'a, P: Pairing, T: CircomPlonkProver

> Round4<'a, P, T> { pub(super) fn round4(self) -> PlonkProofResult> { let Self { mut driver, + runtime, domains, challenges, proof, @@ -121,15 +122,15 @@ impl<'a, P: Pairing, T: CircomPlonkProver

> Round4<'a, P, T> { let challenges = Round4Challenges::new(challenges, xi); tracing::debug!("xi: {xi}"); tracing::debug!("evaluating poly a"); - let eval_a = driver.evaluate_poly_public(&polys.a.poly, challenges.xi); + let eval_a = T::evaluate_poly_public(&polys.a.poly, challenges.xi); tracing::debug!("evaluating poly b"); - let eval_b = driver.evaluate_poly_public(&polys.b.poly, challenges.xi); + let eval_b = T::evaluate_poly_public(&polys.b.poly, challenges.xi); tracing::debug!("evaluating poly c"); - let eval_c = driver.evaluate_poly_public(&polys.c.poly, challenges.xi); + let eval_c = T::evaluate_poly_public(&polys.c.poly, challenges.xi); tracing::debug!("evaluating poly z"); - let eval_z = driver.evaluate_poly_public(&polys.z.poly, xiw); + let eval_z = T::evaluate_poly_public(&polys.z.poly, xiw); - let opened = driver.open_many(&[eval_a, eval_b, eval_c, eval_z])?; + let opened = runtime.block_on(driver.open_vec(&[eval_a, eval_b, eval_c, eval_z]))?; let eval_a = opened[0]; let eval_b = opened[1]; @@ -143,6 +144,7 @@ impl<'a, P: Pairing, T: CircomPlonkProver

> Round4<'a, P, T> { Ok(Round5 { driver, + runtime, domains, challenges, proof, @@ -161,6 +163,7 @@ pub mod tests { use circom_types::plonk::ZKey; use circom_types::Witness; use co_circom_snarks::SharedWitness; + use tokio::runtime; use crate::{ mpc::plain::PlainPlonkDriver, @@ -185,7 +188,8 @@ pub mod tests { }; let challenges = Round1Challenges::deterministic(&mut driver); - let mut round1 = Round1::init_round(driver, &zkey, witness).unwrap(); + let runtime = runtime::Builder::new_current_thread().build().unwrap(); + let mut round1 = Round1::init_round(driver, runtime, &zkey, witness).unwrap(); round1.challenges = challenges; let round2 = round1.round1().unwrap(); let round3 = round2.round2().unwrap(); diff --git a/co-circom/co-plonk/src/round5.rs b/co-circom/co-plonk/src/round5.rs index 185533d0..b71d5da9 100644 --- a/co-circom/co-plonk/src/round5.rs +++ b/co-circom/co-plonk/src/round5.rs @@ -1,3 +1,5 @@ +use core::panic; + use crate::{ mpc::CircomPlonkProver, plonk_utils, @@ -82,21 +84,17 @@ where P::BaseField: CircomArkworksPrimeFieldBridge, P::ScalarField: CircomArkworksPrimeFieldBridge, { - fn div_by_zerofier( - driver: &mut T, - inout: &mut Vec, - n: usize, - beta: P::ScalarField, - ) { + fn div_by_zerofier(inout: &mut Vec, n: usize, beta: P::ScalarField) { let inv_beta = beta.inverse().expect("Highly unlikely to be zero"); let inv_beta_neg = -inv_beta; - for el in inout.iter_mut().take(n) { - *el = driver.mul_with_public(&inv_beta_neg, el); + #[allow(unused_mut)] + for mut el in inout.iter_mut().take(n) { + *el = T::mul_with_public(*el, inv_beta_neg); } for i in n..inout.len() { - let element = driver.sub(&inout[i - n], &inout[i]); - inout[i] = driver.mul_with_public(&inv_beta, &element); + let element = T::sub(inout[i - n], inout[i]); + inout[i] = T::mul_with_public(element, inv_beta); } // We cannot check whether the polyonmial is divisible by the zerofier, but we resize accordingly inout.resize(inout.len() - n, T::ArithmeticShare::default()); @@ -107,7 +105,8 @@ where inout.resize(add_poly.len(), P::ScalarField::zero()); } - for (inout, add) in inout.iter_mut().zip(add_poly.iter()) { + #[allow(unused_mut)] + for (mut inout, add) in inout.iter_mut().zip(add_poly.iter()) { *inout += *add; } } @@ -121,14 +120,15 @@ where inout.resize(add_poly.len(), P::ScalarField::zero()); } - for (inout, add) in inout.iter_mut().zip(add_poly.iter()) { + #[allow(unused_mut)] + for (mut inout, add) in inout.iter_mut().zip(add_poly.iter()) { *inout += *add * factor; } } // The linearisation polynomial R(X) (see https://eprint.iacr.org/2019/953.pdf) fn compute_r( - driver: &mut T, + party_id: T::PartyID, domains: &Domains, proof: &Round4Proof

, challenges: &Round5Challenges

, @@ -164,7 +164,8 @@ where let e24 = e2 + e4; let mut poly_r = zkey.qm_poly.coeffs.clone(); - for coeff in poly_r.iter_mut() { + #[allow(unused_mut)] + for mut coeff in poly_r.iter_mut() { *coeff *= coef_ab; } Self::add_factor_poly(&mut poly_r.coeffs, &zkey.ql_poly.coeffs, proof.eval_a); @@ -181,47 +182,54 @@ where let mut poly_r_shared = vec![T::ArithmeticShare::default(); len]; - for (inout, add) in poly_r_shared + #[allow(unused_mut)] + for (mut inout, add) in poly_r_shared .iter_mut() .zip(polys.z.poly.clone().into_iter()) { - *inout = driver.mul_with_public(&e24, &add) + *inout = T::mul_with_public(add, e24); } - for (inout, add) in poly_r_shared.iter_mut().zip(poly_r.iter()) { - *inout = driver.add_with_public(add, inout); + #[allow(unused_mut)] + for (mut inout, add) in poly_r_shared.iter_mut().zip(poly_r.iter()) { + *inout = T::add_with_public(party_id, *inout, *add); } let mut tmp_poly = vec![T::ArithmeticShare::default(); len]; let xin2 = xin.square(); - for (inout, add) in tmp_poly.iter_mut().zip(polys.t3.clone().into_iter()) { - *inout = driver.mul_with_public(&xin2, &add); + #[allow(unused_mut)] + for (mut inout, add) in tmp_poly.iter_mut().zip(polys.t3.clone().into_iter()) { + *inout = T::mul_with_public(add, xin2); } - for (inout, add) in tmp_poly.iter_mut().zip(polys.t2.clone().into_iter()) { - let tmp = driver.mul_with_public(&xin, &add); - *inout = driver.add(&tmp, inout); + #[allow(unused_mut)] + for (mut inout, add) in tmp_poly.iter_mut().zip(polys.t2.clone().into_iter()) { + let tmp = T::mul_with_public(add, xin); + *inout = T::add(*inout, tmp); } - for (inout, add) in tmp_poly.iter_mut().zip(polys.t1.clone().into_iter()) { - *inout = driver.add(inout, &add); + #[allow(unused_mut)] + for (mut inout, add) in tmp_poly.iter_mut().zip(polys.t1.clone().into_iter()) { + *inout = T::add(*inout, add); } - for inout in tmp_poly.iter_mut() { - *inout = driver.mul_with_public(&zh, inout); + #[allow(unused_mut)] + for mut inout in tmp_poly.iter_mut() { + *inout = T::mul_with_public(*inout, zh); } - for (inout, sub) in poly_r_shared.iter_mut().zip(tmp_poly.iter()) { - *inout = driver.sub(inout, sub); + #[allow(unused_mut)] + for (mut inout, sub) in poly_r_shared.iter_mut().zip(tmp_poly.iter()) { + *inout = T::sub(*inout, *sub); } let r0 = eval_pi - (e3 * (proof.eval_c + challenges.gamma)) - e4; - poly_r_shared[0] = driver.add_with_public(&r0, &poly_r_shared[0]); + poly_r_shared[0] = T::add_with_public(party_id, poly_r_shared[0], r0); tracing::debug!("computing r polynomial done!"); - poly_r_shared.into() + poly_r_shared } // The opening proof polynomial W_xi(X) (see https://eprint.iacr.org/2019/953.pdf) fn compute_wxi( - driver: &mut T, + party_id: T::PartyID, proof: &Round4Proof

, challenges: &Round5Challenges

, data: &PlonkData, @@ -234,43 +242,43 @@ where let mut res = vec![T::ArithmeticShare::default(); data.zkey.domain_size + 6]; // R - for (inout, add) in res.iter_mut().zip(poly_r.clone().into_iter()) { + for (inout, add) in res.iter_mut().zip(poly_r.into_iter()) { *inout = add.clone(); } // A for (inout, add) in res.iter_mut().zip(polys.a.poly.clone().into_iter()) { - let tmp = driver.mul_with_public(&challenges.v[0], &add); - *inout = driver.add(&tmp, inout); + let tmp = T::mul_with_public(add, challenges.v[0]); + *inout = T::add(tmp, *inout); } // B for (inout, add) in res.iter_mut().zip(polys.b.poly.clone().into_iter()) { - let tmp = driver.mul_with_public(&challenges.v[1], &add); - *inout = driver.add(&tmp, inout); + let tmp = T::mul_with_public(add, challenges.v[1]); + *inout = T::add(tmp, *inout); } // C for (inout, add) in res.iter_mut().zip(polys.c.poly.clone().into_iter()) { - let tmp = driver.mul_with_public(&challenges.v[2], &add); - *inout = driver.add(&tmp, inout); + let tmp = T::mul_with_public(add, challenges.v[2]); + *inout = T::add(tmp, *inout); } // Sigma1 for (inout, add) in res.iter_mut().zip(s1_poly_coeffs.iter()) { - *inout = driver.add_with_public(&(challenges.v[3] * add), inout); + *inout = T::add_with_public(party_id, *inout, challenges.v[3] * add); } // Sigma2 for (inout, add) in res.iter_mut().zip(s2_poly_coeffs.iter()) { - *inout = driver.add_with_public(&(challenges.v[4] * add), inout); + *inout = T::add_with_public(party_id, *inout, challenges.v[4] * add); } - res[0] = driver.add_with_public(&-(challenges.v[0] * proof.eval_a), &res[0]); - res[0] = driver.add_with_public(&-(challenges.v[1] * proof.eval_b), &res[0]); - res[0] = driver.add_with_public(&-(challenges.v[2] * proof.eval_c), &res[0]); - res[0] = driver.add_with_public(&-(challenges.v[3] * proof.eval_s1), &res[0]); - res[0] = driver.add_with_public(&-(challenges.v[4] * proof.eval_s2), &res[0]); + res[0] = T::add_with_public(party_id, res[0], -challenges.v[0] * proof.eval_a); + res[0] = T::add_with_public(party_id, res[0], -challenges.v[1] * proof.eval_b); + res[0] = T::add_with_public(party_id, res[0], -challenges.v[2] * proof.eval_c); + res[0] = T::add_with_public(party_id, res[0], -challenges.v[3] * proof.eval_s1); + res[0] = T::add_with_public(party_id, res[0], -challenges.v[4] * proof.eval_s2); - Self::div_by_zerofier(driver, &mut res, 1, challenges.xi); + Self::div_by_zerofier(&mut res, 1, challenges.xi); tracing::debug!("computing wxi polynomial done!"); - res.into() + res } // The opening proof polynomial W_xiw(X) (see https://eprint.iacr.org/2019/953.pdf) @@ -285,8 +293,8 @@ where let xiw = challenges.xi * domains.root_of_unity_pow; let mut res = polys.z.poly.clone().into_iter().collect::>(); - res[0] = driver.add_with_public(&-proof.eval_zw, &res[0]); - Self::div_by_zerofier(driver, &mut res, 1, xiw); + res[0] = T::add_with_public(driver.get_party_id(), res[0], -proof.eval_zw); + Self::div_by_zerofier(&mut res, 1, xiw); tracing::debug!("computing wxiw polynomial done!"); res.into() @@ -296,6 +304,7 @@ where pub(super) fn round5(self) -> PlonkProofResult> { let Self { mut driver, + runtime, domains, challenges, proof, @@ -324,21 +333,22 @@ where tracing::debug!("v[3]: {}", v[3]); tracing::debug!("v[4]: {}", v[4]); let challenges = Round5Challenges::new(challenges, v); + let party_id = driver.get_party_id(); // STEP 5.2 Compute linearisation polynomial r(X) - let r = Self::compute_r(&mut driver, &domains, &proof, &challenges, &data, &polys); + let r = Self::compute_r(party_id, &domains, &proof, &challenges, &data, &polys); //STEP 5.3 Compute opening proof polynomial Wxi(X) - let wxi = Self::compute_wxi(&mut driver, &proof, &challenges, &data, &polys, &r); + let wxi = Self::compute_wxi(party_id, &proof, &challenges, &data, &polys, &r); //STEP 5.4 Compute opening proof polynomial Wxiw(X) let wxiw = Self::compute_wxiw(&mut driver, &domains, &proof, &challenges, &polys); // Fifth output of the prover is ([Wxi]_1, [Wxiw]_1) let p_tau = &data.zkey.p_tau; - let commit_wxi = driver.msm_public_points(&p_tau[..wxi.len()], &wxi); - let commit_wxiw = driver.msm_public_points(&p_tau[..wxiw.len()], &wxiw); + let commit_wxi = T::msm_public_points_g1(&p_tau[..wxi.len()], &wxi); + let commit_wxiw = T::msm_public_points_g1(&p_tau[..wxiw.len()], &wxiw); - let opened = driver.open_point_many(&[commit_wxi, commit_wxiw])?; + let opened = runtime.block_on(driver.open_point_vec_g1(&[commit_wxi, commit_wxiw]))?; let commit_wxi: P::G1 = opened[0]; let commit_wxiw: P::G1 = opened[1]; @@ -360,6 +370,7 @@ pub mod tests { use circom_types::plonk::ZKey; use circom_types::Witness; use co_circom_snarks::SharedWitness; + use tokio::runtime; use crate::{ mpc::plain::PlainPlonkDriver, @@ -393,7 +404,8 @@ pub mod tests { }; let challenges = Round1Challenges::deterministic(&mut driver); - let mut round1 = Round1::init_round(driver, &zkey, witness).unwrap(); + let runtime = runtime::Builder::new_current_thread().build().unwrap(); + let mut round1 = Round1::init_round(driver, runtime, &zkey, witness).unwrap(); round1.challenges = challenges; let round2 = round1.round1().unwrap(); let round3 = round2.round2().unwrap(); diff --git a/mpc-core/src/protocols/rep3/arithmetic.rs b/mpc-core/src/protocols/rep3/arithmetic.rs index dbd8be89..45f10b31 100644 --- a/mpc-core/src/protocols/rep3/arithmetic.rs +++ b/mpc-core/src/protocols/rep3/arithmetic.rs @@ -1,7 +1,7 @@ use ark_ff::PrimeField; use itertools::{izip, Itertools}; use num_bigint::BigUint; -use num_traits::{PrimInt, Zero}; +use num_traits::Zero; use types::Rep3PrimeFieldShare; use crate::protocols::rep3::{detail, id::PartyID, network::Rep3Network}; @@ -210,7 +210,7 @@ pub async fn open( /// Performs the opening of a shared value and returns the equivalent public value. pub async fn open_vec( - a: Vec>, + a: &[FieldShare], io_context: &mut IoContext, ) -> IoResult> { // TODO think about something better... it is not so bad diff --git a/mpc-core/src/protocols/rep3/detail.rs b/mpc-core/src/protocols/rep3/detail.rs index 07421dd3..df5360e3 100644 --- a/mpc-core/src/protocols/rep3/detail.rs +++ b/mpc-core/src/protocols/rep3/detail.rs @@ -5,7 +5,6 @@ use ark_ff::PrimeField; use ark_ff::Zero; use num_bigint::BigUint; -use crate::protocols::rep3::id::PartyID; use crate::protocols::rep3::network::Rep3Network; use super::binary; diff --git a/tests/src/rep3_network.rs b/tests/src/rep3_network.rs index f26e0094..fa138e8c 100644 --- a/tests/src/rep3_network.rs +++ b/tests/src/rep3_network.rs @@ -125,7 +125,18 @@ impl Rep3Network for PartyTestNetwork { if next.len() != 1 || prev.len() != 1 { panic!("got more than one from next or prev"); } - Ok((next.pop().unwrap(), prev.pop().unwrap())) + Ok((prev.pop().unwrap(), next.pop().unwrap())) + } + + async fn broadcast_many( + &mut self, + data: &[F], + ) -> std::io::Result<(Vec, Vec)> { + self.send_many(self.id.next_id(), &data).await?; + self.send_many(self.id.prev_id(), &data).await?; + let prev = self.recv_many(self.id.prev_id()).await?; + let next = self.recv_many(self.id.next_id()).await?; + Ok((prev, next)) } async fn send_many(