Skip to content

Commit

Permalink
refactor!: PLONK now takes zkey as ref for prove
Browse files Browse the repository at this point in the history
BREAKING CHANGE: to unify Groth16 and PLONK
we now take the zkey as ref in PLONK when calling prove
  • Loading branch information
0xThemis committed Aug 14, 2024
1 parent df451a1 commit 7aa5c4e
Show file tree
Hide file tree
Showing 11 changed files with 44 additions and 44 deletions.
4 changes: 2 additions & 2 deletions collaborative-circom/src/bin/co-circom.rs
Original file line number Diff line number Diff line change
Expand Up @@ -533,7 +533,7 @@ where
let prover = CollaborativePlonk::new(protocol);

// execute prover in MPC
let proof = prover.prove(pk, witness_share)?;
let proof = prover.prove(&pk, witness_share)?;
(proof, public_input)
}
MPCProtocol::SHAMIR => {
Expand All @@ -549,7 +549,7 @@ where
let prover = CollaborativePlonk::new(protocol);

// execute prover in MPC
let proof = prover.prove(pk, witness_share)?;
let proof = prover.prove(&pk, witness_share)?;
(proof, public_input)
}
};
Expand Down
6 changes: 3 additions & 3 deletions collaborative-plonk/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ where
/// Execute the PLONK prover using the internal MPC driver.
pub fn prove(
self,
zkey: ZKey<P>,
zkey: &ZKey<P>,
witness: SharedWitness<T, P>,
) -> PlonkProofResult<PlonkProof<P>> {
let state = Round1::init_round(self.driver, zkey, witness)?;
Expand Down Expand Up @@ -218,7 +218,7 @@ pub mod tests {
.unwrap();

let plonk = Plonk::<Bn254>::new(driver);
let proof = plonk.prove(zkey, witness).unwrap();
let proof = plonk.prove(&zkey, witness).unwrap();
let result = Plonk::<Bn254>::verify(&vk, &proof, &public_input.values).unwrap();
assert!(result);
Ok(())
Expand Down Expand Up @@ -250,7 +250,7 @@ pub mod tests {
.unwrap();

let plonk = Plonk::<Bn254>::new(driver);
let proof = plonk.prove(zkey, witness).unwrap();
let proof = plonk.prove(&zkey, witness).unwrap();

let mut proof_bytes = vec![];
serde_json::to_writer(&mut proof_bytes, &proof).unwrap();
Expand Down
24 changes: 12 additions & 12 deletions collaborative-plonk/src/round1.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use crate::{
};

// Round 1 of https://eprint.iacr.org/2019/953.pdf (page 28)
pub(super) struct Round1<T, P: Pairing>
pub(super) struct Round1<'a, T, P: Pairing>
where
T: PrimeFieldMpcProtocol<P::ScalarField>
+ PairingEcMpcProtocol<P>
Expand All @@ -26,22 +26,22 @@ where
pub(super) driver: T,
pub(super) domains: Domains<P::ScalarField>,
pub(super) challenges: Round1Challenges<T, P>,
pub(super) data: PlonkDataRound1<T, P>,
pub(super) data: PlonkDataRound1<'a, T, P>,
}

pub(super) struct PlonkDataRound1<T, P: Pairing>
pub(super) struct PlonkDataRound1<'a, T, P: Pairing>
where
T: PrimeFieldMpcProtocol<P::ScalarField>,
{
witness: PlonkWitness<T, P>,
zkey: ZKey<P>,
zkey: &'a ZKey<P>,
}

impl<T, P: Pairing> From<PlonkDataRound1<T, P>> for PlonkData<T, P>
impl<'a, T, P: Pairing> From<PlonkDataRound1<'a, T, P>> for PlonkData<'a, T, P>
where
T: PrimeFieldMpcProtocol<P::ScalarField>,
{
fn from(mut data: PlonkDataRound1<T, P>) -> Self {
fn from(mut data: PlonkDataRound1<'a, T, P>) -> Self {
//when we are done, we remove the leading zero of the public inputs
data.witness.public_inputs = data.witness.public_inputs[1..].to_vec();
Self {
Expand Down Expand Up @@ -98,7 +98,7 @@ where
}

// Round 1 of https://eprint.iacr.org/2019/953.pdf (page 28)
impl<T, P: Pairing> Round1<T, P>
impl<'a, T, P: Pairing> Round1<'a, T, P>
where
T: PrimeFieldMpcProtocol<P::ScalarField>
+ PairingEcMpcProtocol<P>
Expand Down Expand Up @@ -223,10 +223,10 @@ where

pub(super) fn init_round(
mut driver: T,
zkey: ZKey<P>,
zkey: &'a ZKey<P>,
private_witness: SharedWitness<T, P>,
) -> PlonkProofResult<Self> {
let plonk_witness = Self::calculate_additions(&mut driver, private_witness, &zkey)?;
let plonk_witness = Self::calculate_additions(&mut driver, private_witness, zkey)?;

Ok(Self {
challenges: Round1Challenges::random(&mut driver)?,
Expand All @@ -240,7 +240,7 @@ where
}

// Round 1 of https://eprint.iacr.org/2019/953.pdf (page 28)
pub(super) fn round1(self) -> PlonkProofResult<Round2<T, P>> {
pub(super) fn round1(self) -> PlonkProofResult<Round2<'a, T, P>> {
let Self {
mut driver,
domains,
Expand Down Expand Up @@ -329,7 +329,7 @@ pub mod tests {
witness: vec![witness.values[2], witness.values[3]],
};
let challenges = Round1Challenges::deterministic(&mut driver);
let mut round1 = Round1::init_round(driver, zkey, witness).unwrap();
let mut round1 = Round1::init_round(driver, &zkey, witness).unwrap();
round1.challenges = challenges;
let round2 = round1.round1().unwrap();
assert_eq!(
Expand Down Expand Up @@ -372,7 +372,7 @@ pub mod tests {
};

let challenges = Round1Challenges::deterministic(&mut driver);
let mut round1 = Round1::init_round(driver, zkey, witness).unwrap();
let mut round1 = Round1::init_round(driver, &zkey, witness).unwrap();
round1.challenges = challenges;
let round2 = round1.round1().unwrap();
assert_eq!(
Expand Down
10 changes: 5 additions & 5 deletions collaborative-plonk/src/round2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ macro_rules! array_prod_mul {
}

// Round 2 of https://eprint.iacr.org/2019/953.pdf (page 28)
pub(super) struct Round2<T, P: Pairing>
pub(super) struct Round2<'a, T, P: Pairing>
where
T: PrimeFieldMpcProtocol<P::ScalarField>
+ PairingEcMpcProtocol<P>
Expand All @@ -55,7 +55,7 @@ where
pub(super) challenges: Round1Challenges<T, P>,
pub(super) proof: Round1Proof<P>,
pub(super) polys: Round1Polys<T, P>,
pub(super) data: PlonkData<T, P>,
pub(super) data: PlonkData<'a, T, P>,
}

pub(super) struct Round2Challenges<T, P: Pairing>
Expand Down Expand Up @@ -129,7 +129,7 @@ where
}

// Round 2 of https://eprint.iacr.org/2019/953.pdf (page 28)
impl<T, P: Pairing> Round2<T, P>
impl<'a, T, P: Pairing> Round2<'a, T, P>
where
T: PrimeFieldMpcProtocol<P::ScalarField>
+ PairingEcMpcProtocol<P>
Expand Down Expand Up @@ -236,7 +236,7 @@ where
}

// Round 2 of https://eprint.iacr.org/2019/953.pdf (page 28)
pub(super) fn round2(self) -> PlonkProofResult<Round3<T, P>> {
pub(super) fn round2(self) -> PlonkProofResult<Round3<'a, T, P>> {
let Self {
mut driver,
data,
Expand Down Expand Up @@ -331,7 +331,7 @@ pub mod tests {
};

let challenges = Round1Challenges::deterministic(&mut driver);
let mut round1 = Round1::init_round(driver, zkey, witness).unwrap();
let mut round1 = Round1::init_round(driver, &zkey, witness).unwrap();
round1.challenges = challenges;
let round2 = round1.round1().unwrap();
let round3 = round2.round2().unwrap();
Expand Down
12 changes: 6 additions & 6 deletions collaborative-plonk/src/round3.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ macro_rules! mul4vec_post {
}

// Round 3 of https://eprint.iacr.org/2019/953.pdf (page 29)
pub(super) struct Round3<T, P: Pairing>
pub(super) struct Round3<'a, T, P: Pairing>
where
T: PrimeFieldMpcProtocol<P::ScalarField>
+ PairingEcMpcProtocol<P>
Expand All @@ -85,7 +85,7 @@ where
pub(super) challenges: Round2Challenges<T, P>,
pub(super) proof: Round2Proof<P>,
pub(super) polys: Round2Polys<T, P>,
pub(super) data: PlonkData<T, P>,
pub(super) data: PlonkData<'a, T, P>,
}

pub(super) struct Round3Proof<P: Pairing> {
Expand Down Expand Up @@ -181,7 +181,7 @@ where
}

// Round 3 of https://eprint.iacr.org/2019/953.pdf (page 29)
impl<T, P: Pairing> Round3<T, P>
impl<'a, T, P: Pairing> Round3<'a, T, P>
where
T: PrimeFieldMpcProtocol<P::ScalarField>
+ PairingEcMpcProtocol<P>
Expand Down Expand Up @@ -459,7 +459,7 @@ where
}

// Round 3 of https://eprint.iacr.org/2019/953.pdf (page 29)
pub(super) fn round3(self) -> PlonkProofResult<Round4<T, P>> {
pub(super) fn round3(self) -> PlonkProofResult<Round4<'a, T, P>> {
let Self {
mut driver,
domains,
Expand All @@ -477,7 +477,7 @@ where
let alpha = transcript.get_challenge();
let alpha2 = alpha.square();
let challenges = Round3Challenges::new(challenges, alpha, alpha2);
let [t1, t2, t3] = Self::compute_t(&mut driver, &domains, &challenges, &data.zkey, &polys)?;
let [t1, t2, t3] = Self::compute_t(&mut driver, &domains, &challenges, data.zkey, &polys)?;

// Compute [T1]_1, [T2]_1, [T3]_1
let commit_t1 = MSMProvider::<P::G1>::msm_public_points(
Expand Down Expand Up @@ -551,7 +551,7 @@ pub mod tests {
};

let challenges = Round1Challenges::deterministic(&mut driver);
let mut round1 = Round1::init_round(driver, zkey, witness).unwrap();
let mut round1 = Round1::init_round(driver, &zkey, witness).unwrap();
round1.challenges = challenges;
let round2 = round1.round1().unwrap();
let round3 = round2.round2().unwrap();
Expand Down
10 changes: 5 additions & 5 deletions collaborative-plonk/src/round4.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use mpc_core::traits::{
};

// Round 4 of https://eprint.iacr.org/2019/953.pdf (page 29)
pub(super) struct Round4<T, P: Pairing>
pub(super) struct Round4<'a, T, P: Pairing>
where
T: PrimeFieldMpcProtocol<P::ScalarField>
+ PairingEcMpcProtocol<P>
Expand All @@ -24,7 +24,7 @@ where
pub(super) challenges: Round3Challenges<T, P>,
pub(super) proof: Round3Proof<P>,
pub(super) polys: FinalPolys<T, P>,
pub(super) data: PlonkData<T, P>,
pub(super) data: PlonkData<'a, T, P>,
}
pub(super) struct Round4Challenges<P: Pairing> {
pub(super) beta: P::ScalarField,
Expand Down Expand Up @@ -91,7 +91,7 @@ impl<P: Pairing> Round4Proof<P> {
}

// Round 4 of https://eprint.iacr.org/2019/953.pdf (page 29)
impl<T, P: Pairing> Round4<T, P>
impl<'a, T, P: Pairing> Round4<'a, T, P>
where
T: PrimeFieldMpcProtocol<P::ScalarField>
+ PairingEcMpcProtocol<P>
Expand All @@ -101,7 +101,7 @@ where
P::ScalarField: FFTPostProcessing,
{
// Round 4 of https://eprint.iacr.org/2019/953.pdf (page 29)
pub(super) fn round4(self) -> PlonkProofResult<Round5<T, P>> {
pub(super) fn round4(self) -> PlonkProofResult<Round5<'a, T, P>> {
let Self {
mut driver,
domains,
Expand Down Expand Up @@ -177,7 +177,7 @@ pub mod tests {
};

let challenges = Round1Challenges::deterministic(&mut driver);
let mut round1 = Round1::init_round(driver, zkey, witness).unwrap();
let mut round1 = Round1::init_round(driver, &zkey, witness).unwrap();
round1.challenges = challenges;
let round2 = round1.round1().unwrap();
let round3 = round2.round2().unwrap();
Expand Down
8 changes: 4 additions & 4 deletions collaborative-plonk/src/round5.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use num_traits::One;
use num_traits::Zero;

// Round 5 of https://eprint.iacr.org/2019/953.pdf (page 30)
pub(super) struct Round5<T, P: Pairing>
pub(super) struct Round5<'a, T, P: Pairing>
where
T: PrimeFieldMpcProtocol<P::ScalarField>
+ PairingEcMpcProtocol<P>
Expand All @@ -33,7 +33,7 @@ where
pub(super) challenges: Round4Challenges<P>,
pub(super) proof: Round4Proof<P>,
pub(super) polys: FinalPolys<T, P>,
pub(super) data: PlonkData<T, P>,
pub(super) data: PlonkData<'a, T, P>,
}
pub(super) struct Round5Challenges<P: Pairing> {
beta: P::ScalarField,
Expand Down Expand Up @@ -84,7 +84,7 @@ where
}

// Round 5 of https://eprint.iacr.org/2019/953.pdf (page 30)
impl<T, P: Pairing> Round5<T, P>
impl<'a, T, P: Pairing> Round5<'a, T, P>
where
T: PrimeFieldMpcProtocol<P::ScalarField>
+ PairingEcMpcProtocol<P>
Expand Down Expand Up @@ -390,7 +390,7 @@ pub mod tests {
};

let challenges = Round1Challenges::deterministic(&mut driver);
let mut round1 = Round1::init_round(driver, zkey, witness).unwrap();
let mut round1 = Round1::init_round(driver, &zkey, witness).unwrap();
round1.challenges = challenges;
let round2 = round1.round1().unwrap();
let round3 = round2.round2().unwrap();
Expand Down
4 changes: 2 additions & 2 deletions collaborative-plonk/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,12 @@ where
pub(super) addition_witness: Vec<FieldShare<T, P>>,
}

pub(super) struct PlonkData<T, P: Pairing>
pub(super) struct PlonkData<'a, T, P: Pairing>
where
T: PrimeFieldMpcProtocol<P::ScalarField>,
{
pub(super) witness: PlonkWitness<T, P>,
pub(super) zkey: ZKey<P>,
pub(super) zkey: &'a ZKey<P>,
}

impl<F: PrimeField> Domains<F> {
Expand Down
6 changes: 3 additions & 3 deletions tests/benches/poseidon_hash2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ where

let plain = PlainDriver::default();
let prover = CollaborativePlonk::new(plain);
prover.prove(pk, witness).unwrap()
prover.prove(&pk, witness).unwrap()
}

fn rep3_witness_extension<P>(
Expand Down Expand Up @@ -336,7 +336,7 @@ fn plonk_rep3_proof<P>(
let party = tokio::task::spawn_blocking(move || {
let rep3 = Rep3Protocol::new(net).unwrap();
let prover = CollaborativePlonk::new(rep3);
prover.prove(pk, witness).unwrap()
prover.prove(&pk, witness).unwrap()
});
parties.push(party);
}
Expand Down Expand Up @@ -387,7 +387,7 @@ fn plonk_shamir_proof<P>(
let party = tokio::task::spawn_blocking(move || {
let shamir = ShamirProtocol::new(degree, net).unwrap();
let prover = CollaborativePlonk::new(shamir);
prover.prove(pk, witness).unwrap()
prover.prove(&pk, witness).unwrap()
});
parties.push(party);
}
Expand Down
2 changes: 1 addition & 1 deletion tests/tests/e2e_tests/plonk/rep3.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ fn e2e_proof_poseidon_bn254() {
CollaborativePlonk::<Rep3Protocol<ark_bn254::Fr, PartyTestNetwork>, Bn254>::new(
rep3,
);
prover.prove(pk, x).unwrap()
prover.prove(&pk, x).unwrap()
}));
}
let result3 = threads.pop().unwrap().join().unwrap();
Expand Down
2 changes: 1 addition & 1 deletion tests/tests/e2e_tests/plonk/shamir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ fn e2e_poseidon_bn254_inner(num_parties: usize, threshold: usize) {
CollaborativePlonk::<ShamirProtocol<ark_bn254::Fr, PartyTestNetwork>, Bn254>::new(
shamir,
);
prover.prove(pk, x).unwrap()
prover.prove(&pk, x).unwrap()
}));
}
let mut results = Vec::with_capacity(num_parties);
Expand Down

0 comments on commit 7aa5c4e

Please sign in to comment.