diff --git a/benches/speed_tests.rs b/benches/speed_tests.rs index bf96f34f..5a3a71b3 100644 --- a/benches/speed_tests.rs +++ b/benches/speed_tests.rs @@ -81,16 +81,16 @@ pub fn dp_noise(c: &mut Criterion) { group.finish(); } -/// The asymptotic cost of polynomial multiplication is `O(n log n)` using FFT and `O(n^2)` using +/// The asymptotic cost of polynomial multiplication is `O(n log n)` using NTT and `O(n^2)` using /// the naive method. This benchmark demonstrates that the latter has better concrete performance -/// for small polynomials. The result is used to pick the `FFT_THRESHOLD` constant in +/// for small polynomials. The result is used to pick the `NTT_THRESHOLD` constant in /// `src/flp/gadgets.rs`. fn poly_mul(c: &mut Criterion) { let test_sizes = [1_usize, 30, 60, 90, 120, 150, 255]; let mut group = c.benchmark_group("poly_mul"); for size in test_sizes { - group.bench_with_input(BenchmarkId::new("fft", size), &size, |b, size| { + group.bench_with_input(BenchmarkId::new("ntt", size), &size, |b, size| { let m = (size + 1).next_power_of_two(); let mut g: Mul = Mul::new(*size); let mut outp = vec![F::zero(); 2 * m]; @@ -99,7 +99,7 @@ fn poly_mul(c: &mut Criterion) { inp.push(random_vector(m)); b.iter(|| { - benchmarked_gadget_mul_call_poly_fft(&mut g, &mut outp, &inp).unwrap(); + benchmarked_gadget_mul_call_poly_ntt(&mut g, &mut outp, &inp).unwrap(); }) }); diff --git a/src/benchmarked.rs b/src/benchmarked.rs index d8023236..fdd7ecc2 100644 --- a/src/benchmarked.rs +++ b/src/benchmarked.rs @@ -5,37 +5,37 @@ //! This module provides wrappers around internal components of this crate that we want to //! benchmark, but which we don't want to expose in the public API. -use crate::fft::discrete_fourier_transform; -use crate::field::FftFriendlyFieldElement; +use crate::field::NttFriendlyFieldElement; use crate::flp::gadgets::Mul; use crate::flp::FlpError; -use crate::polynomial::{fft_get_roots, poly_fft, PolyFFTTempMemory}; +use crate::ntt::ntt; +use crate::polynomial::{ntt_get_roots, poly_ntt, PolyNttTempMemory}; -/// Sets `outp` to the Discrete Fourier Transform (DFT) using an iterative FFT algorithm. -pub fn benchmarked_iterative_fft(outp: &mut [F], inp: &[F]) { - discrete_fourier_transform(outp, inp, inp.len()).unwrap(); +/// Sets `outp` to the Discrete Fourier Transform (DFT) using an iterative NTT algorithm. +pub fn benchmarked_iterative_ntt(outp: &mut [F], inp: &[F]) { + ntt(outp, inp, inp.len()).unwrap(); } -/// Sets `outp` to the Discrete Fourier Transform (DFT) using a recursive FFT algorithm. -pub fn benchmarked_recursive_fft(outp: &mut [F], inp: &[F]) { - let roots_2n = fft_get_roots(inp.len(), false); - let mut fft_memory = PolyFFTTempMemory::new(inp.len()); - poly_fft(outp, inp, &roots_2n, inp.len(), false, &mut fft_memory) +/// Sets `outp` to the Discrete Fourier Transform (DFT) using a recursive NTT algorithm. +pub fn benchmarked_recursive_ntt(outp: &mut [F], inp: &[F]) { + let roots_2n = ntt_get_roots(inp.len(), false); + let mut ntt_memory = PolyNttTempMemory::new(inp.len()); + poly_ntt(outp, inp, &roots_2n, inp.len(), false, &mut ntt_memory) } /// Sets `outp` to `inp[0] * inp[1]`, where `inp[0]` and `inp[1]` are polynomials. This function -/// uses FFT for multiplication. -pub fn benchmarked_gadget_mul_call_poly_fft( +/// uses NTT for multiplication. +pub fn benchmarked_gadget_mul_call_poly_ntt( g: &mut Mul, outp: &mut [F], inp: &[Vec], ) -> Result<(), FlpError> { - g.call_poly_fft(outp, inp) + g.call_poly_ntt(outp, inp) } /// Sets `outp` to `inp[0] * inp[1]`, where `inp[0]` and `inp[1]` are polynomials. This function /// does the multiplication directly. -pub fn benchmarked_gadget_mul_call_poly_direct( +pub fn benchmarked_gadget_mul_call_poly_direct( g: &mut Mul, outp: &mut [F], inp: &[Vec], diff --git a/src/field.rs b/src/field.rs index 9573f6fd..0aa76d5d 100644 --- a/src/field.rs +++ b/src/field.rs @@ -4,7 +4,7 @@ //! Finite field arithmetic. //! //! Basic field arithmetic is captured in the [`FieldElement`] trait. Fields used in Prio implement -//! [`FftFriendlyFieldElement`], and have an associated element called the "generator" that +//! [`NttFriendlyFieldElement`], and have an associated element called the "generator" that //! generates a multiplicative subgroup of order `2^n` for some `n`. use crate::{ @@ -406,13 +406,13 @@ impl<'de, F: FieldElement> Visitor<'de> for FieldElementVisitor { /// Objects with this trait represent an element of `GF(p)`, where `p` is some prime and the /// field's multiplicative group has a subgroup with an order that is a power of 2, and at least /// `2^20`. -pub trait FftFriendlyFieldElement: FieldElementWithInteger { +pub trait NttFriendlyFieldElement: FieldElementWithInteger { /// Returns the size of the multiplicative subgroup generated by - /// [`FftFriendlyFieldElement::generator`]. + /// [`NttFriendlyFieldElement::generator`]. fn generator_order() -> Self::Integer; /// Returns the generator of the multiplicative subgroup of size - /// [`FftFriendlyFieldElement::generator_order`]. + /// [`NttFriendlyFieldElement::generator_order`]. fn generator() -> Self; /// Returns the `2^l`-th principal root of unity for any `l <= 20`. Note that the `2^0`-th @@ -751,7 +751,7 @@ macro_rules! make_field { } } - impl FftFriendlyFieldElement for $elem { + impl NttFriendlyFieldElement for $elem { fn generator() -> Self { Self($fp::G) } @@ -1168,7 +1168,7 @@ mod tests { assert_matches!(result, Err(FieldError::InputSizeMismatch)); } - fn field_element_test() { + fn field_element_test() { field_element_test_common::(); let mut prng: Prng = Prng::new(); diff --git a/src/flp.rs b/src/flp.rs index dabd41ab..57e0e208 100644 --- a/src/flp.rs +++ b/src/flp.rs @@ -48,9 +48,9 @@ #[cfg(feature = "experimental")] use crate::dp::DifferentialPrivacyStrategy; -use crate::fft::{discrete_fourier_transform, discrete_fourier_transform_inv_finish, FftError}; -use crate::field::{FftFriendlyFieldElement, FieldElement, FieldElementWithInteger, FieldError}; +use crate::field::{FieldElement, FieldElementWithInteger, FieldError, NttFriendlyFieldElement}; use crate::fp::log2; +use crate::ntt::{ntt, ntt_inv_finish, NttError}; use crate::polynomial::poly_eval; use std::any::Any; use std::convert::TryFrom; @@ -99,9 +99,9 @@ pub enum FlpError { #[error("invalid paramter: {0}")] InvalidParameter(String), - /// Returned if an FFT operation propagates an error. - #[error("FFT error: {0}")] - Fft(#[from] FftError), + /// Returned if an NTT operation propagates an error. + #[error("NTT error: {0}")] + Ntt(#[from] NttError), /// Returned if a field operation encountered an error. #[error("Field error: {0}")] @@ -124,7 +124,7 @@ pub trait Type: Sized + Eq + Clone + Debug { type AggregateResult: Clone + Debug; /// The finite field used for this type. - type Field: FftFriendlyFieldElement; + type Field: NttFriendlyFieldElement; /// Encodes a measurement as a vector of [`Self::input_len`] field elements. fn encode_measurement( @@ -299,7 +299,7 @@ pub trait Type: Sized + Eq + Clone + Debug { .map(|shim| { let gadget_poly_len = gadget_poly_len(shim.degree(), wire_poly_len(shim.calls())); - // Computing the gadget polynomial using FFT requires an amount of memory that is a + // Computing the gadget polynomial using NTT requires an amount of memory that is a // power of 2. Thus we choose the smallest power of 2 that is at least as large as // the gadget polynomial. The wire seeds are encoded in the proof, too, so we // include the arity of the gadget to ensure there is always enough room at the end @@ -336,8 +336,8 @@ pub trait Type: Sized + Eq + Clone + Debug { .zip(gadget.f_vals[..gadget.arity()].iter()) .zip(proof[proof_len..proof_len + gadget.arity()].iter_mut()) { - discrete_fourier_transform(coefficients, values, m)?; - discrete_fourier_transform_inv_finish(coefficients, m, m_inv); + ntt(coefficients, values, m)?; + ntt_inv_finish(coefficients, m, m_inv); // The first point on each wire polynomial is a random value chosen by the prover. This // point is stored in the proof so that the verifier can reconstruct the wire @@ -503,8 +503,8 @@ pub trait Type: Sized + Eq + Clone + Debug { .inv(); let mut f = vec![Self::Field::zero(); m]; for wire in 0..gadget.arity() { - discrete_fourier_transform(&mut f, &gadget.f_vals[wire], m)?; - discrete_fourier_transform_inv_finish(&mut f, m, m_inv); + ntt(&mut f, &gadget.f_vals[wire], m)?; + ntt_inv_finish(&mut f, m, m_inv); verifier.push(poly_eval(&f, *query_rand_val)); } @@ -607,7 +607,7 @@ where } /// A gadget, a non-affine arithmetic circuit that is called when evaluating a validity circuit. -pub trait Gadget: Debug { +pub trait Gadget: Debug { /// Evaluates the gadget on input `inp` and returns the output. fn call(&mut self, inp: &[F]) -> Result; @@ -632,7 +632,7 @@ pub trait Gadget: Debug { /// A "shim" gadget used during proof generation to record the input wires each time a gadget is /// evaluated. #[derive(Debug)] -struct ProveShimGadget { +struct ProveShimGadget { inner: Box>, /// Points at which the wire polynomials are interpolated. @@ -642,7 +642,7 @@ struct ProveShimGadget { ct: usize, } -impl ProveShimGadget { +impl ProveShimGadget { fn new(inner: Box>, prove_rand: &[F]) -> Result { let mut f_vals = vec![vec![F::zero(); 1 + inner.calls()]; inner.arity()]; @@ -661,7 +661,7 @@ impl ProveShimGadget { } } -impl Gadget for ProveShimGadget { +impl Gadget for ProveShimGadget { fn call(&mut self, inp: &[F]) -> Result { for (wire_poly_vals, inp_val) in self.f_vals[..inp.len()].iter_mut().zip(inp.iter()) { wire_poly_vals[self.ct] = *inp_val; @@ -694,7 +694,7 @@ impl Gadget for ProveShimGadget { /// A "shim" gadget used during proof verification to record the points at which the intermediate /// proof polynomials are evaluated. #[derive(Debug)] -struct QueryShimGadget { +struct QueryShimGadget { inner: Box>, /// Points at which intermediate proof polynomials are interpolated. @@ -713,7 +713,7 @@ struct QueryShimGadget { ct: usize, } -impl QueryShimGadget { +impl QueryShimGadget { fn new(inner: Box>, r: F, proof_data: &[F]) -> Result { let gadget_degree = inner.degree(); let gadget_arity = inner.arity(); @@ -731,7 +731,7 @@ impl QueryShimGadget { // Evaluate the gadget polynomial at roots of unity. let size = p.next_power_of_two(); let mut p_vals = vec![F::zero(); size]; - discrete_fourier_transform(&mut p_vals, &proof_data[gadget_arity..], size)?; + ntt(&mut p_vals, &proof_data[gadget_arity..], size)?; // The step is used to compute the element of `p_val` that will be returned by a call to // the gadget. @@ -751,7 +751,7 @@ impl QueryShimGadget { } } -impl Gadget for QueryShimGadget { +impl Gadget for QueryShimGadget { fn call(&mut self, inp: &[F]) -> Result { for (wire_poly_vals, inp_val) in self.f_vals[..inp.len()].iter_mut().zip(inp.iter()) { wire_poly_vals[self.ct] = *inp_val; @@ -1077,7 +1077,7 @@ mod tests { } } - impl Type for TestType { + impl Type for TestType { type Measurement = F::Integer; type AggregateResult = F::Integer; type Field = F; @@ -1215,7 +1215,7 @@ mod tests { } } - impl Type for Issue254Type { + impl Type for Issue254Type { type Measurement = F::Integer; type AggregateResult = F::Integer; type Field = F; diff --git a/src/flp/gadgets.rs b/src/flp/gadgets.rs index 3783c7e6..e9d8a574 100644 --- a/src/flp/gadgets.rs +++ b/src/flp/gadgets.rs @@ -2,9 +2,9 @@ //! A collection of gadgets. -use crate::fft::{discrete_fourier_transform, discrete_fourier_transform_inv_finish}; -use crate::field::FftFriendlyFieldElement; +use crate::field::NttFriendlyFieldElement; use crate::flp::{gadget_poly_len, wire_poly_len, FlpError, Gadget}; +use crate::ntt::{ntt, ntt_inv_finish}; use crate::polynomial::{poly_deg, poly_eval, poly_mul}; #[cfg(feature = "multithreaded")] @@ -15,14 +15,14 @@ use std::convert::TryFrom; use std::fmt::Debug; use std::marker::PhantomData; -/// For input polynomials larger than or equal to this threshold, gadgets will use FFT for +/// For input polynomials larger than or equal to this threshold, gadgets will use NTT for /// polynomial multiplication. Otherwise, the gadget uses direct multiplication. -const FFT_THRESHOLD: usize = 30; +const NTT_THRESHOLD: usize = 30; /// An arity-2 gadget that multiples its inputs. #[derive(Clone, Debug, Eq, PartialEq)] -pub struct Mul { - /// Size of buffer for FFT operations. +pub struct Mul { + /// Size of buffer for NTT operations. n: usize, /// Inverse of `n` in `F`. n_inv: F, @@ -30,11 +30,11 @@ pub struct Mul { num_calls: usize, } -impl Mul { +impl Mul { /// Return a new multiplier gadget. `num_calls` is the number of times this gadget will be /// called by the validity circuit. pub fn new(num_calls: usize) -> Self { - let n = gadget_poly_fft_mem_len(2, num_calls); + let n = gadget_poly_ntt_mem_len(2, num_calls); let n_inv = F::from(F::Integer::try_from(n).unwrap()).inv(); Self { n, @@ -54,25 +54,25 @@ impl Mul { Ok(()) } - /// Multiply input polynomials using FFT. - pub(crate) fn call_poly_fft(&mut self, outp: &mut [F], inp: &[Vec]) -> Result<(), FlpError> { + /// Multiply input polynomials using NTT. + pub(crate) fn call_poly_ntt(&mut self, outp: &mut [F], inp: &[Vec]) -> Result<(), FlpError> { let n = self.n; let mut buf = vec![F::zero(); n]; - discrete_fourier_transform(&mut buf, &inp[0], n)?; - discrete_fourier_transform(outp, &inp[1], n)?; + ntt(&mut buf, &inp[0], n)?; + ntt(outp, &inp[1], n)?; for i in 0..n { buf[i] *= outp[i]; } - discrete_fourier_transform(outp, &buf, n)?; - discrete_fourier_transform_inv_finish(outp, n, self.n_inv); + ntt(outp, &buf, n)?; + ntt_inv_finish(outp, n, self.n_inv); Ok(()) } } -impl Gadget for Mul { +impl Gadget for Mul { fn call(&mut self, inp: &[F]) -> Result { gadget_call_check(self, inp.len())?; Ok(inp[0] * inp[1]) @@ -80,8 +80,8 @@ impl Gadget for Mul { fn call_poly(&mut self, outp: &mut [F], inp: &[Vec]) -> Result<(), FlpError> { gadget_call_poly_check(self, outp, inp)?; - if inp[0].len() >= FFT_THRESHOLD { - self.call_poly_fft(outp, inp) + if inp[0].len() >= NTT_THRESHOLD { + self.call_poly_ntt(outp, inp) } else { self.call_poly_direct(outp, inp) } @@ -108,9 +108,9 @@ impl Gadget for Mul { // // TODO Make `poly` an array of length determined by a const generic. #[derive(Clone, Debug, Eq, PartialEq)] -pub struct PolyEval { +pub struct PolyEval { poly: Vec, - /// Size of buffer for FFT operations. + /// Size of buffer for NTT operations. n: usize, /// Inverse of `n` in `F`. n_inv: F, @@ -118,11 +118,11 @@ pub struct PolyEval { num_calls: usize, } -impl PolyEval { +impl PolyEval { /// Returns a gadget that evaluates its input on `poly`. `num_calls` is the number of times /// this gadget is called by the validity circuit. pub fn new(poly: Vec, num_calls: usize) -> Self { - let n = gadget_poly_fft_mem_len(poly_deg(&poly), num_calls); + let n = gadget_poly_ntt_mem_len(poly_deg(&poly), num_calls); let n_inv = F::from(F::Integer::try_from(n).unwrap()).inv(); Self { poly, @@ -133,7 +133,7 @@ impl PolyEval { } } -impl PolyEval { +impl PolyEval { /// Multiply input polynomials directly. fn call_poly_direct(&mut self, outp: &mut [F], inp: &[Vec]) -> Result<(), FlpError> { outp[0] = self.poly[0]; @@ -150,13 +150,13 @@ impl PolyEval { Ok(()) } - /// Multiply input polynomials using FFT. - fn call_poly_fft(&mut self, outp: &mut [F], inp: &[Vec]) -> Result<(), FlpError> { + /// Multiply input polynomials using NTT. + fn call_poly_ntt(&mut self, outp: &mut [F], inp: &[Vec]) -> Result<(), FlpError> { let n = self.n; let inp = &inp[0]; let mut inp_vals = vec![F::zero(); n]; - discrete_fourier_transform(&mut inp_vals, inp, n)?; + ntt(&mut inp_vals, inp, n)?; let mut x_vals = inp_vals.clone(); let mut x = vec![F::zero(); n]; @@ -173,15 +173,15 @@ impl PolyEval { x_vals[j] *= inp_vals[j]; } - discrete_fourier_transform(&mut x, &x_vals, n)?; - discrete_fourier_transform_inv_finish(&mut x, n, self.n_inv); + ntt(&mut x, &x_vals, n)?; + ntt_inv_finish(&mut x, n, self.n_inv); } } Ok(()) } } -impl Gadget for PolyEval { +impl Gadget for PolyEval { fn call(&mut self, inp: &[F]) -> Result { gadget_call_check(self, inp.len())?; Ok(poly_eval(&self.poly, inp[0])) @@ -194,8 +194,8 @@ impl Gadget for PolyEval { *item = F::zero(); } - if inp[0].len() >= FFT_THRESHOLD { - self.call_poly_fft(outp, inp) + if inp[0].len() >= NTT_THRESHOLD { + self.call_poly_ntt(outp, inp) } else { self.call_poly_direct(outp, inp) } @@ -219,7 +219,7 @@ impl Gadget for PolyEval { } /// Trait for abstracting over [`ParallelSum`]. -pub trait ParallelSumGadget: Gadget + Debug { +pub trait ParallelSumGadget: Gadget + Debug { /// Wraps `inner` into a sum gadget that calls it `chunks` many times, and adds the reuslts. fn new(inner: G, chunks: usize) -> Self; } @@ -228,13 +228,13 @@ pub trait ParallelSumGadget: Gadget + Debug { /// outputs. The arity is equal to the arity of the inner gadget times the number of times it is /// called. #[derive(Clone, Debug, Eq, PartialEq)] -pub struct ParallelSum> { +pub struct ParallelSum> { inner: G, chunks: usize, phantom: PhantomData, } -impl> ParallelSumGadget +impl> ParallelSumGadget for ParallelSum { fn new(inner: G, chunks: usize) -> Self { @@ -246,7 +246,7 @@ impl> ParallelSumGadget } } -impl> Gadget for ParallelSum { +impl> Gadget for ParallelSum { fn call(&mut self, inp: &[F]) -> Result { gadget_call_check(self, inp.len())?; let mut outp = F::zero(); @@ -298,14 +298,14 @@ impl> Gadget for ParallelS #[cfg(feature = "multithreaded")] #[cfg_attr(docsrs, doc(cfg(feature = "multithreaded")))] #[derive(Clone, Debug, Eq, PartialEq)] -pub struct ParallelSumMultithreaded> { +pub struct ParallelSumMultithreaded> { serial_sum: ParallelSum, } #[cfg(feature = "multithreaded")] impl ParallelSumGadget for ParallelSumMultithreaded where - F: FftFriendlyFieldElement + Sync + Send, + F: NttFriendlyFieldElement + Sync + Send, G: 'static + Gadget + Clone + Sync + Send, { fn new(inner: G, chunks: usize) -> Self { @@ -331,7 +331,7 @@ impl ParallelSumFoldState { fn new(gadget: &G, length: usize) -> ParallelSumFoldState where G: Clone, - F: FftFriendlyFieldElement, + F: NttFriendlyFieldElement, { ParallelSumFoldState { inner: gadget.clone(), @@ -344,7 +344,7 @@ impl ParallelSumFoldState { #[cfg(feature = "multithreaded")] impl Gadget for ParallelSumMultithreaded where - F: FftFriendlyFieldElement + Sync + Send, + F: NttFriendlyFieldElement + Sync + Send, G: 'static + Gadget + Clone + Sync + Send, { fn call(&mut self, inp: &[F]) -> Result { @@ -412,7 +412,7 @@ where } /// Check that the input parameters of g.call() are well-formed. -fn gadget_call_check>( +fn gadget_call_check>( gadget: &G, in_len: usize, ) -> Result<(), FlpError> { @@ -432,7 +432,7 @@ fn gadget_call_check>( } /// Check that the input parameters of g.call_poly() are well-formed. -fn gadget_call_poly_check>( +fn gadget_call_poly_check>( gadget: &G, outp: &[F], inp: &[Vec], @@ -460,7 +460,7 @@ fn gadget_call_poly_check>( } #[inline] -fn gadget_poly_fft_mem_len(degree: usize, num_calls: usize) -> usize { +fn gadget_poly_ntt_mem_len(degree: usize, num_calls: usize) -> usize { gadget_poly_len(degree, wire_poly_len(num_calls)).next_power_of_two() } @@ -475,15 +475,15 @@ mod tests { #[test] fn test_mul() { - // Test the gadget with input polynomials shorter than `FFT_THRESHOLD`. This exercises the + // Test the gadget with input polynomials shorter than `NTT_THRESHOLD`. This exercises the // naive multiplication code path. - let num_calls = FFT_THRESHOLD / 2; + let num_calls = NTT_THRESHOLD / 2; let mut g: Mul = Mul::new(num_calls); gadget_test(&mut g, num_calls); - // Test the gadget with input polynomials longer than `FFT_THRESHOLD`. This exercises - // FFT-based polynomial multiplication. - let num_calls = FFT_THRESHOLD; + // Test the gadget with input polynomials longer than `NTT_THRESHOLD`. This exercises + // NTT-based polynomial multiplication. + let num_calls = NTT_THRESHOLD; let mut g: Mul = Mul::new(num_calls); gadget_test(&mut g, num_calls); } @@ -492,11 +492,11 @@ mod tests { fn test_poly_eval() { let poly: Vec = random_vector(10); - let num_calls = FFT_THRESHOLD / 2; + let num_calls = NTT_THRESHOLD / 2; let mut g: PolyEval = PolyEval::new(poly.clone(), num_calls); gadget_test(&mut g, num_calls); - let num_calls = FFT_THRESHOLD; + let num_calls = NTT_THRESHOLD; let mut g: PolyEval = PolyEval::new(poly, num_calls); gadget_test(&mut g, num_calls); } @@ -560,11 +560,11 @@ mod tests { /// Test that calling g.call_poly() and evaluating the output at a given point is equivalent /// to evaluating each of the inputs at the same point and applying g.call() on the results. - fn gadget_test>(g: &mut G, num_calls: usize) { + fn gadget_test>(g: &mut G, num_calls: usize) { let wire_poly_len = (1 + num_calls).next_power_of_two(); let mut prng = Prng::new(); let mut inp = vec![F::zero(); g.arity()]; - let mut gadget_poly = vec![F::zero(); gadget_poly_fft_mem_len(g.degree(), num_calls)]; + let mut gadget_poly = vec![F::zero(); gadget_poly_ntt_mem_len(g.degree(), num_calls)]; let mut wire_polys = vec![vec![F::zero(); wire_poly_len]; g.arity()]; let r = prng.get(); diff --git a/src/flp/types.rs b/src/flp/types.rs index 8f63af10..fa5300b3 100644 --- a/src/flp/types.rs +++ b/src/flp/types.rs @@ -2,7 +2,7 @@ //! A collection of [`Type`] implementations. -use crate::field::{FftFriendlyFieldElement, FieldElementWithIntegerExt, Integer}; +use crate::field::{FieldElementWithIntegerExt, Integer, NttFriendlyFieldElement}; use crate::flp::gadgets::{Mul, ParallelSumGadget, PolyEval}; use crate::flp::{FlpError, Gadget, Type}; use crate::polynomial::poly_range_check; @@ -27,7 +27,7 @@ impl Debug for Count { } } -impl Count { +impl Count { /// Return a new [`Count`] type instance. pub fn new() -> Self { Self { @@ -36,13 +36,13 @@ impl Count { } } -impl Default for Count { +impl Default for Count { fn default() -> Self { Self::new() } } -impl Type for Count { +impl Type for Count { type Measurement = bool; type AggregateResult = F::Integer; type Field = F; @@ -120,7 +120,7 @@ impl Type for Count { /// /// [BBCG+19]: https://ia.cr/2019/188 #[derive(Clone, PartialEq, Eq)] -pub struct Sum { +pub struct Sum { max_measurement: F::Integer, // Computed from max_measurement @@ -130,7 +130,7 @@ pub struct Sum { bit_range_checker: Vec, } -impl Debug for Sum { +impl Debug for Sum { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("Sum") .field("max_measurement", &self.max_measurement) @@ -139,7 +139,7 @@ impl Debug for Sum { } } -impl Sum { +impl Sum { /// Return a new [`Sum`] type parameter. Each value of this type is an integer in range `[0, /// max_measurement]` where `max_measurement > 0`. Errors if `max_measurement == 0`. pub fn new(max_measurement: F::Integer) -> Result { @@ -168,7 +168,7 @@ impl Sum { } } -impl Type for Sum { +impl Type for Sum { type Measurement = F::Integer; type AggregateResult = F::Integer; type Field = F; @@ -267,11 +267,11 @@ impl Type for Sum { // This is just a `Sum` object under the hood. The only difference is that the aggregate result is // an f64, which we get by dividing by `num_measurements` #[derive(Clone, PartialEq, Eq)] -pub struct Average { +pub struct Average { summer: Sum, } -impl Debug for Average { +impl Debug for Average { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("Average") .field("max_measurement", &self.summer.max_measurement) @@ -280,7 +280,7 @@ impl Debug for Average { } } -impl Average { +impl Average { /// Return a new [`Average`] type parameter. Each value of this type is an integer in range `[0, /// max_measurement]` where `max_measurement > 0`. Errors if `max_measurement == 0`. pub fn new(max_measurement: F::Integer) -> Result { @@ -289,7 +289,7 @@ impl Average { } } -impl Type for Average { +impl Type for Average { type Measurement = F::Integer; type AggregateResult = f64; type Field = F; @@ -369,7 +369,7 @@ pub struct Histogram { phantom: PhantomData<(F, S)>, } -impl Debug for Histogram { +impl Debug for Histogram { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("Histogram") .field("length", &self.length) @@ -378,7 +378,7 @@ impl Debug for Histogram { } } -impl>> Histogram { +impl>> Histogram { /// Return a new [`Histogram`] type with the given number of buckets. pub fn new(length: usize, chunk_length: usize) -> Result { if length >= u32::MAX as usize { @@ -424,7 +424,7 @@ impl Clone for Histogram { impl Type for Histogram where - F: FftFriendlyFieldElement, + F: NttFriendlyFieldElement, S: ParallelSumGadget> + Eq + 'static, { type Measurement = usize; @@ -533,7 +533,7 @@ pub struct MultihotCountVec { phantom: PhantomData<(F, S)>, } -impl Debug for MultihotCountVec { +impl Debug for MultihotCountVec { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("MultihotCountVec") .field("length", &self.length) @@ -543,7 +543,7 @@ impl Debug for MultihotCountVec { } } -impl>> MultihotCountVec { +impl>> MultihotCountVec { /// Return a new [`MultihotCountVec`] type with the given number of buckets. pub fn new( num_buckets: usize, @@ -610,7 +610,7 @@ impl Clone for MultihotCountVec { impl Type for MultihotCountVec where - F: FftFriendlyFieldElement, + F: NttFriendlyFieldElement, S: ParallelSumGadget> + Eq + 'static, { type Measurement = Vec; @@ -743,7 +743,7 @@ where /// /// [BBCG+19]: https://eprint.iacr.org/2019/188 #[derive(PartialEq, Eq)] -pub struct SumVec { +pub struct SumVec { len: usize, bits: usize, flattened_len: usize, @@ -753,7 +753,7 @@ pub struct SumVec { phantom: PhantomData, } -impl Debug for SumVec { +impl Debug for SumVec { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("SumVec") .field("len", &self.len) @@ -763,7 +763,7 @@ impl Debug for SumVec { } } -impl>> SumVec { +impl>> SumVec { /// Returns a new [`SumVec`] with the desired bit width and vector length. /// /// # Errors @@ -823,7 +823,7 @@ impl>> SumVec { } } -impl Clone for SumVec { +impl Clone for SumVec { fn clone(&self) -> Self { Self { len: self.len, @@ -839,7 +839,7 @@ impl Clone for SumVec { impl Type for SumVec where - F: FftFriendlyFieldElement, + F: NttFriendlyFieldElement, S: ParallelSumGadget> + Eq + 'static, { type Measurement = Vec; @@ -941,7 +941,7 @@ where /// Given a vector `data` of field elements which should contain exactly one entry, return the /// integer representation of that entry. -pub(crate) fn decode_result( +pub(crate) fn decode_result( data: &[F], ) -> Result { if data.len() != 1 { @@ -952,7 +952,7 @@ pub(crate) fn decode_result( /// Given a vector `data` of field elements, return a vector containing the corresponding integer /// representations, if the number of entries matches `expected_len`. -pub(crate) fn decode_result_vec( +pub(crate) fn decode_result_vec( data: &[F], expected_len: usize, ) -> Result, FlpError> { @@ -981,7 +981,7 @@ pub(crate) fn decode_result_vec( /// /// This returns (additive shares of) zero if all inputs were zero or one, and otherwise returns a /// non-zero value with high probability. -pub(crate) fn parallel_sum_range_checks( +pub(crate) fn parallel_sum_range_checks( gadget: &mut Box>, input: &[F], joint_randomness: &[F], diff --git a/src/flp/types/dp.rs b/src/flp/types/dp.rs index 0ad1e87d..95e842c1 100644 --- a/src/flp/types/dp.rs +++ b/src/flp/types/dp.rs @@ -2,7 +2,7 @@ use crate::dp::{distributions::PureDpDiscreteLaplace, DifferentialPrivacyStrategy}; use crate::dp::{DifferentialPrivacyDistribution, DpError}; -use crate::field::{FftFriendlyFieldElement, Field128, Field64}; +use crate::field::{Field128, Field64, NttFriendlyFieldElement}; use crate::flp::gadgets::{Mul, ParallelSumGadget}; use crate::flp::types::{Histogram, SumVec}; use crate::flp::{FlpError, TypeWithNoise}; @@ -53,7 +53,7 @@ where impl SumVec where - F: FftFriendlyFieldElement, + F: NttFriendlyFieldElement, BigInt: From, F::Integer: TryFrom>, { @@ -127,7 +127,7 @@ where impl Histogram where - F: FftFriendlyFieldElement, + F: NttFriendlyFieldElement, BigInt: From, F::Integer: TryFrom>, { @@ -162,7 +162,7 @@ pub(super) fn add_iid_noise_to_field_vec( distribution: &D, ) -> Result<(), FlpError> where - F: FftFriendlyFieldElement, + F: NttFriendlyFieldElement, BigInt: From, F::Integer: TryFrom>, R: Rng, diff --git a/src/fp.rs b/src/fp.rs index 956ba63c..3033419b 100644 --- a/src/fp.rs +++ b/src/fp.rs @@ -8,7 +8,7 @@ mod ops; pub use ops::{FieldOps, FieldParameters}; /// For each set of field parameters we pre-compute the 1st, 2nd, 4th, ..., 2^20-th principal roots -/// of unity. The largest of these is used to run the FFT algorithm on an input of size 2^20. This +/// of unity. The largest of these is used to run the NTT algorithm on an input of size 2^20. This /// is the largest input size we would ever need for the cryptographic applications in this crate. pub(crate) const MAX_ROOTS: usize = 20; diff --git a/src/lib.rs b/src/lib.rs index ebf204e0..aa1d86c1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -21,7 +21,7 @@ pub mod codec; #[cfg(feature = "experimental")] #[cfg_attr(docsrs, doc(cfg(feature = "experimental")))] pub mod dp; -mod fft; +mod ntt; pub mod field; pub mod flp; mod fp; diff --git a/src/fft.rs b/src/ntt.rs similarity index 72% rename from src/fft.rs rename to src/ntt.rs index 5c1c9e0b..f750a6e1 100644 --- a/src/fft.rs +++ b/src/ntt.rs @@ -1,17 +1,16 @@ // SPDX-License-Identifier: MPL-2.0 -//! This module implements an iterative FFT algorithm for computing the (inverse) Discrete Fourier -//! Transform (DFT) over a slice of field elements. +//! This module implements an iterative NTT (Number Theoretic Transform) algorithm. -use crate::field::FftFriendlyFieldElement; +use crate::field::NttFriendlyFieldElement; use crate::fp::{log2, MAX_ROOTS}; use std::convert::TryFrom; -/// An error returned by an FFT operation. +/// An error returned by an NTT operation. #[derive(Debug, PartialEq, Eq, thiserror::Error)] #[non_exhaustive] -pub enum FftError { +pub enum NttError { /// The output is too small. #[error("output slice is smaller than specified size")] OutputTooSmall, @@ -23,29 +22,29 @@ pub enum FftError { SizeInvalid, } -/// Sets `outp` to the DFT of `inp`. +/// Sets `outp` to the NTT of `inp`. /// /// Interpreting the input as the coefficients of a polynomial, the output is equal to the input /// evaluated at points `p^0, p^1, ... p^(size-1)`, where `p` is the `2^size`-th principal root of /// unity. #[allow(clippy::many_single_char_names)] -pub fn discrete_fourier_transform( +pub fn ntt( outp: &mut [F], inp: &[F], size: usize, -) -> Result<(), FftError> { - let d = usize::try_from(log2(size as u128)).map_err(|_| FftError::SizeTooLarge)?; +) -> Result<(), NttError> { + let d = usize::try_from(log2(size as u128)).map_err(|_| NttError::SizeTooLarge)?; if size > outp.len() { - return Err(FftError::OutputTooSmall); + return Err(NttError::OutputTooSmall); } if size > 1 << MAX_ROOTS { - return Err(FftError::SizeTooLarge); + return Err(NttError::SizeTooLarge); } if size != 1 << d { - return Err(FftError::SizeInvalid); + return Err(NttError::SizeInvalid); } if d > 0 { @@ -90,24 +89,20 @@ pub fn discrete_fourier_transform( /// Sets `outp` to the inverse of the DFT of `inp`. #[cfg(test)] -pub(crate) fn discrete_fourier_transform_inv( +pub(crate) fn ntt_inv( outp: &mut [F], inp: &[F], size: usize, -) -> Result<(), FftError> { +) -> Result<(), NttError> { let size_inv = F::from(F::Integer::try_from(size).unwrap()).inv(); - discrete_fourier_transform(outp, inp, size)?; - discrete_fourier_transform_inv_finish(outp, size, size_inv); + ntt(outp, inp, size)?; + ntt_inv_finish(outp, size, size_inv); Ok(()) } /// An intermediate step in the computation of the inverse DFT. Exposing this function allows us to /// amortize the cost the modular inverse across multiple inverse DFT operations. -pub(crate) fn discrete_fourier_transform_inv_finish( - outp: &mut [F], - size: usize, - size_inv: F, -) { +pub(crate) fn ntt_inv_finish(outp: &mut [F], size: usize, size_inv: F) { let mut tmp: F; outp[0] *= size_inv; outp[size >> 1] *= size_inv; @@ -127,10 +122,9 @@ fn bitrev(d: usize, x: usize) -> usize { mod tests { use super::*; use crate::field::{random_vector, split_vector, Field128, Field64, FieldElement, FieldPrio2}; - use crate::polynomial::{poly_fft, TestPolyAuxMemory}; + use crate::polynomial::{poly_ntt, TestPolyAuxMemory}; - fn discrete_fourier_transform_then_inv_test() -> Result<(), FftError> - { + fn ntt_then_inv_test() -> Result<(), NttError> { let test_sizes = [1, 2, 4, 8, 16, 256, 1024, 2048]; for size in test_sizes.iter() { @@ -138,8 +132,8 @@ mod tests { let mut got = vec![F::zero(); *size]; let want = random_vector(*size); - discrete_fourier_transform(&mut tmp, &want, want.len())?; - discrete_fourier_transform_inv(&mut got, &tmp, tmp.len())?; + ntt(&mut tmp, &want, want.len())?; + ntt_inv(&mut got, &tmp, tmp.len())?; assert_eq!(got, want); } @@ -148,21 +142,21 @@ mod tests { #[test] fn test_priov2_field32() { - discrete_fourier_transform_then_inv_test::().expect("unexpected error"); + ntt_then_inv_test::().expect("unexpected error"); } #[test] fn test_field64() { - discrete_fourier_transform_then_inv_test::().expect("unexpected error"); + ntt_then_inv_test::().expect("unexpected error"); } #[test] fn test_field128() { - discrete_fourier_transform_then_inv_test::().expect("unexpected error"); + ntt_then_inv_test::().expect("unexpected error"); } #[test] - fn test_recursive_fft() { + fn test_recursive_ntt() { let size = 128; let mut mem = TestPolyAuxMemory::new(size / 2); @@ -170,15 +164,15 @@ mod tests { let mut want = vec![FieldPrio2::zero(); size]; let mut got = vec![FieldPrio2::zero(); size]; - discrete_fourier_transform::(&mut want, &inp, inp.len()).unwrap(); + ntt::(&mut want, &inp, inp.len()).unwrap(); - poly_fft( + poly_ntt( &mut got, &inp, &mem.roots_2n, size, false, - &mut mem.fft_memory, + &mut mem.ntt_memory, ); assert_eq!(got, want); @@ -188,7 +182,7 @@ mod tests { // over secret shares and summing up the coefficients is equivalent to interpolating a // polynomial over the plaintext data. #[test] - fn test_fft_linearity() { + fn test_ntt_linearity() { let len = 16; let num_shares = 3; let x: Vec = random_vector(len); @@ -209,14 +203,14 @@ mod tests { let mut got = vec![Field64::zero(); len]; let mut buf = vec![Field64::zero(); len]; for share in x_shares { - discrete_fourier_transform_inv(&mut buf, &share, len).unwrap(); + ntt_inv(&mut buf, &share, len).unwrap(); for i in 0..len { got[i] += buf[i]; } } let mut want = vec![Field64::zero(); len]; - discrete_fourier_transform_inv(&mut want, &x, len).unwrap(); + ntt_inv(&mut want, &x, len).unwrap(); assert_eq!(got, want); } diff --git a/src/polynomial.rs b/src/polynomial.rs index bfe29bbe..61af6ff6 100644 --- a/src/polynomial.rs +++ b/src/polynomial.rs @@ -3,26 +3,26 @@ //! Functions for polynomial interpolation and evaluation +use crate::field::NttFriendlyFieldElement; #[cfg(all(feature = "crypto-dependencies", feature = "experimental"))] -use crate::fft::{discrete_fourier_transform, discrete_fourier_transform_inv_finish}; -use crate::field::FftFriendlyFieldElement; +use crate::ntt::{ntt, ntt_inv_finish}; use std::convert::TryFrom; -/// Temporary memory used for FFT +/// Temporary memory used for NTT #[derive(Clone, Debug)] -pub struct PolyFFTTempMemory { - fft_tmp: Vec, - fft_y_sub: Vec, - fft_roots_sub: Vec, +pub struct PolyNttTempMemory { + ntt_tmp: Vec, + ntt_y_sub: Vec, + ntt_roots_sub: Vec, } -impl PolyFFTTempMemory { +impl PolyNttTempMemory { pub(crate) fn new(length: usize) -> Self { - PolyFFTTempMemory { - fft_tmp: vec![F::zero(); length], - fft_y_sub: vec![F::zero(); length], - fft_roots_sub: vec![F::zero(); length], + PolyNttTempMemory { + ntt_tmp: vec![F::zero(); length], + ntt_y_sub: vec![F::zero(); length], + ntt_roots_sub: vec![F::zero(); length], } } } @@ -32,21 +32,21 @@ impl PolyFFTTempMemory { pub(crate) struct TestPolyAuxMemory { pub roots_2n: Vec, pub roots_2n_inverted: Vec, - pub fft_memory: PolyFFTTempMemory, + pub ntt_memory: PolyNttTempMemory, } #[cfg(test)] -impl TestPolyAuxMemory { +impl TestPolyAuxMemory { pub(crate) fn new(n: usize) -> Self { Self { - roots_2n: fft_get_roots(2 * n, false), - roots_2n_inverted: fft_get_roots(2 * n, true), - fft_memory: PolyFFTTempMemory::new(2 * n), + roots_2n: ntt_get_roots(2 * n, false), + roots_2n_inverted: ntt_get_roots(2 * n, true), + ntt_memory: PolyNttTempMemory::new(2 * n), } } } -fn fft_recurse( +fn ntt_recurse( out: &mut [F], n: usize, roots: &[F], @@ -71,7 +71,7 @@ fn fft_recurse( y_sub_first[i] = ys[i] + ys[i + half_n]; roots_sub_first[i] = roots[2 * i]; } - fft_recurse( + ntt_recurse( tmp_first, half_n, roots_sub_first, @@ -89,7 +89,7 @@ fn fft_recurse( y_sub_first[i] = ys[i] - ys[i + half_n]; y_sub_first[i] *= roots[i]; } - fft_recurse( + ntt_recurse( tmp_first, half_n, roots_sub_first, @@ -104,7 +104,7 @@ fn fft_recurse( } /// Calculate `count` number of roots of unity of order `count` -pub(crate) fn fft_get_roots(count: usize, invert: bool) -> Vec { +pub(crate) fn ntt_get_roots(count: usize, invert: bool) -> Vec { let mut roots = vec![F::zero(); count]; let mut gen = F::generator(); if invert { @@ -125,22 +125,22 @@ pub(crate) fn fft_get_roots(count: usize, invert: bo roots } -fn fft_interpolate_raw( +fn ntt_interpolate_raw( out: &mut [F], ys: &[F], n_points: usize, roots: &[F], invert: bool, - mem: &mut PolyFFTTempMemory, + mem: &mut PolyNttTempMemory, ) { - fft_recurse( + ntt_recurse( out, n_points, roots, ys, - &mut mem.fft_tmp, - &mut mem.fft_y_sub, - &mut mem.fft_roots_sub, + &mut mem.ntt_tmp, + &mut mem.ntt_y_sub, + &mut mem.ntt_roots_sub, ); if invert { let n_inverse = F::from(F::Integer::try_from(n_points).unwrap()).inv(); @@ -150,19 +150,19 @@ fn fft_interpolate_raw( } } -pub fn poly_fft( +pub fn poly_ntt( points_out: &mut [F], points_in: &[F], scaled_roots: &[F], n_points: usize, invert: bool, - mem: &mut PolyFFTTempMemory, + mem: &mut PolyNttTempMemory, ) { - fft_interpolate_raw(points_out, points_in, n_points, scaled_roots, invert, mem) + ntt_interpolate_raw(points_out, points_in, n_points, scaled_roots, invert, mem) } /// Evaluate a polynomial using Horner's method. -pub fn poly_eval(poly: &[F], eval_at: F) -> F { +pub fn poly_eval(poly: &[F], eval_at: F) -> F { if poly.is_empty() { return F::zero(); } @@ -177,7 +177,7 @@ pub fn poly_eval(poly: &[F], eval_at: F) -> F { } /// Returns the degree of polynomial `p`. -pub fn poly_deg(p: &[F]) -> usize { +pub fn poly_deg(p: &[F]) -> usize { let mut d = p.len(); while d > 0 && p[d - 1] == F::zero() { d -= 1; @@ -186,7 +186,7 @@ pub fn poly_deg(p: &[F]) -> usize { } /// Multiplies polynomials `p` and `q` and returns the result. -pub fn poly_mul(p: &[F], q: &[F]) -> Vec { +pub fn poly_mul(p: &[F], q: &[F]) -> Vec { let p_size = poly_deg(p) + 1; let q_size = poly_deg(q) + 1; let mut out = vec![F::zero(); p_size + q_size]; @@ -201,20 +201,20 @@ pub fn poly_mul(p: &[F], q: &[F]) -> Vec { #[cfg(all(feature = "crypto-dependencies", feature = "experimental"))] #[inline] -pub fn poly_interpret_eval( +pub fn poly_interpret_eval( points: &[F], eval_at: F, tmp_coeffs: &mut [F], ) -> F { let size_inv = F::from(F::Integer::try_from(points.len()).unwrap()).inv(); - discrete_fourier_transform(tmp_coeffs, points, points.len()).unwrap(); - discrete_fourier_transform_inv_finish(tmp_coeffs, points.len(), size_inv); + ntt(tmp_coeffs, points, points.len()).unwrap(); + ntt_inv_finish(tmp_coeffs, points.len(), size_inv); poly_eval(&tmp_coeffs[..points.len()], eval_at) } /// Returns a polynomial that evaluates to `0` if the input is in range `[start, end)`. Otherwise, /// the output is not `0`. -pub(crate) fn poly_range_check(start: usize, end: usize) -> Vec { +pub(crate) fn poly_range_check(start: usize, end: usize) -> Vec { let mut p = vec![F::one()]; let mut q = [F::zero(), F::one()]; for i in start..end { @@ -228,10 +228,10 @@ pub(crate) fn poly_range_check(start: usize, end: us mod tests { use crate::{ field::{ - FftFriendlyFieldElement, Field64, FieldElement, FieldElementWithInteger, FieldPrio2, + Field64, FieldElement, FieldElementWithInteger, FieldPrio2, NttFriendlyFieldElement, }, polynomial::{ - fft_get_roots, poly_deg, poly_eval, poly_fft, poly_mul, poly_range_check, + ntt_get_roots, poly_deg, poly_eval, poly_mul, poly_ntt, poly_range_check, TestPolyAuxMemory, }, }; @@ -241,8 +241,8 @@ mod tests { #[test] fn test_roots() { let count = 128; - let roots = fft_get_roots::(count, false); - let roots_inv = fft_get_roots::(count, true); + let roots = ntt_get_roots::(count, false); + let roots_inv = ntt_get_roots::(count, true); for i in 0..count { assert_eq!(roots[i] * roots_inv[i], 1); @@ -327,7 +327,7 @@ mod tests { } #[test] - fn test_fft() { + fn test_ntt() { let count = 128; let mut mem = TestPolyAuxMemory::new(count / 2); @@ -339,33 +339,33 @@ mod tests { .collect::>(); // From points to coeffs and back - poly_fft( + poly_ntt( &mut poly, &points, &mem.roots_2n, count, false, - &mut mem.fft_memory, + &mut mem.ntt_memory, ); - poly_fft( + poly_ntt( &mut points2, &poly, &mem.roots_2n_inverted, count, true, - &mut mem.fft_memory, + &mut mem.ntt_memory, ); assert_eq!(points, points2); // interpolation - poly_fft( + poly_ntt( &mut poly, &points, &mem.roots_2n, count, false, - &mut mem.fft_memory, + &mut mem.ntt_memory, ); for (poly_coeff, root) in poly[..count].iter().zip(mem.roots_2n[..count].iter()) { diff --git a/src/vdaf/prio2.rs b/src/vdaf/prio2.rs index 9757e8d8..4d580f35 100644 --- a/src/vdaf/prio2.rs +++ b/src/vdaf/prio2.rs @@ -5,7 +5,7 @@ use crate::{ codec::{CodecError, Decode, Encode, ParameterizedDecode}, field::{ - decode_fieldvec, FftFriendlyFieldElement, FieldElement, FieldElementWithInteger, FieldPrio2, + decode_fieldvec, FieldElement, FieldElementWithInteger, FieldPrio2, NttFriendlyFieldElement, }, prng::Prng, vdaf::{ diff --git a/src/vdaf/prio2/client.rs b/src/vdaf/prio2/client.rs index 830597e3..9ebd74a2 100644 --- a/src/vdaf/prio2/client.rs +++ b/src/vdaf/prio2/client.rs @@ -5,8 +5,8 @@ use crate::{ codec::CodecError, - field::FftFriendlyFieldElement, - polynomial::{fft_get_roots, poly_fft, PolyFFTTempMemory}, + field::NttFriendlyFieldElement, + polynomial::{ntt_get_roots, poly_ntt, PolyNttTempMemory}, prng::Prng, vdaf::{xof::SeedStreamAes128, VdafError}, }; @@ -34,11 +34,11 @@ pub(crate) struct ClientMemory { evals_g: Vec, roots_2n: Vec, roots_n_inverted: Vec, - fft_memory: PolyFFTTempMemory, + ntt_memory: PolyNttTempMemory, coeffs: Vec, } -impl ClientMemory { +impl ClientMemory { pub(crate) fn new(dimension: usize) -> Result { let mut rng = thread_rng(); let n = (dimension + 1).next_power_of_two(); @@ -60,15 +60,15 @@ impl ClientMemory { points_g: vec![F::zero(); n], evals_f: vec![F::zero(); 2 * n], evals_g: vec![F::zero(); 2 * n], - roots_2n: fft_get_roots(2 * n, false), - roots_n_inverted: fft_get_roots(n, true), - fft_memory: PolyFFTTempMemory::new(2 * n), + roots_2n: ntt_get_roots(2 * n, false), + roots_n_inverted: ntt_get_roots(n, true), + ntt_memory: PolyNttTempMemory::new(2 * n), coeffs: vec![F::zero(); 2 * n], }) } } -impl ClientMemory { +impl ClientMemory { pub(crate) fn prove_with(&mut self, dimension: usize, init_function: G) -> Vec where G: FnOnce(&mut [F]), @@ -106,7 +106,7 @@ pub(crate) fn proof_length(dimension: usize) -> usize { /// Unpacked proof with subcomponents #[derive(Debug)] -pub(crate) struct UnpackedProof<'a, F: FftFriendlyFieldElement> { +pub(crate) struct UnpackedProof<'a, F: NttFriendlyFieldElement> { /// Data pub data: &'a [F], /// Zeroth coefficient of polynomial f @@ -121,7 +121,7 @@ pub(crate) struct UnpackedProof<'a, F: FftFriendlyFieldElement> { /// Unpacked proof with mutable subcomponents #[derive(Debug)] -pub(crate) struct UnpackedProofMut<'a, F: FftFriendlyFieldElement> { +pub(crate) struct UnpackedProofMut<'a, F: NttFriendlyFieldElement> { /// Data pub data: &'a mut [F], /// Zeroth coefficient of polynomial f @@ -135,7 +135,7 @@ pub(crate) struct UnpackedProofMut<'a, F: FftFriendlyFieldElement> { } /// Unpacks the proof vector into subcomponents -pub(crate) fn unpack_proof( +pub(crate) fn unpack_proof( proof: &[F], dimension: usize, ) -> Result, SerializeError> { @@ -159,7 +159,7 @@ pub(crate) fn unpack_proof( } /// Unpacks a mutable proof vector into mutable subcomponents -pub(crate) fn unpack_proof_mut( +pub(crate) fn unpack_proof_mut( proof: &mut [F], dimension: usize, ) -> Result, SerializeError> { @@ -194,28 +194,28 @@ pub(crate) fn unpack_proof_mut( /// unity. This must have length 2 * n. /// * `roots_n_inverted` - Precomputed inverses of the nth roots of unity. /// * `roots_2n` - Precomputed 2nth roots of unity. -/// * `fft_memory` - Scratch space for the FFT algorithm. +/// * `ntt_memory` - Scratch space for the NTT algorithm. /// * `coeffs` - Scratch space. This must have length 2 * n. -fn interpolate_and_evaluate_at_2n( +fn interpolate_and_evaluate_at_2n( n: usize, points_in: &[F], evals_out: &mut [F], roots_n_inverted: &[F], roots_2n: &[F], - fft_memory: &mut PolyFFTTempMemory, + ntt_memory: &mut PolyNttTempMemory, coeffs: &mut [F], ) { // interpolate through roots of unity - poly_fft(coeffs, points_in, roots_n_inverted, n, true, fft_memory); + poly_ntt(coeffs, points_in, roots_n_inverted, n, true, ntt_memory); // evaluate at 2N roots of unity - poly_fft(evals_out, coeffs, roots_2n, 2 * n, false, fft_memory); + poly_ntt(evals_out, coeffs, roots_2n, 2 * n, false, ntt_memory); } /// Proof construction /// /// Based on Theorem 2.3.3 from Henry Corrigan-Gibbs' dissertation /// This constructs the output \pi by doing the necessesary calculations -fn construct_proof( +fn construct_proof( data: &[F], dimension: usize, f0: &mut F, @@ -253,7 +253,7 @@ fn construct_proof( &mut mem.evals_f, &mem.roots_n_inverted, &mem.roots_2n, - &mut mem.fft_memory, + &mut mem.ntt_memory, &mut mem.coeffs, ); interpolate_and_evaluate_at_2n( @@ -262,7 +262,7 @@ fn construct_proof( &mut mem.evals_g, &mem.roots_n_inverted, &mem.roots_2n, - &mut mem.fft_memory, + &mut mem.ntt_memory, &mut mem.coeffs, ); diff --git a/src/vdaf/prio2/server.rs b/src/vdaf/prio2/server.rs index f54eac2d..c9fa10af 100644 --- a/src/vdaf/prio2/server.rs +++ b/src/vdaf/prio2/server.rs @@ -3,7 +3,7 @@ //! Primitives for the Prio2 server. use crate::{ - field::{FftFriendlyFieldElement, FieldError}, + field::{FieldError, NttFriendlyFieldElement}, polynomial::poly_interpret_eval, vdaf::prio2::client::{unpack_proof, SerializeError}, }; @@ -37,7 +37,7 @@ pub struct VerificationMessage { /// Given a proof and evaluation point, this constructs the verification /// message. -pub(crate) fn generate_verification_message( +pub(crate) fn generate_verification_message( dimension: usize, eval_at: F, proof: &[F], @@ -46,40 +46,40 @@ pub(crate) fn generate_verification_message( let unpacked = unpack_proof(proof, dimension)?; let n: usize = (dimension + 1).next_power_of_two(); let proof_length = 2 * n; - let mut fft_in = vec![F::zero(); proof_length]; - let mut fft_mem = vec![F::zero(); proof_length]; + let mut ntt_in = vec![F::zero(); proof_length]; + let mut ntt_mem = vec![F::zero(); proof_length]; // construct and evaluate polynomial f at the random point - fft_in[0] = *unpacked.f0; - fft_in[1..unpacked.data.len() + 1].copy_from_slice(unpacked.data); - let f_r = poly_interpret_eval(&fft_in[..n], eval_at, &mut fft_mem); + ntt_in[0] = *unpacked.f0; + ntt_in[1..unpacked.data.len() + 1].copy_from_slice(unpacked.data); + let f_r = poly_interpret_eval(&ntt_in[..n], eval_at, &mut ntt_mem); // construct and evaluate polynomial g at the random point - fft_in[0] = *unpacked.g0; + ntt_in[0] = *unpacked.g0; if is_first_server { - for x in fft_in[1..unpacked.data.len() + 1].iter_mut() { + for x in ntt_in[1..unpacked.data.len() + 1].iter_mut() { *x -= F::one(); } } - let g_r = poly_interpret_eval(&fft_in[..n], eval_at, &mut fft_mem); + let g_r = poly_interpret_eval(&ntt_in[..n], eval_at, &mut ntt_mem); // construct and evaluate polynomial h at the random point - fft_in[0] = *unpacked.h0; - fft_in[1] = unpacked.points_h_packed[0]; + ntt_in[0] = *unpacked.h0; + ntt_in[1] = unpacked.points_h_packed[0]; for (x, chunk) in unpacked.points_h_packed[1..] .iter() - .zip(fft_in[2..proof_length].chunks_exact_mut(2)) + .zip(ntt_in[2..proof_length].chunks_exact_mut(2)) { chunk[0] = F::zero(); chunk[1] = *x; } - let h_r = poly_interpret_eval(&fft_in, eval_at, &mut fft_mem); + let h_r = poly_interpret_eval(&ntt_in, eval_at, &mut ntt_mem); Ok(VerificationMessage { f_r, g_r, h_r }) } /// Decides if the distributed proof is valid -pub(crate) fn is_valid_share( +pub(crate) fn is_valid_share( v1: &VerificationMessage, v2: &VerificationMessage, ) -> bool { @@ -95,7 +95,7 @@ pub(crate) fn is_valid_share( mod test_util { use crate::{ codec::ParameterizedDecode, - field::{merge_vector, FftFriendlyFieldElement}, + field::{merge_vector, NttFriendlyFieldElement}, prng::Prng, vdaf::{ prio2::client::{proof_length, SerializeError}, @@ -113,7 +113,7 @@ mod test_util { accumulator: Vec, } - impl Server { + impl Server { /// Construct a new server instance /// /// Params: diff --git a/src/vdaf/prio3.rs b/src/vdaf/prio3.rs index 94edc192..0b7e51b6 100644 --- a/src/vdaf/prio3.rs +++ b/src/vdaf/prio3.rs @@ -34,7 +34,7 @@ use crate::codec::{encode_fixlen_items, CodecError, Decode, Encode, Parameterize #[cfg(feature = "experimental")] use crate::dp::DifferentialPrivacyStrategy; use crate::field::{ - decode_fieldvec, FftFriendlyFieldElement, FieldElement, FieldElementWithInteger, + decode_fieldvec, FieldElement, FieldElementWithInteger, NttFriendlyFieldElement, }; use crate::field::{Field128, Field64}; #[cfg(feature = "multithreaded")] @@ -968,7 +968,7 @@ impl ConstantTimeEq for Prio3InputSha } } -impl Encode for Prio3InputShare { +impl Encode for Prio3InputShare { fn encode(&self, bytes: &mut Vec) -> Result<(), CodecError> { match self { Prio3InputShare::Leader { @@ -1102,7 +1102,7 @@ impl ConstantTimeEq for Prio3PrepareS } } -impl Encode +impl Encode for Prio3PrepareShare { fn encode(&self, bytes: &mut Vec) -> Result<(), CodecError> { @@ -1125,7 +1125,7 @@ impl Encode } } -impl +impl ParameterizedDecode> for Prio3PrepareShare { fn decode_with_param( @@ -1192,7 +1192,7 @@ impl Encode for Prio3PrepareMessage { } } -impl +impl ParameterizedDecode> for Prio3PrepareMessage { fn decode_with_param( @@ -1276,7 +1276,7 @@ impl Debug for Prio3PrepareState { } } -impl Encode +impl Encode for Prio3PrepareState { /// Append the encoded form of this object to the end of `bytes`, growing the vector as needed.