Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parallel fft #819

Merged
merged 1 commit into from
Sep 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 2 additions & 7 deletions crates/prover/src/core/backend/simd/blake2s.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ use crate::core::fields::m31::BaseField;
use crate::core::vcs::blake2_hash::Blake2sHash;
use crate::core::vcs::blake2_merkle::Blake2sMerkleHasher;
use crate::core::vcs::ops::{MerkleHasher, MerkleOps};
use crate::parallel_iter;

const IV: [u32; 8] = [
0x6A09E667, 0xBB67AE85, 0x3C6EF372, 0xA54FF53A, 0x510E527F, 0x9B05688C, 0x1F83D9AB, 0x5BE0CD19,
Expand Down Expand Up @@ -51,13 +52,7 @@ impl MerkleOps<Blake2sMerkleHasher> for SimdBackend {
columns: &[&Col<Self, BaseField>],
) -> Vec<Blake2sHash> {
if log_size < LOG_N_LANES {
#[cfg(not(feature = "parallel"))]
let iter = 0..1 << log_size;

#[cfg(feature = "parallel")]
let iter = (0..1 << log_size).into_par_iter();

return iter
return parallel_iter!(0..1 << log_size)
.map(|i| {
Blake2sMerkleHasher::hash_node(
prev_layer.map(|prev_layer| (prev_layer[2 * i], prev_layer[2 * i + 1])),
Expand Down
17 changes: 13 additions & 4 deletions crates/prover/src/core/backend/simd/fft/ifft.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,18 @@
use std::simd::{simd_swizzle, u32x16, u32x2, u32x4};

use itertools::Itertools;
#[cfg(feature = "parallel")]
use rayon::prelude::*;

use super::{
compute_first_twiddles, mul_twiddle, transpose_vecs, CACHED_FFT_LOG_SIZE, MIN_FFT_LOG_SIZE,
};
use crate::core::backend::simd::fft::UnsafeMutI32;
use crate::core::backend::simd::m31::{PackedBaseField, LOG_N_LANES};
use crate::core::circle::Coset;
use crate::core::fields::FieldExpOps;
use crate::core::utils::bit_reverse;
use crate::parallel_iter;

/// Performs an Inverse Circle Fast Fourier Transform (ICFFT) on the given values.
///
Expand All @@ -29,6 +33,7 @@ use crate::core::utils::bit_reverse;
/// Behavior is undefined if `values` does not have the same alignment as [`PackedBaseField`].
pub unsafe fn ifft(values: *mut u32, twiddle_dbl: &[&[u32]], log_n_elements: usize) {
assert!(log_n_elements >= MIN_FFT_LOG_SIZE as usize);

let log_n_vecs = log_n_elements - LOG_N_LANES as usize;
if log_n_elements <= CACHED_FFT_LOG_SIZE as usize {
ifft_lower_with_vecwise(values, twiddle_dbl, log_n_elements, log_n_elements);
Expand Down Expand Up @@ -81,7 +86,9 @@ pub unsafe fn ifft_lower_with_vecwise(

assert_eq!(twiddle_dbl[0].len(), 1 << (log_size - 2));

for index_h in 0..1 << (log_size - fft_layers) {
let values = UnsafeMutI32(values);
parallel_iter!(0..1 << (log_size - fft_layers)).for_each(|index_h| {
let values = values.get();
ifft_vecwise_loop(values, twiddle_dbl, fft_layers - VECWISE_FFT_BITS, index_h);
for layer in (VECWISE_FFT_BITS..fft_layers).step_by(3) {
match fft_layers - layer {
Expand All @@ -102,7 +109,7 @@ pub unsafe fn ifft_lower_with_vecwise(
}
}
}
}
});
}

/// Computes partial ifft on `2^log_size` M31 elements, skipping the vecwise layers (lower 4 bits of
Expand Down Expand Up @@ -131,7 +138,9 @@ pub unsafe fn ifft_lower_without_vecwise(
) {
assert!(log_size >= LOG_N_LANES as usize);

for index_h in 0..1 << (log_size - fft_layers - LOG_N_LANES as usize) {
let values = UnsafeMutI32(values);
parallel_iter!(0..1 << (log_size - fft_layers - LOG_N_LANES as usize)).for_each(|index_h| {
let values = values.get();
for layer in (0..fft_layers).step_by(3) {
let fixed_layer = layer + LOG_N_LANES as usize;
match fft_layers - layer {
Expand All @@ -152,7 +161,7 @@ pub unsafe fn ifft_lower_without_vecwise(
}
}
}
}
});
}

/// Runs the first 5 ifft layers across the entire array.
Expand Down
34 changes: 31 additions & 3 deletions crates/prover/src/core/backend/simd/fft/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
use std::simd::{simd_swizzle, u32x16, u32x8};

#[cfg(feature = "parallel")]
use rayon::prelude::*;

use super::m31::PackedBaseField;
use crate::core::fields::m31::P;
use crate::parallel_iter;

pub mod ifft;
pub mod rfft;
Expand All @@ -10,6 +14,27 @@ pub const CACHED_FFT_LOG_SIZE: u32 = 16;

pub const MIN_FFT_LOG_SIZE: u32 = 5;

// TODO(andrew): Examine usage of unsafe in SIMD FFT.
pub struct UnsafeMutI32(pub *mut u32);
impl UnsafeMutI32 {
pub fn get(&self) -> *mut u32 {
self.0
}
}

unsafe impl Send for UnsafeMutI32 {}
unsafe impl Sync for UnsafeMutI32 {}

pub struct UnsafeConstI32(pub *const u32);
impl UnsafeConstI32 {
pub fn get(&self) -> *const u32 {
self.0
}
}

unsafe impl Send for UnsafeConstI32 {}
unsafe impl Sync for UnsafeConstI32 {}

// TODO(andrew): FFTs return a redundant representation, that can get the value P. need to deal with
// it. Either: reduce before commitment or regenerate proof with new seed if redundant value
// decommitted.
Expand All @@ -30,8 +55,11 @@ pub const MIN_FFT_LOG_SIZE: u32 = 5;
/// Behavior is undefined if `values` does not have the same alignment as [`u32x16`].
pub unsafe fn transpose_vecs(values: *mut u32, log_n_vecs: usize) {
let half = log_n_vecs / 2;
for b in 0..1 << (log_n_vecs & 1) {
for a in 0..1 << half {

let values = UnsafeMutI32(values);
parallel_iter!(0..1 << half).for_each(|a| {
let values = values.get();
for b in 0..1 << (log_n_vecs & 1) {
for c in 0..1 << half {
let i = (a << (log_n_vecs - half)) | (b << half) | c;
let j = (c << (log_n_vecs - half)) | (b << half) | a;
Expand All @@ -44,7 +72,7 @@ pub unsafe fn transpose_vecs(values: *mut u32, log_n_vecs: usize) {
store(values.add(j << 4), val0);
}
}
}
});
}

/// Computes the twiddles for the first fft layer from the second, and loads both to SIMD registers.
Expand Down
22 changes: 16 additions & 6 deletions crates/prover/src/core/backend/simd/fft/rfft.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,17 @@ use std::array;
use std::simd::{simd_swizzle, u32x16, u32x2, u32x4, u32x8};

use itertools::Itertools;
#[cfg(feature = "parallel")]
use rayon::prelude::*;

use super::{
compute_first_twiddles, mul_twiddle, transpose_vecs, CACHED_FFT_LOG_SIZE, MIN_FFT_LOG_SIZE,
};
use crate::core::backend::simd::fft::{UnsafeConstI32, UnsafeMutI32};
use crate::core::backend::simd::m31::{PackedBaseField, LOG_N_LANES};
use crate::core::circle::Coset;
use crate::core::utils::bit_reverse;
use crate::parallel_iter;

/// Performs a Circle Fast Fourier Transform (CFFT) on the given values.
///
Expand Down Expand Up @@ -86,8 +90,11 @@ pub unsafe fn fft_lower_with_vecwise(

assert_eq!(twiddle_dbl[0].len(), 1 << (log_size - 2));

for index_h in 0..1 << (log_size - fft_layers) {
let mut src = src;
let src = UnsafeConstI32(src);
let dst = UnsafeMutI32(dst);
parallel_iter!(0..1 << (log_size - fft_layers)).for_each(|index_h| {
let mut src = src.get();
let dst = dst.get();
for layer in (VECWISE_FFT_BITS..fft_layers).step_by(3).rev() {
match fft_layers - layer {
1 => {
Expand Down Expand Up @@ -116,7 +123,7 @@ pub unsafe fn fft_lower_with_vecwise(
fft_layers - VECWISE_FFT_BITS,
index_h,
);
}
});
}

/// Computes partial fft on `2^log_size` M31 elements, skipping the vecwise layers (lower 4 bits of
Expand Down Expand Up @@ -147,8 +154,11 @@ pub unsafe fn fft_lower_without_vecwise(
) {
assert!(log_size >= LOG_N_LANES as usize);

for index_h in 0..1 << (log_size - fft_layers - LOG_N_LANES as usize) {
let mut src = src;
let src = UnsafeConstI32(src);
let dst = UnsafeMutI32(dst);
parallel_iter!(0..1 << (log_size - fft_layers - LOG_N_LANES as usize)).for_each(|index_h| {
let mut src = src.get();
let dst = dst.get();
for layer in (0..fft_layers).step_by(3).rev() {
let fixed_layer = layer + LOG_N_LANES as usize;
match fft_layers - layer {
Expand All @@ -171,7 +181,7 @@ pub unsafe fn fft_lower_without_vecwise(
}
src = dst;
}
}
});
}

/// Runs the last 5 fft layers across the entire array.
Expand Down
13 changes: 13 additions & 0 deletions crates/prover/src/core/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,16 @@ impl<T> DerefMut for ComponentVec<T> {
&mut self.0
}
}

#[macro_export]
macro_rules! parallel_iter {
($i: expr) => {{
#[cfg(not(feature = "parallel"))]
let iter = $i.into_iter();

#[cfg(feature = "parallel")]
let iter = $i.into_par_iter();

iter
}};
}
Loading