Skip to content

Commit

Permalink
Fall back to CPU in small constraint eval. (#838)
Browse files Browse the repository at this point in the history
  • Loading branch information
alonh5 committed Sep 24, 2024
1 parent 9085df4 commit e1acffa
Show file tree
Hide file tree
Showing 7 changed files with 181 additions and 48 deletions.
2 changes: 1 addition & 1 deletion crates/prover/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ wasm-bindgen-test = "0.3.43"
features = ["html_reports"]
version = "0.5.1"

# Default features cause compile error:
# Default features cause compile error:
# "Rayon cannot be used when targeting wasi32. Try disabling default features."
[target.'cfg(target_arch = "wasm32")'.dev-dependencies.criterion]
default-features = false
Expand Down
32 changes: 31 additions & 1 deletion crates/prover/src/constraint_framework/component.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use std::ops::Deref;
use itertools::Itertools;
use tracing::{span, Level};

use super::cpu_domain::CpuDomainEvaluator;
use super::{EvalAtRow, InfoEvaluator, PointEvaluator, SimdDomainEvaluator};
use crate::core::air::accumulation::{DomainEvaluationAccumulator, PointEvaluationAccumulator};
use crate::core::air::{Component, ComponentProver, Trace};
Expand All @@ -16,6 +17,7 @@ use crate::core::circle::CirclePoint;
use crate::core::constraints::coset_vanishing;
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::secure_column::SecureColumnByCoords;
use crate::core::fields::FieldExpOps;
use crate::core::pcs::{TreeSubspan, TreeVec};
use crate::core::poly::circle::{CanonicCoset, CircleEvaluation, PolyOps};
Expand Down Expand Up @@ -173,7 +175,35 @@ impl<E: FrameworkEval> ComponentProver<SimdBackend> for FrameworkComponent<E> {
evaluation_accumulator.columns([(eval_domain.log_size(), self.n_constraints())]);
accum.random_coeff_powers.reverse();

let _span = span!(Level::INFO, "Constraint pointwise eval").entered();
let _span = span!(Level::INFO, "Constraint point-wise eval").entered();

if trace_domain.log_size() < LOG_N_LANES + LOG_N_VERY_PACKED_ELEMS {
// Fall back to CPU if the trace is too small.
let mut col = accum.col.to_cpu();

for row in 0..(1 << eval_domain.log_size()) {
let trace_cols = trace.as_cols_ref().map_cols(|c| c.to_cpu());
let trace_cols = trace_cols.as_cols_ref();

// Evaluate constrains at row.
let eval = CpuDomainEvaluator::new(
&trace_cols,
row,
&accum.random_coeff_powers,
trace_domain.log_size(),
eval_domain.log_size(),
);
let row_res = self.eval.evaluate(eval).row_res;

// Finalize row.
let denom_inv = denom_inv[row >> trace_domain.log_size()];
col.set(row, col.at(row) + row_res * denom_inv)
}
let col = SecureColumnByCoords::from_cpu(col);
*accum.col = col;
return;
}

let col = unsafe { VeryPackedSecureColumnByCoords::transform_under_mut(accum.col) };

for vec_row in 0..(1 << (eval_domain.log_size() - LOG_N_LANES - LOG_N_VERY_PACKED_ELEMS)) {
Expand Down
91 changes: 91 additions & 0 deletions crates/prover/src/constraint_framework/cpu_domain.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
use std::ops::Mul;

use num_traits::Zero;

use super::EvalAtRow;
use crate::core::backend::CpuBackend;
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::secure_column::SECURE_EXTENSION_DEGREE;
use crate::core::pcs::TreeVec;
use crate::core::poly::circle::CircleEvaluation;
use crate::core::poly::BitReversedOrder;
use crate::core::utils::offset_bit_reversed_circle_domain_index;

/// Evaluates constraints at an evaluation domain points.
pub struct CpuDomainEvaluator<'a> {
pub trace_eval: &'a TreeVec<Vec<&'a CircleEvaluation<CpuBackend, BaseField, BitReversedOrder>>>,
pub column_index_per_interaction: Vec<usize>,
pub row: usize,
pub random_coeff_powers: &'a [SecureField],
pub row_res: SecureField,
pub constraint_index: usize,
pub domain_log_size: u32,
pub eval_domain_log_size: u32,
}

impl<'a> CpuDomainEvaluator<'a> {
#[allow(dead_code)]
pub fn new(
trace_eval: &'a TreeVec<Vec<&CircleEvaluation<CpuBackend, BaseField, BitReversedOrder>>>,
row: usize,
random_coeff_powers: &'a [SecureField],
domain_log_size: u32,
eval_log_size: u32,
) -> Self {
Self {
trace_eval,
column_index_per_interaction: vec![0; trace_eval.len()],
row,
random_coeff_powers,
row_res: SecureField::zero(),
constraint_index: 0,
domain_log_size,
eval_domain_log_size: eval_log_size,
}
}
}

impl<'a> EvalAtRow for CpuDomainEvaluator<'a> {
type F = BaseField;
type EF = SecureField;

// TODO(spapini): Remove all boundary checks.
fn next_interaction_mask<const N: usize>(
&mut self,
interaction: usize,
offsets: [isize; N],
) -> [Self::F; N] {
let col_index = self.column_index_per_interaction[interaction];
self.column_index_per_interaction[interaction] += 1;
offsets.map(|off| {
// If the offset is 0, we can just return the value directly from this row.
if off == 0 {
let col = &self.trace_eval[interaction][col_index];
return col[self.row];
}
// Otherwise, we need to look up the value at the offset.
// Since the domain is bit-reversed circle domain ordered, we need to look up the value
// at the bit-reversed natural order index at an offset.
let row = offset_bit_reversed_circle_domain_index(
self.row,
self.domain_log_size,
self.eval_domain_log_size,
off,
);
self.trace_eval[interaction][col_index][row]
})
}

fn add_constraint<G>(&mut self, constraint: G)
where
Self::EF: Mul<G, Output = Self::EF>,
{
self.row_res += self.random_coeff_powers[self.constraint_index] * constraint;
self.constraint_index += 1;
}

fn combine_ef(values: [Self::F; SECURE_EXTENSION_DEGREE]) -> Self::EF {
SecureField::from_m31_array(values)
}
}
1 change: 1 addition & 0 deletions crates/prover/src/constraint_framework/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
mod assert;
mod component;
pub mod constant_columns;
mod cpu_domain;
mod info;
pub mod logup;
mod point;
Expand Down
4 changes: 2 additions & 2 deletions crates/prover/src/core/backend/simd/circle.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::iter::zip;
use std::mem::transmute;

use bytemuck::{cast_slice, Zeroable};
use bytemuck::Zeroable;
use num_traits::One;

use super::fft::{ifft, rfft, CACHED_FFT_LOG_SIZE, MIN_FFT_LOG_SIZE};
Expand Down Expand Up @@ -334,7 +334,7 @@ fn slow_eval_at_point(
// Swap content of a,c.
a.swap_with_slice(&mut c[0..n0]);
}
fold(cast_slice::<_, BaseField>(&poly.coeffs.data), &mappings)
fold(poly.coeffs.as_slice(), &mappings)
}

#[cfg(test)]
Expand Down
10 changes: 10 additions & 0 deletions crates/prover/src/core/backend/simd/column.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@ impl BaseColumn {
res
}

pub fn from_cpu(values: Vec<BaseField>) -> Self {
values.into_iter().collect()
}

/// Returns a vector of `BaseColumnMutSlice`s, each mutably owning
/// `chunk_size` `PackedBaseField`s (i.e, `chuck_size` * `N_LANES` elements).
pub fn chunks_mut(&mut self, chunk_size: usize) -> Vec<BaseColumnMutSlice<'_>> {
Expand Down Expand Up @@ -400,6 +404,12 @@ impl SecureColumnByCoords<SimdBackend> {
.map(|(a, b, c, d)| SecureColumnByCoordsMutSlice([a, b, c, d]))
.collect_vec()
}

pub fn from_cpu(cpu: SecureColumnByCoords<CpuBackend>) -> Self {
Self {
columns: cpu.columns.map(BaseColumn::from_cpu),
}
}
}

impl FromIterator<SecureField> for SecureColumnByCoords<SimdBackend> {
Expand Down
89 changes: 45 additions & 44 deletions crates/prover/src/examples/wide_fibonacci/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -150,59 +150,60 @@ mod tests {
}

#[test_log::test]
fn test_wide_fib_prove() {
const LOG_N_INSTANCES: u32 = 6;
let config = PcsConfig::default();
// Precompute twiddles.
let twiddles = SimdBackend::precompute_twiddles(
CanonicCoset::new(LOG_N_INSTANCES + 1 + config.fri_config.log_blowup_factor)
.circle_domain()
.half_coset,
);

// Setup protocol.
let prover_channel = &mut Blake2sChannel::default();
let commitment_scheme =
&mut CommitmentSchemeProver::<SimdBackend, Blake2sMerkleChannel>::new(
config, &twiddles,
fn test_wide_fib_prove_with_blake() {
for log_n_instances in 4..=6 {
let config = PcsConfig::default();
// Precompute twiddles.
let twiddles = SimdBackend::precompute_twiddles(
CanonicCoset::new(log_n_instances + 1 + config.fri_config.log_blowup_factor)
.circle_domain()
.half_coset,
);

// Trace.
let trace = generate_test_trace(LOG_N_INSTANCES);
let mut tree_builder = commitment_scheme.tree_builder();
tree_builder.extend_evals(trace);
tree_builder.commit(prover_channel);

// Prove constraints.
let component = WideFibonacciComponent::new(
&mut TraceLocationAllocator::default(),
WideFibonacciEval::<FIB_SEQUENCE_LENGTH> {
log_n_rows: LOG_N_INSTANCES,
},
);

let proof = prove::<SimdBackend, Blake2sMerkleChannel>(
&[&component],
prover_channel,
commitment_scheme,
)
.unwrap();

// Verify.
let verifier_channel = &mut Blake2sChannel::default();
let commitment_scheme = &mut CommitmentSchemeVerifier::<Blake2sMerkleChannel>::new(config);
// Setup protocol.
let prover_channel = &mut Blake2sChannel::default();
let commitment_scheme =
&mut CommitmentSchemeProver::<SimdBackend, Blake2sMerkleChannel>::new(
config, &twiddles,
);

// Trace.
let trace = generate_test_trace(log_n_instances);
let mut tree_builder = commitment_scheme.tree_builder();
tree_builder.extend_evals(trace);
tree_builder.commit(prover_channel);

// Prove constraints.
let component = WideFibonacciComponent::new(
&mut TraceLocationAllocator::default(),
WideFibonacciEval::<FIB_SEQUENCE_LENGTH> {
log_n_rows: log_n_instances,
},
);

// Retrieve the expected column sizes in each commitment interaction, from the AIR.
let sizes = component.trace_log_degree_bounds();
commitment_scheme.commit(proof.commitments[0], &sizes[0], verifier_channel);
verify(&[&component], verifier_channel, commitment_scheme, proof).unwrap();
let proof = prove::<SimdBackend, Blake2sMerkleChannel>(
&[&component],
prover_channel,
commitment_scheme,
)
.unwrap();

// Verify.
let verifier_channel = &mut Blake2sChannel::default();
let commitment_scheme =
&mut CommitmentSchemeVerifier::<Blake2sMerkleChannel>::new(config);

// Retrieve the expected column sizes in each commitment interaction, from the AIR.
let sizes = component.trace_log_degree_bounds();
commitment_scheme.commit(proof.commitments[0], &sizes[0], verifier_channel);
verify(&[&component], verifier_channel, commitment_scheme, proof).unwrap();
}
}

#[test]
#[cfg(not(target_arch = "wasm32"))]
fn test_wide_fib_prove_with_poseidon() {
const LOG_N_INSTANCES: u32 = 6;

let config = PcsConfig::default();
// Precompute twiddles.
let twiddles = SimdBackend::precompute_twiddles(
Expand Down

0 comments on commit e1acffa

Please sign in to comment.