Skip to content

Commit

Permalink
Fall back to CPU in small FRI.
Browse files Browse the repository at this point in the history
  • Loading branch information
alonh5 committed Sep 18, 2024
1 parent 1a9755a commit 035ae28
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 9 deletions.
1 change: 1 addition & 0 deletions crates/prover/src/core/backend/cpu/fri.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ impl FriOps for CpuBackend {
) -> LineEvaluation<Self> {
fold_line(eval, alpha)
}

fn fold_circle_into_line(
dst: &mut LineEvaluation<Self>,
src: &SecureEvaluation<Self, BitReversedOrder>,
Expand Down
12 changes: 9 additions & 3 deletions crates/prover/src/core/backend/simd/fri.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use crate::core::backend::Column;
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::secure_column::SecureColumnByCoords;
use crate::core::fri::{self, FriOps};
use crate::core::fri::{self, fold_circle_into_line, FriOps};
use crate::core::poly::circle::SecureEvaluation;
use crate::core::poly::line::LineEvaluation;
use crate::core::poly::twiddles::TwiddleTree;
Expand Down Expand Up @@ -63,7 +63,13 @@ impl FriOps for SimdBackend {
twiddles: &TwiddleTree<Self>,
) {
let log_size = src.len().ilog2();
assert!(log_size > LOG_N_LANES, "Evaluation too small");
if log_size <= LOG_N_LANES {
// Fall back to CPU implementation.
let mut cpu_dst = dst.to_cpu();
fold_circle_into_line(&mut cpu_dst, &src.to_cpu(), alpha);
*dst = LineEvaluation::new(cpu_dst.domain(), cpu_dst.values.to_simd());
return;
}

let domain = src.domain;
let alpha_sq = alpha * alpha;
Expand Down Expand Up @@ -249,7 +255,7 @@ mod tests {
};
let avx_eval = SecureEvaluation::new(domain, avx_column.clone());
let cpu_eval =
SecureEvaluation::<CpuBackend, BitReversedOrder>::new(domain, avx_eval.to_cpu());
SecureEvaluation::<CpuBackend, BitReversedOrder>::new(domain, avx_eval.values.to_cpu());
let (cpu_g, cpu_lambda) = CpuBackend::decompose(&cpu_eval);
let (avx_g, avx_lambda) = SimdBackend::decompose(&avx_eval);

Expand Down
8 changes: 8 additions & 0 deletions crates/prover/src/core/poly/circle/secure_poly.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,14 @@ impl<B: FieldOps<BaseField>, EvalOrder> SecureEvaluation<B, EvalOrder> {
let Self { domain, values, .. } = self;
values.columns.map(|c| CircleEvaluation::new(domain, c))
}

pub fn to_cpu(&self) -> SecureEvaluation<CpuBackend, EvalOrder> {
SecureEvaluation {
domain: self.domain,
values: self.values.to_cpu(),
_eval_order: PhantomData,
}
}
}

impl<B: FieldOps<BaseField>, EvalOrder> Deref for SecureEvaluation<B, EvalOrder> {
Expand Down
31 changes: 25 additions & 6 deletions crates/prover/src/examples/wide_fibonacci/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use itertools::Itertools;

use crate::constraint_framework::{EvalAtRow, FrameworkComponent, FrameworkEval};
use crate::core::backend::simd::m31::{PackedBaseField, LOG_N_LANES};
use crate::core::backend::simd::m31::PackedBaseField;
use crate::core::backend::simd::SimdBackend;
use crate::core::backend::{Col, Column};
use crate::core::fields::m31::BaseField;
Expand Down Expand Up @@ -47,8 +47,6 @@ pub fn generate_trace<const N: usize>(
log_size: u32,
inputs: &[FibInput],
) -> ColumnVec<CircleEvaluation<SimdBackend, BaseField, BitReversedOrder>> {
assert!(log_size >= LOG_N_LANES);
assert_eq!(inputs.len(), 1 << (log_size - LOG_N_LANES));
let mut trace = (0..N)
.map(|_| Col::<SimdBackend, BaseField>::zeros(1 << log_size))
.collect_vec();
Expand All @@ -72,7 +70,7 @@ pub fn generate_trace<const N: usize>(
#[cfg(test)]
mod tests {
use itertools::Itertools;
use num_traits::One;
use num_traits::{One, Zero};
use test_case::test_case;

use super::WideFibonacciEval;
Expand Down Expand Up @@ -102,6 +100,26 @@ mod tests {
fn generate_test_trace(
log_n_instances: u32,
) -> ColumnVec<CircleEvaluation<SimdBackend, BaseField, BitReversedOrder>> {
if log_n_instances < LOG_N_LANES {
let n_instances = 1 << log_n_instances;
let inputs = vec![FibInput {
a: PackedBaseField::from_array(std::array::from_fn(|j| {
if j < n_instances {
BaseField::one()
} else {
BaseField::zero()
}
})),
b: PackedBaseField::from_array(std::array::from_fn(|j| {
if j < n_instances {
BaseField::from_u32_unchecked((j) as u32)
} else {
BaseField::zero()
}
})),
}];
return generate_trace::<FIB_SEQUENCE_LENGTH>(log_n_instances, &inputs);
}
let inputs = (0..(1 << (log_n_instances - LOG_N_LANES)))
.map(|i| FibInput {
a: PackedBaseField::one(),
Expand Down Expand Up @@ -151,7 +169,7 @@ mod tests {
}

#[test_case(6; "SIMD")]
#[test_case(4; "CPU fall back")]
#[test_case(3; "CPU fall back 3")]
#[test_log::test]
fn test_wide_fib_prove_with_blake(log_n_instances: u32) {
let config = PcsConfig::default();
Expand Down Expand Up @@ -201,7 +219,8 @@ mod tests {
}

#[test_case(6; "SIMD")]
#[test_case(4; "CPU fall back")]
#[test_case(3; "CPU fall back 3")]
#[test_case(2; "CPU fall back 2")]
#[cfg(not(target_arch = "wasm32"))]
fn test_wide_fib_prove_with_poseidon(log_n_instances: u32) {
let config = PcsConfig::default();
Expand Down

0 comments on commit 035ae28

Please sign in to comment.