Skip to content

Commit

Permalink
inner-product-proof: Compute round reduction/folding in parallel
Browse files Browse the repository at this point in the history
  • Loading branch information
joeykraut committed Aug 19, 2023
1 parent 08f8ef4 commit 58c0f28
Show file tree
Hide file tree
Showing 3 changed files with 134 additions and 69 deletions.
15 changes: 11 additions & 4 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ description = "A pure-Rust implementation of collaboratively proved Bulletproofs
edition = "2021"

[dependencies]
ark-ec = "0.4"
ark-ff = "0.4"
ark-serialize = "0.4"
futures = "0.3"
Expand All @@ -24,6 +23,7 @@ sha3 = { version = "0.8", default-features = false }
digest = { version = "0.8", default-features = false }
rand_core = { version = "0.5", default-features = false, features = ["alloc"] }
rand = { version = "0.8", default-features = false, optional = true }
rayon = "1"
byteorder = { version = "1", default-features = false }
num-bigint = "0.4"
itertools = "0.10"
Expand All @@ -32,6 +32,7 @@ serde_derive = { version = "1", default-features = false }
thiserror = { version = "1", optional = true }
tokio = { version = "1.12", features = ["macros", "rt-multi-thread"] }
merlin = { git = "https://github.com/renegade-fi/merlin" }
unzip-n = "0.1"

[dev-dependencies]
async-std = "1.12"
Expand All @@ -56,19 +57,25 @@ integration_test = []

[[test]]
name = "r1cs"
required_features = ["multiprover"]

[[test]]
name = "integration"
path = "integration/main.rs"
harness = false
required_features = ["integration_test", "multiprover"]

[[bench]]
name = "generators"
harness = false

[[bench]]
name = "r1cs"
name = "shuffle"
harness = false
required-features = ["multiprover"]

[[bench]]
name = "r1cs"
harness = false

[[bench]]
name = "inner_product"
harness = false
32 changes: 18 additions & 14 deletions benches/r1cs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@ use std::time::{Duration, Instant};
use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion};
use merlin::HashChainTranscript;
use mpc_bulletproof::{
r1cs::{ConstraintSystem, Prover, R1CSProof, Verifier},
r1cs::{Prover, R1CSProof, RandomizableConstraintSystem, Verifier},
BulletproofGens, PedersenGens,
};
use mpc_stark::{algebra::scalar::Scalar, random_point};
use mpc_stark::algebra::scalar::Scalar;
use rand::thread_rng;

/// The max number of constraints to benchmark
Expand All @@ -18,6 +18,20 @@ const MAX_CONSTRAINTS_LN: usize = 10; // 2^10 = 1024
// | Helpers |
// -----------

struct DummyCircuit;
impl DummyCircuit {
/// Apply dummy constraints to a given proof system
pub fn apply_constraints<CS: RandomizableConstraintSystem>(n_constraints: usize, cs: &mut CS) {
let mut rng = thread_rng();
let val = Scalar::random(&mut rng);
let mut var = cs.commit_public(val);

for _ in 0..n_constraints {
(_, _, var) = cs.multiply(var.into(), var.into());
}
}
}

/// Benchmark a prover with a given number of constraints
fn bench_prover_with_size(n_constraints: usize, c: &mut Criterion) {
assert!(n_constraints.is_power_of_two());
Expand Down Expand Up @@ -62,11 +76,7 @@ fn bench_verifier_with_size(n_constraints: usize, c: &mut Criterion) {
let bp_gens = BulletproofGens::new(n_constraints, 1 /* party_capacity */);

// Apply the constraints
let mut var = verifier.commit(random_point());
for _ in 0..n_constraints {
let (_, _, new_var) = verifier.multiply(var.into(), var.into());
var = new_var;
}
DummyCircuit::apply_constraints(n_constraints, &mut verifier);

// Verify the proof
let start_time = Instant::now();
Expand All @@ -89,13 +99,7 @@ fn prove_sized_statement_with_timer(n_constraints: usize) -> (R1CSProof, Duratio
let bp_gens = BulletproofGens::new(n_constraints, 1 /* party_capacity */);

// Allocate `n_constraints` constraints
let mut rng = thread_rng();
let val = Scalar::random(&mut rng);
let (_, mut var) = prover.commit(val, Scalar::random(&mut rng));

for _ in 0..n_constraints {
(_, _, var) = prover.multiply(var.into(), var.into());
}
DummyCircuit::apply_constraints(n_constraints, &mut prover);

// Only time proof generation
let start_time = Instant::now();
Expand Down
156 changes: 105 additions & 51 deletions src/inner_product_proof.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,26 @@ extern crate alloc;

use alloc::borrow::Borrow;
use alloc::vec::Vec;
use itertools::Itertools;
use mpc_stark::algebra::scalar::{Scalar, SCALAR_BYTES};
use mpc_stark::algebra::stark_curve::{StarkPoint, STARK_POINT_BYTES};
use rayon::prelude::*;
use unzip_n::unzip_n;

use core::iter;
use merlin::HashChainTranscript as Transcript;

use crate::errors::ProofError;
use crate::transcript::TranscriptProtocol;

unzip_n!(4);

/// The size of the inner product proof above which we execute folding operations
/// in parallel
///
/// Copied from `mpc-stark`
const PARALLELISM_THRESHOLD: usize = 10;

#[derive(Clone, Debug, PartialEq, Eq)]
pub struct InnerProductProof {
pub L_vec: Vec<StarkPoint>,
Expand Down Expand Up @@ -45,21 +56,13 @@ impl InnerProductProof {
mut a_vec: Vec<Scalar>,
mut b_vec: Vec<Scalar>,
) -> InnerProductProof {
// Create slices G, H, a, b backed by their respective
// vectors. This lets us reslice as we compress the lengths
// of the vectors in the main loop below.
let mut G = &mut G_vec[..];
let mut H = &mut H_vec[..];
let mut a = &mut a_vec[..];
let mut b = &mut b_vec[..];

let mut n = G.len();
let mut n = G_vec.len();

// All of the input vectors must have the same length.
assert_eq!(G.len(), n);
assert_eq!(H.len(), n);
assert_eq!(a.len(), n);
assert_eq!(b.len(), n);
assert_eq!(G_vec.len(), n);
assert_eq!(H_vec.len(), n);
assert_eq!(a_vec.len(), n);
assert_eq!(b_vec.len(), n);
assert_eq!(G_factors.len(), n);
assert_eq!(H_factors.len(), n);

Expand All @@ -76,10 +79,10 @@ impl InnerProductProof {
// into multiscalar muls, for performance.
if n != 1 {
n /= 2;
let (a_L, a_R) = a.split_at_mut(n);
let (b_L, b_R) = b.split_at_mut(n);
let (G_L, G_R) = G.split_at_mut(n);
let (H_L, H_R) = H.split_at_mut(n);
let (a_L, a_R) = a_vec.split_at_mut(n);
let (b_L, b_R) = b_vec.split_at_mut(n);
let (G_L, G_R) = G_vec.split_at_mut(n);
let (H_L, H_R) = H_vec.split_at_mut(n);

let c_L = inner_product(a_L, b_R);
let c_R = inner_product(a_R, b_L);
Expand Down Expand Up @@ -119,31 +122,36 @@ impl InnerProductProof {
let u = transcript.challenge_scalar(b"u");
let u_inv = u.inverse();

for i in 0..n {
a_L[i] = a_L[i] * u + u_inv * a_R[i];
b_L[i] = b_L[i] * u_inv + u * b_R[i];
G_L[i] = StarkPoint::msm(
&[u_inv * G_factors[i], u * G_factors[n + i]],
&[G_L[i], G_R[i]],
);
H_L[i] = StarkPoint::msm(
&[u * H_factors[i], u_inv * H_factors[n + i]],
&[H_L[i], H_R[i]],
);
}

a = a_L;
b = b_L;
G = G_L;
H = H_L;
let G = G_factors
.iter()
.zip(G_vec.into_iter())
.map(|(g, G_i)| g * G_i)
.collect_vec();
let H = H_factors
.iter()
.zip(H_vec.into_iter())
.map(|(h, H_i)| h * H_i)
.collect_vec();
(a_vec, b_vec, G_vec, H_vec) = Self::fold_witness(
u,
u_inv,
a_L,
a_R,
b_L,
b_R,
&G[..n],
&G[n..],
&H[..n],
&H[n..],
);
}

while n != 1 {
n /= 2;
let (a_L, a_R) = a.split_at_mut(n);
let (b_L, b_R) = b.split_at_mut(n);
let (G_L, G_R) = G.split_at_mut(n);
let (H_L, H_R) = H.split_at_mut(n);
let (a_L, a_R) = a_vec.split_at_mut(n);
let (b_L, b_R) = b_vec.split_at_mut(n);
let (G_L, G_R) = G_vec.split_at_mut(n);
let (H_L, H_R) = H_vec.split_at_mut(n);

let c_L = inner_product(a_L, b_R);
let c_R = inner_product(a_R, b_L);
Expand Down Expand Up @@ -172,27 +180,73 @@ impl InnerProductProof {
let u = transcript.challenge_scalar(b"u");
let u_inv = u.inverse();

for i in 0..n {
a_L[i] = a_L[i] * u + u_inv * a_R[i];
b_L[i] = b_L[i] * u_inv + u * b_R[i];
G_L[i] = StarkPoint::msm(&[u_inv, u], &[G_L[i], G_R[i]]);
H_L[i] = StarkPoint::msm(&[u, u_inv], &[H_L[i], H_R[i]]);
}

a = a_L;
b = b_L;
G = G_L;
H = H_L;
(a_vec, b_vec, G_vec, H_vec) =
Self::fold_witness(u, u_inv, a_L, a_R, b_L, b_R, G_L, G_R, H_L, H_R);
}

InnerProductProof {
L_vec,
R_vec,
a: a[0],
b: b[0],
a: a_vec[0],
b: b_vec[0],
}
}

/// Reduces the inner product proof witness in half by folding the elements via
/// a linear combination with multiplicative inverses
///
/// See equation (4) of the Bulletproof paper:
/// https://eprint.iacr.org/2017/1066.pdf
///
/// Returns the new values of a, b, G, H
fn fold_witness(
u: Scalar,
u_inv: Scalar,
a_L: &[Scalar],
a_R: &[Scalar],
b_L: &[Scalar],
b_R: &[Scalar],
G_L: &[StarkPoint],
G_R: &[StarkPoint],
H_L: &[StarkPoint],
H_R: &[StarkPoint],
) -> (Vec<Scalar>, Vec<Scalar>, Vec<StarkPoint>, Vec<StarkPoint>) {
let n = a_L.len();

// For small proofs, compute serially to avoid parallelism overhead
if n < PARALLELISM_THRESHOLD {
let mut a_res = Vec::with_capacity(n / 2);
let mut b_res = Vec::with_capacity(n / 2);
let mut G_res = Vec::with_capacity(n / 2);
let mut H_res = Vec::with_capacity(n / 2);

for i in 0..n {
a_res.push(a_L[i] * u + u_inv * a_R[i]);
b_res.push(b_L[i] * u_inv + u * b_R[i]);
G_res.push(StarkPoint::msm(&[u_inv, u], &[G_L[i], G_R[i]]));
H_res.push(StarkPoint::msm(&[u, u_inv], &[H_L[i], H_R[i]]));
}

return (a_res, b_res, G_res, H_res);
}

// Parallel implementation
let mut res = Vec::with_capacity(n);
(0..n)
.into_par_iter()
.map(|i| {
(
a_L[i] * u + u_inv * a_R[i],
b_L[i] * u_inv + u * b_R[i],
StarkPoint::msm(&[u_inv, u], &[G_L[i], G_R[i]]),
StarkPoint::msm(&[u, u_inv], &[H_L[i], H_R[i]]),
)
})
.collect_into_vec(&mut res);

res.into_iter().unzip_n_vec()
}

/// Computes three vectors of verification scalars \\([u\_{i}^{2}]\\), \\([u\_{i}^{-2}]\\) and \\([s\_{i}]\\) for combined multiscalar multiplication
/// in a parent protocol. See [inner product protocol notes](index.html#verification-equation) for details.
/// The verifier must provide the input length \\(n\\) explicitly to avoid unbounded allocation within the inner product proof.
Expand Down

0 comments on commit 58c0f28

Please sign in to comment.