From b5f64cc281b8b95ff76c86c97412b783b2573877 Mon Sep 17 00:00:00 2001 From: Al-Kindi-0 <82364884+Al-Kindi-0@users.noreply.github.com> Date: Tue, 10 Sep 2024 18:10:15 +0200 Subject: [PATCH] Add support for periodic columns in LogUp-GKR (#307) --- air/src/air/aux.rs | 6 +- air/src/air/logup_gkr/mod.rs | 90 ++++- air/src/air/logup_gkr/s_column.rs | 2 +- air/src/air/mod.rs | 2 +- air/src/lib.rs | 4 +- prover/src/constraints/evaluator/logup_gkr.rs | 2 +- prover/src/logup_gkr/mod.rs | 10 +- prover/src/logup_gkr/prover.rs | 20 +- prover/src/trace/mod.rs | 2 +- sumcheck/benches/sum_check_high_degree.rs | 22 +- sumcheck/src/prover/high_degree.rs | 109 +++++- sumcheck/src/verifier/mod.rs | 30 +- winterfell/src/tests/logup_gkr_periodic.rs | 357 ++++++++++++++++++ .../{tests.rs => tests/logup_gkr_simple.rs} | 9 +- winterfell/src/tests/mod.rs | 8 + 15 files changed, 619 insertions(+), 54 deletions(-) create mode 100644 winterfell/src/tests/logup_gkr_periodic.rs rename winterfell/src/{tests.rs => tests/logup_gkr_simple.rs} (97%) create mode 100644 winterfell/src/tests/mod.rs diff --git a/air/src/air/aux.rs b/air/src/air/aux.rs index 33d7d8539..d9fa3c2d5 100644 --- a/air/src/air/aux.rs +++ b/air/src/air/aux.rs @@ -80,7 +80,7 @@ pub struct GkrData { pub lagrange_kernel_eval_point: LagrangeKernelRandElements, pub openings_combining_randomness: Vec, pub openings: Vec, - pub oracles: Vec>, + pub oracles: Vec, } impl GkrData { @@ -92,7 +92,7 @@ impl GkrData { lagrange_kernel_eval_point: LagrangeKernelRandElements, openings_combining_randomness: Vec, openings: Vec, - oracles: Vec>, + oracles: Vec, ) -> Self { Self { lagrange_kernel_eval_point, @@ -116,7 +116,7 @@ impl GkrData { &self.openings } - pub fn oracles(&self) -> &[LogUpGkrOracle] { + pub fn oracles(&self) -> &[LogUpGkrOracle] { &self.oracles } diff --git a/air/src/air/logup_gkr/mod.rs b/air/src/air/logup_gkr/mod.rs index a907fad40..d3e198912 100644 --- a/air/src/air/logup_gkr/mod.rs +++ b/air/src/air/logup_gkr/mod.rs @@ -35,7 +35,13 @@ pub trait LogUpGkrEvaluator: Clone + Sync { /// Gets a list of all oracles involved in LogUp-GKR; this is intended to be used in construction of /// MLEs. - fn get_oracles(&self) -> &[LogUpGkrOracle]; + fn get_oracles(&self) -> &[LogUpGkrOracle]; + + /// A vector of virtual periodic columns defined by their values in some given cycle. + /// Note that the cycle lengths must be powers of 2. + fn get_periodic_column_values(&self) -> Vec> { + vec![] + } /// Returns the number of random values needed to evaluate a query. fn get_num_rand_values(&self) -> usize; @@ -56,7 +62,7 @@ pub trait LogUpGkrEvaluator: Clone + Sync { /// information returned from `get_oracles()`. However, this implementation is likely to be /// expensive compared to the hand-written implementation. However, we could provide a test /// which verifies that `get_oracles()` and `build_query()` methods are consistent. - fn build_query(&self, frame: &EvaluationFrame, periodic_values: &[E], query: &mut [E]) + fn build_query(&self, frame: &EvaluationFrame, query: &mut [E]) where E: FieldElement; @@ -70,6 +76,7 @@ pub trait LogUpGkrEvaluator: Clone + Sync { fn evaluate_query( &self, query: &[F], + periodic_values: &[F], logup_randomness: &[E], numerators: &mut [E], denominators: &mut [E], @@ -145,6 +152,22 @@ pub trait LogUpGkrEvaluator: Clone + Sync { ) -> SColumnConstraint { SColumnConstraint::new(gkr_data, composition_coefficient) } + + /// Returns the periodic values used in the LogUp-GKR statement, either as base field element + /// during circuit evaluation or as extension field element during the run of sum-check for + /// the input layer. + fn build_periodic_values(&self) -> PeriodicTable + where + E: FieldElement, + { + let table = self + .get_periodic_column_values() + .iter() + .map(|values| values.iter().map(|x| E::from(*x)).collect()) + .collect(); + + PeriodicTable { table } + } } #[derive(Clone, Default)] @@ -175,7 +198,7 @@ where type PublicInputs = P; - fn get_oracles(&self) -> &[LogUpGkrOracle] { + fn get_oracles(&self) -> &[LogUpGkrOracle] { panic!("LogUpGkrEvaluator method called but LogUp-GKR is not implemented") } @@ -191,7 +214,7 @@ where panic!("LogUpGkrEvaluator method called but LogUp-GKR is not implemented") } - fn build_query(&self, _frame: &EvaluationFrame, _periodic_values: &[E], _query: &mut [E]) + fn build_query(&self, _frame: &EvaluationFrame, _query: &mut [E]) where E: FieldElement, { @@ -201,6 +224,7 @@ where fn evaluate_query( &self, _query: &[F], + _periodic_values: &[F], _rand_values: &[E], _numerator: &mut [E], _denominator: &mut [E], @@ -220,12 +244,62 @@ where } #[derive(Clone, Debug, PartialEq, PartialOrd, Eq, Ord)] -pub enum LogUpGkrOracle { +pub enum LogUpGkrOracle { /// A column with a given index in the main trace segment. CurrentRow(usize), /// A column with a given index in the main trace segment but shifted upwards. NextRow(usize), - /// A virtual periodic column defined by its values in a given cycle. Note that the cycle length - /// must be a power of 2. - PeriodicValue(Vec), +} + +// PERIODIC COLUMNS FOR LOGUP +// ================================================================================================= + +/// Stores the periodic columns used in a LogUp-GKR statement. +/// +/// Each stored periodic column is interpreted as a multi-linear extension polynomial of the column +/// with the given periodic values. Due to the periodic nature of the values, storing, binding of +/// an argument and evaluating the said multi-linear extension can be all done linearly in the size +/// of the smallest cycle defining the periodic values. Hence we only store the values of this +/// smallest cycle. The cycle is assumed throughout to be a power of 2. +#[derive(Clone, Debug, Default, PartialEq, PartialOrd, Eq, Ord)] +pub struct PeriodicTable { + pub table: Vec>, +} + +impl PeriodicTable +where + E: FieldElement, +{ + pub fn new(table: Vec>) -> Self { + let table = table.iter().map(|col| col.iter().map(|x| E::from(*x)).collect()).collect(); + + Self { table } + } + + pub fn num_columns(&self) -> usize { + self.table.len() + } + + pub fn table(&self) -> &[Vec] { + &self.table + } + + pub fn fill_periodic_values_at(&self, row: usize, values: &mut [E]) { + self.table + .iter() + .zip(values.iter_mut()) + .for_each(|(col, value)| *value = col[row % col.len()]) + } + + pub fn bind_least_significant_variable(&mut self, round_challenge: E) { + for col in self.table.iter_mut() { + if col.len() > 1 { + let num_evals = col.len() >> 1; + for i in 0..num_evals { + col[i] = col[i << 1] + round_challenge * (col[(i << 1) + 1] - col[i << 1]); + } + col.truncate(num_evals) + } + } + } } diff --git a/air/src/air/logup_gkr/s_column.rs b/air/src/air/logup_gkr/s_column.rs index 29848ceeb..685c6e026 100644 --- a/air/src/air/logup_gkr/s_column.rs +++ b/air/src/air/logup_gkr/s_column.rs @@ -45,7 +45,7 @@ impl SColumnConstraint { .mul_base(E::BaseField::ONE / E::BaseField::from(air.trace_length() as u32)); let mut query = vec![E::ZERO; air.get_logup_gkr_evaluator().get_oracles().len()]; - air.get_logup_gkr_evaluator().build_query(main_trace_frame, &[], &mut query); + air.get_logup_gkr_evaluator().build_query(main_trace_frame, &mut query); let batched_claim_at_query = self.gkr_data.compute_batched_query::(&query); let rhs = s_cur - mean + batched_claim_at_query * l_cur; let lhs = s_nxt; diff --git a/air/src/air/mod.rs b/air/src/air/mod.rs index bedfa5e35..cc2e82d2b 100644 --- a/air/src/air/mod.rs +++ b/air/src/air/mod.rs @@ -33,7 +33,7 @@ use logup_gkr::PhantomLogUpGkrEval; pub use logup_gkr::{ LagrangeKernelBoundaryConstraint, LagrangeKernelConstraints, LagrangeKernelEvaluationFrame, LagrangeKernelRandElements, LagrangeKernelTransitionConstraints, LogUpGkrEvaluator, - LogUpGkrOracle, + LogUpGkrOracle, PeriodicTable, }; mod coefficients; diff --git a/air/src/lib.rs b/air/src/lib.rs index 2993306b9..39ef44d18 100644 --- a/air/src/lib.rs +++ b/air/src/lib.rs @@ -47,6 +47,6 @@ pub use air::{ DeepCompositionCoefficients, EvaluationFrame, GkrData, LagrangeConstraintsCompositionCoefficients, LagrangeKernelBoundaryConstraint, LagrangeKernelConstraints, LagrangeKernelEvaluationFrame, LagrangeKernelRandElements, - LagrangeKernelTransitionConstraints, LogUpGkrEvaluator, LogUpGkrOracle, TraceInfo, - TransitionConstraintDegree, TransitionConstraints, + LagrangeKernelTransitionConstraints, LogUpGkrEvaluator, LogUpGkrOracle, PeriodicTable, + TraceInfo, TransitionConstraintDegree, TransitionConstraints, }; diff --git a/prover/src/constraints/evaluator/logup_gkr.rs b/prover/src/constraints/evaluator/logup_gkr.rs index f8fa3ae36..cc7390b73 100644 --- a/prover/src/constraints/evaluator/logup_gkr.rs +++ b/prover/src/constraints/evaluator/logup_gkr.rs @@ -138,7 +138,7 @@ where let s_cur = aux_frame.current()[s_col_idx]; let s_nxt = aux_frame.next()[s_col_idx]; - evaluator.build_query(&main_frame, &[], &mut query); + evaluator.build_query(&main_frame, &mut query); let batched_query = self.gkr_data.compute_batched_query(&query); let rhs = s_cur - mean + batched_query * l_cur; diff --git a/prover/src/logup_gkr/mod.rs b/prover/src/logup_gkr/mod.rs index 643258ee2..2c4846369 100644 --- a/prover/src/logup_gkr/mod.rs +++ b/prover/src/logup_gkr/mod.rs @@ -109,21 +109,25 @@ impl EvaluatedCircuit { log_up_randomness: &[E], ) -> CircuitLayer { let num_fractions = evaluator.get_num_fractions(); + let periodic_values = evaluator.build_periodic_values(); + let mut input_layer_wires = Vec::with_capacity(main_trace.main_segment().num_rows() * num_fractions); let mut main_frame = EvaluationFrame::new(main_trace.main_segment().num_cols()); let mut query = vec![E::BaseField::ZERO; evaluator.get_oracles().len()]; + let mut periodic_values_row = vec![E::BaseField::ZERO; periodic_values.num_columns()]; let mut numerators = vec![E::ZERO; num_fractions]; let mut denominators = vec![E::ZERO; num_fractions]; for i in 0..main_trace.main_segment().num_rows() { let wires_from_trace_row = { main_trace.read_main_frame(i, &mut main_frame); - - evaluator.build_query(&main_frame, &[], &mut query); + periodic_values.fill_periodic_values_at(i, &mut periodic_values_row); + evaluator.build_query(&main_frame, &mut query); evaluator.evaluate_query( &query, + &periodic_values_row, log_up_randomness, &mut numerators, &mut denominators, @@ -379,7 +383,7 @@ pub fn build_s_column( for (i, item) in lagrange_kernel_col.iter().enumerate().take(main_segment.num_rows() - 1) { main_trace.read_main_frame(i, &mut main_frame); - evaluator.build_query(&main_frame, &[], &mut query); + evaluator.build_query(&main_frame, &mut query); let cur_value = last_value - mean + gkr_data.compute_batched_query(&query) * *item; result.push(cur_value); diff --git a/prover/src/logup_gkr/prover.rs b/prover/src/logup_gkr/prover.rs index 9fc8fe175..f1a66cf35 100644 --- a/prover/src/logup_gkr/prover.rs +++ b/prover/src/logup_gkr/prover.rs @@ -1,6 +1,6 @@ use alloc::vec::Vec; -use air::{LogUpGkrEvaluator, LogUpGkrOracle}; +use air::{LogUpGkrEvaluator, LogUpGkrOracle, PeriodicTable}; use crypto::{ElementHasher, RandomCoin}; use math::FieldElement; use sumcheck::{ @@ -77,9 +77,18 @@ pub fn prove_gkr( // build the MLEs of the relevant main trace columns let main_trace_mls = build_mls_from_main_trace_segment(evaluator.get_oracles(), main_trace.main_segment())?; + // build the periodic table representing periodic columns as multi-linear extensions + let periodic_table = evaluator.build_periodic_values(); - let final_layer_proof = - prove_input_layer(evaluator, logup_randomness, main_trace_mls, gkr_claim, public_coin)?; + // run the GKR prover for the input layer + let final_layer_proof = prove_input_layer( + evaluator, + logup_randomness, + main_trace_mls, + periodic_table, + gkr_claim, + public_coin, + )?; Ok(GkrCircuitProof { circuit_outputs: CircuitOutput { numerators, denominators }, @@ -97,6 +106,7 @@ fn prove_input_layer< evaluator: &impl LogUpGkrEvaluator, log_up_randomness: Vec, multi_linear_ext_polys: Vec>, + periodic_table: PeriodicTable, claim: GkrClaim, transcript: &mut C, ) -> Result, GkrProverError> { @@ -114,6 +124,7 @@ fn prove_input_layer< r_batch, log_up_randomness, multi_linear_ext_polys, + periodic_table, transcript, )?; @@ -123,7 +134,7 @@ fn prove_input_layer< /// Builds the multi-linear extension polynomials needed to run the final sum-check of GKR for /// LogUp-GKR. fn build_mls_from_main_trace_segment( - oracles: &[LogUpGkrOracle], + oracles: &[LogUpGkrOracle], main_trace: &ColMatrix<::BaseField>, ) -> Result>, GkrProverError> { let mut mls = vec![]; @@ -146,7 +157,6 @@ fn build_mls_from_main_trace_segment( let ml = MultiLinearPoly::from_evaluations(values); mls.push(ml) }, - LogUpGkrOracle::PeriodicValue(_) => unimplemented!(), }; } Ok(mls) diff --git a/prover/src/trace/mod.rs b/prover/src/trace/mod.rs index 5fef475e6..2b2e89a9e 100644 --- a/prover/src/trace/mod.rs +++ b/prover/src/trace/mod.rs @@ -282,7 +282,7 @@ pub trait Trace: Sized { let s_cur = aux_frame.current()[s_col_idx]; let s_nxt = aux_frame.next()[s_col_idx]; - evaluator.build_query(&main_frame, &[], &mut query); + evaluator.build_query(&main_frame, &mut query); let batched_query = gkr_data.compute_batched_query(&query); let rhs = s_cur - mean + batched_query * l_cur; diff --git a/sumcheck/benches/sum_check_high_degree.rs b/sumcheck/benches/sum_check_high_degree.rs index f32329c80..483890579 100644 --- a/sumcheck/benches/sum_check_high_degree.rs +++ b/sumcheck/benches/sum_check_high_degree.rs @@ -5,7 +5,7 @@ use std::{marker::PhantomData, time::Duration}; -use air::{EvaluationFrame, LogUpGkrEvaluator, LogUpGkrOracle}; +use air::{EvaluationFrame, LogUpGkrEvaluator, LogUpGkrOracle, PeriodicTable}; use criterion::{criterion_group, criterion_main, BatchSize, BenchmarkId, Criterion}; use crypto::{hashers::Blake3_192, DefaultRandomCoin, RandomCoin}; use math::{fields::f64::BaseElement, ExtensionOf, FieldElement, StarkField}; @@ -37,7 +37,7 @@ fn sum_check_high_degree(c: &mut Criterion) { ) }, |( - (claim, r_batch, rand_pt, (ml0, ml1, ml2, ml3, ml4)), + (claim, r_batch, rand_pt, (ml0, ml1, ml2, ml3, ml4), periodic_table), evaluator, logup_randomness, transcript, @@ -52,6 +52,7 @@ fn sum_check_high_degree(c: &mut Criterion) { r_batch, logup_randomness, mls, + periodic_table, &mut transcript, ) }, @@ -76,6 +77,7 @@ fn setup_sum_check( MultiLinearPoly, MultiLinearPoly, ), + PeriodicTable, ) { let n = 1 << log_size; let table = MultiLinearPoly::from_evaluations(rand_vector(n)); @@ -83,6 +85,7 @@ fn setup_sum_check( let values_0 = MultiLinearPoly::from_evaluations(rand_vector(n)); let values_1 = MultiLinearPoly::from_evaluations(rand_vector(n)); let values_2 = MultiLinearPoly::from_evaluations(rand_vector(n)); + let periodic_table = PeriodicTable::default(); // this will not generate the correct claim with overwhelming probability but should be fine // for benchmarking @@ -90,12 +93,18 @@ fn setup_sum_check( let r_batch: E = rand_value(); let claim: E = rand_value(); - (claim, r_batch, rand_pt, (table, multiplicity, values_0, values_1, values_2)) + ( + claim, + r_batch, + rand_pt, + (table, multiplicity, values_0, values_1, values_2), + periodic_table, + ) } #[derive(Clone, Default)] pub struct PlainLogUpGkrEval { - oracles: Vec>, + oracles: Vec, _field: PhantomData, } @@ -116,7 +125,7 @@ impl LogUpGkrEvaluator for PlainLogUpGkrEval { type PublicInputs = (); - fn get_oracles(&self) -> &[LogUpGkrOracle] { + fn get_oracles(&self) -> &[LogUpGkrOracle] { &self.oracles } @@ -132,7 +141,7 @@ impl LogUpGkrEvaluator for PlainLogUpGkrEval { 3 } - fn build_query(&self, frame: &EvaluationFrame, _periodic_values: &[E], query: &mut [E]) + fn build_query(&self, frame: &EvaluationFrame, query: &mut [E]) where E: FieldElement, { @@ -142,6 +151,7 @@ impl LogUpGkrEvaluator for PlainLogUpGkrEval { fn evaluate_query( &self, query: &[F], + _periodic_values: &[F], rand_values: &[E], numerator: &mut [E], denominator: &mut [E], diff --git a/sumcheck/src/prover/high_degree.rs b/sumcheck/src/prover/high_degree.rs index 691195925..47be290d7 100644 --- a/sumcheck/src/prover/high_degree.rs +++ b/sumcheck/src/prover/high_degree.rs @@ -5,7 +5,7 @@ use alloc::vec::Vec; -use air::LogUpGkrEvaluator; +use air::{LogUpGkrEvaluator, PeriodicTable}; use crypto::{ElementHasher, RandomCoin}; use math::FieldElement; #[cfg(feature = "concurrent")] @@ -160,6 +160,7 @@ pub fn sum_check_prove_higher_degree< r_sum_check: E, log_up_randomness: Vec, mut mls: Vec>, + mut periodic_table: PeriodicTable, coin: &mut impl RandomCoin, ) -> Result, SumCheckProverError> { let num_rounds = mls[0].num_variables(); @@ -176,8 +177,15 @@ pub fn sum_check_prove_higher_degree< let mut current_round_claim = SumCheckRoundClaim { eval_point: vec![], claim }; // run the first round of the protocol - let round_poly_evals = - sumcheck_round(&eq_mu, evaluator, &eq_nu, &mls, &log_up_randomness, r_sum_check); + let round_poly_evals = sumcheck_round( + &eq_mu, + evaluator, + &eq_nu, + &mls, + &periodic_table, + &log_up_randomness, + r_sum_check, + ); let round_poly_coefs = round_poly_evals.to_poly(current_round_claim.claim); // reseed with the s_0 polynomial @@ -198,10 +206,20 @@ pub fn sum_check_prove_higher_degree< .for_each(|ml| ml.bind_least_significant_variable(round_challenge)); eq_nu.bind_least_significant_variable(round_challenge); + // fold each periodic multi-linear using the round challenge + periodic_table.bind_least_significant_variable(round_challenge); + // run the i-th round of the protocol using the folded multi-linears for the new reduced // claim. This basically computes the s_i polynomial. - let round_poly_evals = - sumcheck_round(&eq_mu, evaluator, &eq_nu, &mls, &log_up_randomness, r_sum_check); + let round_poly_evals = sumcheck_round( + &eq_mu, + evaluator, + &eq_nu, + &mls, + &periodic_table, + &log_up_randomness, + r_sum_check, + ); // update the claim current_round_claim = new_round_claim; @@ -280,21 +298,28 @@ fn sumcheck_round( evaluator: &impl LogUpGkrEvaluator::BaseField>, eq_ml: &MultiLinearPoly, mls: &[MultiLinearPoly], + periodic_table: &PeriodicTable, log_up_randomness: &[E], r_sum_check: E, ) -> CompressedUnivariatePolyEvals { - let num_ml = mls.len(); + let num_mls = mls.len(); + let num_periodic = periodic_table.num_columns(); let num_vars = mls[0].num_variables(); let num_rounds = num_vars - 1; #[cfg(not(feature = "concurrent"))] let evaluations = { - let mut evals_one = vec![E::ZERO; num_ml]; - let mut evals_zero = vec![E::ZERO; num_ml]; - let mut evals_x = vec![E::ZERO; num_ml]; + let mut evals_one = vec![E::ZERO; num_mls]; + let mut evals_zero = vec![E::ZERO; num_mls]; + let mut evals_x = vec![E::ZERO; num_mls]; + + let mut evals_periodic_one = vec![E::ZERO; num_periodic]; + let mut evals_periodic_zero = vec![E::ZERO; num_periodic]; + let mut evals_periodic_x = vec![E::ZERO; num_periodic]; let mut eq_x = E::ZERO; - let mut deltas = vec![E::ZERO; num_ml]; + let mut deltas = vec![E::ZERO; num_mls]; + let mut deltas_periodic = vec![E::ZERO; num_periodic]; let mut eq_delta = E::ZERO; let mut numerators = vec![E::ZERO; evaluator.get_num_fractions()]; @@ -311,9 +336,14 @@ fn sumcheck_round( let eq_at_zero = eq_ml.evaluations()[2 * i]; let eq_at_one = eq_ml.evaluations()[2 * i + 1]; + // add evaluation of periodic columns + periodic_table.fill_periodic_values_at(2 * i, &mut evals_periodic_zero); + periodic_table.fill_periodic_values_at(2 * i + 1, &mut evals_periodic_one); + // compute the evaluation at 1 evaluator.evaluate_query( &evals_one, + &evals_periodic_one, log_up_randomness, &mut numerators, &mut denominators, @@ -327,10 +357,14 @@ fn sumcheck_round( ); // compute the evaluations at 2, ..., d_max points - for i in 0..num_ml { + for i in 0..num_mls { deltas[i] = evals_one[i] - evals_zero[i]; evals_x[i] = evals_one[i]; } + for i in 0..num_periodic { + deltas_periodic[i] = evals_periodic_one[i] - evals_periodic_zero[i]; + evals_periodic_x[i] = evals_periodic_one[i]; + } eq_delta = eq_at_one - eq_at_zero; eq_x = eq_at_one; @@ -338,10 +372,16 @@ fn sumcheck_round( evals_x.iter_mut().zip(deltas.iter()).for_each(|(evx, delta)| { *evx += *delta; }); + evals_periodic_x.iter_mut().zip(deltas_periodic.iter()).for_each( + |(evx, delta)| { + *evx += *delta; + }, + ); eq_x += eq_delta; evaluator.evaluate_query( &evals_x, + &evals_periodic_x, log_up_randomness, &mut numerators, &mut denominators, @@ -371,23 +411,31 @@ fn sumcheck_round( .fold( || { ( - vec![E::ZERO; num_ml], - vec![E::ZERO; num_ml], - vec![E::ZERO; num_ml], + vec![E::ZERO; num_mls], + vec![E::ZERO; num_mls], + vec![E::ZERO; num_mls], + vec![E::ZERO; num_periodic], + vec![E::ZERO; num_periodic], + vec![E::ZERO; num_periodic], vec![E::ZERO; evaluator.max_degree()], vec![E::ZERO; evaluator.get_num_fractions()], vec![E::ZERO; evaluator.get_num_fractions()], - vec![E::ZERO; num_ml], + vec![E::ZERO; num_mls], + vec![E::ZERO; num_periodic], ) }, |( mut evals_zero, mut evals_one, mut evals_x, + mut evals_periodic_zero, + mut evals_periodic_one, + mut evals_periodic_x, mut poly_evals, mut numerators, mut denominators, mut deltas, + mut deltas_periodic, ), i| { for (j, ml) in mls.iter().enumerate() { @@ -398,9 +446,14 @@ fn sumcheck_round( let eq_at_zero = eq_ml.evaluations()[2 * i]; let eq_at_one = eq_ml.evaluations()[2 * i + 1]; + // add evaluation of periodic columns + periodic_table.fill_periodic_values_at(2 * i, &mut evals_periodic_zero); + periodic_table.fill_periodic_values_at(2 * i + 1, &mut evals_periodic_one); + // compute the evaluation at 1 evaluator.evaluate_query( &evals_one, + &evals_periodic_one, log_up_randomness, &mut numerators, &mut denominators, @@ -414,10 +467,14 @@ fn sumcheck_round( ); // compute the evaluations at 2, ..., d_max points - for i in 0..num_ml { + for i in 0..num_mls { deltas[i] = evals_one[i] - evals_zero[i]; evals_x[i] = evals_one[i]; } + for i in 0..num_periodic { + deltas_periodic[i] = evals_periodic_one[i] - evals_periodic_zero[i]; + evals_periodic_x[i] = evals_periodic_one[i]; + } let eq_delta = eq_at_one - eq_at_zero; let mut eq_x = eq_at_one; @@ -425,10 +482,16 @@ fn sumcheck_round( evals_x.iter_mut().zip(deltas.iter()).for_each(|(evx, delta)| { *evx += *delta; }); + evals_periodic_x.iter_mut().zip(deltas_periodic.iter()).for_each( + |(evx, delta)| { + *evx += *delta; + }, + ); eq_x += eq_delta; evaluator.evaluate_query( &evals_x, + &evals_periodic_x, log_up_randomness, &mut numerators, &mut denominators, @@ -442,7 +505,19 @@ fn sumcheck_round( ); } - (evals_zero, evals_one, evals_x, poly_evals, numerators, denominators, deltas) + ( + evals_zero, + evals_one, + evals_x, + evals_periodic_zero, + evals_periodic_one, + evals_periodic_x, + poly_evals, + numerators, + denominators, + deltas, + deltas_periodic, + ) }, ) .map(|(_, _, _, poly_evals, ..)| poly_evals) diff --git a/sumcheck/src/verifier/mod.rs b/sumcheck/src/verifier/mod.rs index 887598cc8..900be4c86 100644 --- a/sumcheck/src/verifier/mod.rs +++ b/sumcheck/src/verifier/mod.rs @@ -5,13 +5,13 @@ use alloc::vec::Vec; -use air::LogUpGkrEvaluator; +use air::{LogUpGkrEvaluator, PeriodicTable}; use crypto::{ElementHasher, RandomCoin}; use math::FieldElement; use crate::{ comb_func, evaluate_composition_poly, EqFunction, FinalLayerProof, FinalOpeningClaim, - RoundProof, SumCheckProof, SumCheckRoundClaim, + MultiLinearPoly, RoundProof, SumCheckProof, SumCheckRoundClaim, }; /// Verifies sum-check proofs, as part of the GKR proof, for all GKR layers except for the last one @@ -86,8 +86,14 @@ pub fn verify_sum_check_input_layer( + periodic_columns: PeriodicTable, + eval_point: &[E], +) -> Vec { + let mut evaluations = vec![]; + for col in periodic_columns.table() { + let ml = MultiLinearPoly::from_evaluations(col.to_vec()); + let num_variables = ml.num_variables(); + let point = &eval_point[..num_variables]; + + let evaluation = ml.evaluate(point); + evaluations.push(evaluation) + } + evaluations +} diff --git a/winterfell/src/tests/logup_gkr_periodic.rs b/winterfell/src/tests/logup_gkr_periodic.rs new file mode 100644 index 000000000..18ee38d00 --- /dev/null +++ b/winterfell/src/tests/logup_gkr_periodic.rs @@ -0,0 +1,357 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +use std::{marker::PhantomData, vec, vec::Vec}; + +use air::{ + Air, AirContext, Assertion, AuxRandElements, ConstraintCompositionCoefficients, FieldExtension, + LogUpGkrEvaluator, LogUpGkrOracle, ProofOptions, TraceInfo, +}; +use crypto::MerkleTree; +use math::StarkField; + +use super::super::*; +use crate::{ + crypto::{hashers::Blake3_256, DefaultRandomCoin}, + math::{fields::f64::BaseElement, ExtensionOf, FieldElement}, + matrix::ColMatrix, + DefaultConstraintEvaluator, DefaultTraceLde, Prover, StarkDomain, TracePolyTable, +}; + +#[test] +fn test_logup_gkr_periodic() { + let aux_trace_width = 1; + let trace = LogUpGkrPeriodic::new(2_usize.pow(7), aux_trace_width); + let prover = LogUpGkrPeriodicProver::new(aux_trace_width); + + let proof = prover.prove(trace).unwrap(); + + verify::< + LogUpGkrPeriodicAir, + Blake3_256, + DefaultRandomCoin>, + MerkleTree>, + >(proof, (), &AcceptableOptions::MinConjecturedSecurity(0)) + .unwrap() +} + +// LogUpGkrPeriodic +// ================================================================================================= + +#[derive(Clone, Debug)] +struct LogUpGkrPeriodic { + // dummy main trace + main_trace: ColMatrix, + info: TraceInfo, +} + +impl LogUpGkrPeriodic { + fn new(trace_len: usize, aux_segment_width: usize) -> Self { + assert!(trace_len < u32::MAX.try_into().unwrap()); + + let table: Vec = + (0..trace_len).map(|idx| BaseElement::from(idx as u32)).collect(); + let mut multiplicity = vec![BaseElement::ZERO; trace_len]; + multiplicity.iter_mut().step_by(8).for_each(|m| *m = BaseElement::from(3_u32)); + + let mut values_0: Vec = (0..trace_len).map(|_idx| BaseElement::ZERO).collect(); + + for i in 0..trace_len / 8 { + values_0[8 * i] = BaseElement::from(8 * i as u32); + } + + let mut values_1: Vec = (0..trace_len).map(|_idx| BaseElement::ZERO).collect(); + + for i in 0..trace_len / 8 { + values_1[8 * i] = BaseElement::from(8 * i as u32); + } + + let mut values_2: Vec = (0..trace_len).map(|_idx| BaseElement::ZERO).collect(); + + for i in 0..trace_len / 8 { + values_2[8 * i] = BaseElement::from(8 * i as u32); + } + + Self { + main_trace: ColMatrix::new(vec![table, multiplicity, values_0, values_1, values_2]), + info: TraceInfo::new_multi_segment(5, aux_segment_width, 0, trace_len, vec![], true), + } + } + + fn len(&self) -> usize { + self.main_trace.num_rows() + } +} + +impl Trace for LogUpGkrPeriodic { + type BaseField = BaseElement; + + fn info(&self) -> &TraceInfo { + &self.info + } + + fn main_segment(&self) -> &ColMatrix { + &self.main_trace + } + + fn read_main_frame(&self, row_idx: usize, frame: &mut EvaluationFrame) { + let next_row_idx = row_idx + 1; + self.main_trace.read_row_into(row_idx, frame.current_mut()); + self.main_trace.read_row_into(next_row_idx % self.len(), frame.next_mut()); + } +} + +// AIR +// ================================================================================================= + +struct LogUpGkrPeriodicAir { + context: AirContext, +} + +impl Air for LogUpGkrPeriodicAir { + type BaseField = BaseElement; + type PublicInputs = (); + + fn new(trace_info: TraceInfo, _pub_inputs: Self::PublicInputs, options: ProofOptions) -> Self { + Self { + context: AirContext::with_logup_gkr( + trace_info, + (), + vec![TransitionConstraintDegree::new(1)], + vec![], + 1, + 0, + options, + ), + } + } + + fn context(&self) -> &AirContext { + &self.context + } + + fn evaluate_transition>( + &self, + frame: &EvaluationFrame, + _periodic_values: &[E], + result: &mut [E], + ) { + let current = frame.current()[0]; + let next = frame.next()[0]; + + // increments by 1 + result[0] = next - current - E::ONE; + } + + fn get_assertions(&self) -> Vec> { + vec![Assertion::single(0, 0, BaseElement::ZERO)] + } + + fn evaluate_aux_transition( + &self, + _main_frame: &EvaluationFrame, + _aux_frame: &EvaluationFrame, + _periodic_values: &[F], + _aux_rand_elements: &AuxRandElements, + _result: &mut [E], + ) where + F: FieldElement, + E: FieldElement + ExtensionOf, + { + // do nothing + } + + fn get_aux_assertions>( + &self, + _aux_rand_elements: &AuxRandElements, + ) -> Vec> { + vec![] + } + + fn get_logup_gkr_evaluator( + &self, + ) -> impl LogUpGkrEvaluator + { + PeriodicLogUpGkrEval::new() + } +} + +#[derive(Clone, Default)] +pub struct PeriodicLogUpGkrEval { + oracles: Vec, + _field: PhantomData, +} + +impl PeriodicLogUpGkrEval { + pub fn new() -> Self { + let committed_0 = LogUpGkrOracle::CurrentRow(0); + let committed_1 = LogUpGkrOracle::CurrentRow(1); + let committed_2 = LogUpGkrOracle::CurrentRow(2); + let committed_3 = LogUpGkrOracle::CurrentRow(3); + let committed_4 = LogUpGkrOracle::CurrentRow(4); + + let oracles = vec![committed_0, committed_1, committed_2, committed_3, committed_4]; + + Self { oracles, _field: PhantomData } + } +} + +impl LogUpGkrEvaluator for PeriodicLogUpGkrEval { + type BaseField = BaseElement; + + type PublicInputs = (); + + fn get_oracles(&self) -> &[LogUpGkrOracle] { + &self.oracles + } + + fn get_periodic_column_values(&self) -> Vec> { + vec![vec![ + Self::BaseField::ONE, + Self::BaseField::ZERO, + Self::BaseField::ZERO, + Self::BaseField::ZERO, + Self::BaseField::ZERO, + Self::BaseField::ZERO, + Self::BaseField::ZERO, + Self::BaseField::ZERO, + ]] + } + + fn get_num_rand_values(&self) -> usize { + 1 + } + + fn get_num_fractions(&self) -> usize { + 4 + } + + fn max_degree(&self) -> usize { + 3 + } + + fn build_query(&self, frame: &EvaluationFrame, query: &mut [E]) + where + E: FieldElement, + { + query.iter_mut().zip(frame.current().iter()).for_each(|(q, f)| *q = *f) + } + + fn evaluate_query( + &self, + query: &[F], + periodic_values: &[F], + rand_values: &[E], + numerator: &mut [E], + denominator: &mut [E], + ) where + F: FieldElement, + E: FieldElement + ExtensionOf, + { + assert_eq!(numerator.len(), 4); + assert_eq!(denominator.len(), 4); + assert_eq!(query.len(), 5); + numerator[0] = E::from(query[1]); + numerator[1] = E::from(periodic_values[0]); + numerator[2] = E::from(periodic_values[0]); + numerator[3] = E::from(periodic_values[0]); + + denominator[0] = rand_values[0] - E::from(query[0]); + denominator[1] = -(rand_values[0] - E::from(query[2])); + denominator[2] = -(rand_values[0] - E::from(query[3])); + denominator[3] = -(rand_values[0] - E::from(query[4])); + } + + fn compute_claim(&self, _inputs: &Self::PublicInputs, _rand_values: &[E]) -> E + where + E: FieldElement, + { + E::ZERO + } +} + +// Prover +// ================================================================================================ + +struct LogUpGkrPeriodicProver { + aux_trace_width: usize, + options: ProofOptions, +} + +impl LogUpGkrPeriodicProver { + fn new(aux_trace_width: usize) -> Self { + Self { + aux_trace_width, + options: ProofOptions::new(1, 8, 0, FieldExtension::Quadratic, 2, 1), + } + } +} + +impl Prover for LogUpGkrPeriodicProver { + type BaseField = BaseElement; + type Air = LogUpGkrPeriodicAir; + type Trace = LogUpGkrPeriodic; + type HashFn = Blake3_256; + type VC = MerkleTree>; + type RandomCoin = DefaultRandomCoin; + type TraceLde> = + DefaultTraceLde; + type ConstraintEvaluator<'a, E: FieldElement> = + DefaultConstraintEvaluator<'a, LogUpGkrPeriodicAir, E>; + + fn get_pub_inputs(&self, _trace: &Self::Trace) -> <::Air as Air>::PublicInputs { + } + + fn options(&self) -> &ProofOptions { + &self.options + } + + fn new_trace_lde( + &self, + trace_info: &TraceInfo, + main_trace: &ColMatrix, + domain: &StarkDomain, + ) -> (Self::TraceLde, TracePolyTable) + where + E: math::FieldElement, + { + DefaultTraceLde::new(trace_info, main_trace, domain) + } + + fn new_evaluator<'a, E>( + &self, + air: &'a Self::Air, + aux_rand_elements: Option>, + composition_coefficients: ConstraintCompositionCoefficients, + ) -> Self::ConstraintEvaluator<'a, E> + where + E: math::FieldElement, + { + DefaultConstraintEvaluator::new(air, aux_rand_elements, composition_coefficients) + } + + fn build_aux_trace(&self, main_trace: &Self::Trace, _aux_rand_elements: &[E]) -> ColMatrix + where + E: FieldElement, + { + let main_trace = main_trace.main_segment(); + + let mut columns = Vec::new(); + + let rand_summed = E::from(777_u32); + for _ in 0..self.aux_trace_width { + // building a dummy auxiliary column + let column = main_trace + .get_column(0) + .iter() + .map(|row_val| rand_summed.mul_base(*row_val)) + .collect(); + + columns.push(column); + } + + ColMatrix::new(columns) + } +} diff --git a/winterfell/src/tests.rs b/winterfell/src/tests/logup_gkr_simple.rs similarity index 97% rename from winterfell/src/tests.rs rename to winterfell/src/tests/logup_gkr_simple.rs index 99e52971f..0fec04c96 100644 --- a/winterfell/src/tests.rs +++ b/winterfell/src/tests/logup_gkr_simple.rs @@ -12,7 +12,7 @@ use air::{ use crypto::MerkleTree; use math::StarkField; -use super::*; +use super::super::*; use crate::{ crypto::{hashers::Blake3_256, DefaultRandomCoin}, math::{fields::f64::BaseElement, ExtensionOf, FieldElement}, @@ -182,7 +182,7 @@ impl Air for LogUpGkrSimpleAir { #[derive(Clone, Default)] pub struct PlainLogUpGkrEval { - oracles: Vec>, + oracles: Vec, _field: PhantomData, } @@ -203,7 +203,7 @@ impl LogUpGkrEvaluator for PlainLogUpGkrEval { type PublicInputs = (); - fn get_oracles(&self) -> &[LogUpGkrOracle] { + fn get_oracles(&self) -> &[LogUpGkrOracle] { &self.oracles } @@ -219,7 +219,7 @@ impl LogUpGkrEvaluator for PlainLogUpGkrEval { 3 } - fn build_query(&self, frame: &EvaluationFrame, _periodic_values: &[E], query: &mut [E]) + fn build_query(&self, frame: &EvaluationFrame, query: &mut [E]) where E: FieldElement, { @@ -229,6 +229,7 @@ impl LogUpGkrEvaluator for PlainLogUpGkrEval { fn evaluate_query( &self, query: &[F], + _periodic_values: &[F], rand_values: &[E], numerator: &mut [E], denominator: &mut [E], diff --git a/winterfell/src/tests/mod.rs b/winterfell/src/tests/mod.rs new file mode 100644 index 000000000..51881e55e --- /dev/null +++ b/winterfell/src/tests/mod.rs @@ -0,0 +1,8 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +mod logup_gkr_simple; + +mod logup_gkr_periodic;