Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
0xThemis committed Sep 11, 2024
1 parent 1613be5 commit c06e0a0
Show file tree
Hide file tree
Showing 17 changed files with 410 additions and 262 deletions.
10 changes: 3 additions & 7 deletions co-circom/co-groth16/src/mpc/rep3.rs
Original file line number Diff line number Diff line change
Expand Up @@ -137,10 +137,7 @@ impl<P: Pairing, N: Rep3Network> CircomGroth16Prover<P> for Rep3Groth16Driver<N>
rep3::pointshare::msm_public_points(points, scalars)
}

fn scalar_mul_public_point_g1(
a: &<P as Pairing>::G1,
b: Self::ArithmeticShare,
) -> Self::PointShareG1 {
fn scalar_mul_public_point_g1(a: &P::G1, b: Self::ArithmeticShare) -> Self::PointShareG1 {
rep3::pointshare::scalar_mul_public_point(a, b)
}

Expand All @@ -153,9 +150,8 @@ impl<P: Pairing, N: Rep3Network> CircomGroth16Prover<P> for Rep3Groth16Driver<N>
rep3::pointshare::add_assign_public(a, b, id)
}

async fn open_point_g1(&mut self, a: &Self::PointShareG1) -> IoResult<<P as Pairing>::G1> {
let c = self.io_context.network.reshare(a.b).await?;
Ok(a.a + a.b + c)
async fn open_point_g1(&mut self, a: &Self::PointShareG1) -> IoResult<P::G1> {
rep3::pointshare::open_point(a, &mut self.io_context).await
}

async fn scalar_mul_g1(
Expand Down
1 change: 1 addition & 0 deletions co-circom/co-plonk/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ itertools = { workspace = true }
mpc-core = { workspace = true }
mpc-net = { workspace = true }
num-traits = { workspace = true }
rand = { workspace = true }
sha3 = { workspace = true }
thiserror = { workspace = true }
tracing = { workspace = true }
Expand Down
13 changes: 6 additions & 7 deletions co-circom/co-plonk/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ where
witness: SharedWitness<P::ScalarField, T::ArithmeticShare>,
) -> PlonkProofResult<PlonkProof<P>> {
tracing::debug!("starting PLONK prove..");
let state = Round1::init_round(self.driver, zkey, witness)?;
let state = Round1::init_round(self.driver, self.runtime, zkey, witness)?;
tracing::debug!("init round done..");
let state = state.round1()?;
tracing::debug!("round 1 done..");
Expand Down Expand Up @@ -108,15 +108,15 @@ mod plonk_utils {
use num_traits::Zero;

pub(crate) fn get_witness<P: Pairing, T: CircomPlonkProver<P>>(
driver: &mut T,
party_id: T::PartyID,
witness: &PlonkWitness<P, T>,
zkey: &ZKey<P>,
index: usize,
) -> PlonkProofResult<T::ArithmeticShare> {
tracing::trace!("get witness on {index}");
let result = if index <= zkey.n_public {
tracing::trace!("indexing public input!");
driver.promote_to_trivial_share(witness.public_inputs[index])
T::promote_to_trivial_share(party_id, witness.public_inputs[index])
} else if index < zkey.n_vars - zkey.n_additions {
tracing::trace!("indexing private input!");
witness.witness[index - zkey.n_public - 1].clone()
Expand All @@ -132,13 +132,13 @@ mod plonk_utils {

// For convenience coeff is given in reverse order
pub(crate) fn blind_coefficients<P: Pairing, T: CircomPlonkProver<P>>(
driver: &mut T,
poly: &[T::ArithmeticShare],
coeff_rev: &[T::ArithmeticShare],
) -> Vec<T::ArithmeticShare> {
let mut res = poly.to_vec();
for (p, c) in res.iter_mut().zip(coeff_rev.iter().rev()) {
*p = driver.sub(p, c);
#[allow(unused_mut)]
for (mut p, c) in res.iter_mut().zip(coeff_rev.iter().rev()) {
*p = T::sub(*p, *c);
}
// Extend
res.reserve(coeff_rev.len());
Expand Down Expand Up @@ -210,7 +210,6 @@ pub mod tests {
use co_circom_snarks::SharedWitness;
use std::{fs::File, io::BufReader};

use crate::mpc::plain::PlainPlonkDriver;
use crate::plonk::Plonk;

#[test]
Expand Down
73 changes: 29 additions & 44 deletions co-circom/co-plonk/src/mpc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,40 +12,38 @@ pub use rep3::Rep3PlonkDriver;
type IoResult<T> = std::io::Result<T>;

pub trait CircomPlonkProver<P: Pairing> {
type ArithmeticShare: CanonicalSerialize + CanonicalDeserialize + Clone + Default;
type PointShare<C: CurveGroup>;
type ArithmeticShare: CanonicalSerialize + CanonicalDeserialize + Copy + Clone + Default + Send;
type PointShareG1: Send;
type PointShareG2: Send;

fn rand(&self) -> Self::ArithmeticShare;
type PartyID: Send + Sync + Copy;

fn rand(&mut self) -> Self::ArithmeticShare;

fn get_party_id(&self) -> Self::PartyID {
self.io_context.id
}

/// Subtract the share b from the share a: \[c\] = \[a\] - \[b\]
fn add(
&mut self,
a: &Self::ArithmeticShare,
b: &Self::ArithmeticShare,
) -> Self::ArithmeticShare;
fn add(a: Self::ArithmeticShare, b: Self::ArithmeticShare) -> Self::ArithmeticShare;

/// Add a public value a to the share b: \[c\] = a + \[b\]
fn add_with_public(
&mut self,
a: &P::ScalarField,
b: &Self::ArithmeticShare,
party_id: Self::PartyID,
shared: Self::ArithmeticShare,
public: P::ScalarField,
) -> Self::ArithmeticShare;

/// Subtract the share b from the share a: \[c\] = \[a\] - \[b\]
fn sub(
&mut self,
a: &Self::ArithmeticShare,
b: &Self::ArithmeticShare,
) -> Self::ArithmeticShare;
fn sub(a: Self::ArithmeticShare, b: Self::ArithmeticShare) -> Self::ArithmeticShare;

/// Negates a vector of shared values: \[b\] = -\[a\] for every element in place.
fn neg_vec_in_place(&mut self, a: &mut [Self::ArithmeticShare]);

/// Multiply a share b by a public value a: c = a * \[b\].
fn mul_with_public(
&mut self,
a: &P::ScalarField,
b: &Self::ArithmeticShare,
shared: Self::ArithmeticShare,
public: P::ScalarField,
) -> Self::ArithmeticShare;

async fn mul_vec(
Expand Down Expand Up @@ -74,63 +72,50 @@ pub trait CircomPlonkProver<P: Pairing> {
}

/// This function performs a multiplication directly followed by an opening. This safes one round of communication in some MPC protocols compared to calling `mul` and `open` separately.
fn mul_open_many(
async fn mul_open_vec(
&mut self,
a: &[Self::ArithmeticShare],
b: &[Self::ArithmeticShare],
) -> IoResult<Vec<P::ScalarField>>;

/// Reconstructs many shared values: a = Open(\[a\]).
fn open_many(&mut self, a: &[Self::ArithmeticShare]) -> IoResult<Vec<P::ScalarField>>;
async fn open_vec(&mut self, a: Vec<Self::ArithmeticShare>) -> IoResult<Vec<P::ScalarField>>;

/// Computes the inverse of many shared values: \[b\] = \[a\] ^ -1. Requires network communication.
async fn inv_many(
async fn inv_vec(
&mut self,
a: &[Self::ArithmeticShare],
) -> IoResult<Vec<Self::ArithmeticShare>>;

/// Transforms a public value into a shared value: \[a\] = a.
fn promote_to_trivial_share(&self, public_values: P::ScalarField) -> Self::ArithmeticShare;
fn promote_to_trivial_share(
party_id: Self::PartyID,
public_value: P::ScalarField,
) -> Self::ArithmeticShare;

/// Computes the FFT of a vector of shared field elements.
fn fft<D: EvaluationDomain<P::ScalarField>>(
&mut self,
data: &[Self::ArithmeticShare],
domain: &D,
) -> Vec<Self::ArithmeticShare>;

/// Computes the inverse FFT of a vector of shared field elements.
fn ifft<D: EvaluationDomain<P::ScalarField>>(
&mut self,
data: &[Self::ArithmeticShare],
domain: &D,
) -> Vec<Self::ArithmeticShare>;

/// Reconstructs many shared points: A = Open(\[A\]).
fn open_point<C: CurveGroup>(&mut self, a: Self::PointShare<C>) -> IoResult<C> {
let mut result = self.open_point_many(&[a])?;
if result.len() != 1 {
Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"During execution of degree_reduce_vec in MPC: Invalid number of elements received",
))
} else {
Ok(result.pop().expect("we checked for len above"))
}
}

/// Reconstructs many shared points: A = Open(\[A\]).
fn open_point_many<C: CurveGroup>(&mut self, a: &[Self::PointShare<C>]) -> IoResult<Vec<C>>;
async fn open_point_g1(&mut self, a: &Self::PointShareG1) -> IoResult<P::G1>;
async fn open_point_vec_g1(&mut self, a: &[Self::PointShareG1]) -> IoResult<Vec<P::G1>>;

// WE NEED THIS ALSO FOR GROTH16
fn msm_public_points<C: CurveGroup>(
&mut self,
points: &[C::Affine],
fn msm_public_points_g1(
points: &[P::G1Affine],
scalars: &[Self::ArithmeticShare],
) -> Self::PointShare<C>;
) -> Self::PointShareG1;

fn evaluate_poly_public(
&mut self,
poly: &[Self::ArithmeticShare],
point: P::ScalarField,
) -> Self::ArithmeticShare;
Expand Down
Loading

0 comments on commit c06e0a0

Please sign in to comment.