Skip to content

Commit

Permalink
WIP - gen_interaction_trace get the original trace
Browse files Browse the repository at this point in the history
  • Loading branch information
shaharsamocha7 committed Sep 24, 2024
1 parent 69e7fd2 commit ae9a904
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 56 deletions.
100 changes: 47 additions & 53 deletions crates/prover/src/examples/state_machine/gen.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
use itertools::Itertools;
use num_traits::One;
use num_traits::{One, Zero};

use super::components::{State, STATE_SIZE};
use crate::constraint_framework::logup::{LogupTraceGenerator, LookupElements};
use crate::core::backend::simd::column::BaseColumn;
use crate::core::backend::simd::m31::{PackedM31, LOG_N_LANES, N_LANES};
use crate::core::backend::simd::m31::{PackedM31, LOG_N_LANES};
use crate::core::backend::simd::qm31::PackedQM31;
use crate::core::backend::simd::SimdBackend;
use crate::core::backend::Column;
use crate::core::fields::m31::M31;
use crate::core::fields::qm31::QM31;
use crate::core::poly::circle::{CanonicCoset, CircleEvaluation};
use crate::core::poly::BitReversedOrder;
use crate::core::utils::{bit_reverse_index, coset_index_to_circle_domain_index};
use crate::core::ColumnVec;

// Given `initial state`, generate a trace that row `i` is the initial state plus `i` in the
Expand All @@ -22,55 +22,58 @@ pub fn gen_trace(
initial_state: State,
inc_index: usize,
) -> ColumnVec<CircleEvaluation<SimdBackend, M31, BitReversedOrder>> {
let n_lanes = PackedM31::broadcast(M31::from_u32_unchecked(N_LANES as u32));
let domain = CanonicCoset::new(log_size).circle_domain();

// Prepare the state for the first packed row.
let mut packed_state = initial_state.map(PackedM31::broadcast);
let inc = PackedM31::from_array(std::array::from_fn(|i| M31::from_u32_unchecked((i) as u32)));
packed_state[inc_index] += inc;

let mut trace = (0..STATE_SIZE)
.map(|_| unsafe { BaseColumn::uninitialized(1 << log_size) })
.map(|_| vec![M31::zero(); 1 << log_size])
.collect_vec();
for i in 0..(1 << (log_size - LOG_N_LANES)) {

let mut curr_state = initial_state;
for i in 0..1 << log_size {
for j in 0..STATE_SIZE {
trace[j].data[i] = packed_state[j];
let bit_rev_index =
bit_reverse_index(coset_index_to_circle_domain_index(i, log_size), log_size);
trace[j][bit_rev_index] = curr_state[j];
}
// Increment the state to the next packed row.
packed_state[inc_index] += n_lanes;
// Increment the state to the next state row.
curr_state[inc_index] += M31::one();
}

trace
.into_iter()
.map(|eval| CircleEvaluation::<SimdBackend, _, BitReversedOrder>::new(domain, eval))
.map(|col| {
CircleEvaluation::<SimdBackend, _, BitReversedOrder>::new(
domain,
BaseColumn::from_iter(col),
)
})
.collect_vec()
}

pub fn gen_interaction_trace(
log_size: u32,
initial_state: [M31; STATE_SIZE],
trace: [BaseColumn; STATE_SIZE],
// claimed_offset: usize,
inc_index: usize,
lookup_elements: &LookupElements<STATE_SIZE>,
) -> (
ColumnVec<CircleEvaluation<SimdBackend, M31, BitReversedOrder>>,
// [QM31; 2],
QM31,
) {
let ones = PackedM31::broadcast(M31::one());
let n_lanes_minus_one = PackedM31::broadcast(M31::from_u32_unchecked(N_LANES as u32)) - ones;

// Prepare the state.
let mut packed_state = initial_state.map(PackedM31::broadcast);
let inc = PackedM31::from_array(std::array::from_fn(|i| M31::from_u32_unchecked((i) as u32)));
packed_state[inc_index] += inc;

let mut logup_gen = LogupTraceGenerator::new(log_size);
let mut col_gen = logup_gen.new_col();

for vec_row in 0..(1 << (log_size - LOG_N_LANES)) {
let mut packed_state: [PackedM31; STATE_SIZE] = trace
.iter()
.map(|col| col.data[vec_row])
.collect_vec()
.try_into()
.unwrap();
let input_denom: PackedQM31 = lookup_elements.combine(&packed_state);
packed_state[inc_index] += ones;
let output_denom: PackedQM31 = lookup_elements.combine(&packed_state);
packed_state[inc_index] += n_lanes_minus_one;
col_gen.write_frac(
vec_row,
output_denom - input_denom,
Expand All @@ -88,11 +91,11 @@ mod tests {
use num_traits::One;

use crate::core::backend::Column;
use crate::core::channel::Blake2sChannel;
use crate::core::fields::m31::M31;
use crate::core::fields::qm31::QM31;
use crate::core::fields::secure_column::SECURE_EXTENSION_DEGREE;
use crate::core::fields::FieldExpOps;
use crate::core::utils::{bit_reverse_index, coset_index_to_circle_domain_index};
use crate::examples::state_machine::components::StateMachineElements;
use crate::examples::state_machine::gen::{gen_interaction_trace, gen_trace};

Expand All @@ -107,8 +110,10 @@ mod tests {

assert_eq!(trace.len(), 2);
assert_eq!(trace[0].at(row), initial_state[0]);
let bit_rev_row =
bit_reverse_index(coset_index_to_circle_domain_index(row, log_size), log_size);
assert_eq!(
trace[1].at(row),
trace[1].at(bit_rev_row),
initial_state[1] + M31::from_u32_unchecked(row as u32)
);
}
Expand All @@ -122,47 +127,36 @@ mod tests {
let mut next_state = state;
next_state[inc_index] += M31::one();

let trace = gen_trace(log_size, state, inc_index);

let lookup_elements = StateMachineElements::dummy();
let comb_state: QM31 = lookup_elements.combine(&state);
let comb_next_state: QM31 = lookup_elements.combine(&next_state);

let (trace, _) = gen_interaction_trace(log_size, state, inc_index, &lookup_elements);
let first_log_up_row = QM31::from_m31_array(
let (interaction_trace, _) = gen_interaction_trace(
log_size,
trace
.iter()
.map(|eval| eval.values.clone())
.collect_vec()
.try_into()
.unwrap(),
inc_index,
&lookup_elements,
);
let first_log_up_row = QM31::from_m31_array(
interaction_trace
.iter()
.map(|col| col.at(0))
.collect_vec()
.try_into()
.unwrap(),
);

assert_eq!(trace.len(), SECURE_EXTENSION_DEGREE); // One quadradic extension column.
assert_eq!(interaction_trace.len(), SECURE_EXTENSION_DEGREE); // One quadradic extension column.
assert_eq!(
first_log_up_row,
comb_state.inverse() - comb_next_state.inverse()
);
}

#[test]
fn test_state_machine_total_sum() {
let log_n_rows = 8;
let lookup_elements = StateMachineElements::draw(&mut Blake2sChannel::default());
let inc_index = 0;

let initial_state = [M31::from(123), M31::from(456)];
let initial_state_comb: QM31 = lookup_elements.combine(&initial_state);

let mut last_state = initial_state;
last_state[inc_index] += M31::from_u32_unchecked(1 << log_n_rows);
let last_state_comb: QM31 = lookup_elements.combine(&last_state);

let (_, total_sum) =
gen_interaction_trace(log_n_rows, initial_state, inc_index, &lookup_elements);

// Assert total sum is `(1 / initial_state_comb) - (1 / last_state_comb)`.
assert_eq!(
total_sum * initial_state_comb * last_state_comb,
last_state_comb - initial_state_comb
);
}
}
11 changes: 8 additions & 3 deletions crates/prover/src/examples/state_machine/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ pub fn prove_state_machine(
assert!(log_n_rows >= LOG_N_LANES);
let x_axis_log_rows = log_n_rows;
let y_axis_log_rows = log_n_rows - 1;

let mut intermediate_state = initial_state;
intermediate_state[0] += M31::from_u32_unchecked(1 << x_axis_log_rows);
let mut final_state = intermediate_state;
Expand All @@ -60,6 +61,9 @@ pub fn prove_state_machine(
};
stmt0.mix_into(channel);

let trace0_copy = [trace_op0[0].values.clone(), trace_op0[1].values.clone()];
let trace1_copy = [trace_op1[0].values.clone(), trace_op1[1].values.clone()];

let mut tree_builder = commitment_scheme.tree_builder();
tree_builder.extend_evals(chain![trace_op0, trace_op1].collect_vec());
tree_builder.commit(channel);
Expand All @@ -69,9 +73,9 @@ pub fn prove_state_machine(

// Interaction trace.
let (interaction_trace_op0, total_sum_op0) =
gen_interaction_trace(x_axis_log_rows, initial_state, 0, &lookup_elements);
gen_interaction_trace(x_axis_log_rows, trace0_copy, 0, &lookup_elements);
let (interaction_trace_op1, total_sum_op1) =
gen_interaction_trace(y_axis_log_rows, intermediate_state, 1, &lookup_elements);
gen_interaction_trace(y_axis_log_rows, trace1_copy, 1, &lookup_elements);

let stmt1 = StateMachineStatement1 {
x_axis_claimed_sum: total_sum_op0,
Expand Down Expand Up @@ -185,12 +189,13 @@ mod tests {
let initial_state = [M31::zero(); STATE_SIZE];

let trace = gen_trace(log_n_rows, initial_state, 0);
let trace_copy = [trace[0].values.clone(), trace[1].values.clone()];

let lookup_elements = StateMachineElements::draw(&mut Blake2sChannel::default());

// Interaction trace.
let (interaction_trace, total_sum) =
gen_interaction_trace(log_n_rows, initial_state, 0, &lookup_elements);
gen_interaction_trace(log_n_rows, trace_copy, 0, &lookup_elements);

let component = StateMachineOp0Component::new(
&mut TraceLocationAllocator::default(),
Expand Down

0 comments on commit ae9a904

Please sign in to comment.