Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

State machine AIR #841

Merged
merged 1 commit into from
Sep 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions crates/prover/src/examples/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
pub mod blake;
pub mod plonk;
pub mod poseidon;
pub mod state_machine;
pub mod wide_fibonacci;
pub mod xor;
49 changes: 49 additions & 0 deletions crates/prover/src/examples/state_machine/components.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
use num_traits::One;

use crate::constraint_framework::logup::{LogupAtRow, LookupElements};
use crate::constraint_framework::{EvalAtRow, FrameworkEval};
use crate::core::fields::qm31::QM31;
use crate::core::lookups::utils::Fraction;

const LOG_CONSTRAINT_DEGREE: u32 = 1;
pub const STATE_SIZE: usize = 2;
/// Random elements to combine the StateMachine state.
pub type StateMachineElements = LookupElements<STATE_SIZE>;

/// State machine with state of size `STATE_SIZE`.
/// Transition `COORDINATE` of state increments the state by 1 at that offset.
#[derive(Clone)]
pub struct StateTransitionEval<const COORDINATE: usize> {
pub log_n_rows: u32,
pub lookup_elements: StateMachineElements,
pub total_sum: QM31,
}

impl<const COORDINATE: usize> FrameworkEval for StateTransitionEval<COORDINATE> {
fn log_size(&self) -> u32 {
self.log_n_rows
}
fn max_constraint_log_degree_bound(&self) -> u32 {
self.log_n_rows + LOG_CONSTRAINT_DEGREE
}
fn evaluate<E: EvalAtRow>(&self, mut eval: E) -> E {
let [is_first] = eval.next_interaction_mask(2, [0]);
let mut logup: LogupAtRow<E> = LogupAtRow::new(1, self.total_sum, None, is_first);

let input_state: [_; STATE_SIZE] = std::array::from_fn(|_| eval.next_trace_mask());
let input_denom: E::EF = self.lookup_elements.combine(&input_state);

let mut output_state = input_state;
output_state[COORDINATE] += E::F::one();
let output_denom: E::EF = self.lookup_elements.combine(&output_state);

logup.write_frac(
&mut eval,
Fraction::new(E::EF::one(), input_denom)
+ Fraction::new(-E::EF::one(), output_denom.clone()),
);

logup.finalize(&mut eval);
eval
}
}
135 changes: 135 additions & 0 deletions crates/prover/src/examples/state_machine/gen.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
use itertools::Itertools;
use num_traits::{One, Zero};

use super::components::STATE_SIZE;
use crate::constraint_framework::logup::{LogupTraceGenerator, LookupElements};
use crate::core::backend::simd::column::BaseColumn;
use crate::core::backend::simd::m31::{PackedM31, LOG_N_LANES};
use crate::core::backend::simd::qm31::PackedQM31;
use crate::core::backend::simd::SimdBackend;
use crate::core::fields::m31::M31;
use crate::core::fields::qm31::QM31;
use crate::core::poly::circle::{CanonicCoset, CircleEvaluation};
use crate::core::poly::BitReversedOrder;
use crate::core::ColumnVec;

pub type State = [M31; STATE_SIZE];

// Given `initial state`, generate a trace that row `i` is the initial state plus `i` in the
// `inc_index` dimension.
// E.g. [x, y] -> [x, y + 1] -> [x, y + 2] -> [x, y + 1 << log_size].
pub fn gen_trace(
log_size: u32,
initial_state: State,
inc_index: usize,
) -> ColumnVec<CircleEvaluation<SimdBackend, M31, BitReversedOrder>> {
let domain = CanonicCoset::new(log_size).circle_domain();
let mut trace = (0..STATE_SIZE)
.map(|_| vec![M31::zero(); 1 << log_size])
.collect_vec();

let mut curr_state = initial_state;
for i in 0..1 << log_size {
for j in 0..STATE_SIZE {
trace[j][i] = curr_state[j];
}
// Increment the state to the next state row.
curr_state[inc_index] += M31::one();
}

trace
.into_iter()
.map(|col| {
CircleEvaluation::<SimdBackend, _, BitReversedOrder>::new(
domain,
BaseColumn::from_iter(col),
)
})
.collect_vec()
}

pub fn gen_interaction_trace(
log_size: u32,
trace: &ColumnVec<CircleEvaluation<SimdBackend, M31, BitReversedOrder>>,
inc_index: usize,
lookup_elements: &LookupElements<STATE_SIZE>,
) -> (
ColumnVec<CircleEvaluation<SimdBackend, M31, BitReversedOrder>>,
QM31,
) {
let ones = PackedM31::broadcast(M31::one());
let mut logup_gen = LogupTraceGenerator::new(log_size);
let mut col_gen = logup_gen.new_col();

for vec_row in 0..(1 << (log_size - LOG_N_LANES)) {
let mut packed_state: [PackedM31; STATE_SIZE] = trace
.iter()
.map(|col| col.data[vec_row])
.collect_vec()
.try_into()
.unwrap();
let input_denom: PackedQM31 = lookup_elements.combine(&packed_state);
packed_state[inc_index] += ones;
let output_denom: PackedQM31 = lookup_elements.combine(&packed_state);
col_gen.write_frac(
vec_row,
output_denom - input_denom,
input_denom * output_denom,
);
}
col_gen.finalize_col();

logup_gen.finalize_last()
}

#[cfg(test)]
mod tests {
use crate::core::backend::Column;
use crate::core::fields::m31::M31;
use crate::core::fields::qm31::QM31;
use crate::core::fields::secure_column::SECURE_EXTENSION_DEGREE;
use crate::core::fields::FieldExpOps;
use crate::examples::state_machine::components::StateMachineElements;
use crate::examples::state_machine::gen::{gen_interaction_trace, gen_trace};

#[test]
fn test_gen_trace() {
let log_size = 8;
let initial_state = [M31::from_u32_unchecked(17), M31::from_u32_unchecked(16)];
let inc_index = 1;
let row = 123;

let trace = gen_trace(log_size, initial_state, inc_index);

assert_eq!(trace.len(), 2);
assert_eq!(trace[0].at(row), initial_state[0]);
assert_eq!(
trace[1].at(row),
initial_state[1] + M31::from_u32_unchecked(row as u32)
);
}

#[test]
fn test_gen_interaction_trace() {
let log_size = 8;
let inc_index = 1;
// Prepare the first and the last states.
let first_state = [M31::from_u32_unchecked(17), M31::from_u32_unchecked(12)];
let mut last_state = first_state;
last_state[inc_index] += M31::from_u32_unchecked(1 << log_size);

let trace = gen_trace(log_size, first_state, inc_index);
let lookup_elements = StateMachineElements::dummy();
let first_state_comb: QM31 = lookup_elements.combine(&first_state);
let last_state_comb: QM31 = lookup_elements.combine(&last_state);

let (interaction_trace, total_sum) =
gen_interaction_trace(log_size, &trace, inc_index, &lookup_elements);

assert_eq!(interaction_trace.len(), SECURE_EXTENSION_DEGREE); // One extension column.
assert_eq!(
total_sum,
first_state_comb.inverse() - last_state_comb.inverse()
);
}
}
162 changes: 162 additions & 0 deletions crates/prover/src/examples/state_machine/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
pub mod components;
pub mod gen;

use components::{StateMachineElements, StateTransitionEval};
use gen::{gen_interaction_trace, gen_trace, State};
use itertools::Itertools;

use crate::constraint_framework::constant_columns::gen_is_first;
use crate::constraint_framework::{FrameworkComponent, TraceLocationAllocator};
use crate::core::air::Component;
use crate::core::backend::simd::m31::LOG_N_LANES;
use crate::core::backend::simd::SimdBackend;
use crate::core::channel::Blake2sChannel;
use crate::core::pcs::{CommitmentSchemeProver, CommitmentSchemeVerifier, PcsConfig, TreeVec};
use crate::core::poly::circle::{CanonicCoset, CirclePoly, PolyOps};
use crate::core::prover::{prove, verify, StarkProof, VerificationError};
use crate::core::vcs::blake2_merkle::{Blake2sMerkleChannel, Blake2sMerkleHasher};

pub type StateMachineOp0Component = FrameworkComponent<StateTransitionEval<0>>;

#[allow(unused)]
pub fn prove_state_machine(
log_n_rows: u32,
initial_state: State,
config: PcsConfig,
channel: &mut Blake2sChannel,
) -> (
StateMachineOp0Component,
StarkProof<Blake2sMerkleHasher>,
TreeVec<Vec<CirclePoly<SimdBackend>>>,
) {
assert!(log_n_rows >= LOG_N_LANES);

// Precompute twiddles.
let twiddles = SimdBackend::precompute_twiddles(
CanonicCoset::new(log_n_rows + config.fri_config.log_blowup_factor + 1)
.circle_domain()
.half_coset,
);

// Setup protocol.
let commitment_scheme =
&mut CommitmentSchemeProver::<_, Blake2sMerkleChannel>::new(config, &twiddles);

// Trace.
let trace_op0 = gen_trace(log_n_rows, initial_state, 0);
let mut tree_builder = commitment_scheme.tree_builder();
tree_builder.extend_evals(trace_op0.clone());
tree_builder.commit(channel);

// Draw lookup element.
let lookup_elements = StateMachineElements::draw(channel);

// Interaction trace.
let (interaction_trace_op0, total_sum_op0) =
gen_interaction_trace(log_n_rows, &trace_op0, 0, &lookup_elements);
let mut tree_builder = commitment_scheme.tree_builder();
tree_builder.extend_evals(interaction_trace_op0);
tree_builder.commit(channel);

// Constant trace.
let mut tree_builder = commitment_scheme.tree_builder();
tree_builder.extend_evals(vec![gen_is_first(log_n_rows)]);
tree_builder.commit(channel);

let trace_polys = commitment_scheme
.trees
.as_ref()
.map(|t| t.polynomials.iter().cloned().collect_vec());

// Prove constraints.
let component_op0 = StateMachineOp0Component::new(
&mut TraceLocationAllocator::default(),
StateTransitionEval {
log_n_rows,
lookup_elements,
total_sum: total_sum_op0,
},
);

let proof = prove(&[&component_op0], channel, commitment_scheme).unwrap();

(component_op0, proof, trace_polys)
}

pub fn verify_state_machine(
config: PcsConfig,
channel: &mut Blake2sChannel,
component: StateMachineOp0Component,
proof: StarkProof<Blake2sMerkleHasher>,
) -> Result<(), VerificationError> {
let commitment_scheme = &mut CommitmentSchemeVerifier::<Blake2sMerkleChannel>::new(config);

// Decommit.
// Retrieve the expected column sizes in each commitment interaction, from the AIR.
let sizes = component.trace_log_degree_bounds();
// Trace columns.
commitment_scheme.commit(proof.commitments[0], &sizes[0], channel);
// Interaction columns.
commitment_scheme.commit(proof.commitments[1], &sizes[1], channel);
// Constant columns.
commitment_scheme.commit(proof.commitments[2], &sizes[2], channel);

verify(&[&component], channel, commitment_scheme, proof)
}

#[cfg(test)]
mod tests {
use num_traits::Zero;

use super::components::STATE_SIZE;
use super::{prove_state_machine, verify_state_machine};
use crate::constraint_framework::{assert_constraints, FrameworkEval};
use crate::core::channel::Blake2sChannel;
use crate::core::fields::m31::M31;
use crate::core::fields::qm31::QM31;
use crate::core::pcs::PcsConfig;
use crate::core::poly::circle::CanonicCoset;

#[test]
fn test_state_machine_constraints() {
let log_n_rows = 8;
let config = PcsConfig::default();

// Initial and last state.
let initial_state = [M31::zero(); STATE_SIZE];
let last_state = [M31::from_u32_unchecked(1 << log_n_rows), M31::zero()];

// Setup protocol.
let channel = &mut Blake2sChannel::default();
let (component, _, trace_polys) =
prove_state_machine(log_n_rows, initial_state, config, channel);

let interaction_elements = component.lookup_elements.clone();
let initial_state_comb: QM31 = interaction_elements.combine(&initial_state);
let last_state_comb: QM31 = interaction_elements.combine(&last_state);

// Assert total sum is `(1 / initial_state_comb) - (1 / last_state_comb)`.
assert_eq!(
component.total_sum * initial_state_comb * last_state_comb,
last_state_comb - initial_state_comb
);

// Assert constraints.
assert_constraints(&trace_polys, CanonicCoset::new(log_n_rows), |eval| {
component.evaluate(eval);
});
}

#[test]
fn test_state_machine_prove() {
let log_n_rows = 8;
let config = PcsConfig::default();
let initial_state = [M31::zero(); STATE_SIZE];
let prover_channel = &mut Blake2sChannel::default();
let (component_op0, proof, _) =
prove_state_machine(log_n_rows, initial_state, config, prover_channel);

let verifier_channel = &mut Blake2sChannel::default();
verify_state_machine(config, verifier_channel, component_op0, proof).unwrap();
}
}
Loading