Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

State machine prover #845

Merged
merged 1 commit into from
Sep 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 85 additions & 3 deletions crates/prover/src/examples/state_machine/components.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,25 @@
use num_traits::One;
use num_traits::{One, Zero};

use crate::constraint_framework::logup::{LogupAtRow, LookupElements};
use crate::constraint_framework::{EvalAtRow, FrameworkEval};
use crate::core::fields::qm31::QM31;
use crate::constraint_framework::{EvalAtRow, FrameworkComponent, FrameworkEval, InfoEvaluator};
use crate::core::air::{Component, ComponentProver};
use crate::core::backend::simd::SimdBackend;
use crate::core::channel::Channel;
use crate::core::fields::m31::M31;
use crate::core::fields::qm31::{SecureField, QM31};
use crate::core::lookups::utils::Fraction;
use crate::core::pcs::TreeVec;
use crate::core::prover::StarkProof;
use crate::core::vcs::ops::MerkleHasher;

const LOG_CONSTRAINT_DEGREE: u32 = 1;
pub const STATE_SIZE: usize = 2;
/// Random elements to combine the StateMachine state.
pub type StateMachineElements = LookupElements<STATE_SIZE>;
pub type State = [M31; STATE_SIZE];

pub type StateMachineOp0Component = FrameworkComponent<StateTransitionEval<0>>;
pub type StateMachineOp1Component = FrameworkComponent<StateTransitionEval<1>>;

/// State machine with state of size `STATE_SIZE`.
/// Transition `COORDINATE` of state increments the state by 1 at that offset.
Expand Down Expand Up @@ -47,3 +58,74 @@ impl<const COORDINATE: usize> FrameworkEval for StateTransitionEval<COORDINATE>
eval
}
}

pub struct StateMachineStatement0 {
pub n: u32,
pub m: u32,
}
impl StateMachineStatement0 {
pub fn log_sizes(&self) -> TreeVec<Vec<u32>> {
let sizes = vec![
state_transition_info::<0>()
.mask_offsets
.as_cols_ref()
.map_cols(|_| self.n),
state_transition_info::<1>()
.mask_offsets
.as_cols_ref()
.map_cols(|_| self.m),
];
TreeVec::concat_cols(sizes.into_iter())
}
pub fn mix_into(&self, channel: &mut impl Channel) {
channel.mix_u64(self.n as u64);
channel.mix_u64(self.m as u64);
}
}

pub struct StateMachineStatement1 {
pub x_axis_claimed_sum: SecureField,
pub y_axis_claimed_sum: SecureField,
}
impl StateMachineStatement1 {
pub fn mix_into(&self, channel: &mut impl Channel) {
channel.mix_felts(&[self.x_axis_claimed_sum, self.y_axis_claimed_sum])
}
}

fn state_transition_info<const INDEX: usize>() -> InfoEvaluator {
let component = StateTransitionEval::<INDEX> {
log_n_rows: 1,
lookup_elements: StateMachineElements::dummy(),
total_sum: QM31::zero(),
};
component.evaluate(InfoEvaluator::default())
}

pub struct StateMachineComponents {
pub component0: StateMachineOp0Component,
pub component1: StateMachineOp1Component,
}

impl StateMachineComponents {
pub fn components(&self) -> Vec<&dyn Component> {
vec![
&self.component0 as &dyn Component,
&self.component1 as &dyn Component,
]
}

pub fn component_provers(&self) -> Vec<&dyn ComponentProver<SimdBackend>> {
vec![
&self.component0 as &dyn ComponentProver<SimdBackend>,
&self.component1 as &dyn ComponentProver<SimdBackend>,
]
}
}

pub struct StateMachineProof<H: MerkleHasher> {
pub public_input: [State; 2], // Initial and final state.
pub stmt0: StateMachineStatement0,
pub stmt1: StateMachineStatement1,
pub stark_proof: StarkProof<H>,
}
7 changes: 3 additions & 4 deletions crates/prover/src/examples/state_machine/gen.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use itertools::Itertools;
use num_traits::{One, Zero};

use super::components::STATE_SIZE;
use super::components::{State, STATE_SIZE};
use crate::constraint_framework::logup::{LogupTraceGenerator, LookupElements};
use crate::core::backend::simd::column::BaseColumn;
use crate::core::backend::simd::m31::{PackedM31, LOG_N_LANES};
Expand All @@ -13,8 +13,6 @@ use crate::core::poly::circle::{CanonicCoset, CircleEvaluation};
use crate::core::poly::BitReversedOrder;
use crate::core::ColumnVec;

pub type State = [M31; STATE_SIZE];

// Given `initial state`, generate a trace that row `i` is the initial state plus `i` in the
// `inc_index` dimension.
// E.g. [x, y] -> [x, y + 1] -> [x, y + 2] -> [x, y + 1 << log_size].
Expand All @@ -27,8 +25,9 @@ pub fn gen_trace(
let mut trace = (0..STATE_SIZE)
.map(|_| vec![M31::zero(); 1 << log_size])
.collect_vec();

let mut curr_state = initial_state;

// Add the states in bit reversed circle domain order.
for i in 0..1 << log_size {
for j in 0..STATE_SIZE {
trace[j][i] = curr_state[j];
Expand Down
Loading
Loading