diff --git a/crates/prover/src/core/backend/cpu/circle.rs b/crates/prover/src/core/backend/cpu/circle.rs index c37ffe248..6e94078cd 100644 --- a/crates/prover/src/core/backend/cpu/circle.rs +++ b/crates/prover/src/core/backend/cpu/circle.rs @@ -148,29 +148,16 @@ impl PolyOps for CpuBackend { CircleEvaluation::new(domain, values) } - fn precompute_twiddles(mut coset: Coset) -> TwiddleTree { + fn precompute_twiddles(coset: Coset) -> TwiddleTree { const CHUNK_LOG_SIZE: usize = 12; const CHUNK_SIZE: usize = 1 << CHUNK_LOG_SIZE; let root_coset = coset; - let mut twiddles = Vec::with_capacity(coset.size()); - for _ in 0..coset.log_size() { - let i0 = twiddles.len(); - twiddles.extend( - coset - .iter() - .take(coset.size() / 2) - .map(|p| p.x) - .collect::>(), - ); - bit_reverse(&mut twiddles[i0..]); - coset = coset.double(); - } - twiddles.push(1.into()); + let twiddles = slow_precompute_twiddles(coset); // Inverse twiddles. // Fallback to the non-chunked version if the domain is not big enough. - if CHUNK_SIZE > coset.size() { + if CHUNK_SIZE > root_coset.size() { let itwiddles = twiddles.iter().map(|&t| t.inverse()).collect(); return TwiddleTree { root_coset, @@ -195,6 +182,25 @@ impl PolyOps for CpuBackend { } } +pub fn slow_precompute_twiddles(mut coset: Coset) -> Vec { + let mut twiddles = Vec::with_capacity(coset.size()); + for _ in 0..coset.log_size() { + let i0 = twiddles.len(); + twiddles.extend( + coset + .iter() + .take(coset.size() / 2) + .map(|p| p.x) + .collect::>(), + ); + bit_reverse(&mut twiddles[i0..]); + coset = coset.double(); + } + // Pad with an arbitrary value to make the length a power of 2. + twiddles.push(1.into()); + twiddles +} + fn fft_layer_loop( values: &mut [BaseField], i: usize, diff --git a/crates/prover/src/core/backend/cpu/mod.rs b/crates/prover/src/core/backend/cpu/mod.rs index 4fe51e34c..cfa514e4c 100644 --- a/crates/prover/src/core/backend/cpu/mod.rs +++ b/crates/prover/src/core/backend/cpu/mod.rs @@ -1,6 +1,6 @@ mod accumulation; mod blake2s; -mod circle; +pub mod circle; mod fri; mod grind; pub mod lookups; diff --git a/crates/prover/src/core/backend/simd/circle.rs b/crates/prover/src/core/backend/simd/circle.rs index 9118f1769..1adb5e18e 100644 --- a/crates/prover/src/core/backend/simd/circle.rs +++ b/crates/prover/src/core/backend/simd/circle.rs @@ -1,5 +1,6 @@ use std::iter::zip; use std::mem::transmute; +use std::simd::Simd; use bytemuck::Zeroable; use num_traits::One; @@ -8,9 +9,11 @@ use super::fft::{ifft, rfft, CACHED_FFT_LOG_SIZE, MIN_FFT_LOG_SIZE}; use super::m31::{PackedBaseField, LOG_N_LANES, N_LANES}; use super::qm31::PackedSecureField; use super::SimdBackend; +use crate::core::backend::cpu::circle::slow_precompute_twiddles; use crate::core::backend::simd::column::BaseColumn; +use crate::core::backend::simd::m31::PackedM31; use crate::core::backend::{Col, Column, CpuBackend}; -use crate::core::circle::{CirclePoint, Coset}; +use crate::core::circle::{CirclePoint, Coset, M31_CIRCLE_LOG_ORDER}; use crate::core::fields::m31::BaseField; use crate::core::fields::qm31::SecureField; use crate::core::fields::{Field, FieldExpOps}; @@ -20,6 +23,7 @@ use crate::core::poly::circle::{ use crate::core::poly::twiddles::TwiddleTree; use crate::core::poly::utils::{domain_line_twiddles_from_tree, fold}; use crate::core::poly::BitReversedOrder; +use crate::core::utils::bit_reverse_index; impl SimdBackend { // TODO(Ohad): optimize. @@ -287,32 +291,96 @@ impl PolyOps for SimdBackend { ) } - fn precompute_twiddles(coset: Coset) -> TwiddleTree { - let mut twiddles = Vec::with_capacity(coset.size()); - let mut itwiddles = Vec::with_capacity(coset.size()); + /// Precomputes the (doubled) twiddles for a given coset tower. + /// The twiddles are the x values of each coset in bit-reversed order. + /// Note: the coset point are symmetrical over the x-axis so only the first half of the coset is + /// needed. + fn precompute_twiddles(mut coset: Coset) -> TwiddleTree { + let root_coset = coset; - // TODO(alont): Optimize. - for layer in &rfft::get_twiddle_dbls(coset) { - twiddles.extend(layer); + if root_coset.size() < N_LANES { + return compute_small_coset_twiddles(root_coset); } - // Pad by any value, to make the size a power of 2. - twiddles.push(1); - assert_eq!(twiddles.len(), coset.size()); - for layer in &ifft::get_itwiddle_dbls(coset) { - itwiddles.extend(layer); + + let mut twiddles = Vec::with_capacity(coset.size() / N_LANES); + while coset.log_size() > LOG_N_LANES { + compute_coset_twiddles(coset, &mut twiddles); + coset = coset.double(); } - // Pad by any value, to make the size a power of 2. - itwiddles.push(1); - assert_eq!(itwiddles.len(), coset.size()); + + // Handle cosets smaller than `N_LANES`. + let remaining_twiddles = slow_precompute_twiddles(coset); + + twiddles.push(PackedM31::from_array( + remaining_twiddles.try_into().unwrap(), + )); + + let mut itwiddles = unsafe { BaseColumn::uninitialized(root_coset.size()) }.data; + PackedBaseField::batch_inverse(&twiddles, &mut itwiddles); + + let dbl_twiddles = twiddles + .into_iter() + .flat_map(|x| (x.into_simd() * Simd::splat(2)).to_array()) + .collect(); + let dbl_itwiddles = itwiddles + .into_iter() + .flat_map(|x| (x.into_simd() * Simd::splat(2)).to_array()) + .collect(); TwiddleTree { - root_coset: coset, - twiddles, - itwiddles, + root_coset, + twiddles: dbl_twiddles, + itwiddles: dbl_itwiddles, } } } +fn compute_small_coset_twiddles(coset: Coset) -> TwiddleTree { + let twiddles = slow_precompute_twiddles(coset); + + let dbl_twiddles = twiddles.iter().map(|x| x.0 * 2).collect(); + let dbl_itwiddles = twiddles.iter().map(|x| x.inverse().0 * 2).collect(); + TwiddleTree { + root_coset: coset, + twiddles: dbl_twiddles, + itwiddles: dbl_itwiddles, + } +} + +/// Computes the twiddles of the coset in bit-reversed order. Optimized for SIMD. +fn compute_coset_twiddles(coset: Coset, twiddles: &mut Vec) { + let log_size = coset.log_size() - 1; + assert!(log_size >= LOG_N_LANES); + + // Compute the first `N_LANES` circle points. + let initial_points = std::array::from_fn(|i| coset.at(bit_reverse_index(i, log_size))); + let mut current = CirclePoint { + x: PackedM31::from_array(initial_points.each_ref().map(|p| p.x)), + y: PackedM31::from_array(initial_points.each_ref().map(|p| p.y)), + }; + + // Precompute the steps needed to compute the next circle points in bit reversed order. + let mut steps = [CirclePoint::zero(); (M31_CIRCLE_LOG_ORDER - LOG_N_LANES) as usize]; + for i in 0..(log_size - LOG_N_LANES) { + let prev_mul = bit_reverse_index((1 << i) - 1, log_size - LOG_N_LANES); + let new_mul = bit_reverse_index(1 << i, log_size - LOG_N_LANES); + let step = coset.step.mul(new_mul as u128) - coset.step.mul(prev_mul as u128); + steps[i as usize] = step; + } + + for i in 0u32..1 << (log_size - LOG_N_LANES) { + // Extract twiddle and compute the next `N_LANES` circle points. + let x = current.x; + let step_index = i.trailing_ones() as usize; + let step = CirclePoint { + x: PackedM31::broadcast(steps[step_index].x), + y: PackedM31::broadcast(steps[step_index].y), + }; + current = current + step; + twiddles.push(x); + } +} + fn slow_eval_at_point( poly: &CirclePoly, point: CirclePoint, @@ -340,13 +408,14 @@ fn slow_eval_at_point( #[cfg(test)] mod tests { + use itertools::Itertools; use rand::rngs::SmallRng; use rand::{Rng, SeedableRng}; use crate::core::backend::simd::circle::slow_eval_at_point; use crate::core::backend::simd::fft::{CACHED_FFT_LOG_SIZE, MIN_FFT_LOG_SIZE}; use crate::core::backend::simd::SimdBackend; - use crate::core::backend::Column; + use crate::core::backend::{Column, CpuBackend}; use crate::core::circle::CirclePoint; use crate::core::fields::m31::BaseField; use crate::core::poly::circle::{CanonicCoset, CircleEvaluation, CirclePoly, PolyOps}; @@ -445,4 +514,20 @@ mod tests { assert_eq!(eval, slow_eval_at_point(&poly, p), "log_size = {log_size}"); } } + + #[test] + fn test_optimized_precompute_twiddles() { + let coset = CanonicCoset::new(10).half_coset(); + let twiddles = SimdBackend::precompute_twiddles(coset); + let expected_twiddles = CpuBackend::precompute_twiddles(coset); + + assert_eq!( + twiddles.twiddles, + expected_twiddles + .twiddles + .iter() + .map(|x| x.0 * 2) + .collect_vec() + ); + } }