Skip to content

Commit

Permalink
Add build_trace functions for MLE eval component
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewmilson committed Aug 24, 2024
1 parent 18b014c commit af52728
Showing 1 changed file with 144 additions and 96 deletions.
240 changes: 144 additions & 96 deletions crates/prover/src/examples/xor/gkr_lookups/mle_eval.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,25 @@
//! Multilinear extension (MLE) eval at point constraints.
use std::array;
use std::iter::zip;

use itertools::{chain, zip_eq};
use num_traits::{One, Zero};

use crate::constraint_framework::constant_columns::gen_is_first;
use crate::constraint_framework::EvalAtRow;
use crate::core::backend::simd::column::SecureColumn;
use crate::core::backend::simd::prefix_sum::inclusive_prefix_sum;
use crate::core::backend::simd::qm31::PackedSecureField;
use crate::core::backend::simd::SimdBackend;
use crate::core::backend::{Col, Column};
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::secure_column::SecureColumnByCoords;
use crate::core::fields::FieldExpOps;
use crate::core::lookups::gkr_prover::GkrOps;
use crate::core::lookups::mle::Mle;
use crate::core::lookups::utils::eq;
use crate::core::poly::circle::{CanonicCoset, SecureEvaluation};
use crate::core::poly::circle::{CanonicCoset, CircleEvaluation, SecureEvaluation};
use crate::core::poly::BitReversedOrder;
use crate::core::utils::{bit_reverse_index, coset_index_to_circle_domain_index};

Expand Down Expand Up @@ -39,7 +49,7 @@ pub struct MleEvalPoint<const N_VARIABLES: usize> {
// Index `i` stores `eq(({1}^|i|, 0), p[0..i+1]) / eq(({0}^|i|, 1), p[0..i+1])`.
eq_carry_quotients: [SecureField; N_VARIABLES],
// Point `p`.
p: [SecureField; N_VARIABLES],
_p: [SecureField; N_VARIABLES],
}

impl<const N_VARIABLES: usize> MleEvalPoint<N_VARIABLES> {
Expand All @@ -58,7 +68,7 @@ impl<const N_VARIABLES: usize> MleEvalPoint<N_VARIABLES> {
denom_assignment[i] = one;
eq(&numer_assignment, &p[..i + 1]) / eq(&denom_assignment, &p[..i + 1])
}),
p,
_p: p,
}
}
}
Expand Down Expand Up @@ -108,6 +118,68 @@ fn eval_prefix_sum_constraints<E: EvalAtRow>(
eval.add_constraint(curr - prev - row_diff + cumulative_sum_shift);
}

/// Generates a trace.
///
/// Trace structure:
///
/// ```text
/// ---------------------------------------------------------
/// | EqEvals (basis) | MLE terms (prefix sum) |
/// ---------------------------------------------------------
/// | c0 | c1 | c2 | c3 | c4 | c5 | c6 | c7 |
/// ---------------------------------------------------------
/// ```
pub fn build_trace(
mle: &Mle<SimdBackend, SecureField>,
eval_point: &[SecureField],
claim: SecureField,
) -> Vec<CircleEvaluation<SimdBackend, BaseField, BitReversedOrder>> {
let eq_evals = SimdBackend::gen_eq_evals(eval_point, SecureField::one()).into_evals();
let mle_terms = hadamard_product(mle, &eq_evals);

let eq_evals_cols = eq_evals.into_secure_column_by_coords().columns;
let mle_terms_cols = mle_terms.into_secure_column_by_coords().columns;

#[cfg(test)]
assert_eq!(claim, mle.eval_at_point(eval_point));
let shift = claim / BaseField::from(mle.len());
let packed_shift_coords = PackedSecureField::broadcast(shift).into_packed_m31s();
let mut shifted_mle_terms_cols = mle_terms_cols.clone();
zip(&mut shifted_mle_terms_cols, packed_shift_coords)
.for_each(|(col, shift_coord)| col.data.iter_mut().for_each(|v| *v -= shift_coord));
let shifted_prefix_sum_cols = shifted_mle_terms_cols.map(inclusive_prefix_sum);

let log_trace_domain_size = mle.n_variables() as u32;
let trace_domain = CanonicCoset::new(log_trace_domain_size).circle_domain();

chain![eq_evals_cols, shifted_prefix_sum_cols]
.map(|c| CircleEvaluation::new(trace_domain, c))
.collect()
}

/// Generates a trace.
///
/// Trace structure:
/// 1. Is first selector column (see [gen_is_first]).
/// 2. Eq carry quotients column (see [gen_carry_quotient_trace]).
///
/// ```text
/// ------------------------------------------------
/// | is first selector | eq carry quotients |
/// ------------------------------------------------
/// | c0 | c1 | c2 | c3 | c4 |
/// ------------------------------------------------
/// ```
pub fn build_constant_trace<const N_VARIABLES: usize>(
eval_point: &[SecureField; N_VARIABLES],
) -> Vec<CircleEvaluation<SimdBackend, BaseField, BitReversedOrder>> {
let log_size = N_VARIABLES as u32;
let mut res = Vec::new();
res.push(gen_is_first(log_size));
res.extend(gen_carry_quotient_trace(eval_point).into_coordinate_evals());
res
}

/// Returns succinct Eq carry quotients column.
///
/// Given column `c(P)` defined on a [`CircleDomain`] `D = +-C`, and an MLE eval point
Expand All @@ -117,13 +189,14 @@ fn eval_prefix_sum_constraints<E: EvalAtRow>(
///
/// [`CircleDomain`]: crate::core::poly::circle::CircleDomain
pub fn gen_carry_quotient_trace<const N_VARIABLES: usize>(
eval_point: &MleEvalPoint<N_VARIABLES>,
eval_point: &[SecureField; N_VARIABLES],
) -> SecureEvaluation<SimdBackend, BitReversedOrder> {
let last_variable = *eval_point.p.last().unwrap();
let last_variable = *eval_point.last().unwrap();
let zero = SecureField::zero();
let one = SecureField::one();

let mut half_coset0_carry_quotients = eval_point.eq_carry_quotients;
let mle_eval_point = MleEvalPoint::new(*eval_point);
let mut half_coset0_carry_quotients = mle_eval_point.eq_carry_quotients;
*half_coset0_carry_quotients.last_mut().unwrap() *=
eq(&[one], &[last_variable]) / eq(&[zero], &[last_variable]);
let half_coset1_carry_quotients = half_coset0_carry_quotients.map(|v| v.inverse());
Expand Down Expand Up @@ -152,59 +225,69 @@ pub fn gen_carry_quotient_trace<const N_VARIABLES: usize>(
SecureEvaluation::new(domain, col)
}

/// Returns the element-wise product of `a` and `b`.
fn hadamard_product(
a: &Col<SimdBackend, SecureField>,
b: &Col<SimdBackend, SecureField>,
) -> Col<SimdBackend, SecureField> {
assert_eq!(a.len(), b.len());
SecureColumn {
data: zip_eq(&a.data, &b.data).map(|(&a, &b)| a * b).collect(),
length: a.len(),
}
}

#[cfg(test)]
mod tests {
use std::array;
use std::iter::{repeat, zip};

use itertools::{chain, zip_eq, Itertools};
use itertools::{chain, Itertools};
use num_traits::One;
use rand::rngs::SmallRng;
use rand::{Rng, SeedableRng};

use super::{
eval_eq_constraints, eval_mle_eval_constraints, eval_prefix_sum_constraints,
gen_carry_quotient_trace, MleEvalPoint,
eval_eq_constraints, eval_mle_eval_constraints, eval_prefix_sum_constraints, MleEvalPoint,
};
use crate::constraint_framework::constant_columns::gen_is_first;
use crate::constraint_framework::{assert_constraints, EvalAtRow};
use crate::core::backend::simd::column::SecureColumn;
use crate::core::backend::simd::prefix_sum::inclusive_prefix_sum;
use crate::core::backend::simd::qm31::PackedSecureField;
use crate::core::backend::simd::SimdBackend;
use crate::core::backend::{Col, Column};
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::secure_column::SecureColumnByCoords;
use crate::core::lookups::gkr_prover::GkrOps;
use crate::core::lookups::mle::Mle;
use crate::core::pcs::TreeVec;
use crate::core::poly::circle::{CanonicCoset, CircleEvaluation};
use crate::core::poly::circle::{CanonicCoset, CircleEvaluation, SecureEvaluation};
use crate::core::poly::BitReversedOrder;
use crate::core::utils::{bit_reverse, coset_order_to_circle_domain_order};

const EVAL_TRACE: usize = 0;
const CONST_TRACE: usize = 1;
use crate::examples::xor::gkr_lookups::mle_eval::{build_constant_trace, build_trace};

#[test]
fn test_mle_eval_constraints_with_log_size_5() {
const N_VARIABLES: usize = 5;
const COEFFS_COL_TRACE: usize = 0;
const EVAL_TRACE: usize = 1;
const CONST_TRACE: usize = 2;
let mut rng = SmallRng::seed_from_u64(0);
let log_size = N_VARIABLES as u32;
let size = 1 << log_size;
let mle = Mle::new((0..size).map(|_| rng.gen::<SecureField>()).collect());
let mle_coeffs = (0..size).map(|_| rng.gen::<SecureField>()).collect();
let mle = Mle::<SimdBackend, SecureField>::new(mle_coeffs);
let eval_point: [SecureField; N_VARIABLES] = array::from_fn(|_| rng.gen());
let mle_eval_point = MleEvalPoint::new(eval_point);
let base_trace = gen_base_trace(&mle, &eval_point);
let claim = mle.eval_at_point(&eval_point);
let mle_eval_point = MleEvalPoint::new(eval_point);
let mle_eval_trace = build_trace(&mle, &eval_point, claim);
let mle_coeffs_col_trace = build_mle_coeffs_trace(mle);
let claim_shift = claim / BaseField::from(size);
let constants_trace = gen_constants_trace(&mle_eval_point);
let traces = TreeVec::new(vec![base_trace, constants_trace]);
let constants_trace = build_constant_trace(&eval_point);
let traces = TreeVec::new(vec![mle_coeffs_col_trace, mle_eval_trace, constants_trace]);
let trace_polys = traces.map(|trace| trace.into_iter().map(|c| c.interpolate()).collect());
let trace_domain = CanonicCoset::new(log_size);

assert_constraints(&trace_polys, trace_domain, |mut eval| {
let [mle_coeff_col_eval] = eval.next_extension_interaction_mask(EVAL_TRACE, [0]);
let [mle_coeff_col_eval] = eval.next_extension_interaction_mask(COEFFS_COL_TRACE, [0]);
eval_mle_eval_constraints(
EVAL_TRACE,
CONST_TRACE,
Expand All @@ -220,56 +303,59 @@ mod tests {
#[ignore = "SimdBackend `MIN_FFT_LOG_SIZE` is 5"]
fn eq_constraints_with_4_variables() {
const N_VARIABLES: usize = 4;
const EVAL_TRACE: usize = 0;
const CONST_TRACE: usize = 1;
let mut rng = SmallRng::seed_from_u64(0);
let mle = Mle::new(repeat(SecureField::one()).take(1 << N_VARIABLES).collect());
let eval_point: [SecureField; N_VARIABLES] = array::from_fn(|_| rng.gen());
let mle_eval_point = MleEvalPoint::new(eval_point);
let base_trace = gen_base_trace(&mle, &eval_point);
let constants_trace = gen_constants_trace(&mle_eval_point);
let traces = TreeVec::new(vec![base_trace, constants_trace]);
let trace = build_trace(&mle, &eval_point, mle.eval_at_point(&eval_point));
let constants_trace = build_constant_trace(&eval_point);
let traces = TreeVec::new(vec![trace, constants_trace]);
let trace_polys = traces.map(|trace| trace.into_iter().map(|c| c.interpolate()).collect());
let trace_domain = CanonicCoset::new(eval_point.len() as u32);

assert_constraints(&trace_polys, trace_domain, |mut eval| {
let _mle_coeffs_col_eval = eval.next_extension_interaction_mask(EVAL_TRACE, [0]);
eval_eq_constraints(EVAL_TRACE, CONST_TRACE, &mut eval, mle_eval_point);
});
}

#[test]
fn eq_constraints_with_5_variables() {
const N_VARIABLES: usize = 5;
const EVAL_TRACE: usize = 0;
const CONST_TRACE: usize = 1;
let mut rng = SmallRng::seed_from_u64(0);
let mle = Mle::new(repeat(SecureField::one()).take(1 << N_VARIABLES).collect());
let eval_point: [SecureField; N_VARIABLES] = array::from_fn(|_| rng.gen());
let mle_eval_point = MleEvalPoint::new(eval_point);
let base_trace = gen_base_trace(&mle, &eval_point);
let constants_trace = gen_constants_trace(&mle_eval_point);
let traces = TreeVec::new(vec![base_trace, constants_trace]);
let trace = build_trace(&mle, &eval_point, mle.eval_at_point(&eval_point));
let constants_trace = build_constant_trace(&eval_point);
let traces = TreeVec::new(vec![trace, constants_trace]);
let trace_polys = traces.map(|trace| trace.into_iter().map(|c| c.interpolate()).collect());
let trace_domain = CanonicCoset::new(eval_point.len() as u32);

assert_constraints(&trace_polys, trace_domain, |mut eval| {
let _mle_coeffs_col_eval = eval.next_extension_interaction_mask(EVAL_TRACE, [0]);
eval_eq_constraints(EVAL_TRACE, CONST_TRACE, &mut eval, mle_eval_point);
});
}

#[test]
fn eq_constraints_with_8_variables() {
const N_VARIABLES: usize = 8;
const EVAL_TRACE: usize = 0;
const CONST_TRACE: usize = 1;
let mut rng = SmallRng::seed_from_u64(0);
let mle = Mle::new(repeat(SecureField::one()).take(1 << N_VARIABLES).collect());
let eval_point: [SecureField; N_VARIABLES] = array::from_fn(|_| rng.gen());
let mle_eval_point = MleEvalPoint::<N_VARIABLES>::new(eval_point);
let base_trace = gen_base_trace(&mle, &eval_point);
let constants_trace = gen_constants_trace(&mle_eval_point);
let traces = TreeVec::new(vec![base_trace, constants_trace]);
let mle_eval_point = MleEvalPoint::new(eval_point);
let trace = build_trace(&mle, &eval_point, mle.eval_at_point(&eval_point));
let constants_trace = build_constant_trace(&eval_point);
let traces = TreeVec::new(vec![trace, constants_trace]);
let trace_polys = traces.map(|trace| trace.into_iter().map(|c| c.interpolate()).collect());
let trace_domain = CanonicCoset::new(N_VARIABLES as u32);
let trace_domain = CanonicCoset::new(eval_point.len() as u32);

assert_constraints(&trace_polys, trace_domain, |mut eval| {
let _mle_coeffs_col_eval = eval.next_extension_interaction_mask(EVAL_TRACE, [0]);
eval_eq_constraints(EVAL_TRACE, CONST_TRACE, &mut eval, mle_eval_point);
});
}
Expand All @@ -287,49 +373,10 @@ mod tests {

assert_constraints(&trace_polys, trace_domain, |mut eval| {
let [row_diff] = eval.next_extension_interaction_mask(0, [0]);
eval_prefix_sum_constraints(EVAL_TRACE, &mut eval, row_diff, cumulative_sum_shift)
eval_prefix_sum_constraints(0, &mut eval, row_diff, cumulative_sum_shift)
});
}

/// Generates a trace.
///
/// Trace structure:
///
/// ```text
/// -------------------------------------------------------------------------------------
/// | MLE coeffs | EqEvals (basis) | MLE terms (prefix sum) |
/// -------------------------------------------------------------------------------------
/// | c0 | c1 | c2 | c3 | c4 | c5 | c6 | c7 | c9 | c9 | c10 | c11 |
/// -------------------------------------------------------------------------------------
/// ```
fn gen_base_trace(
mle: &Mle<SimdBackend, SecureField>,
eval_point: &[SecureField],
) -> Vec<CircleEvaluation<SimdBackend, BaseField, BitReversedOrder>> {
let mle_coeffs = mle.clone().into_evals();
let eq_evals = SimdBackend::gen_eq_evals(eval_point, SecureField::one()).into_evals();
let mle_terms = hadamard_product(&mle_coeffs, &eq_evals);

let mle_coeff_cols = mle_coeffs.into_secure_column_by_coords().columns;
let eq_evals_cols = eq_evals.into_secure_column_by_coords().columns;
let mle_terms_cols = mle_terms.into_secure_column_by_coords().columns;

let claim = mle.eval_at_point(eval_point);
let shift = claim / BaseField::from(mle.len());
let packed_shifts = PackedSecureField::broadcast(shift).into_packed_m31s();
let mut shifted_mle_terms_cols = mle_terms_cols.clone();
zip(&mut shifted_mle_terms_cols, packed_shifts)
.for_each(|(col, shift)| col.data.iter_mut().for_each(|v| *v -= shift));
let shifted_prefix_sum_cols = shifted_mle_terms_cols.map(inclusive_prefix_sum);

let log_trace_domain_size = mle.n_variables() as u32;
let trace_domain = CanonicCoset::new(log_trace_domain_size).circle_domain();

chain![mle_coeff_cols, eq_evals_cols, shifted_prefix_sum_cols]
.map(|c| CircleEvaluation::new(trace_domain, c))
.collect()
}

/// Generates a trace.
///
/// Trace structure:
Expand Down Expand Up @@ -370,25 +417,26 @@ mod tests {
.collect()
}

/// Returns the element-wise product of `a` and `b`.
fn hadamard_product(
a: &Col<SimdBackend, SecureField>,
b: &Col<SimdBackend, SecureField>,
) -> Col<SimdBackend, SecureField> {
assert_eq!(a.len(), b.len());
SecureColumn {
data: zip_eq(&a.data, &b.data).map(|(&a, &b)| a * b).collect(),
length: a.len(),
}
}

fn gen_constants_trace<const N_VARIABLES: usize>(
eval_point: &MleEvalPoint<N_VARIABLES>,
/// Generates a trace.
///
/// Trace structure:
///
/// ```text
/// -----------------------------
/// | MLE coeffs col |
/// -----------------------------
/// | c0 | c1 | c2 | c3 |
/// -----------------------------
/// ```
fn build_mle_coeffs_trace(
mle: Mle<SimdBackend, SecureField>,
) -> Vec<CircleEvaluation<SimdBackend, BaseField, BitReversedOrder>> {
let log_size = N_VARIABLES as u32;
let mut constants_trace = Vec::new();
constants_trace.push(gen_is_first(log_size));
constants_trace.extend(gen_carry_quotient_trace(eval_point).into_coordinate_evals());
constants_trace
let log_size = mle.n_variables() as u32;
let trace_domain = CanonicCoset::new(log_size).circle_domain();
let mle_coeffs_col_by_coords = mle.into_evals().into_secure_column_by_coords();
SecureEvaluation::new(trace_domain, mle_coeffs_col_by_coords)
.into_coordinate_evals()
.into_iter()
.collect()
}
}

0 comments on commit af52728

Please sign in to comment.