Skip to content

Commit

Permalink
inner-product-proof: Compute round reduction/folding in parallel
Browse files Browse the repository at this point in the history
  • Loading branch information
joeykraut committed Aug 19, 2023
1 parent 5ffd2a2 commit 6884f1e
Showing 1 changed file with 105 additions and 51 deletions.
156 changes: 105 additions & 51 deletions src/inner_product_proof.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,26 @@ extern crate alloc;

use alloc::borrow::Borrow;
use alloc::vec::Vec;
use itertools::Itertools;
use mpc_stark::algebra::scalar::{Scalar, SCALAR_BYTES};
use mpc_stark::algebra::stark_curve::{StarkPoint, STARK_POINT_BYTES};
use rayon::prelude::*;

Check failure on line 11 in src/inner_product_proof.rs

View workflow job for this annotation

GitHub Actions / clippy

failed to resolve: use of undeclared crate or module `rayon`

error[E0433]: failed to resolve: use of undeclared crate or module `rayon` --> src/inner_product_proof.rs:11:5 | 11 | use rayon::prelude::*; | ^^^^^ use of undeclared crate or module `rayon`
use unzip_n::unzip_n;

Check failure on line 12 in src/inner_product_proof.rs

View workflow job for this annotation

GitHub Actions / clippy

unresolved import `unzip_n`

error[E0432]: unresolved import `unzip_n` --> src/inner_product_proof.rs:12:5 | 12 | use unzip_n::unzip_n; | ^^^^^^^ use of undeclared crate or module `unzip_n`

use core::iter;
use merlin::HashChainTranscript as Transcript;

use crate::errors::ProofError;
use crate::transcript::TranscriptProtocol;

unzip_n!(4);

/// The size of the inner product proof above which we execute folding operations
/// in parallel
///
/// Copied from `mpc-stark`
const PARALLELISM_THRESHOLD: usize = 10;

#[derive(Clone, Debug, PartialEq, Eq)]
pub struct InnerProductProof {
pub L_vec: Vec<StarkPoint>,
Expand Down Expand Up @@ -45,21 +56,13 @@ impl InnerProductProof {
mut a_vec: Vec<Scalar>,
mut b_vec: Vec<Scalar>,
) -> InnerProductProof {
// Create slices G, H, a, b backed by their respective
// vectors. This lets us reslice as we compress the lengths
// of the vectors in the main loop below.
let mut G = &mut G_vec[..];
let mut H = &mut H_vec[..];
let mut a = &mut a_vec[..];
let mut b = &mut b_vec[..];

let mut n = G.len();
let mut n = G_vec.len();

// All of the input vectors must have the same length.
assert_eq!(G.len(), n);
assert_eq!(H.len(), n);
assert_eq!(a.len(), n);
assert_eq!(b.len(), n);
assert_eq!(G_vec.len(), n);
assert_eq!(H_vec.len(), n);
assert_eq!(a_vec.len(), n);
assert_eq!(b_vec.len(), n);
assert_eq!(G_factors.len(), n);
assert_eq!(H_factors.len(), n);

Expand All @@ -76,10 +79,10 @@ impl InnerProductProof {
// into multiscalar muls, for performance.
if n != 1 {
n /= 2;
let (a_L, a_R) = a.split_at_mut(n);
let (b_L, b_R) = b.split_at_mut(n);
let (G_L, G_R) = G.split_at_mut(n);
let (H_L, H_R) = H.split_at_mut(n);
let (a_L, a_R) = a_vec.split_at_mut(n);
let (b_L, b_R) = b_vec.split_at_mut(n);
let (G_L, G_R) = G_vec.split_at_mut(n);
let (H_L, H_R) = H_vec.split_at_mut(n);

let c_L = inner_product(a_L, b_R);
let c_R = inner_product(a_R, b_L);
Expand Down Expand Up @@ -119,31 +122,36 @@ impl InnerProductProof {
let u = transcript.challenge_scalar(b"u");
let u_inv = u.inverse();

for i in 0..n {
a_L[i] = a_L[i] * u + u_inv * a_R[i];
b_L[i] = b_L[i] * u_inv + u * b_R[i];
G_L[i] = StarkPoint::msm(
&[u_inv * G_factors[i], u * G_factors[n + i]],
&[G_L[i], G_R[i]],
);
H_L[i] = StarkPoint::msm(
&[u * H_factors[i], u_inv * H_factors[n + i]],
&[H_L[i], H_R[i]],
);
}

a = a_L;
b = b_L;
G = G_L;
H = H_L;
let G = G_factors
.iter()
.zip(G_vec.into_iter())
.map(|(g, G_i)| g * G_i)
.collect_vec();
let H = H_factors
.iter()
.zip(H_vec.into_iter())
.map(|(h, H_i)| h * H_i)
.collect_vec();
(a_vec, b_vec, G_vec, H_vec) = Self::fold_witness(
u,
u_inv,
a_L,
a_R,
b_L,
b_R,
&G[..n],
&G[n..],
&H[..n],
&H[n..],
);
}

while n != 1 {
n /= 2;
let (a_L, a_R) = a.split_at_mut(n);
let (b_L, b_R) = b.split_at_mut(n);
let (G_L, G_R) = G.split_at_mut(n);
let (H_L, H_R) = H.split_at_mut(n);
let (a_L, a_R) = a_vec.split_at_mut(n);
let (b_L, b_R) = b_vec.split_at_mut(n);
let (G_L, G_R) = G_vec.split_at_mut(n);
let (H_L, H_R) = H_vec.split_at_mut(n);

let c_L = inner_product(a_L, b_R);
let c_R = inner_product(a_R, b_L);
Expand Down Expand Up @@ -172,27 +180,73 @@ impl InnerProductProof {
let u = transcript.challenge_scalar(b"u");
let u_inv = u.inverse();

for i in 0..n {
a_L[i] = a_L[i] * u + u_inv * a_R[i];
b_L[i] = b_L[i] * u_inv + u * b_R[i];
G_L[i] = StarkPoint::msm(&[u_inv, u], &[G_L[i], G_R[i]]);
H_L[i] = StarkPoint::msm(&[u, u_inv], &[H_L[i], H_R[i]]);
}

a = a_L;
b = b_L;
G = G_L;
H = H_L;
(a_vec, b_vec, G_vec, H_vec) =
Self::fold_witness(u, u_inv, a_L, a_R, b_L, b_R, G_L, G_R, H_L, H_R);
}

InnerProductProof {
L_vec,
R_vec,
a: a[0],
b: b[0],
a: a_vec[0],
b: b_vec[0],
}
}

/// Reduces the inner product proof witness in half by folding the elements via
/// a linear combination with multiplicative inverses
///
/// See equation (4) of the Bulletproof paper:
/// https://eprint.iacr.org/2017/1066.pdf
///
/// Returns the new values of a, b, G, H
fn fold_witness(
u: Scalar,
u_inv: Scalar,
a_L: &[Scalar],
a_R: &[Scalar],
b_L: &[Scalar],
b_R: &[Scalar],
G_L: &[StarkPoint],
G_R: &[StarkPoint],
H_L: &[StarkPoint],
H_R: &[StarkPoint],
) -> (Vec<Scalar>, Vec<Scalar>, Vec<StarkPoint>, Vec<StarkPoint>) {
let n = a_L.len();

// For small proofs, compute serially to avoid parallelism overhead
if n < PARALLELISM_THRESHOLD {
let mut a_res = Vec::with_capacity(n / 2);
let mut b_res = Vec::with_capacity(n / 2);
let mut G_res = Vec::with_capacity(n / 2);
let mut H_res = Vec::with_capacity(n / 2);

for i in 0..n {
a_res.push(a_L[i] * u + u_inv * a_R[i]);
b_res.push(b_L[i] * u_inv + u * b_R[i]);
G_res.push(StarkPoint::msm(&[u_inv, u], &[G_L[i], G_R[i]]));
H_res.push(StarkPoint::msm(&[u, u_inv], &[H_L[i], H_R[i]]));
}

return (a_res, b_res, G_res, H_res);
}

// Parallel implementation
let mut res = Vec::with_capacity(n);
(0..n)
.into_par_iter()

Check failure on line 236 in src/inner_product_proof.rs

View workflow job for this annotation

GitHub Actions / clippy

no method named `into_par_iter` found for struct `std::ops::Range` in the current scope

error[E0599]: no method named `into_par_iter` found for struct `std::ops::Range` in the current scope --> src/inner_product_proof.rs:236:14 | 235 | / (0..n) 236 | | .into_par_iter() | | -^^^^^^^^^^^^^ method not found in `Range<usize>` | |_____________| |
.map(|i| {
(
a_L[i] * u + u_inv * a_R[i],
b_L[i] * u_inv + u * b_R[i],
StarkPoint::msm(&[u_inv, u], &[G_L[i], G_R[i]]),
StarkPoint::msm(&[u, u_inv], &[H_L[i], H_R[i]]),
)
})
.collect_into_vec(&mut res);

res.into_iter().unzip_n_vec()

Check failure on line 247 in src/inner_product_proof.rs

View workflow job for this annotation

GitHub Actions / clippy

no method named `unzip_n_vec` found for struct `std::vec::IntoIter` in the current scope

error[E0599]: no method named `unzip_n_vec` found for struct `std::vec::IntoIter` in the current scope --> src/inner_product_proof.rs:247:25 | 247 | res.into_iter().unzip_n_vec() | ^^^^^^^^^^^ method not found in `IntoIter<_>`
}

/// Computes three vectors of verification scalars \\([u\_{i}^{2}]\\), \\([u\_{i}^{-2}]\\) and \\([s\_{i}]\\) for combined multiscalar multiplication
/// in a parent protocol. See [inner product protocol notes](index.html#verification-equation) for details.
/// The verifier must provide the input length \\(n\\) explicitly to avoid unbounded allocation within the inner product proof.
Expand Down

0 comments on commit 6884f1e

Please sign in to comment.