Skip to content

Commit

Permalink
Parallel fft
Browse files Browse the repository at this point in the history
  • Loading branch information
spapinistarkware authored and alonh5 committed Sep 18, 2024
1 parent a51f630 commit b3b9ee0
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 18 deletions.
7 changes: 2 additions & 5 deletions crates/prover/src/core/backend/simd/blake2s.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ use crate::core::fields::m31::BaseField;
use crate::core::vcs::blake2_hash::Blake2sHash;
use crate::core::vcs::blake2_merkle::Blake2sMerkleHasher;
use crate::core::vcs::ops::{MerkleHasher, MerkleOps};
use crate::parallel_iter;

const IV: [u32; 8] = [
0x6A09E667, 0xBB67AE85, 0x3C6EF372, 0xA54FF53A, 0x510E527F, 0x9B05688C, 0x1F83D9AB, 0x5BE0CD19,
Expand Down Expand Up @@ -51,11 +52,7 @@ impl MerkleOps<Blake2sMerkleHasher> for SimdBackend {
columns: &[&Col<Self, BaseField>],
) -> Vec<Blake2sHash> {
if log_size < LOG_N_LANES {
#[cfg(not(feature = "parallel"))]
let iter = 0..1 << log_size;

#[cfg(feature = "parallel")]
let iter = (0..1 << log_size).into_par_iter();
let iter = parallel_iter!(0..1 << log_size);

return iter
.map(|i| {
Expand Down
21 changes: 17 additions & 4 deletions crates/prover/src/core/backend/simd/fft/ifft.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,18 @@
use std::simd::{simd_swizzle, u32x16, u32x2, u32x4};

use itertools::Itertools;
#[cfg(feature = "parallel")]
use rayon::prelude::*;

use super::{
compute_first_twiddles, mul_twiddle, transpose_vecs, CACHED_FFT_LOG_SIZE, MIN_FFT_LOG_SIZE,
};
use crate::core::backend::simd::fft::UnsafeMutI32;
use crate::core::backend::simd::m31::{PackedBaseField, LOG_N_LANES};
use crate::core::circle::Coset;
use crate::core::fields::FieldExpOps;
use crate::core::utils::bit_reverse;
use crate::parallel_iter;

/// Performs an Inverse Circle Fast Fourier Transform (ICFFT) on the given values.
///
Expand All @@ -29,6 +33,7 @@ use crate::core::utils::bit_reverse;
/// Behavior is undefined if `values` does not have the same alignment as [`PackedBaseField`].
pub unsafe fn ifft(values: *mut u32, twiddle_dbl: &[&[u32]], log_n_elements: usize) {
assert!(log_n_elements >= MIN_FFT_LOG_SIZE as usize);

let log_n_vecs = log_n_elements - LOG_N_LANES as usize;
if log_n_elements <= CACHED_FFT_LOG_SIZE as usize {
ifft_lower_with_vecwise(values, twiddle_dbl, log_n_elements, log_n_elements);
Expand Down Expand Up @@ -81,7 +86,11 @@ pub unsafe fn ifft_lower_with_vecwise(

assert_eq!(twiddle_dbl[0].len(), 1 << (log_size - 2));

for index_h in 0..1 << (log_size - fft_layers) {
let iter = parallel_iter!(0..1 << (log_size - fft_layers));

let values = UnsafeMutI32(values);
iter.for_each(|index_h| {
let values = values.get();
ifft_vecwise_loop(values, twiddle_dbl, fft_layers - VECWISE_FFT_BITS, index_h);
for layer in (VECWISE_FFT_BITS..fft_layers).step_by(3) {
match fft_layers - layer {
Expand All @@ -102,7 +111,7 @@ pub unsafe fn ifft_lower_with_vecwise(
}
}
}
}
});
}

/// Computes partial ifft on `2^log_size` M31 elements, skipping the vecwise layers (lower 4 bits of
Expand Down Expand Up @@ -131,7 +140,11 @@ pub unsafe fn ifft_lower_without_vecwise(
) {
assert!(log_size >= LOG_N_LANES as usize);

for index_h in 0..1 << (log_size - fft_layers - LOG_N_LANES as usize) {
let iter = parallel_iter!(0..1 << (log_size - fft_layers - LOG_N_LANES as usize));

let values = UnsafeMutI32(values);
iter.for_each(|index_h| {
let values = values.get();
for layer in (0..fft_layers).step_by(3) {
let fixed_layer = layer + LOG_N_LANES as usize;
match fft_layers - layer {
Expand All @@ -152,7 +165,7 @@ pub unsafe fn ifft_lower_without_vecwise(
}
}
}
}
});
}

/// Runs the first 5 ifft layers across the entire array.
Expand Down
35 changes: 32 additions & 3 deletions crates/prover/src/core/backend/simd/fft/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
use std::simd::{simd_swizzle, u32x16, u32x8};

#[cfg(feature = "parallel")]
use rayon::prelude::*;

use super::m31::PackedBaseField;
use crate::core::fields::m31::P;
use crate::parallel_iter;

pub mod ifft;
pub mod rfft;
Expand All @@ -10,6 +14,26 @@ pub const CACHED_FFT_LOG_SIZE: u32 = 16;

pub const MIN_FFT_LOG_SIZE: u32 = 5;

pub struct UnsafeMutI32(pub *mut u32);
impl UnsafeMutI32 {
pub fn get(&self) -> *mut u32 {
self.0
}
}

unsafe impl Send for UnsafeMutI32 {}
unsafe impl Sync for UnsafeMutI32 {}

pub struct UnsafeConstI32(pub *const u32);
impl UnsafeConstI32 {
pub fn get(&self) -> *const u32 {
self.0
}
}

unsafe impl Send for UnsafeConstI32 {}
unsafe impl Sync for UnsafeConstI32 {}

// TODO(spapini): FFTs return a redundant representation, that can get the value P. need to reduce
// it somewhere.

Expand All @@ -29,8 +53,13 @@ pub const MIN_FFT_LOG_SIZE: u32 = 5;
/// Behavior is undefined if `values` does not have the same alignment as [`u32x16`].
pub unsafe fn transpose_vecs(values: *mut u32, log_n_vecs: usize) {
let half = log_n_vecs / 2;
for b in 0..1 << (log_n_vecs & 1) {
for a in 0..1 << half {

let iter = parallel_iter!(0..1 << half);

let values = UnsafeMutI32(values);
iter.for_each(|a| {
let values = values.get();
for b in 0..1 << (log_n_vecs & 1) {
for c in 0..1 << half {
let i = (a << (log_n_vecs - half)) | (b << half) | c;
let j = (c << (log_n_vecs - half)) | (b << half) | a;
Expand All @@ -43,7 +72,7 @@ pub unsafe fn transpose_vecs(values: *mut u32, log_n_vecs: usize) {
store(values.add(j << 4), val0);
}
}
}
});
}

/// Computes the twiddles for the first fft layer from the second, and loads both to SIMD registers.
Expand Down
26 changes: 20 additions & 6 deletions crates/prover/src/core/backend/simd/fft/rfft.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,17 @@ use std::array;
use std::simd::{simd_swizzle, u32x16, u32x2, u32x4, u32x8};

use itertools::Itertools;
#[cfg(feature = "parallel")]
use rayon::prelude::*;

use super::{
compute_first_twiddles, mul_twiddle, transpose_vecs, CACHED_FFT_LOG_SIZE, MIN_FFT_LOG_SIZE,
};
use crate::core::backend::simd::fft::{UnsafeConstI32, UnsafeMutI32};
use crate::core::backend::simd::m31::{PackedBaseField, LOG_N_LANES};
use crate::core::circle::Coset;
use crate::core::utils::bit_reverse;
use crate::parallel_iter;

/// Performs a Circle Fast Fourier Transform (CFFT) on the given values.
///
Expand Down Expand Up @@ -86,8 +90,13 @@ pub unsafe fn fft_lower_with_vecwise(

assert_eq!(twiddle_dbl[0].len(), 1 << (log_size - 2));

for index_h in 0..1 << (log_size - fft_layers) {
let mut src = src;
let iter = parallel_iter!(0..1 << (log_size - fft_layers));

let src = UnsafeConstI32(src);
let dst = UnsafeMutI32(dst);
iter.for_each(|index_h| {
let mut src = src.get();
let dst = dst.get();
for layer in (VECWISE_FFT_BITS..fft_layers).step_by(3).rev() {
match fft_layers - layer {
1 => {
Expand Down Expand Up @@ -116,7 +125,7 @@ pub unsafe fn fft_lower_with_vecwise(
fft_layers - VECWISE_FFT_BITS,
index_h,
);
}
});
}

/// Computes partial fft on `2^log_size` M31 elements, skipping the vecwise layers (lower 4 bits of
Expand Down Expand Up @@ -147,8 +156,13 @@ pub unsafe fn fft_lower_without_vecwise(
) {
assert!(log_size >= LOG_N_LANES as usize);

for index_h in 0..1 << (log_size - fft_layers - LOG_N_LANES as usize) {
let mut src = src;
let iter = parallel_iter!(0..1 << (log_size - fft_layers - LOG_N_LANES as usize));

let src = UnsafeConstI32(src);
let dst = UnsafeMutI32(dst);
iter.for_each(|index_h| {
let mut src = src.get();
let dst = dst.get();
for layer in (0..fft_layers).step_by(3).rev() {
let fixed_layer = layer + LOG_N_LANES as usize;
match fft_layers - layer {
Expand All @@ -171,7 +185,7 @@ pub unsafe fn fft_lower_without_vecwise(
}
src = dst;
}
}
});
}

/// Runs the last 5 fft layers across the entire array.
Expand Down
13 changes: 13 additions & 0 deletions crates/prover/src/core/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,16 @@ impl<T> DerefMut for ComponentVec<T> {
&mut self.0
}
}

#[macro_export]
macro_rules! parallel_iter {
($i: expr) => {{
#[cfg(not(feature = "parallel"))]
let iter = $i;

#[cfg(feature = "parallel")]
let iter = $i.into_par_iter();

iter
}};
}

0 comments on commit b3b9ee0

Please sign in to comment.