diff --git a/.gitignore b/.gitignore index fb04d373..5d3170ae 100644 --- a/.gitignore +++ b/.gitignore @@ -15,4 +15,5 @@ Cargo.lock .vscode co-noir/co-acvm/src/bin/ +co-circom/co-plonk/src/bin/ groth16_tester.rs diff --git a/co-circom/co-plonk/Cargo.toml b/co-circom/co-plonk/Cargo.toml index ae0eb1c2..7fc06e26 100644 --- a/co-circom/co-plonk/Cargo.toml +++ b/co-circom/co-plonk/Cargo.toml @@ -25,10 +25,15 @@ mpc-core = { workspace = true } mpc-net = { workspace = true } num-traits = { workspace = true } rand = { workspace = true } +rayon = { workspace = true } sha3 = { workspace = true } thiserror = { workspace = true } tracing = { workspace = true } tokio = { workspace = true } +# DELETE ME +tracing-subscriber.workspace = true +serde_json = { workspace = true } +ark-bn254 = { workspace = true } [dev-dependencies] ark-bls12-381 = { workspace = true } diff --git a/co-circom/co-plonk/src/lib.rs b/co-circom/co-plonk/src/lib.rs index 8e2db90d..93680af2 100644 --- a/co-circom/co-plonk/src/lib.rs +++ b/co-circom/co-plonk/src/lib.rs @@ -79,7 +79,18 @@ where zkey: &ZKey

, witness: SharedWitness, ) -> PlonkProofResult> { - tracing::debug!("starting PLONK prove.."); + tracing::debug!("starting PLONK prove!"); + tracing::debug!( + "we have {} constraints and {} addition constraints", + zkey.n_constraints, + zkey.n_additions + ); + tracing::debug!("the domain size is {}", zkey.domain_size); + tracing::debug!( + "we have {} n_vars and {} public inputs", + zkey.n_vars, + zkey.n_public + ); let state = Round1::init_round(self.driver, self.runtime, zkey, witness)?; tracing::debug!("init round done.."); let state = state.round1()?; diff --git a/co-circom/co-plonk/src/mpc.rs b/co-circom/co-plonk/src/mpc.rs index 0eaddd97..69e6d07a 100644 --- a/co-circom/co-plonk/src/mpc.rs +++ b/co-circom/co-plonk/src/mpc.rs @@ -25,6 +25,9 @@ pub trait CircomPlonkProver { fn rand(&mut self) -> Self::ArithmeticShare; fn get_party_id(&self) -> Self::PartyID; + + fn fork(&mut self) -> Self; + /// Subtract the share b from the share a: \[c\] = \[a\] - \[b\] fn add(a: Self::ArithmeticShare, b: Self::ArithmeticShare) -> Self::ArithmeticShare; @@ -53,6 +56,13 @@ pub trait CircomPlonkProver { b: &[Self::ArithmeticShare], ) -> impl Future>>; + fn mul_vecs( + &mut self, + a: &[Self::ArithmeticShare], + b: &[Self::ArithmeticShare], + c: &[Self::ArithmeticShare], + ) -> impl Future>>; + /// Convenience method for \[a\] + \[b\] * \[c\] fn add_mul_vec( &mut self, diff --git a/co-circom/co-plonk/src/mpc/plain.rs b/co-circom/co-plonk/src/mpc/plain.rs index b80aae6a..3399944f 100644 --- a/co-circom/co-plonk/src/mpc/plain.rs +++ b/co-circom/co-plonk/src/mpc/plain.rs @@ -37,6 +37,10 @@ impl CircomPlonkProver

for PlainPlonkDriver { 0 } + fn fork(&mut self) -> Self { + PlainPlonkDriver + } + fn add(a: Self::ArithmeticShare, b: Self::ArithmeticShare) -> Self::ArithmeticShare { a + b } @@ -75,6 +79,15 @@ impl CircomPlonkProver

for PlainPlonkDriver { Ok(izip!(a, b).map(|(a, b)| *a * *b).collect()) } + async fn mul_vecs( + &mut self, + a: &[Self::ArithmeticShare], + b: &[Self::ArithmeticShare], + c: &[Self::ArithmeticShare], + ) -> IoResult> { + Ok(izip!(a, b, c).map(|(a, b, c)| *a * *b * *c).collect()) + } + async fn add_mul_vec( &mut self, a: &[Self::ArithmeticShare], diff --git a/co-circom/co-plonk/src/mpc/rep3.rs b/co-circom/co-plonk/src/mpc/rep3.rs index 0232c620..e17f2dc6 100644 --- a/co-circom/co-plonk/src/mpc/rep3.rs +++ b/co-circom/co-plonk/src/mpc/rep3.rs @@ -26,7 +26,7 @@ impl CircomPlonkProver

for Rep3PlonkDriver { type PartyID = PartyID; - fn debug_print(a: Self::ArithmeticShare) { + fn debug_print(_: Self::ArithmeticShare) { todo!() } @@ -38,6 +38,10 @@ impl CircomPlonkProver

for Rep3PlonkDriver { self.io_context.id } + fn fork(&mut self) -> Self { + todo!() + } + fn add(a: Self::ArithmeticShare, b: Self::ArithmeticShare) -> Self::ArithmeticShare { rep3::arithmetic::add(a, b) } @@ -76,6 +80,15 @@ impl CircomPlonkProver

for Rep3PlonkDriver { rep3::arithmetic::mul_vec(lhs, rhs, &mut self.io_context).await } + async fn mul_vecs( + &mut self, + _a: &[Self::ArithmeticShare], + _b: &[Self::ArithmeticShare], + _c: &[Self::ArithmeticShare], + ) -> IoResult> { + todo!(); + } + async fn add_mul_vec( &mut self, a: &[Self::ArithmeticShare], diff --git a/co-circom/co-plonk/src/round1.rs b/co-circom/co-plonk/src/round1.rs index 9d6a3669..25a86fb1 100644 --- a/co-circom/co-plonk/src/round1.rs +++ b/co-circom/co-plonk/src/round1.rs @@ -3,6 +3,7 @@ use ark_ec::CurveGroup; use circom_types::plonk::ZKey; use co_circom_snarks::SharedWitness; use tokio::runtime::Runtime; +use tracing::instrument; use crate::{ mpc::CircomPlonkProver, @@ -92,67 +93,89 @@ impl> Round1Challenges { // Round 1 of https://eprint.iacr.org/2019/953.pdf (page 28) impl<'a, P: Pairing, T: CircomPlonkProver

> Round1<'a, P, T> { - // Essentially the fft of the trace columns - fn compute_wire_polynomials( - driver: &mut T, + fn compute_single_wire_poly( + party_id: T::PartyID, + witness: &PlonkWitness, domains: &Domains, - challenges: &Round1Challenges, + blind_factors: &[T::ArithmeticShare], zkey: &ZKey

, - witness: &PlonkWitness, - ) -> PlonkProofResult> { - let party_id = driver.get_party_id(); - tracing::debug!("computing wire polynomials..."); - let num_constraints = zkey.n_constraints; - - let mut buffer_a = Vec::with_capacity(zkey.domain_size); - let mut buffer_b = Vec::with_capacity(zkey.domain_size); - let mut buffer_c = Vec::with_capacity(zkey.domain_size); - - for i in 0..num_constraints { - buffer_a.push(plonk_utils::get_witness( - party_id, - witness, - zkey, - zkey.map_a[i], - )?); - buffer_b.push(plonk_utils::get_witness( - party_id, - witness, - zkey, - zkey.map_b[i], - )?); - buffer_c.push(plonk_utils::get_witness( - party_id, - witness, - zkey, - zkey.map_c[i], - )?); + map: &[usize], + ) -> PlonkProofResult<(Vec, PolyEval)> { + let mut buffer = Vec::with_capacity(zkey.n_constraints); + for i in 0..zkey.n_constraints { + match plonk_utils::get_witness(party_id, witness, zkey, map[i]) { + Ok(witness) => buffer.push(witness), + Err(err) => return Err(err), + } } - buffer_a.resize(zkey.domain_size, T::ArithmeticShare::default()); - buffer_b.resize(zkey.domain_size, T::ArithmeticShare::default()); - buffer_c.resize(zkey.domain_size, T::ArithmeticShare::default()); - - //TODO MULTITHREAD ME - tracing::debug!("iffts for buffers.."); + buffer.resize(zkey.domain_size, T::ArithmeticShare::default()); // Compute the coefficients of the wire polynomials a(X), b(X) and c(X) from A,B & C buffers - let poly_a = T::ifft(&buffer_a, &domains.domain); - let poly_b = T::ifft(&buffer_b, &domains.domain); - let poly_c = T::ifft(&buffer_c, &domains.domain); + let poly = T::ifft(&buffer, &domains.domain); tracing::debug!("ffts for evals.."); // Compute extended evaluations of a(X), b(X) and c(X) polynomials - let eval_a = T::fft(&poly_a, &domains.extended_domain); - let eval_b = T::fft(&poly_b, &domains.extended_domain); - let eval_c = T::fft(&poly_c, &domains.extended_domain); + let eval = T::fft(&poly, &domains.extended_domain); tracing::debug!("blinding coefficients"); - let poly_a = plonk_utils::blind_coefficients::(&poly_a, &challenges.b[..2]); - let poly_b = plonk_utils::blind_coefficients::(&poly_b, &challenges.b[2..4]); - let poly_c = plonk_utils::blind_coefficients::(&poly_c, &challenges.b[4..6]); + let poly = plonk_utils::blind_coefficients::(&poly, blind_factors); + Ok((buffer, PolyEval { poly, eval })) + } - if poly_a.len() > zkey.domain_size + 2 - || poly_b.len() > zkey.domain_size + 2 - || poly_c.len() > zkey.domain_size + 2 + // Essentially the fft of the trace columns + #[instrument(level = "debug", name = "compute wire polys", skip_all)] + fn compute_wire_polynomials( + driver: &mut T, + domains: &Domains, + challenges: &Round1Challenges, + zkey: &ZKey

, + witness: &PlonkWitness, + ) -> PlonkProofResult> { + let party_id = driver.get_party_id(); + + let mut wire_a = None; + let mut wire_b = None; + let mut wire_c = None; + + rayon::scope(|s| { + s.spawn(|_| { + wire_a = Some(Self::compute_single_wire_poly( + party_id, + witness, + domains, + &challenges.b[..2], + zkey, + &zkey.map_a, + )) + }); + s.spawn(|_| { + wire_b = Some(Self::compute_single_wire_poly( + party_id, + witness, + domains, + &challenges.b[2..4], + zkey, + &zkey.map_b, + )) + }); + s.spawn(|_| { + wire_c = Some(Self::compute_single_wire_poly( + party_id, + witness, + domains, + &challenges.b[4..6], + zkey, + &zkey.map_c, + )) + }); + }); + // we have some values as rayon scope finished + let (buffer_a, poly_a) = wire_a.unwrap()?; + let (buffer_b, poly_b) = wire_b.unwrap()?; + let (buffer_c, poly_c) = wire_c.unwrap()?; + + if poly_a.poly.len() > zkey.domain_size + 2 + || poly_b.poly.len() > zkey.domain_size + 2 + || poly_c.poly.len() > zkey.domain_size + 2 { return Err(PlonkProofError::PolynomialDegreeTooLarge); } @@ -161,31 +184,25 @@ impl<'a, P: Pairing, T: CircomPlonkProver

> Round1<'a, P, T> { buffer_a, buffer_b, buffer_c, - a: PolyEval { - poly: poly_a.into(), - eval: eval_a, - }, - b: PolyEval { - poly: poly_b.into(), - eval: eval_b, - }, - c: PolyEval { - poly: poly_c.into(), - eval: eval_c, - }, + a: poly_a, + b: poly_b, + c: poly_c, }) } // Calculate the witnesses for the additions, since they are not part of the SharedWitness + #[instrument(level = "debug", name = "calculate additions", skip_all)] fn calculate_additions( driver: &mut T, witness: SharedWitness, zkey: &ZKey

, ) -> PlonkProofResult> { - tracing::debug!("calculating addition {} constraints...", zkey.n_additions); let party_id = driver.get_party_id(); let mut witness = PlonkWitness::new(witness, zkey.n_additions); - + // This is hard to multithread as we have to add the results + // to the vec as they are needed for the later steps. + // We leave it like that as it does not take to much time (<1ms for poseidon). + // Keep an eye on the span duration, maybe we have to come back to that later. for addition in zkey.additions.iter() { let witness1 = plonk_utils::get_witness( party_id, @@ -205,10 +222,10 @@ impl<'a, P: Pairing, T: CircomPlonkProver

> Round1<'a, P, T> { let result = T::add(f1, f2); witness.addition_witness.push(result); } - tracing::debug!("additions done!"); Ok(witness) } + #[instrument(level = "debug", name = "Plonk - Round Init", skip_all)] pub(super) fn init_round( mut driver: T, runtime: Runtime, @@ -216,12 +233,13 @@ impl<'a, P: Pairing, T: CircomPlonkProver

> Round1<'a, P, T> { private_witness: SharedWitness, ) -> PlonkProofResult { let plonk_witness = Self::calculate_additions(&mut driver, private_witness, zkey)?; - + let challenges = Round1Challenges::random(&mut driver)?; + let domains = Domains::new(zkey.domain_size)?; Ok(Self { - challenges: Round1Challenges::random(&mut driver)?, + challenges, driver, runtime, - domains: Domains::new(zkey.domain_size)?, + domains, data: PlonkDataRound1 { witness: plonk_witness, zkey, @@ -229,6 +247,7 @@ impl<'a, P: Pairing, T: CircomPlonkProver

> Round1<'a, P, T> { }) } + #[instrument(level = "debug", name = "Plonk - Round 1", skip_all)] // Round 1 of https://eprint.iacr.org/2019/953.pdf (page 28) pub(super) fn round1(self) -> PlonkProofResult> { let Self { @@ -246,14 +265,35 @@ impl<'a, P: Pairing, T: CircomPlonkProver

> Round1<'a, P, T> { let polys = Self::compute_wire_polynomials(&mut driver, &domains, &challenges, zkey, witness)?; - tracing::debug!("committing to polys (MSMs)"); + let mut commit_a = None; + let mut commit_b = None; + let mut commit_c = None; + let commit_span = tracing::debug_span!("committing to polys (MSMs)").entered(); // STEP 1.3 - Compute [a]_1, [b]_1, [c]_1 - let commit_a = T::msm_public_points_g1(&p_tau[..polys.a.poly.len()], &polys.a.poly); - let commit_b = T::msm_public_points_g1(&p_tau[..polys.b.poly.len()], &polys.b.poly); - let commit_c = T::msm_public_points_g1(&p_tau[..polys.c.poly.len()], &polys.c.poly); - + rayon::scope(|s| { + s.spawn(|_| { + let result = T::msm_public_points_g1(&p_tau[..polys.a.poly.len()], &polys.a.poly); + commit_a = Some(result); + }); + s.spawn(|_| { + let result = T::msm_public_points_g1(&p_tau[..polys.b.poly.len()], &polys.b.poly); + commit_b = Some(result); + }); + s.spawn(|_| { + let result = T::msm_public_points_g1(&p_tau[..polys.c.poly.len()], &polys.c.poly); + commit_c = Some(result); + }); + }); + // rayon scope must be done therefore some values + let commit_a = commit_a.unwrap(); + let commit_b = commit_b.unwrap(); + let commit_c = commit_c.unwrap(); + + // network round + commit_span.exit(); + let opening_span = tracing::debug_span!("opening commits").entered(); let opened = runtime.block_on(driver.open_point_vec_g1(&[commit_a, commit_b, commit_c]))?; - + opening_span.exit(); let proof = Round1Proof::

{ commit_a: opened[0], commit_b: opened[1], diff --git a/co-circom/co-plonk/src/round2.rs b/co-circom/co-plonk/src/round2.rs index 7e7ee3ea..ef64cea3 100644 --- a/co-circom/co-plonk/src/round2.rs +++ b/co-circom/co-plonk/src/round2.rs @@ -9,8 +9,10 @@ use crate::{ use ark_ec::pairing::Pairing; use ark_ec::CurveGroup; use circom_types::plonk::ZKey; +use futures::executor::block_on; use num_traits::One; use tokio::runtime::Runtime; +use tracing::instrument; // To reduce the number of communication rounds, we implement the array_prod_mul macro according to https://www.usenix.org/system/files/sec22-ozdemir.pdf, p11 first paragraph. // TODO parallelize these? With a different network structure this might not be needed though @@ -113,8 +115,42 @@ impl> Round2Polys { // Round 2 of https://eprint.iacr.org/2019/953.pdf (page 28) impl<'a, P: Pairing, T: CircomPlonkProver

> Round2<'a, P, T> { + async fn array_prod_mul( + driver: &mut T, + inv: bool, + arr1: &[T::ArithmeticShare], + arr2: &[T::ArithmeticShare], + arr3: &[T::ArithmeticShare], + ) -> PlonkProofResult> { + let arr = driver.mul_vecs(arr1, arr2, arr3).await?; + // Do the multiplications of inp[i] * inp[i-1] in constant rounds + let len = arr.len(); + let r = (0..=len).map(|_| driver.rand()).collect::>(); + let r_inv = driver.inv_vec(&r).await?; + let r_inv0 = vec![r_inv[0].clone(); len]; + let mut unblind = driver.mul_vec(&r_inv0, &r[1..]).await?; + + let mul = driver.mul_vec(&r[..len], &arr).await?; + let mut open = driver.mul_open_vec(&mul, &r_inv[1..]).await?; + + for i in 1..open.len() { + open[i] = open[i] * open[i - 1]; + } + + #[allow(unused_mut)] + for (mut unblind, open) in unblind.iter_mut().zip(open.into_iter()) { + *unblind = T::mul_with_public(*unblind, open); + } + if inv { + Ok(driver.inv_vec(&unblind).await?) + } else { + Ok(unblind) + } + } + // Computes the permutation polynomial z(X) (see https://eprint.iacr.org/2019/953.pdf) // To reduce the number of communication rounds, we implement the array_prod_mul macro according to https://www.usenix.org/system/files/sec22-ozdemir.pdf, p11 first paragraph. + #[instrument(level = "info", name = "compute z", skip_all)] fn compute_z( driver: &mut T, runtime: &mut Runtime, @@ -123,9 +159,7 @@ impl<'a, P: Pairing, T: CircomPlonkProver

> Round2<'a, P, T> { challenges: &Round2Challenges, polys: &Round1Polys, ) -> PlonkProofResult> { - tracing::debug!("computing z polynomial..."); let pow_root_of_unity = domains.root_of_unity_pow; - let mut w = P::ScalarField::one(); let mut n1 = Vec::with_capacity(zkey.domain_size); let mut n2 = Vec::with_capacity(zkey.domain_size); let mut n3 = Vec::with_capacity(zkey.domain_size); @@ -133,6 +167,10 @@ impl<'a, P: Pairing, T: CircomPlonkProver

> Round2<'a, P, T> { let mut d2 = Vec::with_capacity(zkey.domain_size); let mut d3 = Vec::with_capacity(zkey.domain_size); let party_id = driver.get_party_id(); + let mut w = P::ScalarField::one(); + // TODO: multithread me - this is not so easy as other + // parts as we go through the roots of unity but it is doable + let num_den_span = tracing::info_span!("compute num/den").entered(); for i in 0..zkey.domain_size { let a = &polys.buffer_a[i]; let b = &polys.buffer_b[i]; @@ -186,23 +224,26 @@ impl<'a, P: Pairing, T: CircomPlonkProver

> Round2<'a, P, T> { w *= &pow_root_of_unity; } - // TODO parallelize these? With a different network structure this might not be needed though - let num = runtime.block_on(driver.mul_vec(&n1, &n2))?; - let num = runtime.block_on(driver.mul_vec(&num, &n3))?; - let den = runtime.block_on(driver.mul_vec(&d1, &d2))?; - let den = runtime.block_on(driver.mul_vec(&den, &d3))?; - - // TODO parallelize these? With a different network structure this might not be needed though - // Do the multiplications of num[i] * num[i-1] and den[i] * den[i-1] in constant rounds - let num = array_prod_mul!(driver, num); - let den = array_prod_mul!(driver, den); + num_den_span.exit(); + let batched_mul_span = tracing::info_span!("buffer z network round").entered(); + let mut forked0 = driver.fork(); + let mut forked1 = driver.fork(); + // TODO: This is super bad atm. There is potentially some heavy + // work involved from the muliplications, but also a lot of networking. + // Maybe we need a better mul implementation for that! + let (num, den) = runtime.block_on(async { + tokio::join!( + Self::array_prod_mul(&mut forked0, false, &n1, &n2, &n3), + Self::array_prod_mul(&mut forked1, true, &d1, &d2, &d3), + ) + }); + let num = num?; + let den = den?; - // Compute the inverse of denArr to compute in the next command the - // division numArr/denArr by multiplying num ยท 1/denArr - let den = futures::executor::block_on(driver.inv_vec(&den))?; - let mut buffer_z = futures::executor::block_on(driver.mul_vec(&num, &den))?; + let mut buffer_z = runtime.block_on(driver.mul_vec(&num, &den))?; buffer_z.rotate_right(1); // Required by SNARKJs/Plonk + batched_mul_span.exit(); // Compute polynomial coefficients z(X) from buffer_z let poly_z = T::ifft(&buffer_z, &domains.domain); @@ -224,6 +265,7 @@ impl<'a, P: Pairing, T: CircomPlonkProver

> Round2<'a, P, T> { } // Round 2 of https://eprint.iacr.org/2019/953.pdf (page 28) + #[instrument(level = "info", name = "Plonk - Round 2", skip_all)] pub(super) fn round2(self) -> PlonkProofResult> { let Self { mut driver, @@ -275,7 +317,6 @@ impl<'a, P: Pairing, T: CircomPlonkProver

> Round2<'a, P, T> { let commit_z = runtime.block_on(driver.open_point_g1(commit_z))?; let proof = Round2Proof::new(proof, commit_z); tracing::debug!("round2 result: {proof}"); - Ok(Round3 { driver, runtime, diff --git a/co-circom/co-plonk/src/round3.rs b/co-circom/co-plonk/src/round3.rs index e2e024ea..5b2ffba7 100644 --- a/co-circom/co-plonk/src/round3.rs +++ b/co-circom/co-plonk/src/round3.rs @@ -217,6 +217,7 @@ impl<'a, P: Pairing, T: CircomPlonkProver

> Round3<'a, P, T> { polys: &Round2Polys, ) -> PlonkProofResult<[Vec; 3]> { tracing::debug!("computing t polynomial..."); + tracing::info!("lzul"); let z1 = Self::get_z1(domains); let z2 = Self::get_z2(domains); let z3 = Self::get_z3(domains);