Skip to content

Commit

Permalink
Add methods for adding/subtracting vectors
Browse files Browse the repository at this point in the history
Add a method `add_assign_vector` that replaces the following zip
pattern:

```
for (x, y) in a.iter_mut().zip(y) {
    *x += *y;
}
```

The method also panics if `b` is shorter than `a`.

Replace the zip pattern with this method wherever we add up vectors.
This helps us ensure that we're not accidentally adding a vector that's
shorter than the other.

Likewise for subtracting vectors.
  • Loading branch information
cjpatton committed Jan 18, 2025
1 parent 9eac4c1 commit fbb7f32
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 99 deletions.
31 changes: 25 additions & 6 deletions src/field.rs
Original file line number Diff line number Diff line change
Expand Up @@ -867,9 +867,7 @@ pub(crate) fn merge_vector<F: FieldElement>(
if accumulator.len() != other_vector.len() {
return Err(FieldError::InputSizeMismatch);
}
for (a, o) in accumulator.iter_mut().zip(other_vector.iter()) {
*a += *o;
}
add_assign_vector(accumulator, other_vector.iter().copied());

Ok(())
}
Expand All @@ -886,15 +884,36 @@ pub(crate) fn split_vector<F: FieldElement>(inp: &[F], num_shares: usize) -> Vec

for _ in 1..num_shares {
let share: Vec<F> = random_vector(inp.len());
for (x, y) in outp[0].iter_mut().zip(&share) {
*x -= *y;
}
sub_assign_vector(&mut outp[0], share.iter().copied());
outp.push(share);
}

outp
}

pub(crate) fn sub_assign_vector<F: FieldElement>(a: &mut [F], b: impl IntoIterator<Item = F>) {
let mut count = 0;
for (x, y) in a.iter_mut().zip(b) {
*x -= y;
count += 1;
}
assert_eq!(a.len(), count);
}

pub(crate) fn add_assign_vector<F: FieldElement>(a: &mut [F], b: impl IntoIterator<Item = F>) {
let mut count = 0;
for (x, y) in a.iter_mut().zip(b) {
*x += y;
count += 1;
}
assert_eq!(a.len(), count);
}

pub(crate) fn add_vector<F: FieldElement>(mut a: Vec<F>, b: Vec<F>) -> Vec<F> {
add_assign_vector(&mut a, b.iter().copied());
a
}

/// Generate a vector of uniformly distributed random field elements.
pub fn random_vector<F: FieldElement>(len: usize) -> Vec<F> {
Prng::new().take(len).collect()
Expand Down
24 changes: 7 additions & 17 deletions src/flp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -800,7 +800,9 @@ pub(crate) fn gadget_poly_len(gadget_degree: usize, wire_poly_len: usize) -> usi
#[cfg_attr(docsrs, doc(cfg(feature = "test-util")))]
pub mod test_utils {
use super::*;
use crate::field::{random_vector, FieldElement, FieldElementWithInteger};
use crate::field::{
add_vector, random_vector, sub_assign_vector, FieldElement, FieldElementWithInteger,
};

/// Various tests for an FLP.
#[cfg_attr(docsrs, doc(cfg(feature = "test-util")))]
Expand Down Expand Up @@ -945,12 +947,7 @@ pub mod test_utils {
)
.unwrap()
})
.reduce(|mut left, right| {
for (x, y) in left.iter_mut().zip(right.iter()) {
*x += *y;
}
left
})
.reduce(add_vector)
.unwrap();

let res = self.flp.decide(&verifier).unwrap();
Expand Down Expand Up @@ -1000,9 +997,7 @@ pub mod test_utils {

for _ in 1..SHARES {
let share: Vec<F> = random_vector(inp.len());
for (x, y) in outp[0].iter_mut().zip(&share) {
*x -= *y;
}
sub_assign_vector(&mut outp[0], share.iter().copied());
outp.push(share);
}

Expand All @@ -1013,7 +1008,7 @@ pub mod test_utils {
#[cfg(test)]
mod tests {
use super::*;
use crate::field::{random_vector, split_vector, Field128};
use crate::field::{add_vector, random_vector, split_vector, Field128};
use crate::flp::gadgets::{Mul, PolyEval};
use crate::polynomial::poly_range_check;

Expand Down Expand Up @@ -1054,12 +1049,7 @@ mod tests {
)
.unwrap()
})
.reduce(|mut left, right| {
for (x, y) in left.iter_mut().zip(right.iter()) {
*x += *y;
}
left
})
.reduce(add_vector)
.unwrap();
assert_eq!(verifier.len(), typ.verifier_len());

Expand Down
12 changes: 2 additions & 10 deletions src/flp/gadgets.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
//! A collection of gadgets.
use crate::fft::{discrete_fourier_transform, discrete_fourier_transform_inv_finish};
use crate::field::FftFriendlyFieldElement;
use crate::field::{add_vector, FftFriendlyFieldElement};
use crate::flp::{gadget_poly_len, wire_poly_len, FlpError, Gadget};
use crate::polynomial::{poly_deg, poly_eval, poly_mul};

Expand Down Expand Up @@ -380,15 +380,7 @@ where
},
)
.map(|state| state.partial_sum)
.reduce(
|| vec![F::zero(); outp.len()],
|mut x, y| {
for (xi, yi) in x.iter_mut().zip(y.iter()) {
*xi += *yi;
}
x
},
);
.reduce(|| vec![F::zero(); outp.len()], add_vector);

outp.copy_from_slice(&res[..]);
Ok(())
Expand Down
74 changes: 25 additions & 49 deletions src/vdaf/mastic/szk.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
use crate::{
codec::{CodecError, Decode, Encode, ParameterizedDecode},
field::{decode_fieldvec, encode_fieldvec, FieldElement},
field::{add_assign_vector, decode_fieldvec, encode_fieldvec, sub_assign_vector, FieldElement},
flp::{FlpError, Type},
vdaf::{
mastic::{self, NONCE_SIZE, SEED_SIZE, USAGE_PROOF_SHARE},
Expand Down Expand Up @@ -453,12 +453,10 @@ impl<T: Type> Szk<T> {
.prove(encoded_measurement, &prove_rand, &joint_rand)?;

// Generate the proof shares.
for (x, y) in leader_proof_share
.iter_mut()
.zip(self.derive_helper_proof_share(&helper_seed, ctx))
{
*x -= y;
}
sub_assign_vector(
&mut leader_proof_share,
self.derive_helper_proof_share(&helper_seed, ctx),
);

// Construct the output messages.
let leader_proof_share = SzkProofShare::Leader {
Expand Down Expand Up @@ -575,13 +573,10 @@ impl<T: Type> Szk<T> {
mut leader_share: SzkQueryShare<T::Field>,
helper_share: SzkQueryShare<T::Field>,
) -> Result<SzkJointShare, SzkError> {
for (x, y) in leader_share
.flp_verifier
.iter_mut()
.zip(helper_share.flp_verifier)
{
*x += y;
}
add_assign_vector(
&mut leader_share.flp_verifier,
helper_share.flp_verifier.iter().copied(),
);
if self.typ.decide(&leader_share.flp_verifier)? {
match (
leader_share.joint_rand_part_opt,
Expand Down Expand Up @@ -662,11 +657,12 @@ where
mod tests {
use super::*;
use crate::{
field::Field128,
field::{random_vector, FieldElementWithInteger},
flp::gadgets::{Mul, ParallelSum},
flp::types::{Count, Sum, SumVec},
flp::Type,
field::{random_vector, sub_assign_vector, Field128, FieldElementWithInteger},
flp::{
gadgets::{Mul, ParallelSum},
types::{Count, Sum, SumVec},
Type,
},
};
use rand::{thread_rng, Rng};

Expand All @@ -683,9 +679,7 @@ mod tests {
let leader_seed_opt = szk_typ.requires_joint_rand().then(|| rng.gen());
let helper_input_share: Vec<T::Field> = random_vector(szk_typ.typ.input_len());
let mut leader_input_share = encoded_measurement.to_owned();
for (x, y) in leader_input_share.iter_mut().zip(&helper_input_share) {
*x -= *y;
}
sub_assign_vector(&mut leader_input_share, helper_input_share.iter().copied());

let proof_shares = szk_typ.prove(
ctx,
Expand Down Expand Up @@ -849,9 +843,7 @@ mod tests {
let leader_seed_opt = Some(rng.gen());
let helper_input_share = random_vector(szk_typ.typ.input_len());
let mut leader_input_share = encoded_measurement.clone().to_owned();
for (x, y) in leader_input_share.iter_mut().zip(&helper_input_share) {
*x -= *y;
}
sub_assign_vector(&mut leader_input_share, helper_input_share.iter().copied());

let [leader_proof_share, _] = szk_typ
.prove(
Expand Down Expand Up @@ -885,9 +877,7 @@ mod tests {
let leader_seed_opt = Some(rng.gen());
let helper_input_share = random_vector(szk_typ.typ.input_len());
let mut leader_input_share = encoded_measurement.clone().to_owned();
for (x, y) in leader_input_share.iter_mut().zip(&helper_input_share) {
*x -= *y;
}
sub_assign_vector(&mut leader_input_share, helper_input_share.iter().copied());

let [l_proof_share, _] = szk_typ
.prove(
Expand Down Expand Up @@ -920,9 +910,7 @@ mod tests {
let leader_seed_opt = Some(rng.gen());
let helper_input_share = random_vector(szk_typ.typ.input_len());
let mut leader_input_share = encoded_measurement.clone().to_owned();
for (x, y) in leader_input_share.iter_mut().zip(&helper_input_share) {
*x -= *y;
}
sub_assign_vector(&mut leader_input_share, helper_input_share.iter().copied());

let [l_proof_share, _] = szk_typ
.prove(
Expand Down Expand Up @@ -956,9 +944,7 @@ mod tests {
let leader_seed_opt = None;
let helper_input_share = random_vector(szk_typ.typ.input_len());
let mut leader_input_share = encoded_measurement.clone().to_owned();
for (x, y) in leader_input_share.iter_mut().zip(&helper_input_share) {
*x -= *y;
}
sub_assign_vector(&mut leader_input_share, helper_input_share.iter().copied());

let [l_proof_share, _] = szk_typ
.prove(
Expand Down Expand Up @@ -998,9 +984,7 @@ mod tests {
let leader_seed_opt = None;
let helper_input_share = random_vector(szk_typ.typ.input_len());
let mut leader_input_share = encoded_measurement.clone().to_owned();
for (x, y) in leader_input_share.iter_mut().zip(&helper_input_share) {
*x -= *y;
}
sub_assign_vector(&mut leader_input_share, helper_input_share.iter().copied());

let [_, h_proof_share] = szk_typ
.prove(
Expand Down Expand Up @@ -1039,9 +1023,7 @@ mod tests {
let leader_seed_opt = None;
let helper_input_share = random_vector(szk_typ.typ.input_len());
let mut leader_input_share = encoded_measurement.clone().to_owned();
for (x, y) in leader_input_share.iter_mut().zip(&helper_input_share) {
*x -= *y;
}
sub_assign_vector(&mut leader_input_share, helper_input_share.iter().copied());

let [l_proof_share, _] = szk_typ
.prove(
Expand Down Expand Up @@ -1080,9 +1062,7 @@ mod tests {
let leader_seed_opt = None;
let helper_input_share = random_vector(szk_typ.typ.input_len());
let mut leader_input_share = encoded_measurement.clone().to_owned();
for (x, y) in leader_input_share.iter_mut().zip(&helper_input_share) {
*x -= *y;
}
sub_assign_vector(&mut leader_input_share, helper_input_share.iter().copied());

let [_, h_proof_share] = szk_typ
.prove(
Expand Down Expand Up @@ -1122,9 +1102,7 @@ mod tests {
let leader_seed_opt = Some(rng.gen());
let helper_input_share = random_vector(szk_typ.typ.input_len());
let mut leader_input_share = encoded_measurement.clone().to_owned();
for (x, y) in leader_input_share.iter_mut().zip(&helper_input_share) {
*x -= *y;
}
sub_assign_vector(&mut leader_input_share, helper_input_share.iter().copied());

let [l_proof_share, _] = szk_typ
.prove(
Expand Down Expand Up @@ -1164,9 +1142,7 @@ mod tests {
let leader_seed_opt = Some(rng.gen());
let helper_input_share = random_vector(szk_typ.typ.input_len());
let mut leader_input_share = encoded_measurement.clone().to_owned();
for (x, y) in leader_input_share.iter_mut().zip(&helper_input_share) {
*x -= *y;
}
sub_assign_vector(&mut leader_input_share, helper_input_share.iter().copied());

let [_, h_proof_share] = szk_typ
.prove(
Expand Down
25 changes: 8 additions & 17 deletions src/vdaf/prio3.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ use crate::codec::{encode_fixlen_items, CodecError, Decode, Encode, Parameterize
#[cfg(feature = "experimental")]
use crate::dp::DifferentialPrivacyStrategy;
use crate::field::{
decode_fieldvec, FftFriendlyFieldElement, FieldElement, FieldElementWithInteger,
add_assign_vector, decode_fieldvec, sub_assign_vector, FftFriendlyFieldElement, FieldElement,
FieldElementWithInteger,
};
use crate::field::{Field128, Field64};
#[cfg(feature = "multithreaded")]
Expand Down Expand Up @@ -650,12 +651,7 @@ where

Some(joint_rand_blind)
} else {
for (x, y) in leader_measurement_share
.iter_mut()
.zip(measurement_share_prng)
{
*x -= y;
}
sub_assign_vector(&mut leader_measurement_share, measurement_share_prng);
None
};
shares_out.push(Prio3InputShare::Helper {
Expand Down Expand Up @@ -737,13 +733,10 @@ where
u8::try_from(j).unwrap() + 1,
);

for (x, y) in leader_proofs_share
.iter_mut()
.zip(prng)
.take(self.typ.proof_len() * self.num_proofs())
{
*x -= y;
}
sub_assign_vector(
&mut leader_proofs_share,
prng.take(self.typ.proof_len() * self.num_proofs()),
);
}

// Overwrite the placeholder first element with the leader share
Expand Down Expand Up @@ -1506,9 +1499,7 @@ where
joint_rand_parts.push(joint_rand_seed_part);
}

for (x, y) in verifiers.iter_mut().zip(share.verifiers) {
*x += y;
}
add_assign_vector(&mut verifiers, share.verifiers.iter().copied());
}

if count != self.num_aggregators {
Expand Down

0 comments on commit fbb7f32

Please sign in to comment.