Skip to content

Commit

Permalink
State machine AIR
Browse files Browse the repository at this point in the history
  • Loading branch information
shaharsamocha7 committed Sep 24, 2024
1 parent 257a05a commit 6911522
Show file tree
Hide file tree
Showing 4 changed files with 382 additions and 0 deletions.
1 change: 1 addition & 0 deletions crates/prover/src/examples/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
pub mod blake;
pub mod plonk;
pub mod poseidon;
pub mod state_machine;
pub mod wide_fibonacci;
pub mod xor;
49 changes: 49 additions & 0 deletions crates/prover/src/examples/state_machine/components.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
use num_traits::One;

use crate::constraint_framework::logup::{LogupAtRow, LookupElements};
use crate::constraint_framework::{EvalAtRow, FrameworkEval};
use crate::core::fields::qm31::QM31;
use crate::core::lookups::utils::Fraction;

const LOG_CONSTRAINT_DEGREE: u32 = 1;
pub const STATE_SIZE: usize = 2;
/// Random elements to combine the StateMachine state.
pub type StateMachineElements = LookupElements<STATE_SIZE>;

/// State machine with state of size `STATE_SIZE`.
/// Transition `COORDINATE` of state increments the state by 1 at that offset.
#[derive(Clone)]
pub struct StateTransitionEval<const COORDINATE: usize> {
pub log_n_rows: u32,
pub lookup_elements: StateMachineElements,
pub total_sum: QM31,
}

impl<const COORDINATE: usize> FrameworkEval for StateTransitionEval<COORDINATE> {
fn log_size(&self) -> u32 {
self.log_n_rows
}
fn max_constraint_log_degree_bound(&self) -> u32 {
self.log_n_rows + LOG_CONSTRAINT_DEGREE
}
fn evaluate<E: EvalAtRow>(&self, mut eval: E) -> E {
let [is_first] = eval.next_interaction_mask(2, [0]);
let mut logup: LogupAtRow<E> = LogupAtRow::new(1, self.total_sum, None, is_first);

let input_state: [_; STATE_SIZE] = std::array::from_fn(|_| eval.next_trace_mask());
let input_denom: E::EF = self.lookup_elements.combine(&input_state);

let mut output_state = input_state;
output_state[COORDINATE] += E::F::one();
let output_denom: E::EF = self.lookup_elements.combine(&output_state);

// Add to the total sum (1/input_denom - 1/output_denom).
logup.write_frac(
&mut eval,
Fraction::new(output_denom - input_denom, output_denom * input_denom),
);

logup.finalize(&mut eval);
eval
}
}
170 changes: 170 additions & 0 deletions crates/prover/src/examples/state_machine/gen.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
use itertools::Itertools;
use num_traits::One;

use super::components::STATE_SIZE;
use crate::constraint_framework::logup::{LogupTraceGenerator, LookupElements};
use crate::core::backend::simd::column::BaseColumn;
use crate::core::backend::simd::m31::{PackedM31, LOG_N_LANES, N_LANES};
use crate::core::backend::simd::qm31::PackedQM31;
use crate::core::backend::simd::SimdBackend;
use crate::core::backend::Column;
use crate::core::fields::m31::M31;
use crate::core::fields::qm31::QM31;
use crate::core::poly::circle::{CanonicCoset, CircleEvaluation};
use crate::core::poly::BitReversedOrder;
use crate::core::ColumnVec;

pub type State = [M31; STATE_SIZE];

// Given `initial state`, generate a trace that row `i` is the initial state plus `i` in the
// `inc_index` dimension.
// E.g. [x, y] -> [x, y + 1] -> [x, y + 2] -> [x, y + 1 << log_size].
pub fn gen_trace(
log_size: u32,
initial_state: State,
inc_index: usize,
) -> ColumnVec<CircleEvaluation<SimdBackend, M31, BitReversedOrder>> {
let n_lanes = PackedM31::broadcast(M31::from_u32_unchecked(N_LANES as u32));
let domain = CanonicCoset::new(log_size).circle_domain();

// Prepare the state for the first packed row.
let mut packed_state = initial_state.map(PackedM31::broadcast);
let inc = PackedM31::from_array(std::array::from_fn(|i| M31::from_u32_unchecked((i) as u32)));
packed_state[inc_index] += inc;

let mut trace = (0..STATE_SIZE)
.map(|_| unsafe { BaseColumn::uninitialized(1 << log_size) })
.collect_vec();
for i in 0..(1 << (log_size - LOG_N_LANES)) {
for j in 0..STATE_SIZE {
trace[j].data[i] = packed_state[j];
}
// Increment the state to the next packed row.
packed_state[inc_index] += n_lanes;
}
trace
.into_iter()
.map(|eval| CircleEvaluation::<SimdBackend, _, BitReversedOrder>::new(domain, eval))
.collect_vec()
}

pub fn gen_interaction_trace(
log_size: u32,
initial_state: [M31; STATE_SIZE],
inc_index: usize,
lookup_elements: &LookupElements<STATE_SIZE>,
) -> (
ColumnVec<CircleEvaluation<SimdBackend, M31, BitReversedOrder>>,
QM31,
) {
let ones = PackedM31::broadcast(M31::one());
let n_lanes_minus_one = PackedM31::broadcast(M31::from_u32_unchecked(N_LANES as u32)) - ones;

// Prepare the state.
let mut packed_state = initial_state.map(PackedM31::broadcast);
let inc = PackedM31::from_array(std::array::from_fn(|i| M31::from_u32_unchecked((i) as u32)));
packed_state[inc_index] += inc;

let mut logup_gen = LogupTraceGenerator::new(log_size);
let mut col_gen = logup_gen.new_col();

for vec_row in 0..(1 << (log_size - LOG_N_LANES)) {
let input_denom: PackedQM31 = lookup_elements.combine(&packed_state);
packed_state[inc_index] += ones;
let output_denom: PackedQM31 = lookup_elements.combine(&packed_state);
packed_state[inc_index] += n_lanes_minus_one;
col_gen.write_frac(
vec_row,
output_denom - input_denom,
input_denom * output_denom,
);
}
col_gen.finalize_col();

logup_gen.finalize_last()
}

#[cfg(test)]
mod tests {
use itertools::Itertools;
use num_traits::One;

use crate::core::backend::Column;
use crate::core::channel::Blake2sChannel;
use crate::core::fields::m31::M31;
use crate::core::fields::qm31::QM31;
use crate::core::fields::secure_column::SECURE_EXTENSION_DEGREE;
use crate::core::fields::FieldExpOps;
use crate::examples::state_machine::components::StateMachineElements;
use crate::examples::state_machine::gen::{gen_interaction_trace, gen_trace};

#[test]
fn test_gen_trace() {
let log_size = 8;
let initial_state = [M31::from_u32_unchecked(17), M31::from_u32_unchecked(16)];
let inc_index = 1;
let row = 123;

let trace = gen_trace(log_size, initial_state, inc_index);

assert_eq!(trace.len(), 2);
assert_eq!(trace[0].at(row), initial_state[0]);
assert_eq!(
trace[1].at(row),
initial_state[1] + M31::from_u32_unchecked(row as u32)
);
}

#[test]
fn test_gen_interaction_trace() {
let log_size = 8;
let inc_index = 1;
// Prepare state and next state.
let state = [M31::from_u32_unchecked(17), M31::from_u32_unchecked(12)];
let mut next_state = state;
next_state[inc_index] += M31::one();

let lookup_elements = StateMachineElements::dummy();
let comb_state: QM31 = lookup_elements.combine(&state);
let comb_next_state: QM31 = lookup_elements.combine(&next_state);

let (trace, _) = gen_interaction_trace(log_size, state, inc_index, &lookup_elements);
let first_log_up_row = QM31::from_m31_array(
trace
.iter()
.map(|col| col.at(0))
.collect_vec()
.try_into()
.unwrap(),
);

assert_eq!(trace.len(), SECURE_EXTENSION_DEGREE); // One quadradic extension column.
assert_eq!(
first_log_up_row,
comb_state.inverse() - comb_next_state.inverse()
);
}

#[test]
fn test_state_machine_total_sum() {
let log_n_rows = 8;
let lookup_elements = StateMachineElements::draw(&mut Blake2sChannel::default());
let inc_index = 0;

let initial_state = [M31::from(123), M31::from(456)];
let initial_state_comb: QM31 = lookup_elements.combine(&initial_state);

let mut last_state = initial_state;
last_state[inc_index] += M31::from_u32_unchecked(1 << log_n_rows);
let last_state_comb: QM31 = lookup_elements.combine(&last_state);

let (_, total_sum) =
gen_interaction_trace(log_n_rows, initial_state, inc_index, &lookup_elements);

// Assert total sum is `(1 / initial_state_comb) - (1 / last_state_comb)`.
assert_eq!(
total_sum * initial_state_comb * last_state_comb,
last_state_comb - initial_state_comb
);
}
}
162 changes: 162 additions & 0 deletions crates/prover/src/examples/state_machine/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
pub mod components;
pub mod gen;

use components::{StateMachineElements, StateTransitionEval};
use gen::{gen_interaction_trace, gen_trace, State};
use itertools::Itertools;

use crate::constraint_framework::constant_columns::gen_is_first;
use crate::constraint_framework::{FrameworkComponent, TraceLocationAllocator};
use crate::core::air::Component;
use crate::core::backend::simd::m31::LOG_N_LANES;
use crate::core::backend::simd::SimdBackend;
use crate::core::channel::Blake2sChannel;
use crate::core::pcs::{CommitmentSchemeProver, CommitmentSchemeVerifier, PcsConfig, TreeVec};
use crate::core::poly::circle::{CanonicCoset, CirclePoly, PolyOps};
use crate::core::prover::{prove, verify, StarkProof, VerificationError};
use crate::core::vcs::blake2_merkle::{Blake2sMerkleChannel, Blake2sMerkleHasher};

pub type StateMachineOp0Component = FrameworkComponent<StateTransitionEval<0>>;

#[allow(unused)]
pub fn prove_state_machine(
log_n_rows: u32,
initial_state: State,
config: PcsConfig,
channel: &mut Blake2sChannel,
) -> (
StateMachineOp0Component,
StarkProof<Blake2sMerkleHasher>,
TreeVec<Vec<CirclePoly<SimdBackend>>>,
) {
assert!(log_n_rows >= LOG_N_LANES);

// Precompute twiddles.
let twiddles = SimdBackend::precompute_twiddles(
CanonicCoset::new(log_n_rows + config.fri_config.log_blowup_factor + 1)
.circle_domain()
.half_coset,
);

// Setup protocol.
let commitment_scheme =
&mut CommitmentSchemeProver::<_, Blake2sMerkleChannel>::new(config, &twiddles);

// Trace.
let trace_op0 = gen_trace(log_n_rows, initial_state, 0);
let mut tree_builder = commitment_scheme.tree_builder();
tree_builder.extend_evals(trace_op0);
tree_builder.commit(channel);

// Draw lookup element.
let lookup_elements = StateMachineElements::draw(channel);

// Interaction trace.
let (interaction_trace_op0, total_sum_op0) =
gen_interaction_trace(log_n_rows, initial_state, 0, &lookup_elements);
let mut tree_builder = commitment_scheme.tree_builder();
tree_builder.extend_evals(interaction_trace_op0);
tree_builder.commit(channel);

// Constant trace.
let mut tree_builder = commitment_scheme.tree_builder();
tree_builder.extend_evals(vec![gen_is_first(log_n_rows)]);
tree_builder.commit(channel);

let trace_polys = commitment_scheme
.trees
.as_ref()
.map(|t| t.polynomials.iter().cloned().collect_vec());

// Prove constraints.
let component_op0 = StateMachineOp0Component::new(
&mut TraceLocationAllocator::default(),
StateTransitionEval {
log_n_rows,
lookup_elements,
total_sum: total_sum_op0,
},
);

let proof = prove(&[&component_op0], channel, commitment_scheme).unwrap();

(component_op0, proof, trace_polys)
}

pub fn verify_state_machine(
config: PcsConfig,
channel: &mut Blake2sChannel,
component: StateMachineOp0Component,
proof: StarkProof<Blake2sMerkleHasher>,
) -> Result<(), VerificationError> {
let commitment_scheme = &mut CommitmentSchemeVerifier::<Blake2sMerkleChannel>::new(config);

// Decommit.
// Retrieve the expected column sizes in each commitment interaction, from the AIR.
let sizes = component.trace_log_degree_bounds();
// Trace columns.
commitment_scheme.commit(proof.commitments[0], &sizes[0], channel);
// Interaction columns.
commitment_scheme.commit(proof.commitments[1], &sizes[1], channel);
// Constant columns.
commitment_scheme.commit(proof.commitments[2], &sizes[2], channel);

verify(&[&component], channel, commitment_scheme, proof)
}

#[cfg(test)]
mod tests {
use num_traits::Zero;

use super::components::STATE_SIZE;
use super::{prove_state_machine, verify_state_machine};
use crate::constraint_framework::{assert_constraints, FrameworkEval};
use crate::core::channel::Blake2sChannel;
use crate::core::fields::m31::M31;
use crate::core::fields::qm31::QM31;
use crate::core::pcs::PcsConfig;
use crate::core::poly::circle::CanonicCoset;

#[test]
fn test_state_machine_constraints() {
let log_n_rows = 8;
let config = PcsConfig::default();

// Initial and last state.
let initial_state = [M31::zero(); STATE_SIZE];
let last_state = [M31::from_u32_unchecked(1 << log_n_rows), M31::zero()];

// Setup protocol.
let channel = &mut Blake2sChannel::default();
let (component, _, trace_polys) =
prove_state_machine(log_n_rows, initial_state, config, channel);

let interaction_elements = component.lookup_elements.clone();
let initial_state_comb: QM31 = interaction_elements.combine(&initial_state);
let last_state_comb: QM31 = interaction_elements.combine(&last_state);

// Assert total sum is `(1 / initial_state_comb) - (1 / last_state_comb)`.
assert_eq!(
component.total_sum * initial_state_comb * last_state_comb,
last_state_comb - initial_state_comb
);

// Assert constraints.
assert_constraints(&trace_polys, CanonicCoset::new(log_n_rows), |eval| {
component.evaluate(eval);
});
}

#[test]
fn test_state_machine_prove() {
let log_n_rows = 8;
let config = PcsConfig::default();
let initial_state = [M31::zero(); STATE_SIZE];
let prover_channel = &mut Blake2sChannel::default();
let (component_op0, proof, _) =
prove_state_machine(log_n_rows, initial_state, config, prover_channel);

let verifier_channel = &mut Blake2sChannel::default();
verify_state_machine(config, verifier_channel, component_op0, proof).unwrap();
}
}

0 comments on commit 6911522

Please sign in to comment.