From 92540c0093f6ded8d47a489ba27f0fba43327309 Mon Sep 17 00:00:00 2001 From: Shahar Samocha Date: Tue, 10 Sep 2024 16:51:37 +0300 Subject: [PATCH] Add constraint for Logup claimed cumsum --- .../prover/src/constraint_framework/logup.rs | 40 +++++++++++++++++-- crates/prover/src/examples/blake/round/mod.rs | 2 +- .../src/examples/blake/scheduler/mod.rs | 2 +- .../src/examples/blake/xor_table/mod.rs | 2 +- crates/prover/src/examples/plonk/mod.rs | 17 +++++--- crates/prover/src/examples/poseidon/mod.rs | 4 +- 6 files changed, 53 insertions(+), 14 deletions(-) diff --git a/crates/prover/src/constraint_framework/logup.rs b/crates/prover/src/constraint_framework/logup.rs index 21ce14b4a..03f55bda5 100644 --- a/crates/prover/src/constraint_framework/logup.rs +++ b/crates/prover/src/constraint_framework/logup.rs @@ -21,6 +21,10 @@ use crate::core::poly::BitReversedOrder; use crate::core::utils::{bit_reverse_index, coset_index_to_circle_domain_index}; use crate::core::ColumnVec; +/// Represents the value of the prefix sum column at some index. +/// Should be used to eliminate padded rows for the logup sum. +pub type ClaimedPrefixSum = (SecureField, usize); + /// Evaluates constraints for batched logups. /// These constraint enforce the sum of multiplicity_i / (z + sum_j alpha^j * x_j) = claimed_sum. pub struct LogupAtRow { @@ -28,6 +32,10 @@ pub struct LogupAtRow { pub interaction: usize, /// The total sum of all the fractions. pub total_sum: SecureField, + /// The claimed sum of the relevant fractions. + /// This is used for padding the component with default rows. Padding should be in bit-reverse. + /// None if the claimed_sum is the total_sum. + pub claimed_sum: Option, /// The evaluation of the last cumulative sum column. pub prev_col_cumsum: E::EF, cur_frac: Option>, @@ -37,10 +45,16 @@ pub struct LogupAtRow { pub is_first: E::F, } impl LogupAtRow { - pub fn new(interaction: usize, total_sum: SecureField, is_first: E::F) -> Self { + pub fn new( + interaction: usize, + total_sum: SecureField, + claimed_sum: Option, + is_first: E::F, + ) -> Self { Self { interaction, total_sum, + claimed_sum, prev_col_cumsum: E::EF::zero(), cur_frac: None, is_finalized: false, @@ -64,9 +78,26 @@ impl LogupAtRow { let frac = self.cur_frac.unwrap(); - let [cur_cumsum, prev_row_cumsum] = - eval.next_extension_interaction_mask(self.interaction, [0, -1]); + // TODO(ShaharS): remove `claimed_row_index` interaction value and get the shifted offset + // from the is_first column when constant columns are supported. + let (cur_cumsum, prev_row_cumsum) = match self.claimed_sum { + Some((claimed_sum, claimed_row_index)) => { + let [cur_cumsum, prev_row_cumsum, claimed_cumsum] = eval + .next_extension_interaction_mask( + self.interaction, + [0, -1, claimed_row_index as isize], + ); + // Constrain that the claimed_sum in case that it is not equal to the total_sum. + eval.add_constraint((claimed_cumsum - claimed_sum) * self.is_first); + (cur_cumsum, prev_row_cumsum) + } + None => { + let [cur_cumsum, prev_row_cumsum] = + eval.next_extension_interaction_mask(self.interaction, [0, -1]); + (cur_cumsum, prev_row_cumsum) + } + }; // Fix `prev_row_cumsum` by subtracting `total_sum` if this is the first row. let fixed_prev_row_cumsum = prev_row_cumsum - self.is_first * self.total_sum; let diff = cur_cumsum - fixed_prev_row_cumsum - self.prev_col_cumsum; @@ -277,7 +308,8 @@ mod tests { #[test] #[should_panic] fn test_logup_not_finalized_panic() { - let mut logup = LogupAtRow::::new(1, SecureField::one(), BaseField::one()); + let mut logup = + LogupAtRow::::new(1, SecureField::one(), None, BaseField::one()); logup.write_frac( &mut InfoEvaluator::default(), Fraction::new(SecureField::one(), SecureField::one()), diff --git a/crates/prover/src/examples/blake/round/mod.rs b/crates/prover/src/examples/blake/round/mod.rs index 807d85525..ec35987a4 100644 --- a/crates/prover/src/examples/blake/round/mod.rs +++ b/crates/prover/src/examples/blake/round/mod.rs @@ -33,7 +33,7 @@ impl FrameworkEval for BlakeRoundEval { eval, xor_lookup_elements: &self.xor_lookup_elements, round_lookup_elements: &self.round_lookup_elements, - logup: LogupAtRow::new(1, self.total_sum, is_first), + logup: LogupAtRow::new(1, self.total_sum, None, is_first), }; blake_eval.eval() } diff --git a/crates/prover/src/examples/blake/scheduler/mod.rs b/crates/prover/src/examples/blake/scheduler/mod.rs index e795ebfd8..8d389f09e 100644 --- a/crates/prover/src/examples/blake/scheduler/mod.rs +++ b/crates/prover/src/examples/blake/scheduler/mod.rs @@ -34,7 +34,7 @@ impl FrameworkEval for BlakeSchedulerEval { &mut eval, &self.blake_lookup_elements, &self.round_lookup_elements, - LogupAtRow::new(1, self.total_sum, is_first), + LogupAtRow::new(1, self.total_sum, None, is_first), ); eval } diff --git a/crates/prover/src/examples/blake/xor_table/mod.rs b/crates/prover/src/examples/blake/xor_table/mod.rs index 802d17282..7be8f9156 100644 --- a/crates/prover/src/examples/blake/xor_table/mod.rs +++ b/crates/prover/src/examples/blake/xor_table/mod.rs @@ -107,7 +107,7 @@ impl FrameworkEval let xor_eval = constraints::XorTableEval::<'_, _, ELEM_BITS, EXPAND_BITS> { eval, lookup_elements: &self.lookup_elements, - logup: LogupAtRow::new(1, self.claimed_sum, is_first), + logup: LogupAtRow::new(1, self.claimed_sum, None, is_first), }; xor_eval.eval() } diff --git a/crates/prover/src/examples/plonk/mod.rs b/crates/prover/src/examples/plonk/mod.rs index 8117d7f3f..6d55cd484 100644 --- a/crates/prover/src/examples/plonk/mod.rs +++ b/crates/prover/src/examples/plonk/mod.rs @@ -3,7 +3,9 @@ use num_traits::One; use tracing::{span, Level}; use crate::constraint_framework::constant_columns::gen_is_first; -use crate::constraint_framework::logup::{LogupAtRow, LogupTraceGenerator, LookupElements}; +use crate::constraint_framework::logup::{ + ClaimedPrefixSum, LogupAtRow, LogupTraceGenerator, LookupElements, +}; use crate::constraint_framework::{ assert_constraints, EvalAtRow, FrameworkComponent, FrameworkEval, TraceLocationAllocator, }; @@ -29,6 +31,7 @@ pub type PlonkComponent = FrameworkComponent; pub struct PlonkEval { pub log_n_rows: u32, pub lookup_elements: LookupElements<2>, + pub claimed_sum: ClaimedPrefixSum, pub total_sum: SecureField, pub base_trace_location: TreeSubspan, pub interaction_trace_location: TreeSubspan, @@ -46,7 +49,7 @@ impl FrameworkEval for PlonkEval { fn evaluate(&self, mut eval: E) -> E { let [is_first] = eval.next_interaction_mask(2, [0]); - let mut logup = LogupAtRow::<_>::new(1, self.total_sum, is_first); + let mut logup = LogupAtRow::<_>::new(1, self.total_sum, Some(self.claimed_sum), is_first); let [a_wire] = eval.next_interaction_mask(2, [0]); let [b_wire] = eval.next_interaction_mask(2, [0]); @@ -113,11 +116,12 @@ pub fn gen_trace( pub fn gen_interaction_trace( log_size: u32, + padding_offset: usize, circuit: &PlonkCircuitTrace, lookup_elements: &LookupElements<2>, ) -> ( ColumnVec>, - SecureField, + [SecureField; 2], ) { let _span = span!(Level::INFO, "Generate interaction trace").entered(); let mut logup_gen = LogupTraceGenerator::new(log_size); @@ -141,7 +145,7 @@ pub fn gen_interaction_trace( } col_gen.finalize_col(); - logup_gen.finalize_last() + logup_gen.finalize_at([(1 << log_size) - 1, padding_offset]) } #[allow(unused)] @@ -156,6 +160,7 @@ pub fn prove_fibonacci_plonk( for _ in 0..(1 << log_n_rows) { fib_values.push(fib_values[fib_values.len() - 1] + fib_values[fib_values.len() - 2]); } + let padding_offset = 17; let range = 0..(1 << log_n_rows); let mut circuit = PlonkCircuitTrace { mult: range.clone().map(|_| 2.into()).collect(), @@ -197,7 +202,8 @@ pub fn prove_fibonacci_plonk( // Interaction trace. let span = span!(Level::INFO, "Interaction").entered(); - let (trace, total_sum) = gen_interaction_trace(log_n_rows, &circuit, &lookup_elements); + let (trace, [total_sum, claimed_sum]) = + gen_interaction_trace(log_n_rows, padding_offset, &circuit, &lookup_elements); let mut tree_builder = commitment_scheme.tree_builder(); let interaction_trace_location = tree_builder.extend_evals(trace); tree_builder.commit(channel); @@ -227,6 +233,7 @@ pub fn prove_fibonacci_plonk( PlonkEval { log_n_rows, lookup_elements, + claimed_sum: (claimed_sum, padding_offset), total_sum, base_trace_location, interaction_trace_location, diff --git a/crates/prover/src/examples/poseidon/mod.rs b/crates/prover/src/examples/poseidon/mod.rs index 125a94fb6..46c99dd92 100644 --- a/crates/prover/src/examples/poseidon/mod.rs +++ b/crates/prover/src/examples/poseidon/mod.rs @@ -61,7 +61,7 @@ impl FrameworkEval for PoseidonEval { } fn evaluate(&self, mut eval: E) -> E { let [is_first] = eval.next_interaction_mask(2, [0]); - let logup = LogupAtRow::new(1, self.total_sum, is_first); + let logup = LogupAtRow::new(1, self.total_sum, None, is_first); eval_poseidon_constraints(&mut eval, logup, &self.lookup_elements); eval } @@ -482,7 +482,7 @@ mod tests { let [is_first] = eval.next_interaction_mask(2, [0]); eval_poseidon_constraints( &mut eval, - LogupAtRow::new(1, total_sum, is_first), + LogupAtRow::new(1, total_sum, None, is_first), &lookup_elements, ); });