Skip to content

Commit

Permalink
not working
Browse files Browse the repository at this point in the history
  • Loading branch information
0xThemis committed Sep 13, 2024
1 parent 2a084a2 commit c7ce9da
Show file tree
Hide file tree
Showing 9 changed files with 139 additions and 138 deletions.
18 changes: 9 additions & 9 deletions co-circom/circom-types/src/groth16/zkey.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ use ark_ff::PrimeField;
use ark_relations::r1cs::ConstraintMatrices;
use ark_serialize::CanonicalDeserialize;

use std::io::Read;
use std::{io::Read, sync::Arc};

use crate::{
binfile::{BinFile, ZKeyParserError, ZKeyParserResult},
Expand Down Expand Up @@ -63,9 +63,9 @@ pub struct ZKey<P: Pairing> {
/// b_query in G2
pub b_g2_query: Vec<P::G2Affine>,
/// h_query
pub h_query: Vec<P::G1Affine>,
pub h_query: Arc<Vec<P::G1Affine>>,
/// l_query
pub l_query: Vec<P::G1Affine>,
pub l_query: Arc<Vec<P::G1Affine>>,
/// The constraint matrices A, B, and C
pub matrices: ConstraintMatrices<P::ScalarField>,
}
Expand Down Expand Up @@ -242,8 +242,8 @@ where
a_query: a_query.unwrap()?,
b_g1_query: b_g1_query.unwrap()?,
b_g2_query: b_g2_query.unwrap()?,
h_query: h_query.unwrap()?,
l_query: l_query.unwrap()?,
h_query: Arc::new(h_query.unwrap()?),
l_query: Arc::new(l_query.unwrap()?),
matrices,
vk,
})
Expand Down Expand Up @@ -412,8 +412,8 @@ mod tests {
assert_eq!(a_query, pk.a_query);
assert_eq!(b_g1_query, pk.b_g1_query);
assert_eq!(b_g2_query, pk.b_g2_query);
assert_eq!(h_query, pk.h_query);
assert_eq!(l_query, pk.l_query);
assert_eq!(h_query, Arc::into_inner(pk.h_query).unwrap());
assert_eq!(l_query, Arc::into_inner(pk.l_query).unwrap());
let vk = pk.vk;
let alpha_g1 = test_utils::to_g1_bls12_381!(
"573513743870798705896078935465463988747193691665514373553428213826028808426481266659437596949247877550493216010640",
Expand Down Expand Up @@ -528,8 +528,8 @@ mod tests {
assert_eq!(a_query, pk.a_query);
assert_eq!(b_g1_query, pk.b_g1_query);
assert_eq!(b_g2_query, pk.b_g2_query);
assert_eq!(h_query, pk.h_query);
assert_eq!(l_query, pk.l_query);
assert_eq!(h_query, Arc::into_inner(pk.h_query).unwrap());
assert_eq!(l_query, Arc::into_inner(pk.l_query).unwrap());
let vk = pk.vk;

let alpha_g1 = test_utils::to_g1_bn254!(
Expand Down
200 changes: 91 additions & 109 deletions co-circom/co-groth16/src/groth16.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@ use mpc_core::protocols::rep3::network::{IoContext, Rep3MpcNet, Rep3Network};
use mpc_net::config::NetworkConfig;
use num_traits::identities::One;
use num_traits::ToPrimitive;
use std::io;
use std::marker::PhantomData;
use std::time::Duration;
use std::{io, thread};
use std::sync::Arc;
use tokio::runtime::{self, Runtime};
use tokio::sync::oneshot;
use tracing::instrument;

use crate::mpc::plain::PlainGroth16Driver;
Expand Down Expand Up @@ -107,8 +108,7 @@ where
let num_inputs = matrices.num_instance_variables;
let num_constraints = matrices.num_constraints;
let public_inputs = &private_witness.public_inputs;
let private_witness = &private_witness.witness;
let party_id = self.driver.get_party_id();
let private_witness = private_witness.witness;
// TODO: actually we do not like that we have to block here
// can we somehow manage that this doesn't have to use the network?
let mut forked_driver = self.runtime.block_on(self.driver.fork())?;
Expand All @@ -118,7 +118,7 @@ where
num_constraints,
num_inputs,
public_inputs,
private_witness,
&private_witness,
)?;
tracing::debug!("getting r and s...");
//TODO: this is bad - we need something else
Expand All @@ -130,7 +130,7 @@ where
zkey,
r,
s,
&h,
h,
&public_inputs[1..],
private_witness,
)
Expand Down Expand Up @@ -164,114 +164,91 @@ where
GeneralEvaluationDomain::<P::ScalarField>::new(num_constraints + num_inputs)
.ok_or(SynthesisError::PolynomialDegreeTooLarge)?;
let root_of_unity = root_of_unity_for_groth16(power, &mut domain);
let domain = Arc::new(domain);
let domain_size = domain.size();
let party_id = self.driver.get_party_id();
let mut a = Option::None;
let mut b = Option::None;
let eval_constraint_span = tracing::debug_span!("evaluate constraints").entered();
rayon::scope(|s| {
s.spawn(|_| {
let mut inner_a = Self::evaluate_constraint(
let (a, b) = rayon::join(
|| {
let mut result = Self::evaluate_constraint(
party_id,
domain_size,
&matrices.a,
public_inputs,
private_witness,
);
let promoted_public = T::promote_to_trivial_shares(party_id, public_inputs);
inner_a[num_constraints..num_constraints + num_inputs]
result[num_constraints..num_constraints + num_inputs]
.clone_from_slice(&promoted_public[..num_inputs]);
a = Some(inner_a);
});
s.spawn(|_| {
let inner_b = Self::evaluate_constraint(
result
},
|| {
let result = Self::evaluate_constraint(
party_id,
domain_size,
&matrices.b,
public_inputs,
private_witness,
);
b = Some(inner_b);
})
});
result
},
);

eval_constraint_span.exit();
// if we are here the scope finished therefore we have to have Some
// values
let a = a.unwrap();
let b = b.unwrap();

let mut a_dist_pow = Option::None;
let mut b_dist_pow = Option::None;
let mut c_dist_pow = Option::None;
let (a_tx, a_rx) = oneshot::channel();
let (b_tx, b_rx) = oneshot::channel();
let a_domain = Arc::clone(&domain);
let b_domain = Arc::clone(&domain);
let mut a_result = a.clone();
let mut b_result = b.clone();
rayon::spawn(move || {
T::ifft_in_place(&mut a_result, a_domain.as_ref());
T::distribute_powers_and_mul_by_const(
&mut a_result,
root_of_unity,
P::ScalarField::one(),
);
T::fft_in_place(&mut a_result, a_domain.as_ref());
a_tx.send(a_result).expect("channel not droped");
});

let ditribute_pow_span = tracing::debug_span!("ifft, distribute pows, fft").entered();
rayon::scope(|s| {
s.spawn(|_| {
match self.runtime.block_on(self.driver.mul_vec(&a, &b)) {
Ok(mut ab) => {
let mul_vec_span =
tracing::debug_span!("groth16 - mul vec in dist pows").entered();
// TODO: this is a very large multiplication - do we want to do that on the runtime
// or maybe on rayon and only sending on the runtime?
let ifft_span = tracing::debug_span!("ifft in dist pows").entered();
T::ifft_in_place(&mut ab, &domain);
ifft_span.exit();
let dist_pows_span = tracing::debug_span!("dist pows").entered();
T::distribute_powers_and_mul_by_const(
&mut ab,
root_of_unity,
P::ScalarField::one(),
);
dist_pows_span.exit();
let fft_span = tracing::debug_span!("fft in dist pows").entered();
T::fft_in_place(&mut ab, &domain);
fft_span.exit();
c_dist_pow = Some(Ok(ab));
mul_vec_span.exit();
}
Err(err) => {
c_dist_pow = Some(Err(err));
}
}
});
s.spawn(|_| {
let mut a_result = T::ifft(&a, &domain);
T::distribute_powers_and_mul_by_const(
&mut a_result,
root_of_unity,
P::ScalarField::one(),
);
T::fft_in_place(&mut a_result, &domain);
a_dist_pow = Some(a_result);
});
s.spawn(|_| {
let mut b_result = T::ifft(&b, &domain);
T::distribute_powers_and_mul_by_const(
&mut b_result,
root_of_unity,
P::ScalarField::one(),
);
T::fft_in_place(&mut b_result, &domain);
b_dist_pow = Some(b_result);
});
rayon::spawn(move || {
T::ifft_in_place(&mut b_result, b_domain.as_ref());
T::distribute_powers_and_mul_by_const(
&mut b_result,
root_of_unity,
P::ScalarField::one(),
);
T::fft_in_place(&mut b_result, b_domain.as_ref());
b_tx.send(b_result).expect("channel not droped");
});
//drop the old values!
std::mem::drop(a);
std::mem::drop(b);
//rayon finished therefore we must have some value
let a = a_dist_pow.unwrap();
let b = b_dist_pow.unwrap();
let c = c_dist_pow.unwrap()?;
ditribute_pow_span.exit();

//need to wait here...

let mut ab = self.runtime.block_on(self.driver.mul_vec(&a, &b))?;
let mul_vec_span = tracing::debug_span!("groth16 - mul vec in dist pows").entered();
let ifft_span = tracing::debug_span!("ifft in dist pows").entered();
T::ifft_in_place(&mut ab, domain.as_ref());
ifft_span.exit();
let dist_pows_span = tracing::debug_span!("dist pows").entered();
T::distribute_powers_and_mul_by_const(&mut ab, root_of_unity, P::ScalarField::one());
dist_pows_span.exit();
let fft_span = tracing::debug_span!("fft in dist pows").entered();
T::fft_in_place(&mut ab, domain.as_ref());
fft_span.exit();
let c_dist_pow = ab;
mul_vec_span.exit();

tracing::error!("DONE MUL VEC SCOPE");
let a = a_rx.blocking_recv()?;
let b = b_rx.blocking_recv()?;

let mul_vec_span = tracing::debug_span!("compute ab").entered();
//TODO we can merge the mul and sub commands but it most likely is not that
//much of a difference
let mut ab = self.runtime.block_on(self.driver.mul_vec(&a, &b))?;
T::sub_assign_vec(&mut ab, &c);
T::sub_assign_vec(&mut ab, &c_dist_pow);
mul_vec_span.exit();
Ok(ab)
Ok(a)
}

fn calculate_coeff_g1(
Expand Down Expand Up @@ -351,31 +328,34 @@ where
zkey: &ZKey<P>,
r: T::ArithmeticShare,
s: T::ArithmeticShare,
h: &[T::ArithmeticShare],
h: Vec<T::ArithmeticShare>,
input_assignment: &[P::ScalarField],
aux_assignment: &[T::ArithmeticShare],
aux_assignment: Vec<T::ArithmeticShare>,
) -> Result<Groth16Proof<P>> {
let mut h_acc = None;
let mut l_aux_acc = None;
let mut r_s_delta_g1 = None;
let delta_g1 = zkey.delta_g1.into_group();
let msm_create_proof = tracing::debug_span!("first MSMs").entered();
let party_id = forked_driver.get_party_id();
rayon::scope(|scope| {
scope.spawn(|_| h_acc = Some(T::msm_public_points_g1(&zkey.h_query, h)));
scope.spawn(|_| {
l_aux_acc = Some(T::msm_public_points_g1(&zkey.l_query, aux_assignment))
});
scope.spawn(|_| match self.runtime.block_on(self.driver.mul(r, s)) {
Ok(rs) => r_s_delta_g1 = Some(Ok(T::scalar_mul_public_point_g1(&delta_g1, rs))),
Err(err) => r_s_delta_g1 = Some(Err(err)),
});

let (h_acc_tx, h_acc_rx) = oneshot::channel();
let (l_acc_tx, l_acc_rx) = oneshot::channel();
let h_query = Arc::clone(&zkey.h_query);
let l_query = Arc::clone(&zkey.l_query);
rayon::spawn(move || {
let result = T::msm_public_points_g1(h_query.as_ref(), &h);
h_acc_tx.send(result).expect("channel not dropped");
});

rayon::spawn(move || {
let result = T::msm_public_points_g1(l_query.as_ref(), &aux_assignment);
l_acc_tx
.send((result, aux_assignment))
.expect("channel not dropped");
});
let rs = self.runtime.block_on(self.driver.mul(r, s))?;
let r_s_delta_g1 = T::scalar_mul_public_point_g1(&delta_g1, rs);
msm_create_proof.exit();

let h_acc = h_acc.unwrap();
let l_aux_acc = l_aux_acc.unwrap();
let r_s_delta_g1 = r_s_delta_g1.unwrap()?;
let h_acc = h_acc_rx.blocking_recv().unwrap();
let (l_aux_acc, aux_assignment) = l_acc_rx.blocking_recv().unwrap();

let mut g_a = None;
let mut g1_b = None;
Expand All @@ -392,7 +372,7 @@ where
&zkey.a_query,
zkey.vk.alpha_g1,
input_assignment,
aux_assignment,
&aux_assignment,
));
});
scope.spawn(|_| {
Expand All @@ -405,7 +385,7 @@ where
&zkey.b_g1_query,
zkey.beta_g1,
input_assignment,
aux_assignment,
&aux_assignment,
));
});
scope.spawn(|_| {
Expand All @@ -417,7 +397,7 @@ where
&zkey.b_g2_query,
zkey.vk.beta_g2,
input_assignment,
aux_assignment,
&aux_assignment,
));
});
});
Expand Down Expand Up @@ -487,7 +467,9 @@ where
phantom_data: PhantomData,
})
}
}

impl<P: Pairing, N: Rep3Network> Rep3CoGroth16<P, N> {
pub fn close_network(self) -> io::Result<()> {
self.runtime
.block_on(self.driver.into_network().shutdown())?;
Expand Down
15 changes: 11 additions & 4 deletions co-circom/co-groth16/src/mpc.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use core::fmt;
use std::future::Future;
use std::{fmt::Debug, future::Future};

use ark_ec::pairing::Pairing;
use ark_poly::EvaluationDomain;
Expand All @@ -15,9 +15,16 @@ pub use rep3::Rep3Groth16Driver;
type IoResult<T> = std::io::Result<T>;

pub trait CircomGroth16Prover<P: Pairing>: Send + Sized {
type ArithmeticShare: CanonicalSerialize + CanonicalDeserialize + Copy + Clone + Default + Send;
type PointShareG1: Send;
type PointShareG2: Send;
type ArithmeticShare: CanonicalSerialize
+ CanonicalDeserialize
+ Copy
+ Clone
+ Default
+ Send
+ Debug
+ 'static;
type PointShareG1: Debug + Send + 'static;
type PointShareG2: Debug + Send + 'static;
type PartyID: Send + Sync + Copy + fmt::Display;

fn rand(&mut self) -> impl Future<Output = IoResult<Self::ArithmeticShare>>;
Expand Down
2 changes: 1 addition & 1 deletion co-circom/co-plonk/src/mpc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -134,5 +134,5 @@ pub trait CircomPlonkProver<P: Pairing> {
fn evaluate_poly_public(
poly: Vec<Self::ArithmeticShare>,
point: P::ScalarField,
) -> Self::ArithmeticShare;
) -> (Self::ArithmeticShare, Vec<Self::ArithmeticShare>);
}
Loading

0 comments on commit c7ce9da

Please sign in to comment.