Skip to content

Commit

Permalink
plonk plain working again
Browse files Browse the repository at this point in the history
  • Loading branch information
0xThemis committed Sep 12, 2024
1 parent c06e0a0 commit 7113108
Show file tree
Hide file tree
Showing 11 changed files with 281 additions and 192 deletions.
8 changes: 6 additions & 2 deletions co-circom/circom-mpc-vm/src/mpc/rep3.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,12 +103,16 @@ impl<F: PrimeField, N: Rep3Network> VmCircomWitnessExtension<F>
}
(Rep3VmType::Public(b), Rep3VmType::Binary(a))
| (Rep3VmType::Binary(a), Rep3VmType::Public(b)) => {
let a = futures::executor::block_on(conversion::b2a(&a, &mut self.io_context))?;
let a = self
.runtime
.block_on(conversion::b2a(&a, &mut self.io_context))?;
Ok(arithmetic::add_public(a, b, self.io_context.id).into())
}
(Rep3VmType::Arithmetic(a), Rep3VmType::Binary(b))
| (Rep3VmType::Binary(b), Rep3VmType::Arithmetic(a)) => {
let b = futures::executor::block_on(conversion::b2a(&b, &mut self.io_context))?;
let b = self
.runtime
.block_on(conversion::b2a(&b, &mut self.io_context))?;
Ok(arithmetic::add(a, b).into())
}
(Rep3VmType::Binary(a), Rep3VmType::Binary(b)) => {
Expand Down
30 changes: 15 additions & 15 deletions co-circom/co-plonk/src/mpc.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::{future::Future, process::Output};

use ark_ec::{pairing::Pairing, CurveGroup};
use ark_poly::EvaluationDomain;
use ark_serialize::{CanonicalDeserialize, CanonicalSerialize};
Expand All @@ -18,12 +20,11 @@ pub trait CircomPlonkProver<P: Pairing> {

type PartyID: Send + Sync + Copy;

fn rand(&mut self) -> Self::ArithmeticShare;
fn debug_print(a: Self::ArithmeticShare);

fn get_party_id(&self) -> Self::PartyID {
self.io_context.id
}
fn rand(&mut self) -> Self::ArithmeticShare;

fn get_party_id(&self) -> Self::PartyID;
/// Subtract the share b from the share a: \[c\] = \[a\] - \[b\]
fn add(a: Self::ArithmeticShare, b: Self::ArithmeticShare) -> Self::ArithmeticShare;

Expand All @@ -46,29 +47,28 @@ pub trait CircomPlonkProver<P: Pairing> {
public: P::ScalarField,
) -> Self::ArithmeticShare;

async fn mul_vec(
fn mul_vec(
&mut self,
a: &[Self::ArithmeticShare],
b: &[Self::ArithmeticShare],
) -> IoResult<Vec<Self::ArithmeticShare>>;
) -> impl Future<Output = IoResult<Vec<Self::ArithmeticShare>>>;

/// Convenience method for \[a\] + \[b\] * \[c\]
async fn add_mul_vec(
fn add_mul_vec(
&mut self,
a: &[Self::ArithmeticShare],
b: &[Self::ArithmeticShare],
c: &[Self::ArithmeticShare],
) -> IoResult<Vec<Self::ArithmeticShare>>;
) -> impl Future<Output = IoResult<Vec<Self::ArithmeticShare>>>;

/// Convenience method for \[a\] + \[b\] * c
fn add_mul_public(
&mut self,
a: &Self::ArithmeticShare,
b: &Self::ArithmeticShare,
c: &P::ScalarField,
a: Self::ArithmeticShare,
b: Self::ArithmeticShare,
c: P::ScalarField,
) -> Self::ArithmeticShare {
let tmp = self.mul_with_public(c, b);
self.add(a, &tmp)
Self::add(a, Self::mul_with_public(b, c))
}

/// This function performs a multiplication directly followed by an opening. This safes one round of communication in some MPC protocols compared to calling `mul` and `open` separately.
Expand All @@ -79,7 +79,7 @@ pub trait CircomPlonkProver<P: Pairing> {
) -> IoResult<Vec<P::ScalarField>>;

/// Reconstructs many shared values: a = Open(\[a\]).
async fn open_vec(&mut self, a: Vec<Self::ArithmeticShare>) -> IoResult<Vec<P::ScalarField>>;
async fn open_vec(&mut self, a: &[Self::ArithmeticShare]) -> IoResult<Vec<P::ScalarField>>;

/// Computes the inverse of many shared values: \[b\] = \[a\] ^ -1. Requires network communication.
async fn inv_vec(
Expand All @@ -106,7 +106,7 @@ pub trait CircomPlonkProver<P: Pairing> {
) -> Vec<Self::ArithmeticShare>;

/// Reconstructs many shared points: A = Open(\[A\]).
async fn open_point_g1(&mut self, a: &Self::PointShareG1) -> IoResult<P::G1>;
async fn open_point_g1(&mut self, a: Self::PointShareG1) -> IoResult<P::G1>;
async fn open_point_vec_g1(&mut self, a: &[Self::PointShareG1]) -> IoResult<Vec<P::G1>>;

// WE NEED THIS ALSO FOR GROTH16
Expand Down
28 changes: 22 additions & 6 deletions co-circom/co-plonk/src/mpc/plain.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
use super::IoResult;
use ark_ec::pairing::Pairing;
use ark_ec::scalar_mul::variable_base::VariableBaseMSM;
use ark_ff::Field;
use ark_ff::UniformRand;
use ark_poly::univariate::DensePolynomial;
use ark_poly::Polynomial;
use itertools::izip;
use num_traits::Zero;

use super::CircomPlonkProver;
use rand::thread_rng;
Expand All @@ -21,6 +23,10 @@ impl<P: Pairing> CircomPlonkProver<P> for PlainPlonkDriver {
//doesn't matter
type PartyID = usize;

fn debug_print(a: Self::ArithmeticShare) {
println!("{a}")
}

fn rand(&mut self) -> Self::ArithmeticShare {
let mut rng = thread_rng();
Self::ArithmeticShare::rand(&mut rng)
Expand Down Expand Up @@ -66,7 +72,7 @@ impl<P: Pairing> CircomPlonkProver<P> for PlainPlonkDriver {
a: &[Self::ArithmeticShare],
b: &[Self::ArithmeticShare],
) -> IoResult<Vec<Self::ArithmeticShare>> {
Ok(izip!(a, b).map(|(a, b)| *a + *b).collect())
Ok(izip!(a, b).map(|(a, b)| *a * *b).collect())
}

async fn add_mul_vec(
Expand All @@ -86,15 +92,25 @@ impl<P: Pairing> CircomPlonkProver<P> for PlainPlonkDriver {
Ok(izip!(a, b).map(|(a, b)| *a * *b).collect())
}

async fn open_vec(&mut self, a: Vec<Self::ArithmeticShare>) -> IoResult<Vec<P::ScalarField>> {
Ok(a)
async fn open_vec(&mut self, a: &[Self::ArithmeticShare]) -> IoResult<Vec<P::ScalarField>> {
Ok(a.to_vec())
}

async fn inv_vec(
&mut self,
a: &[Self::ArithmeticShare],
) -> IoResult<Vec<Self::ArithmeticShare>> {
Ok(a.iter().map(|a| -*a).collect())
let mut res = Vec::with_capacity(a.len());
for a in a {
if a.is_zero() {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
"Cannot invert zero",
));
}
res.push(a.inverse().unwrap());
}
Ok(res)
}

fn promote_to_trivial_share(
Expand All @@ -118,8 +134,8 @@ impl<P: Pairing> CircomPlonkProver<P> for PlainPlonkDriver {
domain.ifft(data)
}

async fn open_point_g1(&mut self, a: &Self::PointShareG1) -> IoResult<P::G1> {
Ok(*a)
async fn open_point_g1(&mut self, a: Self::PointShareG1) -> IoResult<P::G1> {
Ok(a)
}

async fn open_point_vec_g1(&mut self, a: &[Self::PointShareG1]) -> IoResult<Vec<P::G1>> {
Expand Down
10 changes: 7 additions & 3 deletions co-circom/co-plonk/src/mpc/rep3.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ impl<P: Pairing, N: Rep3Network> CircomPlonkProver<P> for Rep3PlonkDriver<N> {

type PartyID = PartyID;

fn debug_print(a: Self::ArithmeticShare) {
todo!()
}

fn rand(&mut self) -> Self::ArithmeticShare {
Self::ArithmeticShare::rand(&mut self.io_context)
}
Expand Down Expand Up @@ -91,7 +95,7 @@ impl<P: Pairing, N: Rep3Network> CircomPlonkProver<P> for Rep3PlonkDriver<N> {
rep3::arithmetic::mul_open_vec(a, b, &mut self.io_context).await
}

async fn open_vec(&mut self, a: Vec<Self::ArithmeticShare>) -> IoResult<Vec<P::ScalarField>> {
async fn open_vec(&mut self, a: &[Self::ArithmeticShare]) -> IoResult<Vec<P::ScalarField>> {
rep3::arithmetic::open_vec(a, &mut self.io_context).await
}

Expand Down Expand Up @@ -123,8 +127,8 @@ impl<P: Pairing, N: Rep3Network> CircomPlonkProver<P> for Rep3PlonkDriver<N> {
domain.ifft(&data)
}

async fn open_point_g1(&mut self, a: &Self::PointShareG1) -> IoResult<P::G1> {
rep3::pointshare::open_point(a, &mut self.io_context).await
async fn open_point_g1(&mut self, a: Self::PointShareG1) -> IoResult<P::G1> {
rep3::pointshare::open_point(&a, &mut self.io_context).await
}

async fn open_point_vec_g1(&mut self, a: &[Self::PointShareG1]) -> IoResult<Vec<P::G1>> {
Expand Down
69 changes: 44 additions & 25 deletions co-circom/co-plonk/src/round2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,19 @@ macro_rules! array_prod_mul {
// Do the multiplications of inp[i] * inp[i-1] in constant rounds
let len = $inp.len();
let r = (0..=len).map(|_| $driver.rand()).collect::<Vec<_>>();
let r_inv = futures::executor::block_on($driver.inv_many(&r))?;
let r_inv = futures::executor::block_on($driver.inv_vec(&r))?;
let r_inv0 = vec![r_inv[0].clone(); len];
let mut unblind = futures::executor::block_on($driver.mul_vec(&r_inv0, &r[1..]))?;

let mul = futures::executor::block_on($driver.mul_vec(&r[..len], &$inp))?;
let mut open = $driver.mul_open_many(&mul, &r_inv[1..])?;
let mut open = futures::executor::block_on($driver.mul_open_vec(&mul, &r_inv[1..]))?;

for i in 1..open.len() {
open[i] = open[i] * open[i - 1];
}

for (unblind, open) in unblind.iter_mut().zip(open.iter()) {
*unblind = $driver.mul_with_public(open, unblind);
for (unblind, open) in unblind.iter_mut().zip(open.into_iter()) {
*unblind = T::mul_with_public(*unblind, open);
}
unblind
}};
Expand Down Expand Up @@ -117,6 +117,7 @@ impl<'a, P: Pairing, T: CircomPlonkProver<P>> Round2<'a, P, T> {
// To reduce the number of communication rounds, we implement the array_prod_mul macro according to https://www.usenix.org/system/files/sec22-ozdemir.pdf, p11 first paragraph.
fn compute_z(
driver: &mut T,
runtime: &mut Runtime,
zkey: &ZKey<P>,
domains: &Domains<P::ScalarField>,
challenges: &Round2Challenges<P, T>,
Expand Down Expand Up @@ -157,17 +158,26 @@ impl<'a, P: Pairing, T: CircomPlonkProver<P>> Round2<'a, P, T> {
n3.push(n3_);

// denArr := (a + beta·sigma1 + gamma)(b + beta·sigma2 + gamma)(c + beta·sigma3 + gamma)
let d1_ =
driver.add_with_public(&(challenges.beta * zkey.s1_poly.evaluations[i * 4]), a);
let d1_ = driver.add_with_public(&challenges.gamma, &d1_);

let d2_ =
driver.add_with_public(&(challenges.beta * zkey.s2_poly.evaluations[i * 4]), b);
let d2_ = driver.add_with_public(&challenges.gamma, &d2_);

let d3_ =
driver.add_with_public(&(challenges.beta * zkey.s3_poly.evaluations[i * 4]), c);
let d3_ = driver.add_with_public(&challenges.gamma, &d3_);
let d1_ = T::add_with_public(
party_id,
*a,
challenges.beta * zkey.s1_poly.evaluations[i * 4],
);
let d1_ = T::add_with_public(party_id, d1_, challenges.gamma);

let d2_ = T::add_with_public(
party_id,
*b,
challenges.beta * zkey.s2_poly.evaluations[i * 4],
);
let d2_ = T::add_with_public(party_id, d2_, challenges.gamma);

let d3_ = T::add_with_public(
party_id,
*c,
challenges.beta * zkey.s3_poly.evaluations[i * 4],
);
let d3_ = T::add_with_public(party_id, d3_, challenges.gamma);

d1.push(d1_);
d2.push(d2_);
Expand All @@ -177,10 +187,10 @@ impl<'a, P: Pairing, T: CircomPlonkProver<P>> Round2<'a, P, T> {
}

// TODO parallelize these? With a different network structure this might not be needed though
let num = futures::executor::block_on(driver.mul_vec(&n1, &n2))?;
let num = futures::executor::block_on(driver.mul_vec(&num, &n3))?;
let den = futures::executor::block_on(driver.mul_vec(&d1, &d2))?;
let den = futures::executor::block_on(driver.mul_vec(&den, &d3))?;
let num = runtime.block_on(driver.mul_vec(&n1, &n2))?;
let num = runtime.block_on(driver.mul_vec(&num, &n3))?;
let den = runtime.block_on(driver.mul_vec(&d1, &d2))?;
let den = runtime.block_on(driver.mul_vec(&den, &d3))?;

// TODO parallelize these? With a different network structure this might not be needed though
// Do the multiplications of num[i] * num[i-1] and den[i] * den[i-1] in constant rounds
Expand All @@ -195,12 +205,12 @@ impl<'a, P: Pairing, T: CircomPlonkProver<P>> Round2<'a, P, T> {
buffer_z.rotate_right(1); // Required by SNARKJs/Plonk

// Compute polynomial coefficients z(X) from buffer_z
let poly_z = driver.ifft(&buffer_z, &domains.domain);
let poly_z = T::ifft(&buffer_z, &domains.domain);

// Compute extended evaluations of z(X) polynomial
let eval_z = driver.fft(&poly_z, &domains.extended_domain);
let eval_z = T::fft(&poly_z, &domains.extended_domain);

let poly_z = plonk_utils::blind_coefficients::<P, T>(driver, &poly_z, &challenges.b[6..9]);
let poly_z = plonk_utils::blind_coefficients::<P, T>(&poly_z, &challenges.b[6..9]);

if poly_z.len() > zkey.domain_size + 3 {
Err(PlonkProofError::PolynomialDegreeTooLarge)
Expand All @@ -217,7 +227,7 @@ impl<'a, P: Pairing, T: CircomPlonkProver<P>> Round2<'a, P, T> {
pub(super) fn round2(self) -> PlonkProofResult<Round3<'a, P, T>> {
let Self {
mut driver,
runtime,
mut runtime,
data,
proof,
challenges,
Expand Down Expand Up @@ -250,7 +260,14 @@ impl<'a, P: Pairing, T: CircomPlonkProver<P>> Round2<'a, P, T> {
let gamma = transcript.get_challenge();
tracing::debug!("beta: {beta}, gamma: {gamma}");
let challenges = Round2Challenges::new(challenges, beta, gamma);
let z = Self::compute_z(&mut driver, zkey, &domains, &challenges, &polys)?;
let z = Self::compute_z(
&mut driver,
&mut runtime,
zkey,
&domains,
&challenges,
&polys,
)?;
// STEP 2.3 - Compute permutation [z]_1

tracing::debug!("committing to poly z (MSMs)");
Expand Down Expand Up @@ -280,6 +297,7 @@ pub mod tests {
use circom_types::plonk::ZKey;
use circom_types::Witness;
use co_circom_snarks::SharedWitness;
use tokio::runtime;

use crate::mpc::plain::PlainPlonkDriver;
use crate::round1::Round1;
Expand Down Expand Up @@ -313,7 +331,8 @@ pub mod tests {
};

let challenges = Round1Challenges::deterministic(&mut driver);
let mut round1 = Round1::init_round(driver, &zkey, witness).unwrap();
let runtime = runtime::Builder::new_current_thread().build().unwrap();
let mut round1 = Round1::init_round(driver, runtime, &zkey, witness).unwrap();
round1.challenges = challenges;
let round2 = round1.round1().unwrap();
let round3 = round2.round2().unwrap();
Expand Down
Loading

0 comments on commit 7113108

Please sign in to comment.