Skip to content

Commit

Permalink
Add support for periodic columns in LogUp-GKR (#307)
Browse files Browse the repository at this point in the history
  • Loading branch information
Al-Kindi-0 authored Sep 10, 2024
1 parent ac9561d commit b5f64cc
Show file tree
Hide file tree
Showing 15 changed files with 619 additions and 54 deletions.
6 changes: 3 additions & 3 deletions air/src/air/aux.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ pub struct GkrData<E: FieldElement> {
pub lagrange_kernel_eval_point: LagrangeKernelRandElements<E>,
pub openings_combining_randomness: Vec<E>,
pub openings: Vec<E>,
pub oracles: Vec<LogUpGkrOracle<E::BaseField>>,
pub oracles: Vec<LogUpGkrOracle>,
}

impl<E: FieldElement> GkrData<E> {
Expand All @@ -92,7 +92,7 @@ impl<E: FieldElement> GkrData<E> {
lagrange_kernel_eval_point: LagrangeKernelRandElements<E>,
openings_combining_randomness: Vec<E>,
openings: Vec<E>,
oracles: Vec<LogUpGkrOracle<E::BaseField>>,
oracles: Vec<LogUpGkrOracle>,
) -> Self {
Self {
lagrange_kernel_eval_point,
Expand All @@ -116,7 +116,7 @@ impl<E: FieldElement> GkrData<E> {
&self.openings
}

pub fn oracles(&self) -> &[LogUpGkrOracle<E::BaseField>] {
pub fn oracles(&self) -> &[LogUpGkrOracle] {
&self.oracles
}

Expand Down
90 changes: 82 additions & 8 deletions air/src/air/logup_gkr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,13 @@ pub trait LogUpGkrEvaluator: Clone + Sync {

/// Gets a list of all oracles involved in LogUp-GKR; this is intended to be used in construction of
/// MLEs.
fn get_oracles(&self) -> &[LogUpGkrOracle<Self::BaseField>];
fn get_oracles(&self) -> &[LogUpGkrOracle];

/// A vector of virtual periodic columns defined by their values in some given cycle.
/// Note that the cycle lengths must be powers of 2.
fn get_periodic_column_values(&self) -> Vec<Vec<Self::BaseField>> {
vec![]
}

/// Returns the number of random values needed to evaluate a query.
fn get_num_rand_values(&self) -> usize;
Expand All @@ -56,7 +62,7 @@ pub trait LogUpGkrEvaluator: Clone + Sync {
/// information returned from `get_oracles()`. However, this implementation is likely to be
/// expensive compared to the hand-written implementation. However, we could provide a test
/// which verifies that `get_oracles()` and `build_query()` methods are consistent.
fn build_query<E>(&self, frame: &EvaluationFrame<E>, periodic_values: &[E], query: &mut [E])
fn build_query<E>(&self, frame: &EvaluationFrame<E>, query: &mut [E])
where
E: FieldElement<BaseField = Self::BaseField>;

Expand All @@ -70,6 +76,7 @@ pub trait LogUpGkrEvaluator: Clone + Sync {
fn evaluate_query<F, E>(
&self,
query: &[F],
periodic_values: &[F],
logup_randomness: &[E],
numerators: &mut [E],
denominators: &mut [E],
Expand Down Expand Up @@ -145,6 +152,22 @@ pub trait LogUpGkrEvaluator: Clone + Sync {
) -> SColumnConstraint<E> {
SColumnConstraint::new(gkr_data, composition_coefficient)
}

/// Returns the periodic values used in the LogUp-GKR statement, either as base field element
/// during circuit evaluation or as extension field element during the run of sum-check for
/// the input layer.
fn build_periodic_values<E>(&self) -> PeriodicTable<E>
where
E: FieldElement<BaseField = Self::BaseField>,
{
let table = self
.get_periodic_column_values()
.iter()
.map(|values| values.iter().map(|x| E::from(*x)).collect())
.collect();

PeriodicTable { table }
}
}

#[derive(Clone, Default)]
Expand Down Expand Up @@ -175,7 +198,7 @@ where

type PublicInputs = P;

fn get_oracles(&self) -> &[LogUpGkrOracle<Self::BaseField>] {
fn get_oracles(&self) -> &[LogUpGkrOracle] {
panic!("LogUpGkrEvaluator method called but LogUp-GKR is not implemented")
}

Expand All @@ -191,7 +214,7 @@ where
panic!("LogUpGkrEvaluator method called but LogUp-GKR is not implemented")
}

fn build_query<E>(&self, _frame: &EvaluationFrame<E>, _periodic_values: &[E], _query: &mut [E])
fn build_query<E>(&self, _frame: &EvaluationFrame<E>, _query: &mut [E])
where
E: FieldElement<BaseField = Self::BaseField>,
{
Expand All @@ -201,6 +224,7 @@ where
fn evaluate_query<F, E>(
&self,
_query: &[F],
_periodic_values: &[F],
_rand_values: &[E],
_numerator: &mut [E],
_denominator: &mut [E],
Expand All @@ -220,12 +244,62 @@ where
}

#[derive(Clone, Debug, PartialEq, PartialOrd, Eq, Ord)]
pub enum LogUpGkrOracle<B: StarkField> {
pub enum LogUpGkrOracle {
/// A column with a given index in the main trace segment.
CurrentRow(usize),
/// A column with a given index in the main trace segment but shifted upwards.
NextRow(usize),
/// A virtual periodic column defined by its values in a given cycle. Note that the cycle length
/// must be a power of 2.
PeriodicValue(Vec<B>),
}

// PERIODIC COLUMNS FOR LOGUP
// =================================================================================================

/// Stores the periodic columns used in a LogUp-GKR statement.
///
/// Each stored periodic column is interpreted as a multi-linear extension polynomial of the column
/// with the given periodic values. Due to the periodic nature of the values, storing, binding of
/// an argument and evaluating the said multi-linear extension can be all done linearly in the size
/// of the smallest cycle defining the periodic values. Hence we only store the values of this
/// smallest cycle. The cycle is assumed throughout to be a power of 2.
#[derive(Clone, Debug, Default, PartialEq, PartialOrd, Eq, Ord)]
pub struct PeriodicTable<E: FieldElement> {
pub table: Vec<Vec<E>>,
}

impl<E> PeriodicTable<E>
where
E: FieldElement,
{
pub fn new(table: Vec<Vec<E::BaseField>>) -> Self {
let table = table.iter().map(|col| col.iter().map(|x| E::from(*x)).collect()).collect();

Self { table }
}

pub fn num_columns(&self) -> usize {
self.table.len()
}

pub fn table(&self) -> &[Vec<E>] {
&self.table
}

pub fn fill_periodic_values_at(&self, row: usize, values: &mut [E]) {
self.table
.iter()
.zip(values.iter_mut())
.for_each(|(col, value)| *value = col[row % col.len()])
}

pub fn bind_least_significant_variable(&mut self, round_challenge: E) {
for col in self.table.iter_mut() {
if col.len() > 1 {
let num_evals = col.len() >> 1;
for i in 0..num_evals {
col[i] = col[i << 1] + round_challenge * (col[(i << 1) + 1] - col[i << 1]);
}
col.truncate(num_evals)
}
}
}
}
2 changes: 1 addition & 1 deletion air/src/air/logup_gkr/s_column.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ impl<E: FieldElement> SColumnConstraint<E> {
.mul_base(E::BaseField::ONE / E::BaseField::from(air.trace_length() as u32));

let mut query = vec![E::ZERO; air.get_logup_gkr_evaluator().get_oracles().len()];
air.get_logup_gkr_evaluator().build_query(main_trace_frame, &[], &mut query);
air.get_logup_gkr_evaluator().build_query(main_trace_frame, &mut query);
let batched_claim_at_query = self.gkr_data.compute_batched_query::<E>(&query);
let rhs = s_cur - mean + batched_claim_at_query * l_cur;
let lhs = s_nxt;
Expand Down
2 changes: 1 addition & 1 deletion air/src/air/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ use logup_gkr::PhantomLogUpGkrEval;
pub use logup_gkr::{
LagrangeKernelBoundaryConstraint, LagrangeKernelConstraints, LagrangeKernelEvaluationFrame,
LagrangeKernelRandElements, LagrangeKernelTransitionConstraints, LogUpGkrEvaluator,
LogUpGkrOracle,
LogUpGkrOracle, PeriodicTable,
};

mod coefficients;
Expand Down
4 changes: 2 additions & 2 deletions air/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,6 @@ pub use air::{
DeepCompositionCoefficients, EvaluationFrame, GkrData,
LagrangeConstraintsCompositionCoefficients, LagrangeKernelBoundaryConstraint,
LagrangeKernelConstraints, LagrangeKernelEvaluationFrame, LagrangeKernelRandElements,
LagrangeKernelTransitionConstraints, LogUpGkrEvaluator, LogUpGkrOracle, TraceInfo,
TransitionConstraintDegree, TransitionConstraints,
LagrangeKernelTransitionConstraints, LogUpGkrEvaluator, LogUpGkrOracle, PeriodicTable,
TraceInfo, TransitionConstraintDegree, TransitionConstraints,
};
2 changes: 1 addition & 1 deletion prover/src/constraints/evaluator/logup_gkr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ where
let s_cur = aux_frame.current()[s_col_idx];
let s_nxt = aux_frame.next()[s_col_idx];

evaluator.build_query(&main_frame, &[], &mut query);
evaluator.build_query(&main_frame, &mut query);
let batched_query = self.gkr_data.compute_batched_query(&query);

let rhs = s_cur - mean + batched_query * l_cur;
Expand Down
10 changes: 7 additions & 3 deletions prover/src/logup_gkr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,21 +109,25 @@ impl<E: FieldElement> EvaluatedCircuit<E> {
log_up_randomness: &[E],
) -> CircuitLayer<E> {
let num_fractions = evaluator.get_num_fractions();
let periodic_values = evaluator.build_periodic_values();

let mut input_layer_wires =
Vec::with_capacity(main_trace.main_segment().num_rows() * num_fractions);
let mut main_frame = EvaluationFrame::new(main_trace.main_segment().num_cols());

let mut query = vec![E::BaseField::ZERO; evaluator.get_oracles().len()];
let mut periodic_values_row = vec![E::BaseField::ZERO; periodic_values.num_columns()];
let mut numerators = vec![E::ZERO; num_fractions];
let mut denominators = vec![E::ZERO; num_fractions];
for i in 0..main_trace.main_segment().num_rows() {
let wires_from_trace_row = {
main_trace.read_main_frame(i, &mut main_frame);

evaluator.build_query(&main_frame, &[], &mut query);
periodic_values.fill_periodic_values_at(i, &mut periodic_values_row);
evaluator.build_query(&main_frame, &mut query);

evaluator.evaluate_query(
&query,
&periodic_values_row,
log_up_randomness,
&mut numerators,
&mut denominators,
Expand Down Expand Up @@ -379,7 +383,7 @@ pub fn build_s_column<E: FieldElement>(
for (i, item) in lagrange_kernel_col.iter().enumerate().take(main_segment.num_rows() - 1) {
main_trace.read_main_frame(i, &mut main_frame);

evaluator.build_query(&main_frame, &[], &mut query);
evaluator.build_query(&main_frame, &mut query);
let cur_value = last_value - mean + gkr_data.compute_batched_query(&query) * *item;

result.push(cur_value);
Expand Down
20 changes: 15 additions & 5 deletions prover/src/logup_gkr/prover.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use alloc::vec::Vec;

use air::{LogUpGkrEvaluator, LogUpGkrOracle};
use air::{LogUpGkrEvaluator, LogUpGkrOracle, PeriodicTable};
use crypto::{ElementHasher, RandomCoin};
use math::FieldElement;
use sumcheck::{
Expand Down Expand Up @@ -77,9 +77,18 @@ pub fn prove_gkr<E: FieldElement>(
// build the MLEs of the relevant main trace columns
let main_trace_mls =
build_mls_from_main_trace_segment(evaluator.get_oracles(), main_trace.main_segment())?;
// build the periodic table representing periodic columns as multi-linear extensions
let periodic_table = evaluator.build_periodic_values();

let final_layer_proof =
prove_input_layer(evaluator, logup_randomness, main_trace_mls, gkr_claim, public_coin)?;
// run the GKR prover for the input layer
let final_layer_proof = prove_input_layer(
evaluator,
logup_randomness,
main_trace_mls,
periodic_table,
gkr_claim,
public_coin,
)?;

Ok(GkrCircuitProof {
circuit_outputs: CircuitOutput { numerators, denominators },
Expand All @@ -97,6 +106,7 @@ fn prove_input_layer<
evaluator: &impl LogUpGkrEvaluator<BaseField = E::BaseField>,
log_up_randomness: Vec<E>,
multi_linear_ext_polys: Vec<MultiLinearPoly<E>>,
periodic_table: PeriodicTable<E>,
claim: GkrClaim<E>,
transcript: &mut C,
) -> Result<FinalLayerProof<E>, GkrProverError> {
Expand All @@ -114,6 +124,7 @@ fn prove_input_layer<
r_batch,
log_up_randomness,
multi_linear_ext_polys,
periodic_table,
transcript,
)?;

Expand All @@ -123,7 +134,7 @@ fn prove_input_layer<
/// Builds the multi-linear extension polynomials needed to run the final sum-check of GKR for
/// LogUp-GKR.
fn build_mls_from_main_trace_segment<E: FieldElement>(
oracles: &[LogUpGkrOracle<E::BaseField>],
oracles: &[LogUpGkrOracle],
main_trace: &ColMatrix<<E as FieldElement>::BaseField>,
) -> Result<Vec<MultiLinearPoly<E>>, GkrProverError> {
let mut mls = vec![];
Expand All @@ -146,7 +157,6 @@ fn build_mls_from_main_trace_segment<E: FieldElement>(
let ml = MultiLinearPoly::from_evaluations(values);
mls.push(ml)
},
LogUpGkrOracle::PeriodicValue(_) => unimplemented!(),
};
}
Ok(mls)
Expand Down
2 changes: 1 addition & 1 deletion prover/src/trace/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ pub trait Trace: Sized {
let s_cur = aux_frame.current()[s_col_idx];
let s_nxt = aux_frame.next()[s_col_idx];

evaluator.build_query(&main_frame, &[], &mut query);
evaluator.build_query(&main_frame, &mut query);
let batched_query = gkr_data.compute_batched_query(&query);

let rhs = s_cur - mean + batched_query * l_cur;
Expand Down
Loading

0 comments on commit b5f64cc

Please sign in to comment.