From b3b9ee02927d0a5af504aae6f1d50af29ce467db Mon Sep 17 00:00:00 2001 From: Shahar Papini Date: Thu, 5 Sep 2024 07:49:56 +0300 Subject: [PATCH] Parallel fft --- .../prover/src/core/backend/simd/blake2s.rs | 7 ++-- .../prover/src/core/backend/simd/fft/ifft.rs | 21 ++++++++--- .../prover/src/core/backend/simd/fft/mod.rs | 35 +++++++++++++++++-- .../prover/src/core/backend/simd/fft/rfft.rs | 26 ++++++++++---- crates/prover/src/core/mod.rs | 13 +++++++ 5 files changed, 84 insertions(+), 18 deletions(-) diff --git a/crates/prover/src/core/backend/simd/blake2s.rs b/crates/prover/src/core/backend/simd/blake2s.rs index fbcfe89e2..06def9ebe 100644 --- a/crates/prover/src/core/backend/simd/blake2s.rs +++ b/crates/prover/src/core/backend/simd/blake2s.rs @@ -18,6 +18,7 @@ use crate::core::fields::m31::BaseField; use crate::core::vcs::blake2_hash::Blake2sHash; use crate::core::vcs::blake2_merkle::Blake2sMerkleHasher; use crate::core::vcs::ops::{MerkleHasher, MerkleOps}; +use crate::parallel_iter; const IV: [u32; 8] = [ 0x6A09E667, 0xBB67AE85, 0x3C6EF372, 0xA54FF53A, 0x510E527F, 0x9B05688C, 0x1F83D9AB, 0x5BE0CD19, @@ -51,11 +52,7 @@ impl MerkleOps for SimdBackend { columns: &[&Col], ) -> Vec { if log_size < LOG_N_LANES { - #[cfg(not(feature = "parallel"))] - let iter = 0..1 << log_size; - - #[cfg(feature = "parallel")] - let iter = (0..1 << log_size).into_par_iter(); + let iter = parallel_iter!(0..1 << log_size); return iter .map(|i| { diff --git a/crates/prover/src/core/backend/simd/fft/ifft.rs b/crates/prover/src/core/backend/simd/fft/ifft.rs index eb34da490..526f2f481 100644 --- a/crates/prover/src/core/backend/simd/fft/ifft.rs +++ b/crates/prover/src/core/backend/simd/fft/ifft.rs @@ -3,14 +3,18 @@ use std::simd::{simd_swizzle, u32x16, u32x2, u32x4}; use itertools::Itertools; +#[cfg(feature = "parallel")] +use rayon::prelude::*; use super::{ compute_first_twiddles, mul_twiddle, transpose_vecs, CACHED_FFT_LOG_SIZE, MIN_FFT_LOG_SIZE, }; +use crate::core::backend::simd::fft::UnsafeMutI32; use crate::core::backend::simd::m31::{PackedBaseField, LOG_N_LANES}; use crate::core::circle::Coset; use crate::core::fields::FieldExpOps; use crate::core::utils::bit_reverse; +use crate::parallel_iter; /// Performs an Inverse Circle Fast Fourier Transform (ICFFT) on the given values. /// @@ -29,6 +33,7 @@ use crate::core::utils::bit_reverse; /// Behavior is undefined if `values` does not have the same alignment as [`PackedBaseField`]. pub unsafe fn ifft(values: *mut u32, twiddle_dbl: &[&[u32]], log_n_elements: usize) { assert!(log_n_elements >= MIN_FFT_LOG_SIZE as usize); + let log_n_vecs = log_n_elements - LOG_N_LANES as usize; if log_n_elements <= CACHED_FFT_LOG_SIZE as usize { ifft_lower_with_vecwise(values, twiddle_dbl, log_n_elements, log_n_elements); @@ -81,7 +86,11 @@ pub unsafe fn ifft_lower_with_vecwise( assert_eq!(twiddle_dbl[0].len(), 1 << (log_size - 2)); - for index_h in 0..1 << (log_size - fft_layers) { + let iter = parallel_iter!(0..1 << (log_size - fft_layers)); + + let values = UnsafeMutI32(values); + iter.for_each(|index_h| { + let values = values.get(); ifft_vecwise_loop(values, twiddle_dbl, fft_layers - VECWISE_FFT_BITS, index_h); for layer in (VECWISE_FFT_BITS..fft_layers).step_by(3) { match fft_layers - layer { @@ -102,7 +111,7 @@ pub unsafe fn ifft_lower_with_vecwise( } } } - } + }); } /// Computes partial ifft on `2^log_size` M31 elements, skipping the vecwise layers (lower 4 bits of @@ -131,7 +140,11 @@ pub unsafe fn ifft_lower_without_vecwise( ) { assert!(log_size >= LOG_N_LANES as usize); - for index_h in 0..1 << (log_size - fft_layers - LOG_N_LANES as usize) { + let iter = parallel_iter!(0..1 << (log_size - fft_layers - LOG_N_LANES as usize)); + + let values = UnsafeMutI32(values); + iter.for_each(|index_h| { + let values = values.get(); for layer in (0..fft_layers).step_by(3) { let fixed_layer = layer + LOG_N_LANES as usize; match fft_layers - layer { @@ -152,7 +165,7 @@ pub unsafe fn ifft_lower_without_vecwise( } } } - } + }); } /// Runs the first 5 ifft layers across the entire array. diff --git a/crates/prover/src/core/backend/simd/fft/mod.rs b/crates/prover/src/core/backend/simd/fft/mod.rs index ca44979e8..e06517496 100644 --- a/crates/prover/src/core/backend/simd/fft/mod.rs +++ b/crates/prover/src/core/backend/simd/fft/mod.rs @@ -1,7 +1,11 @@ use std::simd::{simd_swizzle, u32x16, u32x8}; +#[cfg(feature = "parallel")] +use rayon::prelude::*; + use super::m31::PackedBaseField; use crate::core::fields::m31::P; +use crate::parallel_iter; pub mod ifft; pub mod rfft; @@ -10,6 +14,26 @@ pub const CACHED_FFT_LOG_SIZE: u32 = 16; pub const MIN_FFT_LOG_SIZE: u32 = 5; +pub struct UnsafeMutI32(pub *mut u32); +impl UnsafeMutI32 { + pub fn get(&self) -> *mut u32 { + self.0 + } +} + +unsafe impl Send for UnsafeMutI32 {} +unsafe impl Sync for UnsafeMutI32 {} + +pub struct UnsafeConstI32(pub *const u32); +impl UnsafeConstI32 { + pub fn get(&self) -> *const u32 { + self.0 + } +} + +unsafe impl Send for UnsafeConstI32 {} +unsafe impl Sync for UnsafeConstI32 {} + // TODO(spapini): FFTs return a redundant representation, that can get the value P. need to reduce // it somewhere. @@ -29,8 +53,13 @@ pub const MIN_FFT_LOG_SIZE: u32 = 5; /// Behavior is undefined if `values` does not have the same alignment as [`u32x16`]. pub unsafe fn transpose_vecs(values: *mut u32, log_n_vecs: usize) { let half = log_n_vecs / 2; - for b in 0..1 << (log_n_vecs & 1) { - for a in 0..1 << half { + + let iter = parallel_iter!(0..1 << half); + + let values = UnsafeMutI32(values); + iter.for_each(|a| { + let values = values.get(); + for b in 0..1 << (log_n_vecs & 1) { for c in 0..1 << half { let i = (a << (log_n_vecs - half)) | (b << half) | c; let j = (c << (log_n_vecs - half)) | (b << half) | a; @@ -43,7 +72,7 @@ pub unsafe fn transpose_vecs(values: *mut u32, log_n_vecs: usize) { store(values.add(j << 4), val0); } } - } + }); } /// Computes the twiddles for the first fft layer from the second, and loads both to SIMD registers. diff --git a/crates/prover/src/core/backend/simd/fft/rfft.rs b/crates/prover/src/core/backend/simd/fft/rfft.rs index 6d51fd09d..7aea28ac6 100644 --- a/crates/prover/src/core/backend/simd/fft/rfft.rs +++ b/crates/prover/src/core/backend/simd/fft/rfft.rs @@ -4,13 +4,17 @@ use std::array; use std::simd::{simd_swizzle, u32x16, u32x2, u32x4, u32x8}; use itertools::Itertools; +#[cfg(feature = "parallel")] +use rayon::prelude::*; use super::{ compute_first_twiddles, mul_twiddle, transpose_vecs, CACHED_FFT_LOG_SIZE, MIN_FFT_LOG_SIZE, }; +use crate::core::backend::simd::fft::{UnsafeConstI32, UnsafeMutI32}; use crate::core::backend::simd::m31::{PackedBaseField, LOG_N_LANES}; use crate::core::circle::Coset; use crate::core::utils::bit_reverse; +use crate::parallel_iter; /// Performs a Circle Fast Fourier Transform (CFFT) on the given values. /// @@ -86,8 +90,13 @@ pub unsafe fn fft_lower_with_vecwise( assert_eq!(twiddle_dbl[0].len(), 1 << (log_size - 2)); - for index_h in 0..1 << (log_size - fft_layers) { - let mut src = src; + let iter = parallel_iter!(0..1 << (log_size - fft_layers)); + + let src = UnsafeConstI32(src); + let dst = UnsafeMutI32(dst); + iter.for_each(|index_h| { + let mut src = src.get(); + let dst = dst.get(); for layer in (VECWISE_FFT_BITS..fft_layers).step_by(3).rev() { match fft_layers - layer { 1 => { @@ -116,7 +125,7 @@ pub unsafe fn fft_lower_with_vecwise( fft_layers - VECWISE_FFT_BITS, index_h, ); - } + }); } /// Computes partial fft on `2^log_size` M31 elements, skipping the vecwise layers (lower 4 bits of @@ -147,8 +156,13 @@ pub unsafe fn fft_lower_without_vecwise( ) { assert!(log_size >= LOG_N_LANES as usize); - for index_h in 0..1 << (log_size - fft_layers - LOG_N_LANES as usize) { - let mut src = src; + let iter = parallel_iter!(0..1 << (log_size - fft_layers - LOG_N_LANES as usize)); + + let src = UnsafeConstI32(src); + let dst = UnsafeMutI32(dst); + iter.for_each(|index_h| { + let mut src = src.get(); + let dst = dst.get(); for layer in (0..fft_layers).step_by(3).rev() { let fixed_layer = layer + LOG_N_LANES as usize; match fft_layers - layer { @@ -171,7 +185,7 @@ pub unsafe fn fft_lower_without_vecwise( } src = dst; } - } + }); } /// Runs the last 5 fft layers across the entire array. diff --git a/crates/prover/src/core/mod.rs b/crates/prover/src/core/mod.rs index a00aad687..9ba0ea2f9 100644 --- a/crates/prover/src/core/mod.rs +++ b/crates/prover/src/core/mod.rs @@ -57,3 +57,16 @@ impl DerefMut for ComponentVec { &mut self.0 } } + +#[macro_export] +macro_rules! parallel_iter { + ($i: expr) => {{ + #[cfg(not(feature = "parallel"))] + let iter = $i; + + #[cfg(feature = "parallel")] + let iter = $i.into_par_iter(); + + iter + }}; +}