Skip to content

Commit

Permalink
Fall back to CPU in small FRI.
Browse files Browse the repository at this point in the history
  • Loading branch information
alonh5 committed Sep 24, 2024
1 parent e1acffa commit d14ba6b
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 9 deletions.
1 change: 1 addition & 0 deletions crates/prover/src/core/backend/cpu/fri.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ impl FriOps for CpuBackend {
) -> LineEvaluation<Self> {
fold_line(eval, alpha)
}

fn fold_circle_into_line(
dst: &mut LineEvaluation<Self>,
src: &SecureEvaluation<Self, BitReversedOrder>,
Expand Down
15 changes: 12 additions & 3 deletions crates/prover/src/core/backend/simd/fri.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use crate::core::backend::Column;
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::secure_column::SecureColumnByCoords;
use crate::core::fri::{self, FriOps};
use crate::core::fri::{self, fold_circle_into_line, FriOps};
use crate::core::poly::circle::SecureEvaluation;
use crate::core::poly::line::LineEvaluation;
use crate::core::poly::twiddles::TwiddleTree;
Expand Down Expand Up @@ -63,7 +63,16 @@ impl FriOps for SimdBackend {
twiddles: &TwiddleTree<Self>,
) {
let log_size = src.len().ilog2();
assert!(log_size > LOG_N_LANES, "Evaluation too small");
if log_size <= LOG_N_LANES {
// Fall back to CPU implementation.
let mut cpu_dst = dst.to_cpu();
fold_circle_into_line(&mut cpu_dst, &src.to_cpu(), alpha);
*dst = LineEvaluation::new(
cpu_dst.domain(),
SecureColumnByCoords::from_cpu(cpu_dst.values),
);
return;
}

let domain = src.domain;
let alpha_sq = alpha * alpha;
Expand Down Expand Up @@ -249,7 +258,7 @@ mod tests {
};
let avx_eval = SecureEvaluation::new(domain, avx_column.clone());
let cpu_eval =
SecureEvaluation::<CpuBackend, BitReversedOrder>::new(domain, avx_eval.to_cpu());
SecureEvaluation::<CpuBackend, BitReversedOrder>::new(domain, avx_eval.values.to_cpu());
let (cpu_g, cpu_lambda) = CpuBackend::decompose(&cpu_eval);
let (avx_g, avx_lambda) = SimdBackend::decompose(&avx_eval);

Expand Down
8 changes: 8 additions & 0 deletions crates/prover/src/core/poly/circle/secure_poly.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,14 @@ impl<B: FieldOps<BaseField>, EvalOrder> SecureEvaluation<B, EvalOrder> {
let Self { domain, values, .. } = self;
values.columns.map(|c| CircleEvaluation::new(domain, c))
}

pub fn to_cpu(&self) -> SecureEvaluation<CpuBackend, EvalOrder> {
SecureEvaluation {
domain: self.domain,
values: self.values.to_cpu(),
_eval_order: PhantomData,
}
}
}

impl<B: FieldOps<BaseField>, EvalOrder> Deref for SecureEvaluation<B, EvalOrder> {
Expand Down
28 changes: 23 additions & 5 deletions crates/prover/src/examples/wide_fibonacci/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use itertools::Itertools;

use crate::constraint_framework::{EvalAtRow, FrameworkComponent, FrameworkEval};
use crate::core::backend::simd::m31::{PackedBaseField, LOG_N_LANES};
use crate::core::backend::simd::m31::PackedBaseField;
use crate::core::backend::simd::SimdBackend;
use crate::core::backend::{Col, Column};
use crate::core::fields::m31::BaseField;
Expand Down Expand Up @@ -47,8 +47,6 @@ pub fn generate_trace<const N: usize>(
log_size: u32,
inputs: &[FibInput],
) -> ColumnVec<CircleEvaluation<SimdBackend, BaseField, BitReversedOrder>> {
assert!(log_size >= LOG_N_LANES);
assert_eq!(inputs.len(), 1 << (log_size - LOG_N_LANES));
let mut trace = (0..N)
.map(|_| Col::<SimdBackend, BaseField>::zeros(1 << log_size))
.collect_vec();
Expand All @@ -72,7 +70,7 @@ pub fn generate_trace<const N: usize>(
#[cfg(test)]
mod tests {
use itertools::Itertools;
use num_traits::One;
use num_traits::{One, Zero};

use super::WideFibonacciEval;
use crate::constraint_framework::{
Expand Down Expand Up @@ -101,6 +99,26 @@ mod tests {
fn generate_test_trace(
log_n_instances: u32,
) -> ColumnVec<CircleEvaluation<SimdBackend, BaseField, BitReversedOrder>> {
if log_n_instances < LOG_N_LANES {
let n_instances = 1 << log_n_instances;
let inputs = vec![FibInput {
a: PackedBaseField::from_array(std::array::from_fn(|j| {
if j < n_instances {
BaseField::one()
} else {
BaseField::zero()
}
})),
b: PackedBaseField::from_array(std::array::from_fn(|j| {
if j < n_instances {
BaseField::from_u32_unchecked((j) as u32)
} else {
BaseField::zero()
}
})),
}];
return generate_trace::<FIB_SEQUENCE_LENGTH>(log_n_instances, &inputs);
}
let inputs = (0..(1 << (log_n_instances - LOG_N_LANES)))
.map(|i| FibInput {
a: PackedBaseField::one(),
Expand Down Expand Up @@ -151,7 +169,7 @@ mod tests {

#[test_log::test]
fn test_wide_fib_prove_with_blake() {
for log_n_instances in 4..=6 {
for log_n_instances in 2..=6 {
let config = PcsConfig::default();
// Precompute twiddles.
let twiddles = SimdBackend::precompute_twiddles(
Expand Down
1 change: 0 additions & 1 deletion crates/prover/src/examples/xor/gkr_lookups/mle_eval.rs
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,6 @@ mod tests {
}

#[test]
#[ignore = "SimdBackend `MIN_FFT_LOG_SIZE` is 5"]
fn eq_constraints_with_4_variables() {
const N_VARIABLES: usize = 4;
const EQ_EVAL_TRACE: usize = 0;
Expand Down

0 comments on commit d14ba6b

Please sign in to comment.