Skip to content

Commit

Permalink
ShamirGroth16Driver
Browse files Browse the repository at this point in the history
  • Loading branch information
romanmarkusholler committed Sep 12, 2024
1 parent 7113108 commit c11db99
Show file tree
Hide file tree
Showing 11 changed files with 336 additions and 35 deletions.
2 changes: 1 addition & 1 deletion co-circom/co-groth16/src/mpc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ pub trait CircomGroth16Prover<P: Pairing>: Send {
type PointShareG2: Send;
type PartyID: Send + Sync + Copy;

fn rand(&mut self) -> Self::ArithmeticShare;
async fn rand(&mut self) -> IoResult<Self::ArithmeticShare>;

fn get_party_id(&self) -> Self::PartyID;

Expand Down
6 changes: 4 additions & 2 deletions co-circom/co-groth16/src/mpc/plain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ use rand::thread_rng;

use super::CircomGroth16Prover;

type IoResult<T> = std::io::Result<T>;

pub struct PlainGroth16Driver;

impl<P: Pairing> CircomGroth16Prover<P> for PlainGroth16Driver {
Expand All @@ -17,9 +19,9 @@ impl<P: Pairing> CircomGroth16Prover<P> for PlainGroth16Driver {

type PartyID = usize;

fn rand(&mut self) -> Self::ArithmeticShare {
async fn rand(&mut self) -> IoResult<Self::ArithmeticShare> {
let mut rng = thread_rng();
Self::ArithmeticShare::rand(&mut rng)
Ok(Self::ArithmeticShare::rand(&mut rng))
}

fn get_party_id(&self) -> Self::PartyID {
Expand Down
40 changes: 20 additions & 20 deletions co-circom/co-groth16/src/mpc/rep3.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
use ark_ec::pairing::Pairing;
use itertools::izip;
use mpc_core::protocols::rep3::{
self, arithmetic,
arithmetic,
id::PartyID,
network::{IoContext, Rep3Network},
Rep3PointShare, Rep3PrimeFieldShare,
pointshare, Rep3PointShare, Rep3PrimeFieldShare,
};

use super::{CircomGroth16Prover, IoResult};
Expand All @@ -26,8 +26,8 @@ impl<P: Pairing, N: Rep3Network> CircomGroth16Prover<P> for Rep3Groth16Driver<N>

type PartyID = PartyID;

fn rand(&mut self) -> Self::ArithmeticShare {
Self::ArithmeticShare::rand(&mut self.io_context)
async fn rand(&mut self) -> IoResult<Self::ArithmeticShare> {
Ok(Self::ArithmeticShare::rand(&mut self.io_context))
}

fn get_party_id(&self) -> Self::PartyID {
Expand All @@ -44,7 +44,7 @@ impl<P: Pairing, N: Rep3Network> CircomGroth16Prover<P> for Rep3Groth16Driver<N>
public_inputs: &[P::ScalarField],
private_witness: &[Self::ArithmeticShare],
) -> Self::ArithmeticShare {
let mut acc = Rep3PrimeFieldShare::default();
let mut acc = Self::ArithmeticShare::default();
for (coeff, index) in lhs {
if index < &public_inputs.len() {
let val = public_inputs[*index];
Expand All @@ -70,7 +70,7 @@ impl<P: Pairing, N: Rep3Network> CircomGroth16Prover<P> for Rep3Groth16Driver<N>

fn sub_assign_vec(a: &mut [Self::ArithmeticShare], b: &[Self::ArithmeticShare]) {
for (a, b) in izip!(a, b) {
rep3::arithmetic::sub_assign(a, *b);
arithmetic::sub_assign(a, *b);
}
}

Expand All @@ -79,15 +79,15 @@ impl<P: Pairing, N: Rep3Network> CircomGroth16Prover<P> for Rep3Groth16Driver<N>
a: Self::ArithmeticShare,
b: Self::ArithmeticShare,
) -> IoResult<Self::ArithmeticShare> {
rep3::arithmetic::mul(a, b, &mut self.io_context).await
arithmetic::mul(a, b, &mut self.io_context).await
}

async fn mul_vec(
&mut self,
lhs: &[Self::ArithmeticShare],
rhs: &[Self::ArithmeticShare],
) -> IoResult<Vec<Self::ArithmeticShare>> {
rep3::arithmetic::mul_vec(lhs, rhs, &mut self.io_context).await
arithmetic::mul_vec(lhs, rhs, &mut self.io_context).await
}

fn fft_in_place<D: ark_poly::EvaluationDomain<P::ScalarField>>(
Expand Down Expand Up @@ -118,7 +118,7 @@ impl<P: Pairing, N: Rep3Network> CircomGroth16Prover<P> for Rep3Groth16Driver<N>
) {
let mut pow = c;
for share in coeffs.iter_mut() {
rep3::arithmetic::mul_assign_public(share, pow);
arithmetic::mul_assign_public(share, pow);
pow *= g;
}
}
Expand All @@ -127,55 +127,55 @@ impl<P: Pairing, N: Rep3Network> CircomGroth16Prover<P> for Rep3Groth16Driver<N>
points: &[P::G1Affine],
scalars: &[Self::ArithmeticShare],
) -> Self::PointShareG1 {
rep3::pointshare::msm_public_points(points, scalars)
pointshare::msm_public_points(points, scalars)
}

fn msm_public_points_g2(
points: &[P::G2Affine],
scalars: &[Self::ArithmeticShare],
) -> Self::PointShareG2 {
rep3::pointshare::msm_public_points(points, scalars)
pointshare::msm_public_points(points, scalars)
}

fn scalar_mul_public_point_g1(a: &P::G1, b: Self::ArithmeticShare) -> Self::PointShareG1 {
rep3::pointshare::scalar_mul_public_point(a, b)
pointshare::scalar_mul_public_point(a, b)
}

/// Add a shared point B in place to the shared point A: \[A\] += \[B\]
fn add_assign_points_g1(a: &mut Self::PointShareG1, b: &Self::PointShareG1) {
rep3::pointshare::add_assign(a, b)
pointshare::add_assign(a, b)
}

fn add_assign_points_public_g1(id: Self::PartyID, a: &mut Self::PointShareG1, b: &P::G1) {
rep3::pointshare::add_assign_public(a, b, id)
pointshare::add_assign_public(a, b, id)
}

async fn open_point_g1(&mut self, a: &Self::PointShareG1) -> IoResult<P::G1> {
rep3::pointshare::open_point(a, &mut self.io_context).await
pointshare::open_point(a, &mut self.io_context).await
}

async fn scalar_mul_g1(
&mut self,
a: &Self::PointShareG1,
b: Self::ArithmeticShare,
) -> IoResult<Self::PointShareG1> {
rep3::pointshare::scalar_mul(a, b, &mut self.io_context).await
pointshare::scalar_mul(a, b, &mut self.io_context).await
}

fn sub_assign_points_g1(a: &mut Self::PointShareG1, b: &Self::PointShareG1) {
rep3::pointshare::sub_assign(a, b);
pointshare::sub_assign(a, b);
}

fn scalar_mul_public_point_g2(a: &P::G2, b: Self::ArithmeticShare) -> Self::PointShareG2 {
rep3::pointshare::scalar_mul_public_point(a, b)
pointshare::scalar_mul_public_point(a, b)
}

fn add_assign_points_g2(a: &mut Self::PointShareG2, b: &Self::PointShareG2) {
rep3::pointshare::add_assign(a, b)
pointshare::add_assign(a, b)
}

fn add_assign_points_public_g2(id: Self::PartyID, a: &mut Self::PointShareG2, b: &P::G2) {
rep3::pointshare::add_assign_public(a, b, id)
pointshare::add_assign_public(a, b, id)
}

async fn open_two_points(
Expand Down
196 changes: 195 additions & 1 deletion co-circom/co-groth16/src/mpc/shamir.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,200 @@
use super::{CircomGroth16Prover, IoResult};
use ark_ec::pairing::Pairing;
use ark_ff::PrimeField;
use mpc_core::protocols::shamir::{network::ShamirNetwork, ShamirProtocol};
use itertools::izip;
use mpc_core::protocols::shamir::{
arithmetic, network::ShamirNetwork, pointshare, ShamirPointShare, ShamirPrimeFieldShare,
ShamirProtocol,
};

pub struct ShamirGroth16Driver<F: PrimeField, N: ShamirNetwork> {
protocol: ShamirProtocol<F, N>,
}

impl<F: PrimeField, N: ShamirNetwork> ShamirGroth16Driver<F, N> {
pub fn new(protocol: ShamirProtocol<F, N>) -> Self {
Self { protocol }
}
}

impl<P: Pairing, N: ShamirNetwork> CircomGroth16Prover<P>
for ShamirGroth16Driver<P::ScalarField, N>
{
type ArithmeticShare = ShamirPrimeFieldShare<P::ScalarField>;
type PointShareG1 = ShamirPointShare<P::G1>;
type PointShareG2 = ShamirPointShare<P::G2>;

type PartyID = usize;

async fn rand(&mut self) -> IoResult<Self::ArithmeticShare> {
self.protocol.rand().await
}

fn get_party_id(&self) -> Self::PartyID {
self.protocol.network.get_id()
}

fn fork(&mut self) -> Self {
todo!()
}

fn evaluate_constraint(
_party_id: Self::PartyID,
lhs: &[(<P as Pairing>::ScalarField, usize)],
public_inputs: &[<P as Pairing>::ScalarField],
private_witness: &[Self::ArithmeticShare],
) -> Self::ArithmeticShare {
let mut acc = Self::ArithmeticShare::default();
for (coeff, index) in lhs {
if index < &public_inputs.len() {
let val = public_inputs[*index];
let mul_result = val * coeff;
arithmetic::add_assign_public(&mut acc, mul_result);
} else {
let current_witness = private_witness[*index - public_inputs.len()];
arithmetic::add_assign(&mut acc, arithmetic::mul_public(current_witness, *coeff));
}
}
acc
}

fn promote_to_trivial_shares(
id: Self::PartyID,
public_values: &[<P as Pairing>::ScalarField],
) -> Vec<Self::ArithmeticShare> {
todo!()
}

fn sub_assign_vec(a: &mut [Self::ArithmeticShare], b: &[Self::ArithmeticShare]) {
for (a, b) in izip!(a, b) {
arithmetic::sub_assign(a, *b);
}
}

async fn mul(
&mut self,
a: Self::ArithmeticShare,
b: Self::ArithmeticShare,
) -> IoResult<Self::ArithmeticShare> {
arithmetic::mul(a, b, &mut self.protocol).await
}

async fn mul_vec(
&mut self,
a: &[Self::ArithmeticShare],
b: &[Self::ArithmeticShare],
) -> IoResult<Vec<Self::ArithmeticShare>> {
arithmetic::mul_vec(a, b, &mut self.protocol).await
}

fn fft_in_place<D: ark_poly::EvaluationDomain<<P as Pairing>::ScalarField>>(
data: &mut Vec<Self::ArithmeticShare>,
domain: &D,
) {
domain.fft_in_place(data)
}

fn ifft_in_place<D: ark_poly::EvaluationDomain<<P as Pairing>::ScalarField>>(
data: &mut Vec<Self::ArithmeticShare>,
domain: &D,
) {
domain.ifft_in_place(data)
}

fn ifft<D: ark_poly::EvaluationDomain<<P as Pairing>::ScalarField>>(
data: &[Self::ArithmeticShare],
domain: &D,
) -> Vec<Self::ArithmeticShare> {
domain.ifft(&data)
}

fn distribute_powers_and_mul_by_const(
coeffs: &mut [Self::ArithmeticShare],
g: <P as Pairing>::ScalarField,
c: <P as Pairing>::ScalarField,
) {
let mut pow = c;
for share in coeffs.iter_mut() {
arithmetic::mul_assign_public(share, pow);
pow *= g;
}
}

fn msm_public_points_g1(
points: &[<P as Pairing>::G1Affine],
scalars: &[Self::ArithmeticShare],
) -> Self::PointShareG1 {
pointshare::msm_public_points(points, scalars)
}

fn msm_public_points_g2(
points: &[<P as Pairing>::G2Affine],
scalars: &[Self::ArithmeticShare],
) -> Self::PointShareG2 {
pointshare::msm_public_points(points, scalars)
}

fn scalar_mul_public_point_g1(
a: &<P as Pairing>::G1,
b: Self::ArithmeticShare,
) -> Self::PointShareG1 {
pointshare::scalar_mul_public_point(b, a)
}

fn add_assign_points_g1(a: &mut Self::PointShareG1, b: &Self::PointShareG1) {
pointshare::add_assign(a, b)
}

fn add_assign_points_public_g1(
_id: Self::PartyID,
a: &mut Self::PointShareG1,
b: &<P as Pairing>::G1,
) {
pointshare::add_assign_public(a, b)
}

async fn open_point_g1(&mut self, a: &Self::PointShareG1) -> IoResult<<P as Pairing>::G1> {
pointshare::open_point(a, &mut self.protocol).await
}

async fn scalar_mul_g1(
&mut self,
a: &Self::PointShareG1,
b: Self::ArithmeticShare,
) -> IoResult<Self::PointShareG1> {
pointshare::scalar_mul(a, b, &mut self.protocol).await
}

fn sub_assign_points_g1(a: &mut Self::PointShareG1, b: &Self::PointShareG1) {
pointshare::sub_assign(a, b);
}

fn scalar_mul_public_point_g2(
a: &<P as Pairing>::G2,
b: Self::ArithmeticShare,
) -> Self::PointShareG2 {
pointshare::scalar_mul_public_point(b, a)
}

fn add_assign_points_g2(a: &mut Self::PointShareG2, b: &Self::PointShareG2) {
pointshare::add_assign(a, b)
}

fn add_assign_points_public_g2(
_id: Self::PartyID,
a: &mut Self::PointShareG2,
b: &<P as Pairing>::G2,
) {
pointshare::add_assign_public(a, b)
}

async fn open_two_points(
&mut self,
a: Self::PointShareG1,
b: Self::PointShareG2,
) -> std::io::Result<(<P as Pairing>::G1, <P as Pairing>::G2)> {
let a_res = pointshare::open_point(&a, &mut self.protocol).await?;
let b_res = pointshare::open_point(&b, &mut self.protocol).await?;
Ok((a_res, b_res))
}
}
4 changes: 2 additions & 2 deletions mpc-core/src/protocols/shamir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ pub struct ShamirProtocol<F: PrimeField, N: ShamirNetwork> {
pub(crate) open_lagrange_2t: Vec<F>,
mul_lagrange_2t: Vec<F>,
rng_buffer: ShamirRng<F>,
network: N,
pub network: N,
field: PhantomData<F>,
}

Expand Down Expand Up @@ -223,7 +223,7 @@ impl<F: PrimeField, N: ShamirNetwork> ShamirProtocol<F, N> {
.await
}

pub(crate) async fn rand(&mut self) -> IoResult<ShamirPrimeFieldShare<F>> {
pub async fn rand(&mut self) -> IoResult<ShamirPrimeFieldShare<F>> {
let (r, _) = self.rng_buffer.get_pair(&mut self.network).await?;
Ok(ShamirPrimeFieldShare::new(r))
}
Expand Down
Loading

0 comments on commit c11db99

Please sign in to comment.