Skip to content

Commit

Permalink
Simd twiddles
Browse files Browse the repository at this point in the history
  • Loading branch information
spapinistarkware authored and alonh5 committed Sep 22, 2024
1 parent 2b715dd commit e666dfa
Showing 1 changed file with 84 additions and 18 deletions.
102 changes: 84 additions & 18 deletions crates/prover/src/core/backend/simd/circle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,18 @@ use std::iter::zip;
use std::mem::transmute;

use bytemuck::{cast_slice, Zeroable};
use itertools::Itertools;
use num_traits::One;

use super::fft::{ifft, rfft, CACHED_FFT_LOG_SIZE};
use super::m31::{PackedBaseField, LOG_N_LANES, N_LANES};
use super::qm31::PackedSecureField;
use super::SimdBackend;
use crate::core::backend::simd::column::BaseColumn;
use crate::core::backend::{Col, CpuBackend};
use crate::core::circle::{CirclePoint, Coset};
use crate::core::fields::m31::BaseField;
use crate::core::backend::simd::m31::PackedM31;
use crate::core::backend::{Col, Column, CpuBackend};
use crate::core::circle::{CirclePoint, Coset, M31_CIRCLE_LOG_ORDER};
use crate::core::fields::m31::{BaseField, M31};
use crate::core::fields::qm31::SecureField;
use crate::core::fields::{Field, FieldExpOps};
use crate::core::poly::circle::{
Expand All @@ -20,6 +22,7 @@ use crate::core::poly::circle::{
use crate::core::poly::twiddles::TwiddleTree;
use crate::core::poly::utils::{domain_line_twiddles_from_tree, fold};
use crate::core::poly::BitReversedOrder;
use crate::core::utils::{bit_reverse, bit_reverse_index};

impl SimdBackend {
// TODO(Ohad): optimize.
Expand Down Expand Up @@ -275,32 +278,95 @@ impl PolyOps for SimdBackend {
)
}

fn precompute_twiddles(coset: Coset) -> TwiddleTree<Self> {
let mut twiddles = Vec::with_capacity(coset.size());
let mut itwiddles = Vec::with_capacity(coset.size());
#[allow(clippy::int_plus_one)]
fn precompute_twiddles(mut coset: Coset) -> TwiddleTree<Self> {
let root_coset = coset;

// TODO(spapini): Optimize.
for layer in &rfft::get_twiddle_dbls(coset) {
twiddles.extend(layer);
// Generate xs for descending cosets, each bit reversed.
let mut xs = Vec::with_capacity(coset.size() / N_LANES);
while coset.log_size() - 1 >= LOG_N_LANES {
gen_coset_xs(coset, &mut xs);
coset = coset.double();
}

let mut extra = Vec::with_capacity(N_LANES);
while coset.log_size() > 0 {
let start = extra.len();
extra.extend(
coset
.iter()
.take(coset.size() / 2)
.map(|p| p.x)
.collect_vec(),
);
bit_reverse(&mut extra[start..]);
coset = coset.double();
}
// Pad by any value, to make the size a power of 2.
twiddles.push(1);
assert_eq!(twiddles.len(), coset.size());
for layer in &ifft::get_itwiddle_dbls(coset) {
itwiddles.extend(layer);
extra.push(M31::one());

if extra.len() < N_LANES {
let twiddles = extra.iter().map(|x| x.0 * 2).collect();
let itwiddles = extra.iter().map(|x| x.inverse().0 * 2).collect();
return TwiddleTree {
root_coset,
twiddles,
itwiddles,
};
}
// Pad by any value, to make the size a power of 2.
itwiddles.push(1);
assert_eq!(itwiddles.len(), coset.size());

xs.push(PackedM31::from_array(extra.try_into().unwrap()));

let mut ixs = unsafe { BaseColumn::uninitialized(root_coset.size()) }.data;
PackedBaseField::batch_inverse(&xs, &mut ixs);

let twiddles = xs
.into_iter()
.flat_map(|x| x.to_array().map(|x| x.0 * 2))
.collect();
let itwiddles = ixs
.into_iter()
.flat_map(|x| x.to_array().map(|x| x.0 * 2))
.collect();

TwiddleTree {
root_coset: coset,
root_coset,
twiddles,
itwiddles,
}
}
}

#[allow(clippy::int_plus_one)]
fn gen_coset_xs(coset: Coset, res: &mut Vec<PackedM31>) {
let log_size = coset.log_size() - 1;
assert!(log_size >= LOG_N_LANES);

let initial_points = std::array::from_fn(|i| coset.at(bit_reverse_index(i, log_size)));
let mut current = CirclePoint {
x: PackedM31::from_array(initial_points.each_ref().map(|p| p.x)),
y: PackedM31::from_array(initial_points.each_ref().map(|p| p.y)),
};

let mut flips = [CirclePoint::zero(); (M31_CIRCLE_LOG_ORDER - LOG_N_LANES) as usize];
for i in 0..(log_size - LOG_N_LANES) {
let prev_mul = bit_reverse_index((1 << i) - 1, log_size - LOG_N_LANES);
let new_mul = bit_reverse_index(1 << i, log_size - LOG_N_LANES);
let flip = coset.step.mul(new_mul as u128) - coset.step.mul(prev_mul as u128);
flips[i as usize] = flip;
}

for i in 0u32..1 << (log_size - LOG_N_LANES) {
let x = current.x;
let flip_index = i.trailing_ones() as usize;
let flip = CirclePoint {
x: PackedM31::broadcast(flips[flip_index].x),
y: PackedM31::broadcast(flips[flip_index].y),
};
current = current + flip;
res.push(x);
}
}

fn slow_eval_at_point(
poly: &CirclePoly<SimdBackend>,
point: CirclePoint<SecureField>,
Expand Down

0 comments on commit e666dfa

Please sign in to comment.