Skip to content

Commit

Permalink
Rename FFT to NTT
Browse files Browse the repository at this point in the history
  • Loading branch information
cjpatton committed Jan 17, 2025
1 parent 9eac4c1 commit 3797e57
Show file tree
Hide file tree
Showing 15 changed files with 250 additions and 256 deletions.
8 changes: 4 additions & 4 deletions benches/speed_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,16 +81,16 @@ pub fn dp_noise(c: &mut Criterion) {
group.finish();
}

/// The asymptotic cost of polynomial multiplication is `O(n log n)` using FFT and `O(n^2)` using
/// The asymptotic cost of polynomial multiplication is `O(n log n)` using NTT and `O(n^2)` using
/// the naive method. This benchmark demonstrates that the latter has better concrete performance
/// for small polynomials. The result is used to pick the `FFT_THRESHOLD` constant in
/// for small polynomials. The result is used to pick the `NTT_THRESHOLD` constant in
/// `src/flp/gadgets.rs`.
fn poly_mul(c: &mut Criterion) {
let test_sizes = [1_usize, 30, 60, 90, 120, 150, 255];

let mut group = c.benchmark_group("poly_mul");
for size in test_sizes {
group.bench_with_input(BenchmarkId::new("fft", size), &size, |b, size| {
group.bench_with_input(BenchmarkId::new("ntt", size), &size, |b, size| {
let m = (size + 1).next_power_of_two();
let mut g: Mul<F> = Mul::new(*size);
let mut outp = vec![F::zero(); 2 * m];
Expand All @@ -99,7 +99,7 @@ fn poly_mul(c: &mut Criterion) {
inp.push(random_vector(m));

b.iter(|| {
benchmarked_gadget_mul_call_poly_fft(&mut g, &mut outp, &inp).unwrap();
benchmarked_gadget_mul_call_poly_ntt(&mut g, &mut outp, &inp).unwrap();
})
});

Expand Down
30 changes: 15 additions & 15 deletions src/benchmarked.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,37 +5,37 @@
//! This module provides wrappers around internal components of this crate that we want to
//! benchmark, but which we don't want to expose in the public API.
use crate::fft::discrete_fourier_transform;
use crate::field::FftFriendlyFieldElement;
use crate::field::NttFriendlyFieldElement;
use crate::flp::gadgets::Mul;
use crate::flp::FlpError;
use crate::polynomial::{fft_get_roots, poly_fft, PolyFFTTempMemory};
use crate::ntt::ntt;
use crate::polynomial::{ntt_get_roots, poly_ntt, PolyNttTempMemory};

/// Sets `outp` to the Discrete Fourier Transform (DFT) using an iterative FFT algorithm.
pub fn benchmarked_iterative_fft<F: FftFriendlyFieldElement>(outp: &mut [F], inp: &[F]) {
discrete_fourier_transform(outp, inp, inp.len()).unwrap();
/// Sets `outp` to the Discrete Fourier Transform (DFT) using an iterative NTT algorithm.
pub fn benchmarked_iterative_ntt<F: NttFriendlyFieldElement>(outp: &mut [F], inp: &[F]) {
ntt(outp, inp, inp.len()).unwrap();
}

/// Sets `outp` to the Discrete Fourier Transform (DFT) using a recursive FFT algorithm.
pub fn benchmarked_recursive_fft<F: FftFriendlyFieldElement>(outp: &mut [F], inp: &[F]) {
let roots_2n = fft_get_roots(inp.len(), false);
let mut fft_memory = PolyFFTTempMemory::new(inp.len());
poly_fft(outp, inp, &roots_2n, inp.len(), false, &mut fft_memory)
/// Sets `outp` to the Discrete Fourier Transform (DFT) using a recursive NTT algorithm.
pub fn benchmarked_recursive_ntt<F: NttFriendlyFieldElement>(outp: &mut [F], inp: &[F]) {
let roots_2n = ntt_get_roots(inp.len(), false);
let mut ntt_memory = PolyNttTempMemory::new(inp.len());
poly_ntt(outp, inp, &roots_2n, inp.len(), false, &mut ntt_memory)
}

/// Sets `outp` to `inp[0] * inp[1]`, where `inp[0]` and `inp[1]` are polynomials. This function
/// uses FFT for multiplication.
pub fn benchmarked_gadget_mul_call_poly_fft<F: FftFriendlyFieldElement>(
/// uses NTT for multiplication.
pub fn benchmarked_gadget_mul_call_poly_ntt<F: NttFriendlyFieldElement>(
g: &mut Mul<F>,
outp: &mut [F],
inp: &[Vec<F>],
) -> Result<(), FlpError> {
g.call_poly_fft(outp, inp)
g.call_poly_ntt(outp, inp)
}

/// Sets `outp` to `inp[0] * inp[1]`, where `inp[0]` and `inp[1]` are polynomials. This function
/// does the multiplication directly.
pub fn benchmarked_gadget_mul_call_poly_direct<F: FftFriendlyFieldElement>(
pub fn benchmarked_gadget_mul_call_poly_direct<F: NttFriendlyFieldElement>(
g: &mut Mul<F>,
outp: &mut [F],
inp: &[Vec<F>],
Expand Down
12 changes: 6 additions & 6 deletions src/field.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
//! Finite field arithmetic.
//!
//! Basic field arithmetic is captured in the [`FieldElement`] trait. Fields used in Prio implement
//! [`FftFriendlyFieldElement`], and have an associated element called the "generator" that
//! [`NttFriendlyFieldElement`], and have an associated element called the "generator" that
//! generates a multiplicative subgroup of order `2^n` for some `n`.
use crate::{
Expand Down Expand Up @@ -406,13 +406,13 @@ impl<'de, F: FieldElement> Visitor<'de> for FieldElementVisitor<F> {
/// Objects with this trait represent an element of `GF(p)`, where `p` is some prime and the
/// field's multiplicative group has a subgroup with an order that is a power of 2, and at least
/// `2^20`.
pub trait FftFriendlyFieldElement: FieldElementWithInteger {
pub trait NttFriendlyFieldElement: FieldElementWithInteger {
/// Returns the size of the multiplicative subgroup generated by
/// [`FftFriendlyFieldElement::generator`].
/// [`NttFriendlyFieldElement::generator`].
fn generator_order() -> Self::Integer;

/// Returns the generator of the multiplicative subgroup of size
/// [`FftFriendlyFieldElement::generator_order`].
/// [`NttFriendlyFieldElement::generator_order`].
fn generator() -> Self;

/// Returns the `2^l`-th principal root of unity for any `l <= 20`. Note that the `2^0`-th
Expand Down Expand Up @@ -751,7 +751,7 @@ macro_rules! make_field {
}
}

impl FftFriendlyFieldElement for $elem {
impl NttFriendlyFieldElement for $elem {
fn generator() -> Self {
Self($fp::G)
}
Expand Down Expand Up @@ -1168,7 +1168,7 @@ mod tests {
assert_matches!(result, Err(FieldError::InputSizeMismatch));
}

fn field_element_test<F: FftFriendlyFieldElement + Hash>() {
fn field_element_test<F: NttFriendlyFieldElement + Hash>() {
field_element_test_common::<F>();

let mut prng: Prng<F, _> = Prng::new();
Expand Down
42 changes: 21 additions & 21 deletions src/flp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,9 @@
#[cfg(feature = "experimental")]
use crate::dp::DifferentialPrivacyStrategy;
use crate::fft::{discrete_fourier_transform, discrete_fourier_transform_inv_finish, FftError};
use crate::field::{FftFriendlyFieldElement, FieldElement, FieldElementWithInteger, FieldError};
use crate::field::{FieldElement, FieldElementWithInteger, FieldError, NttFriendlyFieldElement};
use crate::fp::log2;
use crate::ntt::{ntt, ntt_inv_finish, NttError};
use crate::polynomial::poly_eval;
use std::any::Any;
use std::convert::TryFrom;
Expand Down Expand Up @@ -99,9 +99,9 @@ pub enum FlpError {
#[error("invalid paramter: {0}")]
InvalidParameter(String),

/// Returned if an FFT operation propagates an error.
#[error("FFT error: {0}")]
Fft(#[from] FftError),
/// Returned if an NTT operation propagates an error.
#[error("NTT error: {0}")]
Ntt(#[from] NttError),

/// Returned if a field operation encountered an error.
#[error("Field error: {0}")]
Expand All @@ -124,7 +124,7 @@ pub trait Type: Sized + Eq + Clone + Debug {
type AggregateResult: Clone + Debug;

/// The finite field used for this type.
type Field: FftFriendlyFieldElement;
type Field: NttFriendlyFieldElement;

/// Encodes a measurement as a vector of [`Self::input_len`] field elements.
fn encode_measurement(
Expand Down Expand Up @@ -299,7 +299,7 @@ pub trait Type: Sized + Eq + Clone + Debug {
.map(|shim| {
let gadget_poly_len = gadget_poly_len(shim.degree(), wire_poly_len(shim.calls()));

// Computing the gadget polynomial using FFT requires an amount of memory that is a
// Computing the gadget polynomial using NTT requires an amount of memory that is a
// power of 2. Thus we choose the smallest power of 2 that is at least as large as
// the gadget polynomial. The wire seeds are encoded in the proof, too, so we
// include the arity of the gadget to ensure there is always enough room at the end
Expand Down Expand Up @@ -336,8 +336,8 @@ pub trait Type: Sized + Eq + Clone + Debug {
.zip(gadget.f_vals[..gadget.arity()].iter())
.zip(proof[proof_len..proof_len + gadget.arity()].iter_mut())
{
discrete_fourier_transform(coefficients, values, m)?;
discrete_fourier_transform_inv_finish(coefficients, m, m_inv);
ntt(coefficients, values, m)?;
ntt_inv_finish(coefficients, m, m_inv);

// The first point on each wire polynomial is a random value chosen by the prover. This
// point is stored in the proof so that the verifier can reconstruct the wire
Expand Down Expand Up @@ -503,8 +503,8 @@ pub trait Type: Sized + Eq + Clone + Debug {
.inv();
let mut f = vec![Self::Field::zero(); m];
for wire in 0..gadget.arity() {
discrete_fourier_transform(&mut f, &gadget.f_vals[wire], m)?;
discrete_fourier_transform_inv_finish(&mut f, m, m_inv);
ntt(&mut f, &gadget.f_vals[wire], m)?;
ntt_inv_finish(&mut f, m, m_inv);
verifier.push(poly_eval(&f, *query_rand_val));
}

Expand Down Expand Up @@ -607,7 +607,7 @@ where
}

/// A gadget, a non-affine arithmetic circuit that is called when evaluating a validity circuit.
pub trait Gadget<F: FftFriendlyFieldElement>: Debug {
pub trait Gadget<F: NttFriendlyFieldElement>: Debug {
/// Evaluates the gadget on input `inp` and returns the output.
fn call(&mut self, inp: &[F]) -> Result<F, FlpError>;

Expand All @@ -632,7 +632,7 @@ pub trait Gadget<F: FftFriendlyFieldElement>: Debug {
/// A "shim" gadget used during proof generation to record the input wires each time a gadget is
/// evaluated.
#[derive(Debug)]
struct ProveShimGadget<F: FftFriendlyFieldElement> {
struct ProveShimGadget<F: NttFriendlyFieldElement> {
inner: Box<dyn Gadget<F>>,

/// Points at which the wire polynomials are interpolated.
Expand All @@ -642,7 +642,7 @@ struct ProveShimGadget<F: FftFriendlyFieldElement> {
ct: usize,
}

impl<F: FftFriendlyFieldElement> ProveShimGadget<F> {
impl<F: NttFriendlyFieldElement> ProveShimGadget<F> {
fn new(inner: Box<dyn Gadget<F>>, prove_rand: &[F]) -> Result<Self, FlpError> {
let mut f_vals = vec![vec![F::zero(); 1 + inner.calls()]; inner.arity()];

Expand All @@ -661,7 +661,7 @@ impl<F: FftFriendlyFieldElement> ProveShimGadget<F> {
}
}

impl<F: FftFriendlyFieldElement> Gadget<F> for ProveShimGadget<F> {
impl<F: NttFriendlyFieldElement> Gadget<F> for ProveShimGadget<F> {
fn call(&mut self, inp: &[F]) -> Result<F, FlpError> {
for (wire_poly_vals, inp_val) in self.f_vals[..inp.len()].iter_mut().zip(inp.iter()) {
wire_poly_vals[self.ct] = *inp_val;
Expand Down Expand Up @@ -694,7 +694,7 @@ impl<F: FftFriendlyFieldElement> Gadget<F> for ProveShimGadget<F> {
/// A "shim" gadget used during proof verification to record the points at which the intermediate
/// proof polynomials are evaluated.
#[derive(Debug)]
struct QueryShimGadget<F: FftFriendlyFieldElement> {
struct QueryShimGadget<F: NttFriendlyFieldElement> {
inner: Box<dyn Gadget<F>>,

/// Points at which intermediate proof polynomials are interpolated.
Expand All @@ -713,7 +713,7 @@ struct QueryShimGadget<F: FftFriendlyFieldElement> {
ct: usize,
}

impl<F: FftFriendlyFieldElement> QueryShimGadget<F> {
impl<F: NttFriendlyFieldElement> QueryShimGadget<F> {
fn new(inner: Box<dyn Gadget<F>>, r: F, proof_data: &[F]) -> Result<Self, FlpError> {
let gadget_degree = inner.degree();
let gadget_arity = inner.arity();
Expand All @@ -731,7 +731,7 @@ impl<F: FftFriendlyFieldElement> QueryShimGadget<F> {
// Evaluate the gadget polynomial at roots of unity.
let size = p.next_power_of_two();
let mut p_vals = vec![F::zero(); size];
discrete_fourier_transform(&mut p_vals, &proof_data[gadget_arity..], size)?;
ntt(&mut p_vals, &proof_data[gadget_arity..], size)?;

// The step is used to compute the element of `p_val` that will be returned by a call to
// the gadget.
Expand All @@ -751,7 +751,7 @@ impl<F: FftFriendlyFieldElement> QueryShimGadget<F> {
}
}

impl<F: FftFriendlyFieldElement> Gadget<F> for QueryShimGadget<F> {
impl<F: NttFriendlyFieldElement> Gadget<F> for QueryShimGadget<F> {
fn call(&mut self, inp: &[F]) -> Result<F, FlpError> {
for (wire_poly_vals, inp_val) in self.f_vals[..inp.len()].iter_mut().zip(inp.iter()) {
wire_poly_vals[self.ct] = *inp_val;
Expand Down Expand Up @@ -1077,7 +1077,7 @@ mod tests {
}
}

impl<F: FftFriendlyFieldElement> Type for TestType<F> {
impl<F: NttFriendlyFieldElement> Type for TestType<F> {
type Measurement = F::Integer;
type AggregateResult = F::Integer;
type Field = F;
Expand Down Expand Up @@ -1215,7 +1215,7 @@ mod tests {
}
}

impl<F: FftFriendlyFieldElement> Type for Issue254Type<F> {
impl<F: NttFriendlyFieldElement> Type for Issue254Type<F> {
type Measurement = F::Integer;
type AggregateResult = F::Integer;
type Field = F;
Expand Down
Loading

0 comments on commit 3797e57

Please sign in to comment.