-
Notifications
You must be signed in to change notification settings - Fork 79
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
cbe1dd7
commit 7242e27
Showing
4 changed files
with
382 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
pub mod blake; | ||
pub mod plonk; | ||
pub mod poseidon; | ||
pub mod state_machine; | ||
pub mod wide_fibonacci; | ||
pub mod xor; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
use num_traits::One; | ||
|
||
use crate::constraint_framework::logup::{LogupAtRow, LookupElements}; | ||
use crate::constraint_framework::{EvalAtRow, FrameworkEval}; | ||
use crate::core::fields::qm31::QM31; | ||
use crate::core::lookups::utils::Fraction; | ||
|
||
const LOG_CONSTRAINT_DEGREE: u32 = 1; | ||
pub const STATE_SIZE: usize = 2; | ||
/// Random elements to combine the StateMachine state. | ||
pub type StateMachineElements = LookupElements<STATE_SIZE>; | ||
|
||
/// State machine with state of size `STATE_SIZE`. | ||
/// Transition `COORDINATE` of state increments the state by 1 at that offset. | ||
#[derive(Clone)] | ||
pub struct StateTransitionEval<const COORDINATE: usize> { | ||
pub log_n_rows: u32, | ||
pub lookup_elements: StateMachineElements, | ||
pub total_sum: QM31, | ||
} | ||
|
||
impl<const COORDINATE: usize> FrameworkEval for StateTransitionEval<COORDINATE> { | ||
fn log_size(&self) -> u32 { | ||
self.log_n_rows | ||
} | ||
fn max_constraint_log_degree_bound(&self) -> u32 { | ||
self.log_n_rows + LOG_CONSTRAINT_DEGREE | ||
} | ||
fn evaluate<E: EvalAtRow>(&self, mut eval: E) -> E { | ||
let [is_first] = eval.next_interaction_mask(2, [0]); | ||
let mut logup: LogupAtRow<E> = LogupAtRow::new(1, self.total_sum, None, is_first); | ||
|
||
let input_state: [_; STATE_SIZE] = std::array::from_fn(|_| eval.next_trace_mask()); | ||
let input_denom: E::EF = self.lookup_elements.combine(&input_state); | ||
|
||
let mut output_state = input_state; | ||
output_state[COORDINATE] += E::F::one(); | ||
let output_denom: E::EF = self.lookup_elements.combine(&output_state); | ||
|
||
// Add to the total sum (1/input_denom - 1/output_denom). | ||
logup.write_frac( | ||
&mut eval, | ||
Fraction::new(output_denom - input_denom, output_denom * input_denom), | ||
); | ||
|
||
logup.finalize(&mut eval); | ||
eval | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,170 @@ | ||
use itertools::Itertools; | ||
use num_traits::One; | ||
|
||
use super::components::STATE_SIZE; | ||
use crate::constraint_framework::logup::{LogupTraceGenerator, LookupElements}; | ||
use crate::core::backend::simd::column::BaseColumn; | ||
use crate::core::backend::simd::m31::{PackedM31, LOG_N_LANES, N_LANES}; | ||
use crate::core::backend::simd::qm31::PackedQM31; | ||
use crate::core::backend::simd::SimdBackend; | ||
use crate::core::backend::Column; | ||
use crate::core::fields::m31::M31; | ||
use crate::core::fields::qm31::QM31; | ||
use crate::core::poly::circle::{CanonicCoset, CircleEvaluation}; | ||
use crate::core::poly::BitReversedOrder; | ||
use crate::core::ColumnVec; | ||
|
||
pub type State = [M31; STATE_SIZE]; | ||
|
||
// Given `initial state`, generate a trace that row `i` is the initial state plus `i` in the | ||
// `inc_index` dimension. | ||
// E.g. [x, y] -> [x, y + 1] -> [x, y + 2] -> [x, y + 1 << log_size]. | ||
pub fn gen_trace( | ||
log_size: u32, | ||
initial_state: State, | ||
inc_index: usize, | ||
) -> ColumnVec<CircleEvaluation<SimdBackend, M31, BitReversedOrder>> { | ||
let n_lanes = PackedM31::broadcast(M31::from_u32_unchecked(N_LANES as u32)); | ||
let domain = CanonicCoset::new(log_size).circle_domain(); | ||
|
||
// Prepare the state for the first packed row. | ||
let mut packed_state = initial_state.map(PackedM31::broadcast); | ||
let inc = PackedM31::from_array(std::array::from_fn(|i| M31::from_u32_unchecked((i) as u32))); | ||
packed_state[inc_index] += inc; | ||
|
||
let mut trace = (0..STATE_SIZE) | ||
.map(|_| unsafe { BaseColumn::uninitialized(1 << log_size) }) | ||
.collect_vec(); | ||
for i in 0..(1 << (log_size - LOG_N_LANES)) { | ||
for j in 0..STATE_SIZE { | ||
trace[j].data[i] = packed_state[j]; | ||
} | ||
// Increment the state to the next packed row. | ||
packed_state[inc_index] += n_lanes; | ||
} | ||
trace | ||
.into_iter() | ||
.map(|eval| CircleEvaluation::<SimdBackend, _, BitReversedOrder>::new(domain, eval)) | ||
.collect_vec() | ||
} | ||
|
||
pub fn gen_interaction_trace( | ||
log_size: u32, | ||
initial_state: [M31; STATE_SIZE], | ||
inc_index: usize, | ||
lookup_elements: &LookupElements<STATE_SIZE>, | ||
) -> ( | ||
ColumnVec<CircleEvaluation<SimdBackend, M31, BitReversedOrder>>, | ||
QM31, | ||
) { | ||
let ones = PackedM31::broadcast(M31::one()); | ||
let n_lanes_minus_one = PackedM31::broadcast(M31::from_u32_unchecked(N_LANES as u32)) - ones; | ||
|
||
// Prepare the state. | ||
let mut packed_state = initial_state.map(PackedM31::broadcast); | ||
let inc = PackedM31::from_array(std::array::from_fn(|i| M31::from_u32_unchecked((i) as u32))); | ||
packed_state[inc_index] += inc; | ||
|
||
let mut logup_gen = LogupTraceGenerator::new(log_size); | ||
let mut col_gen = logup_gen.new_col(); | ||
|
||
for vec_row in 0..(1 << (log_size - LOG_N_LANES)) { | ||
let input_denom: PackedQM31 = lookup_elements.combine(&packed_state); | ||
packed_state[inc_index] += ones; | ||
let output_denom: PackedQM31 = lookup_elements.combine(&packed_state); | ||
packed_state[inc_index] += n_lanes_minus_one; | ||
col_gen.write_frac( | ||
vec_row, | ||
output_denom - input_denom, | ||
input_denom * output_denom, | ||
); | ||
} | ||
col_gen.finalize_col(); | ||
|
||
logup_gen.finalize_last() | ||
} | ||
|
||
#[cfg(test)] | ||
mod tests { | ||
use itertools::Itertools; | ||
use num_traits::One; | ||
|
||
use crate::core::backend::Column; | ||
use crate::core::channel::Blake2sChannel; | ||
use crate::core::fields::m31::M31; | ||
use crate::core::fields::qm31::QM31; | ||
use crate::core::fields::secure_column::SECURE_EXTENSION_DEGREE; | ||
use crate::core::fields::FieldExpOps; | ||
use crate::examples::state_machine::components::StateMachineElements; | ||
use crate::examples::state_machine::gen::{gen_interaction_trace, gen_trace}; | ||
|
||
#[test] | ||
fn test_gen_trace() { | ||
let log_size = 8; | ||
let initial_state = [M31::from_u32_unchecked(17), M31::from_u32_unchecked(16)]; | ||
let inc_index = 1; | ||
let row = 123; | ||
|
||
let trace = gen_trace(log_size, initial_state, inc_index); | ||
|
||
assert_eq!(trace.len(), 2); | ||
assert_eq!(trace[0].at(row), initial_state[0]); | ||
assert_eq!( | ||
trace[1].at(row), | ||
initial_state[1] + M31::from_u32_unchecked(row as u32) | ||
); | ||
} | ||
|
||
#[test] | ||
fn test_gen_interaction_trace() { | ||
let log_size = 8; | ||
let inc_index = 1; | ||
// Prepare state and next state. | ||
let state = [M31::from_u32_unchecked(17), M31::from_u32_unchecked(12)]; | ||
let mut next_state = state; | ||
next_state[inc_index] += M31::one(); | ||
|
||
let lookup_elements = StateMachineElements::dummy(); | ||
let comb_state: QM31 = lookup_elements.combine(&state); | ||
let comb_next_state: QM31 = lookup_elements.combine(&next_state); | ||
|
||
let (trace, _) = gen_interaction_trace(log_size, state, inc_index, &lookup_elements); | ||
let first_log_up_row = QM31::from_m31_array( | ||
trace | ||
.iter() | ||
.map(|col| col.at(0)) | ||
.collect_vec() | ||
.try_into() | ||
.unwrap(), | ||
); | ||
|
||
assert_eq!(trace.len(), SECURE_EXTENSION_DEGREE); // One quadradic extension column. | ||
assert_eq!( | ||
first_log_up_row, | ||
comb_state.inverse() - comb_next_state.inverse() | ||
); | ||
} | ||
|
||
#[test] | ||
fn test_state_machine_total_sum() { | ||
let log_n_rows = 8; | ||
let lookup_elements = StateMachineElements::draw(&mut Blake2sChannel::default()); | ||
let inc_index = 0; | ||
|
||
let initial_state = [M31::from(123), M31::from(456)]; | ||
let initial_state_comb: QM31 = lookup_elements.combine(&initial_state); | ||
|
||
let mut last_state = initial_state; | ||
last_state[inc_index] += M31::from_u32_unchecked(1 << log_n_rows); | ||
let last_state_comb: QM31 = lookup_elements.combine(&last_state); | ||
|
||
let (_, total_sum) = | ||
gen_interaction_trace(log_n_rows, initial_state, inc_index, &lookup_elements); | ||
|
||
// Assert total sum is `(1 / initial_state_comb) - (1 / last_state_comb)`. | ||
assert_eq!( | ||
total_sum * initial_state_comb * last_state_comb, | ||
last_state_comb - initial_state_comb | ||
); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,162 @@ | ||
pub mod components; | ||
pub mod gen; | ||
|
||
use components::{StateMachineElements, StateTransitionEval}; | ||
use gen::{gen_interaction_trace, gen_trace, State}; | ||
use itertools::Itertools; | ||
|
||
use crate::constraint_framework::constant_columns::gen_is_first; | ||
use crate::constraint_framework::{FrameworkComponent, TraceLocationAllocator}; | ||
use crate::core::air::Component; | ||
use crate::core::backend::simd::m31::LOG_N_LANES; | ||
use crate::core::backend::simd::SimdBackend; | ||
use crate::core::channel::Blake2sChannel; | ||
use crate::core::pcs::{CommitmentSchemeProver, CommitmentSchemeVerifier, PcsConfig, TreeVec}; | ||
use crate::core::poly::circle::{CanonicCoset, CirclePoly, PolyOps}; | ||
use crate::core::prover::{prove, verify, StarkProof, VerificationError}; | ||
use crate::core::vcs::blake2_merkle::{Blake2sMerkleChannel, Blake2sMerkleHasher}; | ||
|
||
pub type StateMachineOp0Component = FrameworkComponent<StateTransitionEval<0>>; | ||
|
||
#[allow(unused)] | ||
pub fn prove_state_machine( | ||
log_n_rows: u32, | ||
initial_state: State, | ||
config: PcsConfig, | ||
channel: &mut Blake2sChannel, | ||
) -> ( | ||
StateMachineOp0Component, | ||
StarkProof<Blake2sMerkleHasher>, | ||
TreeVec<Vec<CirclePoly<SimdBackend>>>, | ||
) { | ||
assert!(log_n_rows >= LOG_N_LANES); | ||
|
||
// Precompute twiddles. | ||
let twiddles = SimdBackend::precompute_twiddles( | ||
CanonicCoset::new(log_n_rows + config.fri_config.log_blowup_factor + 1) | ||
.circle_domain() | ||
.half_coset, | ||
); | ||
|
||
// Setup protocol. | ||
let commitment_scheme = | ||
&mut CommitmentSchemeProver::<_, Blake2sMerkleChannel>::new(config, &twiddles); | ||
|
||
// Trace. | ||
let trace_op0 = gen_trace(log_n_rows, initial_state, 0); | ||
let mut tree_builder = commitment_scheme.tree_builder(); | ||
tree_builder.extend_evals(trace_op0); | ||
tree_builder.commit(channel); | ||
|
||
// Draw lookup element. | ||
let lookup_elements = StateMachineElements::draw(channel); | ||
|
||
// Interaction trace. | ||
let (interaction_trace_op0, total_sum_op0) = | ||
gen_interaction_trace(log_n_rows, initial_state, 0, &lookup_elements); | ||
let mut tree_builder = commitment_scheme.tree_builder(); | ||
tree_builder.extend_evals(interaction_trace_op0); | ||
tree_builder.commit(channel); | ||
|
||
// Constant trace. | ||
let mut tree_builder = commitment_scheme.tree_builder(); | ||
tree_builder.extend_evals(vec![gen_is_first(log_n_rows)]); | ||
tree_builder.commit(channel); | ||
|
||
let trace_polys = commitment_scheme | ||
.trees | ||
.as_ref() | ||
.map(|t| t.polynomials.iter().cloned().collect_vec()); | ||
|
||
// Prove constraints. | ||
let component_op0 = StateMachineOp0Component::new( | ||
&mut TraceLocationAllocator::default(), | ||
StateTransitionEval { | ||
log_n_rows, | ||
lookup_elements, | ||
total_sum: total_sum_op0, | ||
}, | ||
); | ||
|
||
let proof = prove(&[&component_op0], channel, commitment_scheme).unwrap(); | ||
|
||
(component_op0, proof, trace_polys) | ||
} | ||
|
||
pub fn verify_state_machine( | ||
config: PcsConfig, | ||
channel: &mut Blake2sChannel, | ||
component: StateMachineOp0Component, | ||
proof: StarkProof<Blake2sMerkleHasher>, | ||
) -> Result<(), VerificationError> { | ||
let commitment_scheme = &mut CommitmentSchemeVerifier::<Blake2sMerkleChannel>::new(config); | ||
|
||
// Decommit. | ||
// Retrieve the expected column sizes in each commitment interaction, from the AIR. | ||
let sizes = component.trace_log_degree_bounds(); | ||
// Trace columns. | ||
commitment_scheme.commit(proof.commitments[0], &sizes[0], channel); | ||
// Interaction columns. | ||
commitment_scheme.commit(proof.commitments[1], &sizes[1], channel); | ||
// Constant columns. | ||
commitment_scheme.commit(proof.commitments[2], &sizes[2], channel); | ||
|
||
verify(&[&component], channel, commitment_scheme, proof) | ||
} | ||
|
||
#[cfg(test)] | ||
mod tests { | ||
use num_traits::Zero; | ||
|
||
use super::components::STATE_SIZE; | ||
use super::{prove_state_machine, verify_state_machine}; | ||
use crate::constraint_framework::{assert_constraints, FrameworkEval}; | ||
use crate::core::channel::Blake2sChannel; | ||
use crate::core::fields::m31::M31; | ||
use crate::core::fields::qm31::QM31; | ||
use crate::core::pcs::PcsConfig; | ||
use crate::core::poly::circle::CanonicCoset; | ||
|
||
#[test] | ||
fn test_state_machine_constraints() { | ||
let log_n_rows = 8; | ||
let config = PcsConfig::default(); | ||
|
||
// Initial and last state. | ||
let initial_state = [M31::zero(); STATE_SIZE]; | ||
let last_state = [M31::from_u32_unchecked(1 << log_n_rows), M31::zero()]; | ||
|
||
// Setup protocol. | ||
let channel = &mut Blake2sChannel::default(); | ||
let (component, _, trace_polys) = | ||
prove_state_machine(log_n_rows, initial_state, config, channel); | ||
|
||
let interaction_elements = component.lookup_elements.clone(); | ||
let initial_state_comb: QM31 = interaction_elements.combine(&initial_state); | ||
let last_state_comb: QM31 = interaction_elements.combine(&last_state); | ||
|
||
// Assert total sum is `(1 / initial_state_comb) - (1 / last_state_comb)`. | ||
assert_eq!( | ||
component.total_sum * initial_state_comb * last_state_comb, | ||
last_state_comb - initial_state_comb | ||
); | ||
|
||
// Assert constraints. | ||
assert_constraints(&trace_polys, CanonicCoset::new(log_n_rows), |eval| { | ||
component.evaluate(eval); | ||
}); | ||
} | ||
|
||
#[test] | ||
fn test_state_machine_prove() { | ||
let log_n_rows = 8; | ||
let config = PcsConfig::default(); | ||
let initial_state = [M31::zero(); STATE_SIZE]; | ||
let prover_channel = &mut Blake2sChannel::default(); | ||
let (component_op0, proof, _) = | ||
prove_state_machine(log_n_rows, initial_state, config, prover_channel); | ||
|
||
let verifier_channel = &mut Blake2sChannel::default(); | ||
verify_state_machine(config, verifier_channel, component_op0, proof).unwrap(); | ||
} | ||
} |