Skip to content

Commit

Permalink
Abstracted eval.write_fracs to eval.add_to_relations.
Browse files Browse the repository at this point in the history
  • Loading branch information
Alon-Ti committed Nov 10, 2024
1 parent f76e199 commit 601dd1b
Show file tree
Hide file tree
Showing 14 changed files with 159 additions and 81 deletions.
2 changes: 1 addition & 1 deletion crates/prover/src/constraint_framework/logup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ impl<E: EvalAtRow> Drop for LogupAtRow<E> {
pub struct LookupElements<const N: usize> {
pub z: SecureField,
pub alpha: SecureField,
alpha_powers: [SecureField; N],
pub alpha_powers: [SecureField; N],
}
impl<const N: usize> LookupElements<N> {
pub fn draw(channel: &mut impl Channel) -> Self {
Expand Down
82 changes: 70 additions & 12 deletions crates/prover/src/constraint_framework/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,15 +114,30 @@ pub trait EvalAtRow {
/// Adds `elems` to `relation` with `multiplicity`.
fn add_to_relation<Relation: RelationType<Self::F, Self::EF>>(
&mut self,
relation: Relation,
multiplicity: usize,
relation: &Relation,
multiplicity: Self::EF,
elems: &[Self::F],
) {
let denom = relation.combine(elems);
self.write_frac(Fraction::new(
Self::EF::from(SecureField::from(multiplicity)),
denom,
));
self.write_frac(Fraction::new(multiplicity, denom));
}

/// Adds `elems[01]` to `relation[01]` with `multiplicity[01]`, batched.
/// TODO(alont): Generalize this to n elements if more than 2-batching is used.
fn add_to_relation_batched<Relation: RelationType<Self::F, Self::EF>>(
&mut self,
relation0: &Relation,
multiplicity0: Self::EF,
elems0: &[Self::F],
relation1: &Relation,
multiplicity1: Self::EF,
elems1: &[Self::F],
) {
let denom0 = relation0.combine(elems0);
let denom1 = relation1.combine(elems1);
self.write_frac(
Fraction::new(multiplicity0, denom0) + Fraction::new(multiplicity1, denom1),
);
}

// TODO(alont): Remove these once LogupAtRow is no longer used.
Expand Down Expand Up @@ -181,22 +196,65 @@ macro_rules! logup_proxy {
}
pub(crate) use logup_proxy;

pub trait RelationType<F, EF>
pub trait RelationEFTraitBound<F: Clone>:
Clone + Zero + From<F> + From<SecureField> + Mul<F, Output = Self> + Sub<Self, Output = Self>
{
}

impl<F, EF> RelationEFTraitBound<F> for EF
where
F: Clone,
EF: Clone + Zero + From<F> + From<SecureField> + Mul<F, Output = EF> + Sub<EF, Output = EF>,
{
}

pub trait RelationType<F: Clone, EF: RelationEFTraitBound<F>> {
fn combine(&self, values: &[F]) -> EF {
values
.iter()
.zip(self.get_alpha_powers())
.fold(EF::zero(), |acc, (value, power)| {
acc + power.clone() * value.clone()
acc + EF::from(*power) * value.clone()
})
- self.get_z()
- self.get_z().into()
}

fn get_z(&self) -> EF;
fn get_alpha_powers(&self) -> &[EF];
fn name(&self) -> &str;
fn get_z(&self) -> SecureField;
fn get_alpha_powers(&self) -> &[SecureField];
fn get_name(&self) -> &str;
}

macro_rules! relation {
($name:tt, $size:tt) => {
#[derive(Clone, Debug, PartialEq)]
pub struct $name(crate::constraint_framework::logup::LookupElements<$size>);

impl $name {
pub fn dummy() -> Self {
Self(crate::constraint_framework::logup::LookupElements::dummy())
}
pub fn draw(channel: &mut impl crate::core::channel::Channel) -> Self {
Self(crate::constraint_framework::logup::LookupElements::draw(
channel,
))
}
}

impl<F: Clone, EF: crate::constraint_framework::RelationEFTraitBound<F>>
crate::constraint_framework::RelationType<F, EF> for $name
{
fn get_z(&self) -> crate::core::fields::qm31::SecureField {
self.0.z
}

fn get_alpha_powers(&self) -> &[crate::core::fields::qm31::SecureField] {
&self.0.alpha_powers
}

fn get_name(&self) -> &str {
stringify!($name)
}
}
};
}
pub(crate) use relation;
1 change: 1 addition & 0 deletions crates/prover/src/examples/blake/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ impl XorAccums {
}
}

// TODO(alont): Get these out of the struct and give them names.
#[derive(Clone)]
pub struct BlakeXorElements {
xor12: XorElements,
Expand Down
21 changes: 10 additions & 11 deletions crates/prover/src/examples/blake/round/constraints.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use super::{BlakeXorElements, RoundElements};
use crate::constraint_framework::EvalAtRow;
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::lookups::utils::{Fraction, Reciprocal};
use crate::core::lookups::utils::Reciprocal;
use crate::examples::blake::{Fu32, STATE_SIZE};

const INV16: BaseField = BaseField::from_u32_unchecked(1 << 15);
Expand Down Expand Up @@ -67,17 +67,16 @@ impl<'a, E: EvalAtRow> BlakeRoundEval<'a, E> {
);

// Yield `Round(input_v, output_v, message)`.
self.eval.write_frac(Fraction::new(
self.eval.add_to_relation(
self.round_lookup_elements,
-E::EF::one(),
self.round_lookup_elements.combine(
&chain![
input_v.iter().cloned().flat_map(Fu32::to_felts),
v.iter().cloned().flat_map(Fu32::to_felts),
m.iter().cloned().flat_map(Fu32::to_felts)
]
.collect_vec(),
),
));
&chain![
input_v.iter().cloned().flat_map(Fu32::to_felts),
v.iter().cloned().flat_map(Fu32::to_felts),
m.iter().cloned().flat_map(Fu32::to_felts)
]
.collect_vec(),
);

self.eval.finalize_logup();
self.eval
Expand Down
2 changes: 1 addition & 1 deletion crates/prover/src/examples/blake/round/gen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use tracing::{span, Level};

use super::{BlakeXorElements, RoundElements};
use crate::constraint_framework::logup::LogupTraceGenerator;
use crate::constraint_framework::ORIGINAL_TRACE_IDX;
use crate::constraint_framework::{RelationType, ORIGINAL_TRACE_IDX};
use crate::core::backend::simd::column::BaseColumn;
use crate::core::backend::simd::m31::{PackedBaseField, LOG_N_LANES};
use crate::core::backend::simd::qm31::PackedSecureField;
Expand Down
9 changes: 5 additions & 4 deletions crates/prover/src/examples/blake/round/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@ pub use gen::{generate_interaction_trace, generate_trace, BlakeRoundInput};
use num_traits::Zero;

use super::{BlakeXorElements, N_ROUND_INPUT_FELTS};
use crate::constraint_framework::logup::LookupElements;
use crate::constraint_framework::{EvalAtRow, FrameworkComponent, FrameworkEval, InfoEvaluator};
use crate::constraint_framework::{
relation, EvalAtRow, FrameworkComponent, FrameworkEval, InfoEvaluator,
};
use crate::core::fields::qm31::SecureField;

pub type BlakeRoundComponent = FrameworkComponent<BlakeRoundEval>;

pub type RoundElements = LookupElements<N_ROUND_INPUT_FELTS>;
relation!(RoundElements, N_ROUND_INPUT_FELTS);

pub struct BlakeRoundEval {
pub log_size: u32,
Expand All @@ -32,7 +33,7 @@ impl FrameworkEval for BlakeRoundEval {
eval,
xor_lookup_elements: &self.xor_lookup_elements,
round_lookup_elements: &self.round_lookup_elements,
total_sum: self.total_sum.into(),
total_sum: self.total_sum,
log_size: self.log_size,
};
blake_eval.eval()
Expand Down
31 changes: 18 additions & 13 deletions crates/prover/src/examples/blake/scheduler/constraints.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
use itertools::{chain, Itertools};
use num_traits::Zero;
use num_traits::{One, Zero};

use super::BlakeElements;
use crate::constraint_framework::EvalAtRow;
use crate::constraint_framework::{EvalAtRow, RelationType};
use crate::core::fields::qm31::SecureField;
use crate::core::lookups::utils::{Fraction, Reciprocal};
use crate::core::lookups::utils::Fraction;
use crate::core::vcs::blake2s_ref::SIGMA;
use crate::examples::blake::round::RoundElements;
use crate::examples::blake::{Fu32, N_ROUNDS, STATE_SIZE};
Expand All @@ -24,20 +24,25 @@ pub fn eval_blake_scheduler_constraints<E: EvalAtRow>(
// Schedule.
for [i, j] in (0..N_ROUNDS).array_chunks::<2>() {
// Use triplet in round lookup.
let [denom_i, denom_j] = [i, j].map(|idx| {
let [elems_i, elems_j] = [i, j].map(|idx| {
let input_state = &states[idx];
let output_state = &states[idx + 1];
let round_messages = SIGMA[idx].map(|k| messages[k as usize].clone());
round_lookup_elements.combine::<E::F, E::EF>(
&chain![
input_state.iter().cloned().flat_map(Fu32::to_felts),
output_state.iter().cloned().flat_map(Fu32::to_felts),
round_messages.iter().cloned().flat_map(Fu32::to_felts)
]
.collect_vec(),
)
chain![
input_state.iter().cloned().flat_map(Fu32::to_felts),
output_state.iter().cloned().flat_map(Fu32::to_felts),
round_messages.iter().cloned().flat_map(Fu32::to_felts)
]
.collect_vec()
});
eval.write_frac(Reciprocal::new(denom_i) + Reciprocal::new(denom_j));
eval.add_to_relation_batched(
round_lookup_elements,
E::EF::one(),
&elems_i,
round_lookup_elements,
E::EF::one(),
&elems_j,
);
}

let input_state = &states[0];
Expand Down
2 changes: 1 addition & 1 deletion crates/prover/src/examples/blake/scheduler/gen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use tracing::{span, Level};

use super::{blake_scheduler_info, BlakeElements};
use crate::constraint_framework::logup::LogupTraceGenerator;
use crate::constraint_framework::ORIGINAL_TRACE_IDX;
use crate::constraint_framework::{RelationType, ORIGINAL_TRACE_IDX};
use crate::core::backend::simd::column::BaseColumn;
use crate::core::backend::simd::m31::LOG_N_LANES;
use crate::core::backend::simd::qm31::PackedSecureField;
Expand Down
7 changes: 4 additions & 3 deletions crates/prover/src/examples/blake/scheduler/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,14 @@ use num_traits::Zero;

use super::round::RoundElements;
use super::N_ROUND_INPUT_FELTS;
use crate::constraint_framework::logup::LookupElements;
use crate::constraint_framework::{EvalAtRow, FrameworkComponent, FrameworkEval, InfoEvaluator};
use crate::constraint_framework::{
relation, EvalAtRow, FrameworkComponent, FrameworkEval, InfoEvaluator,
};
use crate::core::fields::qm31::SecureField;

pub type BlakeSchedulerComponent = FrameworkComponent<BlakeSchedulerEval>;

pub type BlakeElements = LookupElements<N_ROUND_INPUT_FELTS>;
relation!(BlakeElements, N_ROUND_INPUT_FELTS);

pub struct BlakeSchedulerEval {
pub log_size: u32,
Expand Down
36 changes: 18 additions & 18 deletions crates/prover/src/examples/plonk/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ use tracing::{span, Level};
use crate::constraint_framework::logup::{ClaimedPrefixSum, LogupTraceGenerator, LookupElements};
use crate::constraint_framework::preprocessed_columns::{gen_is_first, PreprocessedColumn};
use crate::constraint_framework::{
assert_constraints, EvalAtRow, FrameworkComponent, FrameworkEval, TraceLocationAllocator,
assert_constraints, relation, EvalAtRow, FrameworkComponent, FrameworkEval,
TraceLocationAllocator,
};
use crate::core::backend::simd::column::BaseColumn;
use crate::core::backend::simd::m31::LOG_N_LANES;
Expand All @@ -15,7 +16,6 @@ use crate::core::backend::Column;
use crate::core::channel::Blake2sChannel;
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::lookups::utils::Fraction;
use crate::core::pcs::{CommitmentSchemeProver, PcsConfig, TreeSubspan};
use crate::core::poly::circle::{CanonicCoset, CircleEvaluation, PolyOps};
use crate::core::poly::BitReversedOrder;
Expand All @@ -25,10 +25,12 @@ use crate::core::ColumnVec;

pub type PlonkComponent = FrameworkComponent<PlonkEval>;

relation!(PlonkLookupElements, 2);

#[derive(Clone)]
pub struct PlonkEval {
pub log_n_rows: u32,
pub lookup_elements: LookupElements<2>,
pub lookup_elements: PlonkLookupElements,
pub claimed_sum: ClaimedPrefixSum,
pub total_sum: SecureField,
pub base_trace_location: TreeSubspan,
Expand Down Expand Up @@ -65,17 +67,16 @@ impl FrameworkEval for PlonkEval {
+ (E::F::one() - op) * a_val.clone() * b_val.clone(),
);

let denom_a: E::EF = self.lookup_elements.combine(&[a_wire, a_val]);
let denom_b: E::EF = self.lookup_elements.combine(&[b_wire, b_val]);
eval.add_to_relation_batched(
&self.lookup_elements,
E::EF::one(),
&[a_wire, a_val],
&self.lookup_elements,
E::EF::one(),
&[b_wire, b_val],
);

eval.write_frac(Fraction::new(
denom_a.clone() + denom_b.clone(),
denom_a * denom_b,
));
eval.write_frac(Fraction::new(
(-mult).into(),
self.lookup_elements.combine(&[c_wire, c_val]),
));
eval.add_to_relation(&self.lookup_elements, (-mult).into(), &[c_wire, c_val]);

eval.finalize_logup();
eval
Expand Down Expand Up @@ -195,12 +196,12 @@ pub fn prove_fibonacci_plonk(
span.exit();

// Draw lookup element.
let lookup_elements = LookupElements::draw(channel);
let lookup_elements = PlonkLookupElements::draw(channel);

// Interaction trace.
let span = span!(Level::INFO, "Interaction").entered();
let (trace, [total_sum, claimed_sum]) =
gen_interaction_trace(log_n_rows, padding_offset, &circuit, &lookup_elements);
gen_interaction_trace(log_n_rows, padding_offset, &circuit, &lookup_elements.0);
let mut tree_builder = commitment_scheme.tree_builder();
let interaction_trace_location = tree_builder.extend_evals(trace);
tree_builder.commit(channel);
Expand Down Expand Up @@ -256,14 +257,13 @@ pub fn prove_fibonacci_plonk(
mod tests {
use std::env;

use crate::constraint_framework::logup::LookupElements;
use crate::core::air::Component;
use crate::core::channel::Blake2sChannel;
use crate::core::fri::FriConfig;
use crate::core::pcs::{CommitmentSchemeVerifier, PcsConfig};
use crate::core::prover::verify;
use crate::core::vcs::blake2_merkle::Blake2sMerkleChannel;
use crate::examples::plonk::prove_fibonacci_plonk;
use crate::examples::plonk::{prove_fibonacci_plonk, PlonkLookupElements};

#[test_log::test]
fn test_simd_plonk_prove() {
Expand Down Expand Up @@ -291,7 +291,7 @@ mod tests {
// Trace columns.
commitment_scheme.commit(proof.commitments[0], &sizes[0], channel);
// Draw lookup element.
let lookup_elements = LookupElements::<2>::draw(channel);
let lookup_elements = PlonkLookupElements::draw(channel);
assert_eq!(lookup_elements, component.lookup_elements);
// Interaction columns.
commitment_scheme.commit(proof.commitments[1], &sizes[1], channel);
Expand Down
Loading

0 comments on commit 601dd1b

Please sign in to comment.