Skip to content

Commit

Permalink
State machine with two components
Browse files Browse the repository at this point in the history
  • Loading branch information
shaharsamocha7 committed Sep 24, 2024
1 parent f86b74a commit bb2e53b
Show file tree
Hide file tree
Showing 3 changed files with 246 additions and 64 deletions.
88 changes: 85 additions & 3 deletions crates/prover/src/examples/state_machine/components.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,25 @@
use num_traits::One;
use num_traits::{One, Zero};

use crate::constraint_framework::logup::{LogupAtRow, LookupElements};
use crate::constraint_framework::{EvalAtRow, FrameworkEval};
use crate::core::fields::qm31::QM31;
use crate::constraint_framework::{EvalAtRow, FrameworkComponent, FrameworkEval, InfoEvaluator};
use crate::core::air::{Component, ComponentProver};
use crate::core::backend::simd::SimdBackend;
use crate::core::channel::Channel;
use crate::core::fields::m31::M31;
use crate::core::fields::qm31::{SecureField, QM31};
use crate::core::lookups::utils::Fraction;
use crate::core::pcs::TreeVec;
use crate::core::prover::StarkProof;
use crate::core::vcs::ops::MerkleHasher;

const LOG_CONSTRAINT_DEGREE: u32 = 1;
pub const STATE_SIZE: usize = 2;
/// Random elements to combine the StateMachine state.
pub type StateMachineElements = LookupElements<STATE_SIZE>;
pub type State = [M31; STATE_SIZE];

pub type StateMachineOp0Component = FrameworkComponent<StateTransitionEval<0>>;
pub type StateMachineOp1Component = FrameworkComponent<StateTransitionEval<1>>;

/// State machine with state of size `STATE_SIZE`.
/// Transition `COORDINATE` of state increments the state by 1 at that offset.
Expand Down Expand Up @@ -47,3 +58,74 @@ impl<const COORDINATE: usize> FrameworkEval for StateTransitionEval<COORDINATE>
eval
}
}

pub struct StateMachineStatement0 {
pub n: u32,
pub m: u32,
}
impl StateMachineStatement0 {
pub fn log_sizes(&self) -> TreeVec<Vec<u32>> {
let sizes = vec![
state_transition_info::<0>()
.mask_offsets
.as_cols_ref()
.map_cols(|_| self.n),
state_transition_info::<1>()
.mask_offsets
.as_cols_ref()
.map_cols(|_| self.m),
];
TreeVec::concat_cols(sizes.into_iter())
}
pub fn mix_into(&self, channel: &mut impl Channel) {
channel.mix_u64(self.n as u64);
channel.mix_u64(self.m as u64);
}
}

pub struct StateMachineStatement1 {
pub x_axis_claimed_sum: SecureField,
pub y_axis_claimed_sum: SecureField,
}
impl StateMachineStatement1 {
pub fn mix_into(&self, channel: &mut impl Channel) {
channel.mix_felts(&[self.x_axis_claimed_sum, self.y_axis_claimed_sum])
}
}

fn state_transition_info<const INDEX: usize>() -> InfoEvaluator {
let component = StateTransitionEval::<INDEX> {
log_n_rows: 1,
lookup_elements: StateMachineElements::dummy(),
total_sum: QM31::zero(),
};
component.evaluate(InfoEvaluator::default())
}

pub struct StateMachineComponents {
pub component0: StateMachineOp0Component,
pub component1: StateMachineOp1Component,
}

impl StateMachineComponents {
pub fn components(&self) -> Vec<&dyn Component> {
vec![
&self.component0 as &dyn Component,
&self.component1 as &dyn Component,
]
}

pub fn component_provers(&self) -> Vec<&dyn ComponentProver<SimdBackend>> {
vec![
&self.component0 as &dyn ComponentProver<SimdBackend>,
&self.component1 as &dyn ComponentProver<SimdBackend>,
]
}
}

pub struct StateMachineProof<H: MerkleHasher> {
pub public_input: [State; 2], // Initial and final state.
pub stmt0: StateMachineStatement0,
pub stmt1: StateMachineStatement1,
pub stark_proof: StarkProof<H>,
}
17 changes: 11 additions & 6 deletions crates/prover/src/examples/state_machine/gen.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use itertools::Itertools;
use num_traits::{One, Zero};

use super::components::STATE_SIZE;
use super::components::{State, STATE_SIZE};
use crate::constraint_framework::logup::{LogupTraceGenerator, LookupElements};
use crate::core::backend::simd::column::BaseColumn;
use crate::core::backend::simd::m31::{PackedM31, LOG_N_LANES};
Expand All @@ -11,10 +11,9 @@ use crate::core::fields::m31::M31;
use crate::core::fields::qm31::QM31;
use crate::core::poly::circle::{CanonicCoset, CircleEvaluation};
use crate::core::poly::BitReversedOrder;
use crate::core::utils::{bit_reverse_index, coset_index_to_circle_domain_index};
use crate::core::ColumnVec;

pub type State = [M31; STATE_SIZE];

// Given `initial state`, generate a trace that row `i` is the initial state plus `i` in the
// `inc_index` dimension.
// E.g. [x, y] -> [x, y + 1] -> [x, y + 2] -> [x, y + 1 << log_size].
Expand All @@ -27,11 +26,14 @@ pub fn gen_trace(
let mut trace = (0..STATE_SIZE)
.map(|_| vec![M31::zero(); 1 << log_size])
.collect_vec();

let mut curr_state = initial_state;

// Add the states in bit reversed circle domain order.
for i in 0..1 << log_size {
for j in 0..STATE_SIZE {
trace[j][i] = curr_state[j];
let bit_rev_index =
bit_reverse_index(coset_index_to_circle_domain_index(i, log_size), log_size);
trace[j][bit_rev_index] = curr_state[j];
}
// Increment the state to the next state row.
curr_state[inc_index] += M31::one();
Expand Down Expand Up @@ -89,6 +91,7 @@ mod tests {
use crate::core::fields::qm31::QM31;
use crate::core::fields::secure_column::SECURE_EXTENSION_DEGREE;
use crate::core::fields::FieldExpOps;
use crate::core::utils::{bit_reverse_index, coset_index_to_circle_domain_index};
use crate::examples::state_machine::components::StateMachineElements;
use crate::examples::state_machine::gen::{gen_interaction_trace, gen_trace};

Expand All @@ -103,8 +106,10 @@ mod tests {

assert_eq!(trace.len(), 2);
assert_eq!(trace[0].at(row), initial_state[0]);
let bit_rev_row =
bit_reverse_index(coset_index_to_circle_domain_index(row, log_size), log_size);
assert_eq!(
trace[1].at(row),
trace[1].at(bit_rev_row),
initial_state[1] + M31::from_u32_unchecked(row as u32)
);
}
Expand Down
Loading

0 comments on commit bb2e53b

Please sign in to comment.