Skip to content

Commit

Permalink
Simd twiddles
Browse files Browse the repository at this point in the history
  • Loading branch information
spapinistarkware authored and alonh5 committed Sep 24, 2024
1 parent 39763f5 commit 80f3bbb
Show file tree
Hide file tree
Showing 3 changed files with 127 additions and 36 deletions.
38 changes: 22 additions & 16 deletions crates/prover/src/core/backend/cpu/circle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,29 +148,16 @@ impl PolyOps for CpuBackend {
CircleEvaluation::new(domain, values)
}

fn precompute_twiddles(mut coset: Coset) -> TwiddleTree<Self> {
fn precompute_twiddles(coset: Coset) -> TwiddleTree<Self> {
const CHUNK_LOG_SIZE: usize = 12;
const CHUNK_SIZE: usize = 1 << CHUNK_LOG_SIZE;

let root_coset = coset;
let mut twiddles = Vec::with_capacity(coset.size());
for _ in 0..coset.log_size() {
let i0 = twiddles.len();
twiddles.extend(
coset
.iter()
.take(coset.size() / 2)
.map(|p| p.x)
.collect::<Vec<_>>(),
);
bit_reverse(&mut twiddles[i0..]);
coset = coset.double();
}
twiddles.push(1.into());
let twiddles = slow_precompute_twiddles(coset);

// Inverse twiddles.
// Fallback to the non-chunked version if the domain is not big enough.
if CHUNK_SIZE > coset.size() {
if CHUNK_SIZE > root_coset.size() {
let itwiddles = twiddles.iter().map(|&t| t.inverse()).collect();
return TwiddleTree {
root_coset,
Expand All @@ -195,6 +182,25 @@ impl PolyOps for CpuBackend {
}
}

pub fn slow_precompute_twiddles(mut coset: Coset) -> Vec<BaseField> {
let mut twiddles = Vec::with_capacity(coset.size());
for _ in 0..coset.log_size() {
let i0 = twiddles.len();
twiddles.extend(
coset
.iter()
.take(coset.size() / 2)
.map(|p| p.x)
.collect::<Vec<_>>(),
);
bit_reverse(&mut twiddles[i0..]);
coset = coset.double();
}
// Pad with an arbitrary value to make the length a power of 2.
twiddles.push(1.into());
twiddles
}

fn fft_layer_loop(
values: &mut [BaseField],
i: usize,
Expand Down
2 changes: 1 addition & 1 deletion crates/prover/src/core/backend/cpu/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
mod accumulation;
mod blake2s;
mod circle;
pub mod circle;
mod fri;
mod grind;
pub mod lookups;
Expand Down
123 changes: 104 additions & 19 deletions crates/prover/src/core/backend/simd/circle.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::iter::zip;
use std::mem::transmute;
use std::simd::Simd;

use bytemuck::Zeroable;
use num_traits::One;
Expand All @@ -8,9 +9,11 @@ use super::fft::{ifft, rfft, CACHED_FFT_LOG_SIZE, MIN_FFT_LOG_SIZE};
use super::m31::{PackedBaseField, LOG_N_LANES, N_LANES};
use super::qm31::PackedSecureField;
use super::SimdBackend;
use crate::core::backend::cpu::circle::slow_precompute_twiddles;
use crate::core::backend::simd::column::BaseColumn;
use crate::core::backend::simd::m31::PackedM31;
use crate::core::backend::{Col, Column, CpuBackend};
use crate::core::circle::{CirclePoint, Coset};
use crate::core::circle::{CirclePoint, Coset, M31_CIRCLE_LOG_ORDER};
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::{Field, FieldExpOps};
Expand All @@ -20,6 +23,7 @@ use crate::core::poly::circle::{
use crate::core::poly::twiddles::TwiddleTree;
use crate::core::poly::utils::{domain_line_twiddles_from_tree, fold};
use crate::core::poly::BitReversedOrder;
use crate::core::utils::bit_reverse_index;

impl SimdBackend {
// TODO(Ohad): optimize.
Expand Down Expand Up @@ -287,32 +291,96 @@ impl PolyOps for SimdBackend {
)
}

fn precompute_twiddles(coset: Coset) -> TwiddleTree<Self> {
let mut twiddles = Vec::with_capacity(coset.size());
let mut itwiddles = Vec::with_capacity(coset.size());
/// Precomputes the (doubled) twiddles for a given coset tower.
/// The twiddles are the x values of each coset in bit-reversed order.
/// Note: the coset point are symmetrical over the x-axis so only the first half of the coset is
/// needed.
fn precompute_twiddles(mut coset: Coset) -> TwiddleTree<Self> {
let root_coset = coset;

// TODO(alont): Optimize.
for layer in &rfft::get_twiddle_dbls(coset) {
twiddles.extend(layer);
if root_coset.size() < N_LANES {
return compute_small_coset_twiddles(root_coset);
}
// Pad by any value, to make the size a power of 2.
twiddles.push(1);
assert_eq!(twiddles.len(), coset.size());
for layer in &ifft::get_itwiddle_dbls(coset) {
itwiddles.extend(layer);

let mut twiddles = Vec::with_capacity(coset.size() / N_LANES);
while coset.log_size() > LOG_N_LANES {
compute_coset_twiddles(coset, &mut twiddles);
coset = coset.double();
}
// Pad by any value, to make the size a power of 2.
itwiddles.push(1);
assert_eq!(itwiddles.len(), coset.size());

// Handle cosets smaller than `N_LANES`.
let remaining_twiddles = slow_precompute_twiddles(coset);

twiddles.push(PackedM31::from_array(
remaining_twiddles.try_into().unwrap(),
));

let mut itwiddles = unsafe { BaseColumn::uninitialized(root_coset.size()) }.data;
PackedBaseField::batch_inverse(&twiddles, &mut itwiddles);

let dbl_twiddles = twiddles
.into_iter()
.flat_map(|x| (x.into_simd() * Simd::splat(2)).to_array())
.collect();
let dbl_itwiddles = itwiddles
.into_iter()
.flat_map(|x| (x.into_simd() * Simd::splat(2)).to_array())
.collect();

TwiddleTree {
root_coset: coset,
twiddles,
itwiddles,
root_coset,
twiddles: dbl_twiddles,
itwiddles: dbl_itwiddles,
}
}
}

fn compute_small_coset_twiddles(coset: Coset) -> TwiddleTree<SimdBackend> {
let twiddles = slow_precompute_twiddles(coset);

let dbl_twiddles = twiddles.iter().map(|x| x.0 * 2).collect();
let dbl_itwiddles = twiddles.iter().map(|x| x.inverse().0 * 2).collect();
TwiddleTree {
root_coset: coset,
twiddles: dbl_twiddles,
itwiddles: dbl_itwiddles,
}
}

/// Computes the twiddles of the coset in bit-reversed order. Optimized for SIMD.
fn compute_coset_twiddles(coset: Coset, twiddles: &mut Vec<PackedM31>) {
let log_size = coset.log_size() - 1;
assert!(log_size >= LOG_N_LANES);

// Compute the first `N_LANES` circle points.
let initial_points = std::array::from_fn(|i| coset.at(bit_reverse_index(i, log_size)));
let mut current = CirclePoint {
x: PackedM31::from_array(initial_points.each_ref().map(|p| p.x)),
y: PackedM31::from_array(initial_points.each_ref().map(|p| p.y)),
};

// Precompute the steps needed to compute the next circle points in bit reversed order.
let mut steps = [CirclePoint::zero(); (M31_CIRCLE_LOG_ORDER - LOG_N_LANES) as usize];
for i in 0..(log_size - LOG_N_LANES) {
let prev_mul = bit_reverse_index((1 << i) - 1, log_size - LOG_N_LANES);
let new_mul = bit_reverse_index(1 << i, log_size - LOG_N_LANES);
let step = coset.step.mul(new_mul as u128) - coset.step.mul(prev_mul as u128);
steps[i as usize] = step;
}

for i in 0u32..1 << (log_size - LOG_N_LANES) {
// Extract twiddle and compute the next `N_LANES` circle points.
let x = current.x;
let step_index = i.trailing_ones() as usize;
let step = CirclePoint {
x: PackedM31::broadcast(steps[step_index].x),
y: PackedM31::broadcast(steps[step_index].y),
};
current = current + step;
twiddles.push(x);
}
}

fn slow_eval_at_point(
poly: &CirclePoly<SimdBackend>,
point: CirclePoint<SecureField>,
Expand Down Expand Up @@ -340,13 +408,14 @@ fn slow_eval_at_point(

#[cfg(test)]
mod tests {
use itertools::Itertools;
use rand::rngs::SmallRng;
use rand::{Rng, SeedableRng};

use crate::core::backend::simd::circle::slow_eval_at_point;
use crate::core::backend::simd::fft::{CACHED_FFT_LOG_SIZE, MIN_FFT_LOG_SIZE};
use crate::core::backend::simd::SimdBackend;
use crate::core::backend::Column;
use crate::core::backend::{Column, CpuBackend};
use crate::core::circle::CirclePoint;
use crate::core::fields::m31::BaseField;
use crate::core::poly::circle::{CanonicCoset, CircleEvaluation, CirclePoly, PolyOps};
Expand Down Expand Up @@ -445,4 +514,20 @@ mod tests {
assert_eq!(eval, slow_eval_at_point(&poly, p), "log_size = {log_size}");
}
}

#[test]
fn test_optimized_precompute_twiddles() {
let coset = CanonicCoset::new(10).half_coset();
let twiddles = SimdBackend::precompute_twiddles(coset);
let expected_twiddles = CpuBackend::precompute_twiddles(coset);

assert_eq!(
twiddles.twiddles,
expected_twiddles
.twiddles
.iter()
.map(|x| x.0 * 2)
.collect_vec()
);
}
}

0 comments on commit 80f3bbb

Please sign in to comment.