From 62727146b0e53ab6fa9e41072c11c4086da50fb0 Mon Sep 17 00:00:00 2001 From: Shahar Samocha Date: Tue, 3 Sep 2024 18:53:10 +0300 Subject: [PATCH] Revert logup without is_first --- .../prover/src/constraint_framework/logup.rs | 37 ++++++++----------- crates/prover/src/examples/blake/air.rs | 7 ++++ crates/prover/src/examples/blake/round/mod.rs | 5 ++- .../src/examples/blake/scheduler/mod.rs | 6 ++- .../src/examples/blake/xor_table/gen.rs | 7 +++- .../src/examples/blake/xor_table/mod.rs | 3 +- crates/prover/src/examples/plonk/mod.rs | 26 ++++++++----- crates/prover/src/examples/poseidon/mod.rs | 23 ++++++++++-- 8 files changed, 71 insertions(+), 43 deletions(-) diff --git a/crates/prover/src/constraint_framework/logup.rs b/crates/prover/src/constraint_framework/logup.rs index a608d89b0..89982a5c5 100644 --- a/crates/prover/src/constraint_framework/logup.rs +++ b/crates/prover/src/constraint_framework/logup.rs @@ -25,23 +25,25 @@ use crate::core::ColumnVec; pub struct LogupAtRow { /// The index of the interaction used for the cumulative sum columns. pub interaction: usize, - /// A constant to subtract from each row, to make the totall sum of the last column zero. - /// In other words, claimed_sum / 2^log_size. - /// This is used to make the constraint uniform. - pub cumsum_shift: SecureField, + /// The claimed sum of all the fractions. + pub claimed_sum: SecureField, /// The evaluation of the last cumulative sum column. pub prev_col_cumsum: E::EF, cur_frac: Option>, is_finalized: bool, + /// The value of the `is_first` constant column at current row. + /// See [`super::constant_columns::gen_is_first()`]. + pub is_first: E::F, } impl LogupAtRow { - pub fn new(interaction: usize, claimed_sum: SecureField, log_size: u32) -> Self { + pub fn new(interaction: usize, claimed_sum: SecureField, is_first: E::F) -> Self { Self { interaction, - cumsum_shift: claimed_sum / BaseField::from_u32_unchecked(1 << log_size), + claimed_sum, prev_col_cumsum: E::EF::zero(), cur_frac: None, is_finalized: false, + is_first, } } @@ -64,12 +66,11 @@ impl LogupAtRow { let [cur_cumsum, prev_row_cumsum] = eval.next_extension_interaction_mask(self.interaction, [0, -1]); - let diff = cur_cumsum - prev_row_cumsum - self.prev_col_cumsum; - // Instead of checking diff = num / denom, check diff = num / denom - cumsum_shift. - // This makes (num / denom - cumsum_shift) have sum zero, which makes the constraint - // uniform - apply on all rows. - let fixed_diff = diff + self.cumsum_shift; - eval.add_constraint(fixed_diff * frac.denominator - frac.numerator); + // Fix `prev_row_cumsum` by subtracting `claimed_sum` if this is the first row. + let fixed_prev_row_cumsum = prev_row_cumsum - self.is_first * self.claimed_sum; + let diff = cur_cumsum - fixed_prev_row_cumsum - self.prev_col_cumsum; + + eval.add_constraint(diff * frac.denominator - frac.numerator); self.is_finalized = true; } @@ -163,21 +164,13 @@ impl LogupTraceGenerator { SecureField, ) { // Compute claimed sum. - let mut last_col_coords = self.trace.pop().unwrap().columns; + let last_col_coords = self.trace.pop().unwrap().columns; let packed_sums: [PackedBaseField; SECURE_EXTENSION_DEGREE] = last_col_coords .each_ref() .map(|c| c.data.iter().copied().sum()); let base_sums = packed_sums.map(|s| s.pointwise_sum()); let claimed_sum = SecureField::from_m31_array(base_sums); - // Shift the last column to make the sum zero. - let cumsum_shift = claimed_sum / BaseField::from_u32_unchecked(1 << self.log_size); - last_col_coords.iter_mut().enumerate().for_each(|(i, c)| { - c.data - .iter_mut() - .for_each(|x| *x -= PackedBaseField::broadcast(cumsum_shift.to_m31_array()[i])) - }); - // Prefix sum the last column. let coord_prefix_sum = last_col_coords.map(inclusive_prefix_sum); self.trace.push(SecureColumnByCoords { @@ -259,7 +252,7 @@ mod tests { #[test] #[should_panic] fn test_logup_not_finalized_panic() { - let mut logup = LogupAtRow::::new(1, SecureField::one(), 7); + let mut logup = LogupAtRow::::new(1, SecureField::one(), BaseField::one()); logup.write_frac( &mut InfoEvaluator::default(), Fraction::new(SecureField::one(), SecureField::one()), diff --git a/crates/prover/src/examples/blake/air.rs b/crates/prover/src/examples/blake/air.rs index ca583abe3..a91e02700 100644 --- a/crates/prover/src/examples/blake/air.rs +++ b/crates/prover/src/examples/blake/air.rs @@ -8,6 +8,7 @@ use tracing::{span, Level}; use super::round::{blake_round_info, BlakeRoundComponent, BlakeRoundEval}; use super::scheduler::{BlakeSchedulerComponent, BlakeSchedulerEval}; use super::xor_table::{XorTableComponent, XorTableEval}; +use crate::constraint_framework::constant_columns::gen_is_first; use crate::constraint_framework::TraceLocationAllocator; use crate::core::air::{Component, ComponentProver}; use crate::core::backend::simd::m31::LOG_N_LANES; @@ -362,10 +363,16 @@ where span.exit(); // Constant trace. + // TODO(ShaharS): share is_first column between components when constant columns support this. let span = span!(Level::INFO, "Constant Trace").entered(); let mut tree_builder = commitment_scheme.tree_builder(); tree_builder.extend_evals( chain![ + vec![gen_is_first(log_size)], + ROUND_LOG_SPLIT + .iter() + .map(|l| gen_is_first(log_size + l)) + .collect_vec(), xor_table::generate_constant_trace::<12, 4>(), xor_table::generate_constant_trace::<9, 2>(), xor_table::generate_constant_trace::<8, 2>(), diff --git a/crates/prover/src/examples/blake/round/mod.rs b/crates/prover/src/examples/blake/round/mod.rs index cf8311339..68156cde5 100644 --- a/crates/prover/src/examples/blake/round/mod.rs +++ b/crates/prover/src/examples/blake/round/mod.rs @@ -27,12 +27,13 @@ impl FrameworkEval for BlakeRoundEval { fn max_constraint_log_degree_bound(&self) -> u32 { self.log_size + 1 } - fn evaluate(&self, eval: E) -> E { + fn evaluate(&self, mut eval: E) -> E { + let [is_first] = eval.next_interaction_mask(2, [0]); let blake_eval = constraints::BlakeRoundEval { eval, xor_lookup_elements: &self.xor_lookup_elements, round_lookup_elements: &self.round_lookup_elements, - logup: LogupAtRow::new(1, self.claimed_sum, self.log_size), + logup: LogupAtRow::new(1, self.claimed_sum, is_first), }; blake_eval.eval() } diff --git a/crates/prover/src/examples/blake/scheduler/mod.rs b/crates/prover/src/examples/blake/scheduler/mod.rs index e8a8c32f3..e43c8d88d 100644 --- a/crates/prover/src/examples/blake/scheduler/mod.rs +++ b/crates/prover/src/examples/blake/scheduler/mod.rs @@ -29,11 +29,12 @@ impl FrameworkEval for BlakeSchedulerEval { self.log_size + 1 } fn evaluate(&self, mut eval: E) -> E { + let [is_first] = eval.next_interaction_mask(2, [0]); eval_blake_scheduler_constraints( &mut eval, &self.blake_lookup_elements, &self.round_lookup_elements, - LogupAtRow::new(1, self.claimed_sum, self.log_size), + LogupAtRow::new(1, self.claimed_sum, is_first), ); eval } @@ -55,6 +56,7 @@ mod tests { use itertools::Itertools; + use crate::constraint_framework::constant_columns::gen_is_first; use crate::constraint_framework::FrameworkEval; use crate::core::poly::circle::CanonicCoset; use crate::examples::blake::round::RoundElements; @@ -86,7 +88,7 @@ mod tests { &blake_lookup_elements, ); - let trace = TreeVec::new(vec![trace, interaction_trace]); + let trace = TreeVec::new(vec![trace, interaction_trace, vec![gen_is_first(LOG_SIZE)]]); let trace_polys = trace.map_cols(|c| c.interpolate()); let component = BlakeSchedulerEval { diff --git a/crates/prover/src/examples/blake/xor_table/gen.rs b/crates/prover/src/examples/blake/xor_table/gen.rs index 195a6ca46..bf6e43ad6 100644 --- a/crates/prover/src/examples/blake/xor_table/gen.rs +++ b/crates/prover/src/examples/blake/xor_table/gen.rs @@ -4,6 +4,7 @@ use itertools::Itertools; use tracing::{span, Level}; use super::{column_bits, limb_bits, XorAccumulator, XorElements}; +use crate::constraint_framework::constant_columns::gen_is_first; use crate::constraint_framework::logup::{LogupTraceGenerator, LookupElements}; use crate::core::backend::simd::column::BaseColumn; use crate::core::backend::simd::m31::{PackedBaseField, LOG_N_LANES}; @@ -157,12 +158,14 @@ pub fn generate_constant_trace( }) .collect(); - [a_col, b_col, c_col] + let mut constant_trace = [a_col, b_col, c_col] .map(|x| { CircleEvaluation::new( CanonicCoset::new(column_bits::()).circle_domain(), x, ) }) - .to_vec() + .to_vec(); + constant_trace.insert(0, gen_is_first(column_bits::())); + constant_trace } diff --git a/crates/prover/src/examples/blake/xor_table/mod.rs b/crates/prover/src/examples/blake/xor_table/mod.rs index 877a65114..802d17282 100644 --- a/crates/prover/src/examples/blake/xor_table/mod.rs +++ b/crates/prover/src/examples/blake/xor_table/mod.rs @@ -103,10 +103,11 @@ impl FrameworkEval column_bits::() + 1 } fn evaluate(&self, mut eval: E) -> E { + let [is_first] = eval.next_interaction_mask(2, [0]); let xor_eval = constraints::XorTableEval::<'_, _, ELEM_BITS, EXPAND_BITS> { eval, lookup_elements: &self.lookup_elements, - logup: LogupAtRow::new(1, self.claimed_sum, self.log_size()), + logup: LogupAtRow::new(1, self.claimed_sum, is_first), }; xor_eval.eval() } diff --git a/crates/prover/src/examples/plonk/mod.rs b/crates/prover/src/examples/plonk/mod.rs index f2340e681..ae7a29b7c 100644 --- a/crates/prover/src/examples/plonk/mod.rs +++ b/crates/prover/src/examples/plonk/mod.rs @@ -1,7 +1,8 @@ -use itertools::{chain, Itertools}; +use itertools::Itertools; use num_traits::One; use tracing::{span, Level}; +use crate::constraint_framework::constant_columns::gen_is_first; use crate::constraint_framework::logup::{LogupAtRow, LogupTraceGenerator, LookupElements}; use crate::constraint_framework::{ assert_constraints, EvalAtRow, FrameworkComponent, FrameworkEval, TraceLocationAllocator, @@ -44,7 +45,8 @@ impl FrameworkEval for PlonkEval { } fn evaluate(&self, mut eval: E) -> E { - let mut logup = LogupAtRow::<_>::new(1, self.claimed_sum, self.log_n_rows); + let [is_first] = eval.next_interaction_mask(2, [0]); + let mut logup = LogupAtRow::<_>::new(1, self.claimed_sum, is_first); let [a_wire] = eval.next_interaction_mask(2, [0]); let [b_wire] = eval.next_interaction_mask(2, [0]); @@ -204,14 +206,18 @@ pub fn prove_fibonacci_plonk( // Constant trace. let span = span!(Level::INFO, "Constant").entered(); let mut tree_builder = commitment_scheme.tree_builder(); - let constants_trace_location = tree_builder.extend_evals(chain!([ - circuit.a_wire, - circuit.b_wire, - circuit.c_wire, - circuit.op - ] - .into_iter() - .map(|col| CircleEvaluation::new(CanonicCoset::new(log_n_rows).circle_domain(), col)))); + let is_first = gen_is_first(log_n_rows); + let mut constant_trace = [circuit.a_wire, circuit.b_wire, circuit.c_wire, circuit.op] + .into_iter() + .map(|col| { + CircleEvaluation::::new( + CanonicCoset::new(log_n_rows).circle_domain(), + col, + ) + }) + .collect_vec(); + constant_trace.insert(0, is_first); + let constants_trace_location = tree_builder.extend_evals(constant_trace); tree_builder.commit(channel); span.exit(); diff --git a/crates/prover/src/examples/poseidon/mod.rs b/crates/prover/src/examples/poseidon/mod.rs index d25cc1865..33072ec91 100644 --- a/crates/prover/src/examples/poseidon/mod.rs +++ b/crates/prover/src/examples/poseidon/mod.rs @@ -5,6 +5,7 @@ use std::ops::{Add, AddAssign, Mul, Sub}; use itertools::Itertools; use tracing::{span, Level}; +use crate::constraint_framework::constant_columns::gen_is_first; use crate::constraint_framework::logup::{LogupAtRow, LogupTraceGenerator, LookupElements}; use crate::constraint_framework::{ EvalAtRow, FrameworkComponent, FrameworkEval, TraceLocationAllocator, @@ -59,7 +60,8 @@ impl FrameworkEval for PoseidonEval { self.log_n_rows + LOG_EXPAND } fn evaluate(&self, mut eval: E) -> E { - let logup = LogupAtRow::new(1, self.claimed_sum, self.log_n_rows); + let [is_first] = eval.next_interaction_mask(2, [0]); + let logup = LogupAtRow::new(1, self.claimed_sum, is_first); eval_poseidon_constraints(&mut eval, logup, &self.lookup_elements); eval } @@ -366,6 +368,14 @@ pub fn prove_poseidon( tree_builder.commit(channel); span.exit(); + // Constant trace. + let span = span!(Level::INFO, "Constant").entered(); + let mut tree_builder = commitment_scheme.tree_builder(); + let constant_trace = vec![gen_is_first(log_n_rows)]; + tree_builder.extend_evals(constant_trace); + tree_builder.commit(channel); + span.exit(); + // Prove constraints. let component = PoseidonComponent::new( &mut TraceLocationAllocator::default(), @@ -387,8 +397,9 @@ mod tests { use itertools::Itertools; use num_traits::One; - use crate::constraint_framework::assert_constraints; + use crate::constraint_framework::constant_columns::gen_is_first; use crate::constraint_framework::logup::{LogupAtRow, LookupElements}; + use crate::constraint_framework::{assert_constraints, EvalAtRow}; use crate::core::air::Component; use crate::core::channel::Blake2sChannel; use crate::core::fields::m31::BaseField; @@ -463,13 +474,14 @@ mod tests { let (trace1, claimed_sum) = gen_interaction_trace(LOG_N_ROWS, interaction_data, &lookup_elements); - let traces = TreeVec::new(vec![trace0, trace1]); + let traces = TreeVec::new(vec![trace0, trace1, vec![gen_is_first(LOG_N_ROWS)]]); let trace_polys = traces.map(|trace| trace.into_iter().map(|c| c.interpolate()).collect_vec()); assert_constraints(&trace_polys, CanonicCoset::new(LOG_N_ROWS), |mut eval| { + let [is_first] = eval.next_interaction_mask(2, [0]); eval_poseidon_constraints( &mut eval, - LogupAtRow::new(1, claimed_sum, LOG_N_ROWS), + LogupAtRow::new(1, claimed_sum, is_first), &lookup_elements, ); }); @@ -512,6 +524,9 @@ mod tests { // Interaction columns. commitment_scheme.commit(proof.commitments[1], &sizes[1], channel); + // Constant columns. + commitment_scheme.commit(proof.commitments[2], &sizes[2], channel); + verify(&[&component], channel, commitment_scheme, proof).unwrap(); } }