From 820a2efca805869eda1a16ef3fbf24af86cfd67e Mon Sep 17 00:00:00 2001 From: Shahar Samocha Date: Wed, 18 Sep 2024 15:09:57 +0300 Subject: [PATCH] State machine AIR --- crates/prover/src/examples/mod.rs | 1 + .../src/examples/state_machine/components.rs | 49 +++++ .../prover/src/examples/state_machine/gen.rs | 171 ++++++++++++++++++ .../prover/src/examples/state_machine/mod.rs | 162 +++++++++++++++++ 4 files changed, 383 insertions(+) create mode 100644 crates/prover/src/examples/state_machine/components.rs create mode 100644 crates/prover/src/examples/state_machine/gen.rs create mode 100644 crates/prover/src/examples/state_machine/mod.rs diff --git a/crates/prover/src/examples/mod.rs b/crates/prover/src/examples/mod.rs index 330662de9..4a3511b51 100644 --- a/crates/prover/src/examples/mod.rs +++ b/crates/prover/src/examples/mod.rs @@ -1,5 +1,6 @@ pub mod blake; pub mod plonk; pub mod poseidon; +pub mod state_machine; pub mod wide_fibonacci; pub mod xor; diff --git a/crates/prover/src/examples/state_machine/components.rs b/crates/prover/src/examples/state_machine/components.rs new file mode 100644 index 000000000..0a10ec4c7 --- /dev/null +++ b/crates/prover/src/examples/state_machine/components.rs @@ -0,0 +1,49 @@ +use num_traits::One; + +use crate::constraint_framework::logup::{LogupAtRow, LookupElements}; +use crate::constraint_framework::{EvalAtRow, FrameworkEval}; +use crate::core::fields::qm31::QM31; +use crate::core::lookups::utils::Fraction; + +const LOG_CONSTRAINT_DEGREE: u32 = 1; +pub const STATE_SIZE: usize = 2; +/// Random elements to combine the StateMachine state. +pub type StateMachineElements = LookupElements; + +/// State machine with state of size `STATE_SIZE`. +/// Transition `COORDINATE` of state increments the state by 1 at that offset. +#[derive(Clone)] +pub struct StateTransitionEval { + pub log_n_rows: u32, + pub lookup_elements: StateMachineElements, + pub total_sum: QM31, +} + +impl FrameworkEval for StateTransitionEval { + fn log_size(&self) -> u32 { + self.log_n_rows + } + fn max_constraint_log_degree_bound(&self) -> u32 { + self.log_n_rows + LOG_CONSTRAINT_DEGREE + } + fn evaluate(&self, mut eval: E) -> E { + let [is_first] = eval.next_interaction_mask(2, [0]); + let mut logup: LogupAtRow = LogupAtRow::new(1, self.total_sum, None, is_first); + + let input_state: [_; STATE_SIZE] = std::array::from_fn(|_| eval.next_trace_mask()); + let input_denom: E::EF = self.lookup_elements.combine(&input_state); + + let mut output_state = input_state; + output_state[COORDINATE] += E::F::one(); + let output_denom: E::EF = self.lookup_elements.combine(&output_state); + + // Add to the total sum (1/input_denom - 1/output_denom). + logup.write_frac( + &mut eval, + Fraction::new(output_denom - input_denom, output_denom * input_denom), + ); + + logup.finalize(&mut eval); + eval + } +} diff --git a/crates/prover/src/examples/state_machine/gen.rs b/crates/prover/src/examples/state_machine/gen.rs new file mode 100644 index 000000000..c4220dedd --- /dev/null +++ b/crates/prover/src/examples/state_machine/gen.rs @@ -0,0 +1,171 @@ +use itertools::Itertools; +use num_traits::One; + +use super::components::STATE_SIZE; +use crate::constraint_framework::logup::{LogupTraceGenerator, LookupElements}; +use crate::core::backend::simd::column::BaseColumn; +use crate::core::backend::simd::m31::{PackedM31, LOG_N_LANES, N_LANES}; +use crate::core::backend::simd::qm31::PackedQM31; +use crate::core::backend::simd::SimdBackend; +use crate::core::backend::Column; +use crate::core::fields::m31::M31; +use crate::core::fields::qm31::QM31; +use crate::core::poly::circle::{CanonicCoset, CircleEvaluation}; +use crate::core::poly::BitReversedOrder; +use crate::core::ColumnVec; + +pub type State = [M31; STATE_SIZE]; + +// Given `initial state`, generate a trace that row `i` is the initial state plus `i` in the +// `inc_index` dimension. +// E.g. [x, y] -> [x, y + 1] -> [x, y + 2] -> [x, y + 1 << log_size]. +pub fn gen_trace( + log_size: u32, + initial_state: State, + inc_index: usize, +) -> ColumnVec> { + let n_lanes = PackedM31::broadcast(M31::from_u32_unchecked(N_LANES as u32)); + let domain = CanonicCoset::new(log_size).circle_domain(); + + // Prepare the state for the first packed row. + let mut packed_state = initial_state.map(PackedM31::broadcast); + let inc = PackedM31::from_array(std::array::from_fn(|i| M31::from_u32_unchecked((i) as u32))); + packed_state[inc_index] += inc; + + let mut trace = (0..STATE_SIZE) + .map(|_| unsafe { BaseColumn::uninitialized(1 << log_size) }) + .collect_vec(); + for i in 0..(1 << (log_size - LOG_N_LANES)) { + for j in 0..STATE_SIZE { + trace[j].data[i] = packed_state[j]; + } + // Increment the state to the next packed row. + packed_state[inc_index] += n_lanes; + } + trace + .into_iter() + .map(|eval| CircleEvaluation::::new(domain, eval)) + .collect_vec() +} + +pub fn gen_interaction_trace( + log_size: u32, + initial_state: [M31; STATE_SIZE], + inc_index: usize, + lookup_elements: &LookupElements, +) -> ( + ColumnVec>, + QM31, +) { + let ones = PackedM31::broadcast(M31::one()); + let n_lanes_minus_one = PackedM31::broadcast(M31::from_u32_unchecked(N_LANES as u32)) - ones; + + // Prepare the state. + let mut packed_state = initial_state.map(PackedM31::broadcast); + let inc = PackedM31::from_array(std::array::from_fn(|i| M31::from_u32_unchecked((i) as u32))); + packed_state[inc_index] += inc; + + let mut logup_gen = LogupTraceGenerator::new(log_size); + let mut col_gen = logup_gen.new_col(); + + for vec_row in 0..(1 << (log_size - LOG_N_LANES)) { + let input_denom: PackedQM31 = lookup_elements.combine(&packed_state); + packed_state[inc_index] += ones; + let output_denom: PackedQM31 = lookup_elements.combine(&packed_state); + packed_state[inc_index] += n_lanes_minus_one; + col_gen.write_frac( + vec_row, + output_denom - input_denom, + input_denom * output_denom, + ); + } + col_gen.finalize_col(); + + let (trace, [total_sum]) = logup_gen.finalize([(1 << log_size) - 1]); + (trace, total_sum) +} + +#[cfg(test)] +mod tests { + use itertools::Itertools; + use num_traits::One; + + use crate::core::backend::Column; + use crate::core::channel::Blake2sChannel; + use crate::core::fields::m31::M31; + use crate::core::fields::qm31::QM31; + use crate::core::fields::secure_column::SECURE_EXTENSION_DEGREE; + use crate::core::fields::FieldExpOps; + use crate::examples::state_machine::components::StateMachineElements; + use crate::examples::state_machine::gen::{gen_interaction_trace, gen_trace}; + + #[test] + fn test_gen_trace() { + let log_size = 8; + let initial_state = [M31::from_u32_unchecked(17), M31::from_u32_unchecked(16)]; + let inc_index = 1; + let row = 123; + + let trace = gen_trace(log_size, initial_state, inc_index); + + assert_eq!(trace.len(), 2); + assert_eq!(trace[0].at(row), initial_state[0]); + assert_eq!( + trace[1].at(row), + initial_state[1] + M31::from_u32_unchecked(row as u32) + ); + } + + #[test] + fn test_gen_interaction_trace() { + let log_size = 8; + let inc_index = 1; + // Prepare state and next state. + let state = [M31::from_u32_unchecked(17), M31::from_u32_unchecked(12)]; + let mut next_state = state; + next_state[inc_index] += M31::one(); + + let lookup_elements = StateMachineElements::dummy(); + let comb_state: QM31 = lookup_elements.combine(&state); + let comb_next_state: QM31 = lookup_elements.combine(&next_state); + + let (trace, _) = gen_interaction_trace(log_size, state, inc_index, &lookup_elements); + let first_log_up_row = QM31::from_m31_array( + trace + .iter() + .map(|col| col.at(0)) + .collect_vec() + .try_into() + .unwrap(), + ); + + assert_eq!(trace.len(), SECURE_EXTENSION_DEGREE); // One quadradic extension column. + assert_eq!( + first_log_up_row, + comb_state.inverse() - comb_next_state.inverse() + ); + } + + #[test] + fn test_state_machine_total_sum() { + let log_n_rows = 8; + let lookup_elements = StateMachineElements::draw(&mut Blake2sChannel::default()); + let inc_index = 0; + + let initial_state = [M31::from(123), M31::from(456)]; + let initial_state_comb: QM31 = lookup_elements.combine(&initial_state); + + let mut last_state = initial_state; + last_state[inc_index] += M31::from_u32_unchecked(1 << log_n_rows); + let last_state_comb: QM31 = lookup_elements.combine(&last_state); + + let (_, total_sum) = + gen_interaction_trace(log_n_rows, initial_state, inc_index, &lookup_elements); + + // Assert total sum is `(1 / initial_state_comb) - (1 / last_state_comb)`. + assert_eq!( + total_sum * initial_state_comb * last_state_comb, + last_state_comb - initial_state_comb + ); + } +} diff --git a/crates/prover/src/examples/state_machine/mod.rs b/crates/prover/src/examples/state_machine/mod.rs new file mode 100644 index 000000000..4758592de --- /dev/null +++ b/crates/prover/src/examples/state_machine/mod.rs @@ -0,0 +1,162 @@ +pub mod components; +pub mod gen; + +use components::{StateMachineElements, StateTransitionEval}; +use gen::{gen_interaction_trace, gen_trace, State}; +use itertools::Itertools; + +use crate::constraint_framework::constant_columns::gen_is_first; +use crate::constraint_framework::{FrameworkComponent, TraceLocationAllocator}; +use crate::core::air::Component; +use crate::core::backend::simd::m31::LOG_N_LANES; +use crate::core::backend::simd::SimdBackend; +use crate::core::channel::Blake2sChannel; +use crate::core::pcs::{CommitmentSchemeProver, CommitmentSchemeVerifier, PcsConfig, TreeVec}; +use crate::core::poly::circle::{CanonicCoset, CirclePoly, PolyOps}; +use crate::core::prover::{prove, verify, StarkProof, VerificationError}; +use crate::core::vcs::blake2_merkle::{Blake2sMerkleChannel, Blake2sMerkleHasher}; + +pub type StateMachineOp0Component = FrameworkComponent>; + +#[allow(unused)] +pub fn prove_state_machine( + log_n_rows: u32, + initial_state: State, + config: PcsConfig, + channel: &mut Blake2sChannel, +) -> ( + StateMachineOp0Component, + StarkProof, + TreeVec>>, +) { + assert!(log_n_rows >= LOG_N_LANES); + + // Precompute twiddles. + let twiddles = SimdBackend::precompute_twiddles( + CanonicCoset::new(log_n_rows + config.fri_config.log_blowup_factor + 1) + .circle_domain() + .half_coset, + ); + + // Setup protocol. + let commitment_scheme = + &mut CommitmentSchemeProver::<_, Blake2sMerkleChannel>::new(config, &twiddles); + + // Trace. + let trace_op0 = gen_trace(log_n_rows, initial_state, 0); + let mut tree_builder = commitment_scheme.tree_builder(); + tree_builder.extend_evals(trace_op0); + tree_builder.commit(channel); + + // Draw lookup element. + let lookup_elements = StateMachineElements::draw(channel); + + // Interaction trace. + let (interaction_trace_op0, total_sum_op0) = + gen_interaction_trace(log_n_rows, initial_state, 0, &lookup_elements); + let mut tree_builder = commitment_scheme.tree_builder(); + tree_builder.extend_evals(interaction_trace_op0); + tree_builder.commit(channel); + + // Constant trace. + let mut tree_builder = commitment_scheme.tree_builder(); + tree_builder.extend_evals(vec![gen_is_first(log_n_rows)]); + tree_builder.commit(channel); + + let trace_polys = commitment_scheme + .trees + .as_ref() + .map(|t| t.polynomials.iter().cloned().collect_vec()); + + // Prove constraints. + let component_op0 = StateMachineOp0Component::new( + &mut TraceLocationAllocator::default(), + StateTransitionEval { + log_n_rows, + lookup_elements, + total_sum: total_sum_op0, + }, + ); + + let proof = prove(&[&component_op0], channel, commitment_scheme).unwrap(); + + (component_op0, proof, trace_polys) +} + +pub fn verify_state_machine( + config: PcsConfig, + channel: &mut Blake2sChannel, + component: StateMachineOp0Component, + proof: StarkProof, +) -> Result<(), VerificationError> { + let commitment_scheme = &mut CommitmentSchemeVerifier::::new(config); + + // Decommit. + // Retrieve the expected column sizes in each commitment interaction, from the AIR. + let sizes = component.trace_log_degree_bounds(); + // Trace columns. + commitment_scheme.commit(proof.commitments[0], &sizes[0], channel); + // Interaction columns. + commitment_scheme.commit(proof.commitments[1], &sizes[1], channel); + // Constant columns. + commitment_scheme.commit(proof.commitments[2], &sizes[2], channel); + + verify(&[&component], channel, commitment_scheme, proof) +} + +#[cfg(test)] +mod tests { + use num_traits::Zero; + + use super::components::STATE_SIZE; + use super::{prove_state_machine, verify_state_machine}; + use crate::constraint_framework::{assert_constraints, FrameworkEval}; + use crate::core::channel::Blake2sChannel; + use crate::core::fields::m31::M31; + use crate::core::fields::qm31::QM31; + use crate::core::pcs::PcsConfig; + use crate::core::poly::circle::CanonicCoset; + + #[test] + fn test_state_machine_constraints() { + let log_n_rows = 8; + let config = PcsConfig::default(); + + // Initial and last state. + let initial_state = [M31::zero(); STATE_SIZE]; + let last_state = [M31::from_u32_unchecked(1 << log_n_rows), M31::zero()]; + + // Setup protocol. + let channel = &mut Blake2sChannel::default(); + let (component, _, trace_polys) = + prove_state_machine(log_n_rows, initial_state, config, channel); + + let interaction_elements = component.lookup_elements.clone(); + let initial_state_comb: QM31 = interaction_elements.combine(&initial_state); + let last_state_comb: QM31 = interaction_elements.combine(&last_state); + + // Assert total sum is `(1 / initial_state_comb) - (1 / last_state_comb)`. + assert_eq!( + component.total_sum * initial_state_comb * last_state_comb, + last_state_comb - initial_state_comb + ); + + // Assert constraints. + assert_constraints(&trace_polys, CanonicCoset::new(log_n_rows), |eval| { + component.evaluate(eval); + }); + } + + #[test] + fn test_state_machine_prove() { + let log_n_rows = 8; + let config = PcsConfig::default(); + let initial_state = [M31::zero(); STATE_SIZE]; + let prover_channel = &mut Blake2sChannel::default(); + let (component_op0, proof, _) = + prove_state_machine(log_n_rows, initial_state, config, prover_channel); + + let verifier_channel = &mut Blake2sChannel::default(); + verify_state_machine(config, verifier_channel, component_op0, proof).unwrap(); + } +}