Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
0xThemis committed Sep 12, 2024
1 parent 6a24991 commit 0ee5f88
Show file tree
Hide file tree
Showing 9 changed files with 229 additions and 94 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,5 @@ Cargo.lock

.vscode
co-noir/co-acvm/src/bin/
co-circom/co-plonk/src/bin/
groth16_tester.rs
5 changes: 5 additions & 0 deletions co-circom/co-plonk/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,15 @@ mpc-core = { workspace = true }
mpc-net = { workspace = true }
num-traits = { workspace = true }
rand = { workspace = true }
rayon = { workspace = true }
sha3 = { workspace = true }
thiserror = { workspace = true }
tracing = { workspace = true }
tokio = { workspace = true }
# DELETE ME
tracing-subscriber.workspace = true
serde_json = { workspace = true }
ark-bn254 = { workspace = true }

[dev-dependencies]
ark-bls12-381 = { workspace = true }
Expand Down
13 changes: 12 additions & 1 deletion co-circom/co-plonk/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,18 @@ where
zkey: &ZKey<P>,
witness: SharedWitness<P::ScalarField, T::ArithmeticShare>,
) -> PlonkProofResult<PlonkProof<P>> {
tracing::debug!("starting PLONK prove..");
tracing::debug!("starting PLONK prove!");
tracing::debug!(
"we have {} constraints and {} addition constraints",
zkey.n_constraints,
zkey.n_additions
);
tracing::debug!("the domain size is {}", zkey.domain_size);
tracing::debug!(
"we have {} n_vars and {} public inputs",
zkey.n_vars,
zkey.n_public
);
let state = Round1::init_round(self.driver, self.runtime, zkey, witness)?;
tracing::debug!("init round done..");
let state = state.round1()?;
Expand Down
10 changes: 10 additions & 0 deletions co-circom/co-plonk/src/mpc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ pub trait CircomPlonkProver<P: Pairing> {
fn rand(&mut self) -> Self::ArithmeticShare;

fn get_party_id(&self) -> Self::PartyID;

fn fork(&mut self) -> Self;

/// Subtract the share b from the share a: \[c\] = \[a\] - \[b\]
fn add(a: Self::ArithmeticShare, b: Self::ArithmeticShare) -> Self::ArithmeticShare;

Expand Down Expand Up @@ -53,6 +56,13 @@ pub trait CircomPlonkProver<P: Pairing> {
b: &[Self::ArithmeticShare],
) -> impl Future<Output = IoResult<Vec<Self::ArithmeticShare>>>;

fn mul_vecs(
&mut self,
a: &[Self::ArithmeticShare],
b: &[Self::ArithmeticShare],
c: &[Self::ArithmeticShare],
) -> impl Future<Output = IoResult<Vec<Self::ArithmeticShare>>>;

/// Convenience method for \[a\] + \[b\] * \[c\]
fn add_mul_vec(
&mut self,
Expand Down
13 changes: 13 additions & 0 deletions co-circom/co-plonk/src/mpc/plain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ impl<P: Pairing> CircomPlonkProver<P> for PlainPlonkDriver {
0
}

fn fork(&mut self) -> Self {
PlainPlonkDriver
}

fn add(a: Self::ArithmeticShare, b: Self::ArithmeticShare) -> Self::ArithmeticShare {
a + b
}
Expand Down Expand Up @@ -75,6 +79,15 @@ impl<P: Pairing> CircomPlonkProver<P> for PlainPlonkDriver {
Ok(izip!(a, b).map(|(a, b)| *a * *b).collect())
}

async fn mul_vecs(
&mut self,
a: &[Self::ArithmeticShare],
b: &[Self::ArithmeticShare],
c: &[Self::ArithmeticShare],
) -> IoResult<Vec<Self::ArithmeticShare>> {
Ok(izip!(a, b, c).map(|(a, b, c)| *a * *b * *c).collect())
}

async fn add_mul_vec(
&mut self,
a: &[Self::ArithmeticShare],
Expand Down
15 changes: 14 additions & 1 deletion co-circom/co-plonk/src/mpc/rep3.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ impl<P: Pairing, N: Rep3Network> CircomPlonkProver<P> for Rep3PlonkDriver<N> {

type PartyID = PartyID;

fn debug_print(a: Self::ArithmeticShare) {
fn debug_print(_: Self::ArithmeticShare) {
todo!()
}

Expand All @@ -38,6 +38,10 @@ impl<P: Pairing, N: Rep3Network> CircomPlonkProver<P> for Rep3PlonkDriver<N> {
self.io_context.id
}

fn fork(&mut self) -> Self {
todo!()
}

fn add(a: Self::ArithmeticShare, b: Self::ArithmeticShare) -> Self::ArithmeticShare {
rep3::arithmetic::add(a, b)
}
Expand Down Expand Up @@ -76,6 +80,15 @@ impl<P: Pairing, N: Rep3Network> CircomPlonkProver<P> for Rep3PlonkDriver<N> {
rep3::arithmetic::mul_vec(lhs, rhs, &mut self.io_context).await
}

async fn mul_vecs(
&mut self,
_a: &[Self::ArithmeticShare],
_b: &[Self::ArithmeticShare],
_c: &[Self::ArithmeticShare],
) -> IoResult<Vec<Self::ArithmeticShare>> {
todo!();
}

async fn add_mul_vec(
&mut self,
a: &[Self::ArithmeticShare],
Expand Down
190 changes: 115 additions & 75 deletions co-circom/co-plonk/src/round1.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use ark_ec::CurveGroup;
use circom_types::plonk::ZKey;
use co_circom_snarks::SharedWitness;
use tokio::runtime::Runtime;
use tracing::instrument;

use crate::{
mpc::CircomPlonkProver,
Expand Down Expand Up @@ -92,67 +93,89 @@ impl<P: Pairing, T: CircomPlonkProver<P>> Round1Challenges<P, T> {

// Round 1 of https://eprint.iacr.org/2019/953.pdf (page 28)
impl<'a, P: Pairing, T: CircomPlonkProver<P>> Round1<'a, P, T> {
// Essentially the fft of the trace columns
fn compute_wire_polynomials(
driver: &mut T,
fn compute_single_wire_poly(
party_id: T::PartyID,
witness: &PlonkWitness<P, T>,
domains: &Domains<P::ScalarField>,
challenges: &Round1Challenges<P, T>,
blind_factors: &[T::ArithmeticShare],
zkey: &ZKey<P>,
witness: &PlonkWitness<P, T>,
) -> PlonkProofResult<Round1Polys<P, T>> {
let party_id = driver.get_party_id();
tracing::debug!("computing wire polynomials...");
let num_constraints = zkey.n_constraints;

let mut buffer_a = Vec::with_capacity(zkey.domain_size);
let mut buffer_b = Vec::with_capacity(zkey.domain_size);
let mut buffer_c = Vec::with_capacity(zkey.domain_size);

for i in 0..num_constraints {
buffer_a.push(plonk_utils::get_witness(
party_id,
witness,
zkey,
zkey.map_a[i],
)?);
buffer_b.push(plonk_utils::get_witness(
party_id,
witness,
zkey,
zkey.map_b[i],
)?);
buffer_c.push(plonk_utils::get_witness(
party_id,
witness,
zkey,
zkey.map_c[i],
)?);
map: &[usize],
) -> PlonkProofResult<(Vec<T::ArithmeticShare>, PolyEval<P, T>)> {
let mut buffer = Vec::with_capacity(zkey.n_constraints);
for i in 0..zkey.n_constraints {
match plonk_utils::get_witness(party_id, witness, zkey, map[i]) {
Ok(witness) => buffer.push(witness),
Err(err) => return Err(err),
}
}
buffer_a.resize(zkey.domain_size, T::ArithmeticShare::default());
buffer_b.resize(zkey.domain_size, T::ArithmeticShare::default());
buffer_c.resize(zkey.domain_size, T::ArithmeticShare::default());

//TODO MULTITHREAD ME
tracing::debug!("iffts for buffers..");
buffer.resize(zkey.domain_size, T::ArithmeticShare::default());
// Compute the coefficients of the wire polynomials a(X), b(X) and c(X) from A,B & C buffers
let poly_a = T::ifft(&buffer_a, &domains.domain);
let poly_b = T::ifft(&buffer_b, &domains.domain);
let poly_c = T::ifft(&buffer_c, &domains.domain);
let poly = T::ifft(&buffer, &domains.domain);

tracing::debug!("ffts for evals..");
// Compute extended evaluations of a(X), b(X) and c(X) polynomials
let eval_a = T::fft(&poly_a, &domains.extended_domain);
let eval_b = T::fft(&poly_b, &domains.extended_domain);
let eval_c = T::fft(&poly_c, &domains.extended_domain);
let eval = T::fft(&poly, &domains.extended_domain);

tracing::debug!("blinding coefficients");
let poly_a = plonk_utils::blind_coefficients::<P, T>(&poly_a, &challenges.b[..2]);
let poly_b = plonk_utils::blind_coefficients::<P, T>(&poly_b, &challenges.b[2..4]);
let poly_c = plonk_utils::blind_coefficients::<P, T>(&poly_c, &challenges.b[4..6]);
let poly = plonk_utils::blind_coefficients::<P, T>(&poly, blind_factors);
Ok((buffer, PolyEval { poly, eval }))
}

if poly_a.len() > zkey.domain_size + 2
|| poly_b.len() > zkey.domain_size + 2
|| poly_c.len() > zkey.domain_size + 2
// Essentially the fft of the trace columns
#[instrument(level = "debug", name = "compute wire polys", skip_all)]
fn compute_wire_polynomials(
driver: &mut T,
domains: &Domains<P::ScalarField>,
challenges: &Round1Challenges<P, T>,
zkey: &ZKey<P>,
witness: &PlonkWitness<P, T>,
) -> PlonkProofResult<Round1Polys<P, T>> {
let party_id = driver.get_party_id();

let mut wire_a = None;
let mut wire_b = None;
let mut wire_c = None;

rayon::scope(|s| {
s.spawn(|_| {
wire_a = Some(Self::compute_single_wire_poly(
party_id,
witness,
domains,
&challenges.b[..2],
zkey,
&zkey.map_a,
))
});
s.spawn(|_| {
wire_b = Some(Self::compute_single_wire_poly(
party_id,
witness,
domains,
&challenges.b[2..4],
zkey,
&zkey.map_b,
))
});
s.spawn(|_| {
wire_c = Some(Self::compute_single_wire_poly(
party_id,
witness,
domains,
&challenges.b[4..6],
zkey,
&zkey.map_c,
))
});
});
// we have some values as rayon scope finished
let (buffer_a, poly_a) = wire_a.unwrap()?;
let (buffer_b, poly_b) = wire_b.unwrap()?;
let (buffer_c, poly_c) = wire_c.unwrap()?;

if poly_a.poly.len() > zkey.domain_size + 2
|| poly_b.poly.len() > zkey.domain_size + 2
|| poly_c.poly.len() > zkey.domain_size + 2
{
return Err(PlonkProofError::PolynomialDegreeTooLarge);
}
Expand All @@ -161,31 +184,25 @@ impl<'a, P: Pairing, T: CircomPlonkProver<P>> Round1<'a, P, T> {
buffer_a,
buffer_b,
buffer_c,
a: PolyEval {
poly: poly_a.into(),
eval: eval_a,
},
b: PolyEval {
poly: poly_b.into(),
eval: eval_b,
},
c: PolyEval {
poly: poly_c.into(),
eval: eval_c,
},
a: poly_a,
b: poly_b,
c: poly_c,
})
}

// Calculate the witnesses for the additions, since they are not part of the SharedWitness
#[instrument(level = "debug", name = "calculate additions", skip_all)]
fn calculate_additions(
driver: &mut T,
witness: SharedWitness<P::ScalarField, T::ArithmeticShare>,
zkey: &ZKey<P>,
) -> PlonkProofResult<PlonkWitness<P, T>> {
tracing::debug!("calculating addition {} constraints...", zkey.n_additions);
let party_id = driver.get_party_id();
let mut witness = PlonkWitness::new(witness, zkey.n_additions);

// This is hard to multithread as we have to add the results
// to the vec as they are needed for the later steps.
// We leave it like that as it does not take to much time (<1ms for poseidon).
// Keep an eye on the span duration, maybe we have to come back to that later.
for addition in zkey.additions.iter() {
let witness1 = plonk_utils::get_witness(
party_id,
Expand All @@ -205,30 +222,32 @@ impl<'a, P: Pairing, T: CircomPlonkProver<P>> Round1<'a, P, T> {
let result = T::add(f1, f2);
witness.addition_witness.push(result);
}
tracing::debug!("additions done!");
Ok(witness)
}

#[instrument(level = "debug", name = "Plonk - Round Init", skip_all)]
pub(super) fn init_round(
mut driver: T,
runtime: Runtime,
zkey: &'a ZKey<P>,
private_witness: SharedWitness<P::ScalarField, T::ArithmeticShare>,
) -> PlonkProofResult<Self> {
let plonk_witness = Self::calculate_additions(&mut driver, private_witness, zkey)?;

let challenges = Round1Challenges::random(&mut driver)?;
let domains = Domains::new(zkey.domain_size)?;
Ok(Self {
challenges: Round1Challenges::random(&mut driver)?,
challenges,
driver,
runtime,
domains: Domains::new(zkey.domain_size)?,
domains,
data: PlonkDataRound1 {
witness: plonk_witness,
zkey,
},
})
}

#[instrument(level = "debug", name = "Plonk - Round 1", skip_all)]
// Round 1 of https://eprint.iacr.org/2019/953.pdf (page 28)
pub(super) fn round1(self) -> PlonkProofResult<Round2<'a, P, T>> {
let Self {
Expand All @@ -246,14 +265,35 @@ impl<'a, P: Pairing, T: CircomPlonkProver<P>> Round1<'a, P, T> {
let polys =
Self::compute_wire_polynomials(&mut driver, &domains, &challenges, zkey, witness)?;

tracing::debug!("committing to polys (MSMs)");
let mut commit_a = None;
let mut commit_b = None;
let mut commit_c = None;
let commit_span = tracing::debug_span!("committing to polys (MSMs)").entered();
// STEP 1.3 - Compute [a]_1, [b]_1, [c]_1
let commit_a = T::msm_public_points_g1(&p_tau[..polys.a.poly.len()], &polys.a.poly);
let commit_b = T::msm_public_points_g1(&p_tau[..polys.b.poly.len()], &polys.b.poly);
let commit_c = T::msm_public_points_g1(&p_tau[..polys.c.poly.len()], &polys.c.poly);

rayon::scope(|s| {
s.spawn(|_| {
let result = T::msm_public_points_g1(&p_tau[..polys.a.poly.len()], &polys.a.poly);
commit_a = Some(result);
});
s.spawn(|_| {
let result = T::msm_public_points_g1(&p_tau[..polys.b.poly.len()], &polys.b.poly);
commit_b = Some(result);
});
s.spawn(|_| {
let result = T::msm_public_points_g1(&p_tau[..polys.c.poly.len()], &polys.c.poly);
commit_c = Some(result);
});
});
// rayon scope must be done therefore some values
let commit_a = commit_a.unwrap();
let commit_b = commit_b.unwrap();
let commit_c = commit_c.unwrap();

// network round
commit_span.exit();
let opening_span = tracing::debug_span!("opening commits").entered();
let opened = runtime.block_on(driver.open_point_vec_g1(&[commit_a, commit_b, commit_c]))?;

opening_span.exit();
let proof = Round1Proof::<P> {
commit_a: opened[0],
commit_b: opened[1],
Expand Down
Loading

0 comments on commit 0ee5f88

Please sign in to comment.