diff --git a/src/inner_product_proof.rs b/src/inner_product_proof.rs index 6674ef50..3efd07dd 100644 --- a/src/inner_product_proof.rs +++ b/src/inner_product_proof.rs @@ -5,8 +5,11 @@ extern crate alloc; use alloc::borrow::Borrow; use alloc::vec::Vec; +use itertools::Itertools; use mpc_stark::algebra::scalar::{Scalar, SCALAR_BYTES}; use mpc_stark::algebra::stark_curve::{StarkPoint, STARK_POINT_BYTES}; +use rayon::prelude::*; +use unzip_n::unzip_n; use core::iter; use merlin::HashChainTranscript as Transcript; @@ -14,6 +17,14 @@ use merlin::HashChainTranscript as Transcript; use crate::errors::ProofError; use crate::transcript::TranscriptProtocol; +unzip_n!(4); + +/// The size of the inner product proof above which we execute folding operations +/// in parallel +/// +/// Copied from `mpc-stark` +const PARALLELISM_THRESHOLD: usize = 10; + #[derive(Clone, Debug, PartialEq, Eq)] pub struct InnerProductProof { pub L_vec: Vec, @@ -45,21 +56,13 @@ impl InnerProductProof { mut a_vec: Vec, mut b_vec: Vec, ) -> InnerProductProof { - // Create slices G, H, a, b backed by their respective - // vectors. This lets us reslice as we compress the lengths - // of the vectors in the main loop below. - let mut G = &mut G_vec[..]; - let mut H = &mut H_vec[..]; - let mut a = &mut a_vec[..]; - let mut b = &mut b_vec[..]; - - let mut n = G.len(); + let mut n = G_vec.len(); // All of the input vectors must have the same length. - assert_eq!(G.len(), n); - assert_eq!(H.len(), n); - assert_eq!(a.len(), n); - assert_eq!(b.len(), n); + assert_eq!(G_vec.len(), n); + assert_eq!(H_vec.len(), n); + assert_eq!(a_vec.len(), n); + assert_eq!(b_vec.len(), n); assert_eq!(G_factors.len(), n); assert_eq!(H_factors.len(), n); @@ -76,10 +79,10 @@ impl InnerProductProof { // into multiscalar muls, for performance. if n != 1 { n /= 2; - let (a_L, a_R) = a.split_at_mut(n); - let (b_L, b_R) = b.split_at_mut(n); - let (G_L, G_R) = G.split_at_mut(n); - let (H_L, H_R) = H.split_at_mut(n); + let (a_L, a_R) = a_vec.split_at_mut(n); + let (b_L, b_R) = b_vec.split_at_mut(n); + let (G_L, G_R) = G_vec.split_at_mut(n); + let (H_L, H_R) = H_vec.split_at_mut(n); let c_L = inner_product(a_L, b_R); let c_R = inner_product(a_R, b_L); @@ -119,31 +122,36 @@ impl InnerProductProof { let u = transcript.challenge_scalar(b"u"); let u_inv = u.inverse(); - for i in 0..n { - a_L[i] = a_L[i] * u + u_inv * a_R[i]; - b_L[i] = b_L[i] * u_inv + u * b_R[i]; - G_L[i] = StarkPoint::msm( - &[u_inv * G_factors[i], u * G_factors[n + i]], - &[G_L[i], G_R[i]], - ); - H_L[i] = StarkPoint::msm( - &[u * H_factors[i], u_inv * H_factors[n + i]], - &[H_L[i], H_R[i]], - ); - } - - a = a_L; - b = b_L; - G = G_L; - H = H_L; + let G = G_factors + .iter() + .zip(G_vec.into_iter()) + .map(|(g, G_i)| g * G_i) + .collect_vec(); + let H = H_factors + .iter() + .zip(H_vec.into_iter()) + .map(|(h, H_i)| h * H_i) + .collect_vec(); + (a_vec, b_vec, G_vec, H_vec) = Self::fold_witness( + u, + u_inv, + a_L, + a_R, + b_L, + b_R, + &G[..n], + &G[n..], + &H[..n], + &H[n..], + ); } while n != 1 { n /= 2; - let (a_L, a_R) = a.split_at_mut(n); - let (b_L, b_R) = b.split_at_mut(n); - let (G_L, G_R) = G.split_at_mut(n); - let (H_L, H_R) = H.split_at_mut(n); + let (a_L, a_R) = a_vec.split_at_mut(n); + let (b_L, b_R) = b_vec.split_at_mut(n); + let (G_L, G_R) = G_vec.split_at_mut(n); + let (H_L, H_R) = H_vec.split_at_mut(n); let c_L = inner_product(a_L, b_R); let c_R = inner_product(a_R, b_L); @@ -172,27 +180,73 @@ impl InnerProductProof { let u = transcript.challenge_scalar(b"u"); let u_inv = u.inverse(); - for i in 0..n { - a_L[i] = a_L[i] * u + u_inv * a_R[i]; - b_L[i] = b_L[i] * u_inv + u * b_R[i]; - G_L[i] = StarkPoint::msm(&[u_inv, u], &[G_L[i], G_R[i]]); - H_L[i] = StarkPoint::msm(&[u, u_inv], &[H_L[i], H_R[i]]); - } - - a = a_L; - b = b_L; - G = G_L; - H = H_L; + (a_vec, b_vec, G_vec, H_vec) = + Self::fold_witness(u, u_inv, a_L, a_R, b_L, b_R, G_L, G_R, H_L, H_R); } InnerProductProof { L_vec, R_vec, - a: a[0], - b: b[0], + a: a_vec[0], + b: b_vec[0], } } + /// Reduces the inner product proof witness in half by folding the elements via + /// a linear combination with multiplicative inverses + /// + /// See equation (4) of the Bulletproof paper: + /// https://eprint.iacr.org/2017/1066.pdf + /// + /// Returns the new values of a, b, G, H + fn fold_witness( + u: Scalar, + u_inv: Scalar, + a_L: &[Scalar], + a_R: &[Scalar], + b_L: &[Scalar], + b_R: &[Scalar], + G_L: &[StarkPoint], + G_R: &[StarkPoint], + H_L: &[StarkPoint], + H_R: &[StarkPoint], + ) -> (Vec, Vec, Vec, Vec) { + let n = a_L.len(); + + // For small proofs, compute serially to avoid parallelism overhead + if n < PARALLELISM_THRESHOLD { + let mut a_res = Vec::with_capacity(n / 2); + let mut b_res = Vec::with_capacity(n / 2); + let mut G_res = Vec::with_capacity(n / 2); + let mut H_res = Vec::with_capacity(n / 2); + + for i in 0..n { + a_res.push(a_L[i] * u + u_inv * a_R[i]); + b_res.push(b_L[i] * u_inv + u * b_R[i]); + G_res.push(StarkPoint::msm(&[u_inv, u], &[G_L[i], G_R[i]])); + H_res.push(StarkPoint::msm(&[u, u_inv], &[H_L[i], H_R[i]])); + } + + return (a_res, b_res, G_res, H_res); + } + + // Parallel implementation + let mut res = Vec::with_capacity(n); + (0..n) + .into_par_iter() + .map(|i| { + ( + a_L[i] * u + u_inv * a_R[i], + b_L[i] * u_inv + u * b_R[i], + StarkPoint::msm(&[u_inv, u], &[G_L[i], G_R[i]]), + StarkPoint::msm(&[u, u_inv], &[H_L[i], H_R[i]]), + ) + }) + .collect_into_vec(&mut res); + + res.into_iter().unzip_n_vec() + } + /// Computes three vectors of verification scalars \\([u\_{i}^{2}]\\), \\([u\_{i}^{-2}]\\) and \\([s\_{i}]\\) for combined multiscalar multiplication /// in a parent protocol. See [inner product protocol notes](index.html#verification-equation) for details. /// The verifier must provide the input length \\(n\\) explicitly to avoid unbounded allocation within the inner product proof.