Skip to content

Commit

Permalink
Parallelize evaluate_constraint_quotients_on_domain.
Browse files Browse the repository at this point in the history
  • Loading branch information
Alon-Ti committed Sep 18, 2024
1 parent 2c88349 commit 5c158eb
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 20 deletions.
59 changes: 40 additions & 19 deletions crates/prover/src/constraint_framework/component.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ use std::iter::zip;
use std::ops::Deref;

use itertools::Itertools;
#[cfg(feature = "parallel")]
use rayon::prelude::*;
use tracing::{span, Level};

use super::{EvalAtRow, InfoEvaluator, PointEvaluator, SimdDomainEvaluator};
Expand All @@ -22,6 +24,8 @@ use crate::core::poly::circle::{CanonicCoset, CircleEvaluation, PolyOps};
use crate::core::poly::BitReversedOrder;
use crate::core::{utils, ColumnVec};

const CHUNK_SIZE: usize = 1;

// TODO(andrew): Docs.
// TODO(andrew): Consider better location for this.
#[derive(Debug, Default)]
Expand Down Expand Up @@ -130,7 +134,7 @@ impl<E: FrameworkEval> Component for FrameworkComponent<E> {
}
}

impl<E: FrameworkEval> ComponentProver<SimdBackend> for FrameworkComponent<E> {
impl<E: FrameworkEval + Sync> ComponentProver<SimdBackend> for FrameworkComponent<E> {
fn evaluate_constraint_quotients_on_domain(
&self,
trace: &Trace<'_, SimdBackend>,
Expand Down Expand Up @@ -176,28 +180,45 @@ impl<E: FrameworkEval> ComponentProver<SimdBackend> for FrameworkComponent<E> {
let _span = span!(Level::INFO, "Constraint pointwise eval").entered();
let col = unsafe { VeryPackedSecureColumnByCoords::transform_under_mut(accum.col) };

for vec_row in 0..(1 << (eval_domain.log_size() - LOG_N_LANES - LOG_N_VERY_PACKED_ELEMS)) {
let range = 0..(1 << (eval_domain.log_size() - LOG_N_LANES - LOG_N_VERY_PACKED_ELEMS));

#[cfg(not(feature = "parallel"))]
let iter = range.step_by(CHUNK_SIZE).zip(col.chunks_mut(CHUNK_SIZE));

#[cfg(feature = "parallel")]
let iter = range
.into_par_iter()
.step_by(CHUNK_SIZE)
.zip(col.chunks_mut(CHUNK_SIZE));

iter.for_each(|(chunk_idx, mut chunk)| {
let trace_cols = trace.as_cols_ref().map_cols(|c| c.as_ref());

// Evaluate constrains at row.
let eval = SimdDomainEvaluator::new(
&trace_cols,
vec_row,
&accum.random_coeff_powers,
trace_domain.log_size(),
eval_domain.log_size(),
);
let row_res = self.eval.evaluate(eval).row_res;

// Finalize row.
unsafe {
let denom_inv = VeryPackedBaseField::broadcast(
denom_inv[vec_row
>> (trace_domain.log_size() - LOG_N_LANES - LOG_N_VERY_PACKED_ELEMS)],
for idx_in_chunk in 0..CHUNK_SIZE {
let vec_row = chunk_idx * CHUNK_SIZE + idx_in_chunk;
// Evaluate constrains at row.
let eval = SimdDomainEvaluator::new(
&trace_cols,
vec_row,
&accum.random_coeff_powers,
trace_domain.log_size(),
eval_domain.log_size(),
);
col.set_packed(vec_row, col.packed_at(vec_row) + row_res * denom_inv)
let row_res = self.eval.evaluate(eval).row_res;

// Finalize row.
unsafe {
let denom_inv = VeryPackedBaseField::broadcast(
denom_inv[vec_row
>> (trace_domain.log_size() - LOG_N_LANES - LOG_N_VERY_PACKED_ELEMS)],
);
chunk.set_packed(
idx_in_chunk,
chunk.packed_at(idx_in_chunk) + row_res * denom_inv,
)
}
}
}
});
}
}

Expand Down
2 changes: 1 addition & 1 deletion poseidon_benchmark.sh
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
LOG_N_INSTANCES=18 RUST_LOG_SPAN_EVENTS=enter,close RUST_LOG=info \
RUSTFLAGS="-C target-cpu=native -C opt-level=3" \
cargo test test_simd_poseidon_prove -- --nocapture
cargo test test_simd_poseidon_prove --features parallel -- --nocapture

0 comments on commit 5c158eb

Please sign in to comment.