diff --git a/Cargo.lock b/Cargo.lock index fd3e1315..3f06768c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1922,6 +1922,12 @@ dependencies = [ "spki", ] +[[package]] +name = "plain" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b4596b6d070b27117e987119b4dac604f3c58cfb0b191112e24771b2faeac1a6" + [[package]] name = "plotters" version = "0.3.7" @@ -2525,7 +2531,7 @@ checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67" [[package]] name = "sphinx-core" version = "1.0.0" -source = "git+https://github.com/argumentcomputer/sphinx.git?branch=dev#68a48fb36cf521097e18220afb78de24ce67187a" +source = "git+https://github.com/argumentcomputer/sphinx.git?branch=forward_ports_48#2aacea1f40ca3b8e17f4b198fd06ef9a6f9990ae" dependencies = [ "anyhow", "arrayref", @@ -2565,6 +2571,7 @@ dependencies = [ "p3-symmetric", "p3-uni-stark", "p3-util", + "plain", "rand", "rayon-scan", "rrs-lib", @@ -2587,7 +2594,7 @@ dependencies = [ [[package]] name = "sphinx-derive" version = "1.0.0" -source = "git+https://github.com/argumentcomputer/sphinx.git?branch=dev#68a48fb36cf521097e18220afb78de24ce67187a" +source = "git+https://github.com/argumentcomputer/sphinx.git?branch=forward_ports_48#2aacea1f40ca3b8e17f4b198fd06ef9a6f9990ae" dependencies = [ "proc-macro2", "quote", @@ -2610,7 +2617,7 @@ dependencies = [ [[package]] name = "sphinx-primitives" version = "1.0.0" -source = "git+https://github.com/argumentcomputer/sphinx.git?branch=dev#68a48fb36cf521097e18220afb78de24ce67187a" +source = "git+https://github.com/argumentcomputer/sphinx.git?branch=forward_ports_48#2aacea1f40ca3b8e17f4b198fd06ef9a6f9990ae" dependencies = [ "itertools 0.12.1", "lazy_static", diff --git a/Cargo.toml b/Cargo.toml index 1a3f560f..a3cd22b0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -32,8 +32,8 @@ p3-poseidon2 = { git = "https://github.com/argumentcomputer/Plonky3.git", branch p3-symmetric = { git = "https://github.com/argumentcomputer/Plonky3.git", branch = "sp1" } p3-uni-stark = { git = "https://github.com/argumentcomputer/Plonky3.git", branch = "sp1" } p3-util = { git = "https://github.com/argumentcomputer/Plonky3.git", branch = "sp1" } -sphinx-core = { git = "https://github.com/argumentcomputer/sphinx.git", branch = "dev"} -sphinx-derive = { git = "https://github.com/argumentcomputer/sphinx.git", branch = "dev" } +sphinx-core = { git = "https://github.com/argumentcomputer/sphinx.git", branch = "forward_ports_48"} +sphinx-derive = { git = "https://github.com/argumentcomputer/sphinx.git", branch = "forward_ports_48" } anyhow = "1.0.72" ascent = { git = "https://github.com/argumentcomputer/ascent.git" } arc-swap = "1.7.1" diff --git a/benches/fib.rs b/benches/fib.rs index ca7d0970..dda583fa 100644 --- a/benches/fib.rs +++ b/benches/fib.rs @@ -4,7 +4,7 @@ use p3_field::AbstractField; use rayon::iter::{IntoParallelRefIterator, ParallelIterator}; use sphinx_core::{ air::MachineAir, - stark::{LocalProver, StarkGenericConfig, StarkMachine}, + stark::{DefaultProver, MachineProver, StarkGenericConfig, StarkMachine}, utils::{BabyBearPoseidon2, SphinxCoreOpts}, }; use std::time::Duration; @@ -48,7 +48,7 @@ fn setup>( toplevel: &Toplevel, ) -> ( List, - FuncChip<'_, BabyBear, C, NoChip>, + FuncChip, QueryRecord, ) { let code = build_lurk_expr(arg); @@ -125,7 +125,10 @@ fn e2e(c: &mut Criterion) { let mut challenger_p = machine.config().challenger(); let opts = SphinxCoreOpts::default(); let shard = Shard::new(&record); - machine.prove::>(&pk, shard, &mut challenger_p, opts); + let prover = DefaultProver::new(machine); + prover + .prove(&pk, vec![shard], &mut challenger_p, opts) + .unwrap(); }, BatchSize::SmallInput, ) diff --git a/benches/lcs.rs b/benches/lcs.rs index 92e05c86..b8082352 100644 --- a/benches/lcs.rs +++ b/benches/lcs.rs @@ -4,7 +4,7 @@ use p3_field::AbstractField; use rayon::iter::{IntoParallelRefIterator, ParallelIterator}; use sphinx_core::{ air::MachineAir, - stark::{LocalProver, StarkGenericConfig, StarkMachine}, + stark::{DefaultProver, MachineProver, StarkGenericConfig, StarkMachine}, utils::{BabyBearPoseidon2, SphinxCoreOpts}, }; use std::time::Duration; @@ -52,7 +52,7 @@ fn setup<'a, C: Chipset>( toplevel: &'a Toplevel, ) -> ( List, - FuncChip<'a, BabyBear, C, NoChip>, + FuncChip, QueryRecord, ) { let code = build_lurk_expr(a, b); @@ -129,7 +129,10 @@ fn e2e(c: &mut Criterion) { let mut challenger_p = machine.config().challenger(); let opts = SphinxCoreOpts::default(); let shard = Shard::new(&record); - machine.prove::>(&pk, shard, &mut challenger_p, opts); + let prover = DefaultProver::new(machine); + prover + .prove(&pk, vec![shard], &mut challenger_p, opts) + .unwrap(); }, BatchSize::SmallInput, ) diff --git a/benches/sum.rs b/benches/sum.rs index 8eae7a30..795c27ab 100644 --- a/benches/sum.rs +++ b/benches/sum.rs @@ -4,7 +4,7 @@ use p3_field::AbstractField; use rayon::iter::{IntoParallelRefIterator, ParallelIterator}; use sphinx_core::{ air::MachineAir, - stark::{LocalProver, StarkGenericConfig, StarkMachine}, + stark::{DefaultProver, MachineProver, StarkGenericConfig, StarkMachine}, utils::{BabyBearPoseidon2, SphinxCoreOpts}, }; use std::time::Duration; @@ -52,7 +52,7 @@ fn setup>( toplevel: &Toplevel, ) -> ( List, - FuncChip<'_, BabyBear, C, NoChip>, + FuncChip, QueryRecord, ) { let code = build_lurk_expr(n); @@ -130,7 +130,10 @@ fn e2e(c: &mut Criterion) { let mut challenger_p = machine.config().challenger(); let opts = SphinxCoreOpts::default(); let shard = Shard::new(&record); - machine.prove::>(&pk, shard, &mut challenger_p, opts); + let prover = DefaultProver::new(machine); + prover + .prove(&pk, vec![shard], &mut challenger_p, opts) + .unwrap(); }, BatchSize::SmallInput, ) diff --git a/src/air/debug.rs b/src/air/debug.rs index b431e2e0..35fee0f8 100644 --- a/src/air/debug.rs +++ b/src/air/debug.rs @@ -9,7 +9,6 @@ use p3_matrix::dense::{RowMajorMatrix, RowMajorMatrixView}; use p3_matrix::stack::VerticalPair; use p3_matrix::Matrix; use sphinx_core::air::MachineAir; -use sphinx_core::stark::MachineRecord; use std::collections::BTreeMap; type LocalRowView<'a, F> = VerticalPair, RowMajorMatrixView<'a, F>>; @@ -122,14 +121,13 @@ pub fn debug_chip_constraints_and_queries_with_sharding< C2: Chipset, >( record: &QueryRecord, - chips: &[LairChip<'_, F, C1, C2>], + chips: &[LairChip], config: Option, ) { - let full_shard = Shard::new(record); let shards = if let Some(config) = config { - full_shard.shard(&config) + config.shard(record) } else { - vec![full_shard] + vec![Shard::new(&record.clone())] }; let lookup_queries: Vec<_> = shards diff --git a/src/lair/air.rs b/src/lair/air.rs index f9d890ce..ef43eb43 100644 --- a/src/lair/air.rs +++ b/src/lair/air.rs @@ -130,13 +130,13 @@ fn eval_depth( out.extend(dep_depth.iter().cloned()); } -impl<'a, AB, C1: Chipset, C2: Chipset> Air for FuncChip<'a, AB::F, C1, C2> +impl, C2: Chipset> Air for FuncChip where AB: AirBuilder + LookupBuilder, ::Var: Debug, { fn eval(&self, builder: &mut AB) { - self.func.eval(builder, self.toplevel, self.layout_sizes) + self.func.eval(builder, &self.toplevel, self.layout_sizes) } } diff --git a/src/lair/chipset.rs b/src/lair/chipset.rs index 435e1125..5a029d1e 100644 --- a/src/lair/chipset.rs +++ b/src/lair/chipset.rs @@ -6,7 +6,7 @@ use crate::air::builder::{LookupBuilder, Record, RequireRecord}; use super::execute::QueryRecord; -pub trait Chipset: Sync { +pub trait Chipset: Send + Sync + 'static + Clone { fn input_size(&self) -> usize; fn output_size(&self) -> usize; @@ -46,7 +46,7 @@ pub trait Chipset: Sync { ) -> Vec; } -impl, C2: Chipset> Chipset for &Either { +impl, C2: Chipset> Chipset for Either { fn input_size(&self) -> usize { match self { Either::Left(c) => c.input_size(), @@ -121,7 +121,7 @@ impl, C2: Chipset> Chipset for &Either { } } -#[derive(Default)] +#[derive(Clone, Default)] pub struct NoChip; impl Chipset for NoChip { diff --git a/src/lair/execute.rs b/src/lair/execute.rs index 63f530b6..09c2d1c2 100644 --- a/src/lair/execute.rs +++ b/src/lair/execute.rs @@ -3,7 +3,7 @@ use hashbrown::HashMap; use itertools::Itertools; use p3_field::{AbstractField, PrimeField32}; use rustc_hash::FxHashMap; -use sphinx_core::stark::{Indexed, MachineRecord}; +use sphinx_core::{stark::MachineRecord, utils::SphinxCoreOpts}; use std::ops::Range; use crate::{ @@ -75,14 +75,14 @@ pub struct QueryRecord { } #[derive(Default, Clone, Debug, Eq, PartialEq)] -pub struct Shard<'a, F: PrimeField32> { +pub struct Shard { pub(crate) index: u32, // TODO: remove this `Option` once Sphinx no longer requires `Default` - pub(crate) queries: Option<&'a QueryRecord>, + pub(crate) queries: Option>, pub(crate) shard_config: ShardingConfig, } -impl<'a, F: PrimeField32> Shard<'a, F> { +impl Shard { /// Creates a new initial shard from the given `QueryRecord`. /// /// # Note @@ -90,17 +90,19 @@ impl<'a, F: PrimeField32> Shard<'a, F> { /// Make sure to call `.shard()` on a `Shard` created by `new` when generating /// the traces, otherwise you will only get the first shard's trace. #[inline] - pub fn new(queries: &'a QueryRecord) -> Self { + pub fn new(queries: &QueryRecord) -> Self { Shard { index: 0, - queries: queries.into(), + queries: Some(queries.clone()), shard_config: ShardingConfig::default(), } } #[inline] pub fn queries(&self) -> &QueryRecord { - self.queries.expect("Missing query record reference") + self.queries + .as_ref() + .expect("Missing query record reference") } pub fn get_func_range(&self, func_index: usize) -> Range { @@ -123,18 +125,9 @@ impl<'a, F: PrimeField32> Shard<'a, F> { } } -impl<'a, F: PrimeField32> Indexed for Shard<'a, F> { - fn index(&self) -> u32 { - self.index - } -} - -impl<'a, F: PrimeField32> MachineRecord for Shard<'a, F> { - type Config = ShardingConfig; - - fn set_index(&mut self, index: u32) { - self.index = index - } +impl MachineRecord for Shard { + // type Config = ShardingConfig; // FIXME + type Config = SphinxCoreOpts; fn stats(&self) -> HashMap { // TODO: use `IndexMap` instead so the original insertion order is kept @@ -183,9 +176,34 @@ impl<'a, F: PrimeField32> MachineRecord for Shard<'a, F> { // just a no-op because `generate_dependencies` is a no-op } - fn shard(self, config: &Self::Config) -> Vec { - let queries = self.queries(); - let shard_size = config.max_shard_size as usize; + fn public_values(&self) -> Vec { + self.expect_public_values() + .iter() + .map(|f| F2::from_canonical_u32(f.as_canonical_u32())) + .collect() + } +} + +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub struct ShardingConfig { + pub(crate) max_shard_size: u32, +} + +impl Default for ShardingConfig { + fn default() -> Self { + const DEFAULT_SHARD_SIZE: u32 = 1 << 22; + Self { + max_shard_size: std::env::var("SHARD_SIZE").map_or_else( + |_| DEFAULT_SHARD_SIZE, + |s| s.parse::().unwrap_or(DEFAULT_SHARD_SIZE), + ), + } + } +} + +impl ShardingConfig { + pub fn shard(&self, queries: &QueryRecord) -> Vec> { + let shard_size = self.max_shard_size as usize; let max_num_func_rows: usize = queries .func_queries .iter() @@ -208,36 +226,12 @@ impl<'a, F: PrimeField32> MachineRecord for Shard<'a, F> { for shard_index in 0..num_shards { shards.push(Shard { index: shard_index as u32, - queries: self.queries, - shard_config: *config, + queries: Some(queries.clone()), + shard_config: *self, }); } shards } - - fn public_values(&self) -> Vec { - self.expect_public_values() - .iter() - .map(|f| F2::from_canonical_u32(f.as_canonical_u32())) - .collect() - } -} - -#[derive(Clone, Copy, Debug, Eq, PartialEq)] -pub struct ShardingConfig { - pub(crate) max_shard_size: u32, -} - -impl Default for ShardingConfig { - fn default() -> Self { - const DEFAULT_SHARD_SIZE: u32 = 1 << 22; - Self { - max_shard_size: std::env::var("SHARD_SIZE").map_or_else( - |_| DEFAULT_SHARD_SIZE, - |s| s.parse::().unwrap_or(DEFAULT_SHARD_SIZE), - ), - } - } } const NUM_MEM_TABLES: usize = 5; diff --git a/src/lair/func_chip.rs b/src/lair/func_chip.rs index 9b000215..86781db5 100644 --- a/src/lair/func_chip.rs +++ b/src/lair/func_chip.rs @@ -25,37 +25,37 @@ impl LayoutSizes { } } -pub struct FuncChip<'a, F, C1: Chipset, C2: Chipset> { - pub(crate) func: &'a Func, - pub(crate) toplevel: &'a Toplevel, +pub struct FuncChip, C2: Chipset> { + pub(crate) func: Func, + pub(crate) toplevel: Toplevel, pub(crate) layout_sizes: LayoutSizes, } -impl<'a, F, C1: Chipset, C2: Chipset> FuncChip<'a, F, C1, C2> { +impl, C2: Chipset> FuncChip { #[inline] - pub fn from_name(name: &'static str, toplevel: &'a Toplevel) -> Self { + pub fn from_name(name: &'static str, toplevel: &Toplevel) -> Self { let func = toplevel.func_by_name(name); Self::from_func(func, toplevel) } #[inline] - pub fn from_index(idx: usize, toplevel: &'a Toplevel) -> Self { + pub fn from_index(idx: usize, toplevel: &Toplevel) -> Self { let func = toplevel.func_by_index(idx); Self::from_func(func, toplevel) } #[inline] - pub fn from_func(func: &'a Func, toplevel: &'a Toplevel) -> Self { + pub fn from_func(func: &Func, toplevel: &Toplevel) -> Self { let layout_sizes = func.compute_layout_sizes(toplevel); Self { - func, - toplevel, + func: func.clone(), // FIXME + toplevel: toplevel.clone(), // FIXME layout_sizes, } } #[inline] - pub fn from_toplevel(toplevel: &'a Toplevel) -> Vec { + pub fn from_toplevel(toplevel: &Toplevel) -> Vec { toplevel .func_map .values() @@ -70,18 +70,18 @@ impl<'a, F, C1: Chipset, C2: Chipset> FuncChip<'a, F, C1, C2> { #[inline] pub fn func(&self) -> &Func { - self.func + &self.func } #[inline] pub fn toplevel(&self) -> &Toplevel { - self.toplevel + &self.toplevel } } -impl<'a, F: Sync, C1: Chipset, C2: Chipset> BaseAir for FuncChip<'a, F, C1, C2> { +impl, C2: Chipset> BaseAir for FuncChip { fn width(&self) -> usize { - self.width() + self.layout_sizes.total() } } diff --git a/src/lair/lair_chip.rs b/src/lair/lair_chip.rs index 9f6b79b6..7c7466b4 100644 --- a/src/lair/lair_chip.rs +++ b/src/lair/lair_chip.rs @@ -19,8 +19,8 @@ use super::{ relations::OuterCallRelation, }; -pub enum LairChip<'a, F, C1: Chipset, C2: Chipset> { - Func(FuncChip<'a, F, C1, C2>), +pub enum LairChip, C2: Chipset> { + Func(FuncChip), Mem(MemChip), Bytes(BytesChip), Entrypoint { @@ -29,7 +29,7 @@ pub enum LairChip<'a, F, C1: Chipset, C2: Chipset> { }, } -impl<'a, F, C1: Chipset, C2: Chipset> LairChip<'a, F, C1, C2> { +impl, C2: Chipset> LairChip { #[inline] pub fn entrypoint(func: &Func) -> Self { let partial = if func.partial { DEPTH_W } else { 0 }; @@ -41,21 +41,17 @@ impl<'a, F, C1: Chipset, C2: Chipset> LairChip<'a, F, C1, C2> { } } -impl<'a, F: PrimeField32, C1: Chipset, C2: Chipset> WithEvents<'a> - for LairChip<'_, F, C1, C2> -{ - type Events = &'a Shard<'a, F>; +impl<'a, F: PrimeField32, C1: Chipset, C2: Chipset> WithEvents<'a> for LairChip { + type Events = &'a Shard; } -impl<'a, F: PrimeField32, C1: Chipset, C2: Chipset> EventLens> - for Shard<'a, F> -{ - fn events(&self) -> as WithEvents<'_>>::Events { +impl, C2: Chipset> EventLens> for Shard { + fn events(&self) -> as WithEvents<'_>>::Events { self } } -impl<'a, F: Field + Sync, C1: Chipset, C2: Chipset> BaseAir for LairChip<'a, F, C1, C2> { +impl, C2: Chipset> BaseAir for LairChip { fn width(&self) -> usize { match self { Self::Func(func_chip) => func_chip.width(), @@ -76,10 +72,8 @@ impl MachineProgram for LairMachineProgram { } } -impl<'a, F: PrimeField32, C1: Chipset, C2: Chipset> MachineAir - for LairChip<'a, F, C1, C2> -{ - type Record = Shard<'a, F>; +impl, C2: Chipset> MachineAir for LairChip { + type Record = Shard; type Program = LairMachineProgram; fn name(&self) -> String { @@ -103,7 +97,7 @@ impl<'a, F: PrimeField32, C1: Chipset, C2: Chipset> MachineAir Self::Mem(mem_chip) => mem_chip.generate_trace(shard.events()), Self::Bytes(bytes_chip) => { // TODO: Shard the byte events differently? - if shard.index() == 0 { + if shard.events().index == 0 { bytes_chip.generate_trace(&shard.events().queries().bytes) } else { bytes_chip.generate_trace(&Default::default()) @@ -153,7 +147,7 @@ impl<'a, F: PrimeField32, C1: Chipset, C2: Chipset> MachineAir } } -impl<'a, AB, C1: Chipset, C2: Chipset> Air for LairChip<'a, AB::F, C1, C2> +impl, C2: Chipset> Air for LairChip where AB: AirBuilderWithPublicValues + LookupBuilder + PairBuilder, ::Var: std::fmt::Debug, @@ -193,9 +187,9 @@ where } } -pub fn build_lair_chip_vector<'a, F: PrimeField32, C1: Chipset, C2: Chipset>( - entry_func_chip: &FuncChip<'a, F, C1, C2>, -) -> Vec> { +pub fn build_lair_chip_vector, C2: Chipset>( + entry_func_chip: &FuncChip, +) -> Vec> { let toplevel = &entry_func_chip.toplevel; let func = &entry_func_chip.func; let mut chip_vector = Vec::with_capacity(2 + toplevel.num_funcs() + MEM_TABLE_SIZES.len()); @@ -212,21 +206,20 @@ pub fn build_lair_chip_vector<'a, F: PrimeField32, C1: Chipset, C2: Chipset, C2: Chipset, - I: IntoIterator>, + I: IntoIterator>, >( lair_chips: I, -) -> Vec>> { +) -> Vec>> { lair_chips.into_iter().map(Chip::new).collect() } #[inline] -pub fn build_chip_vector<'a, F: PrimeField32, C1: Chipset, C2: Chipset>( - entry_func_chip: &FuncChip<'a, F, C1, C2>, -) -> Vec>> { +pub fn build_chip_vector, C2: Chipset>( + entry_func_chip: &FuncChip, +) -> Vec>> { build_chip_vector_from_lair_chips(build_lair_chip_vector(entry_func_chip)) } @@ -237,9 +230,10 @@ mod tests { use super::*; use p3_baby_bear::BabyBear; + use sphinx_core::stark::MachineProver; use sphinx_core::utils::BabyBearPoseidon2; use sphinx_core::{ - stark::{LocalProver, StarkGenericConfig, StarkMachine}, + stark::{DefaultProver, StarkGenericConfig, StarkMachine}, utils::SphinxCoreOpts, }; @@ -267,10 +261,14 @@ mod tests { let mut challenger_v = machine.config().challenger(); let shard = Shard::new(&queries); - machine.debug_constraints(&pk, shard.clone()); + machine.debug_constraints(&pk, vec![shard.clone()], &mut machine.config().challenger()); let opts = SphinxCoreOpts::default(); - let proof = machine.prove::>(&pk, shard, &mut challenger_p, opts); - machine + let prover = DefaultProver::new(machine); + let proof = prover + .prove(&pk, vec![shard], &mut challenger_p, opts) + .unwrap(); + prover + .machine() .verify(&vk, &proof, &mut challenger_v) .expect("proof verifies"); } diff --git a/src/lair/memory.rs b/src/lair/memory.rs index 2100205f..c4eac79e 100644 --- a/src/lair/memory.rs +++ b/src/lair/memory.rs @@ -27,7 +27,7 @@ impl MemChip { } } - pub fn generate_trace(&self, shard: &Shard<'_, F>) -> RowMajorMatrix { + pub fn generate_trace(&self, shard: &Shard) -> RowMajorMatrix { let record = &shard.queries().mem_queries; let mem_idx = mem_index_from_len(self.len); let mem = &record[mem_idx]; diff --git a/src/lair/trace.rs b/src/lair/trace.rs index 7038a28a..3029378a 100644 --- a/src/lair/trace.rs +++ b/src/lair/trace.rs @@ -69,9 +69,9 @@ impl<'a, T> ColumnMutSlice<'a, T> { } } -impl<'a, F: PrimeField32, C1: Chipset, C2: Chipset> FuncChip<'a, F, C1, C2> { +impl, C2: Chipset> FuncChip { /// Per-row parallel trace generation - pub fn generate_trace(&self, shard: &Shard<'_, F>) -> RowMajorMatrix { + pub fn generate_trace(&self, shard: &Shard) -> RowMajorMatrix { let func_queries = &shard.queries().func_queries()[self.func.index]; let range = shard.get_func_range(self.func.index); let width = self.width(); @@ -125,7 +125,7 @@ impl<'a, F: PrimeField32, C1: Chipset, C2: Chipset> FuncChip<'a, F, C1, C2 slice, queries, requires, - self.toplevel, + &self.toplevel, result.depth, depth_requires, ); @@ -436,7 +436,7 @@ mod tests { use p3_baby_bear::BabyBear as F; use p3_field::AbstractField; use sphinx_core::{ - stark::{LocalProver, MachineRecord, StarkGenericConfig, StarkMachine}, + stark::{DefaultProver, MachineProver, StarkGenericConfig, StarkMachine}, utils::{BabyBearPoseidon2, SphinxCoreOpts}, }; @@ -693,8 +693,7 @@ mod tests { let lair_chips = build_lair_chip_vector(&ack_chip); - let shard = Shard::new(&queries); - let shards = shard.clone().shard(&ShardingConfig::default()); + let shards = ShardingConfig::default().shard(&queries); assert!( shards.len() > 1, "lair_shard_test must have more than one shard" @@ -716,12 +715,15 @@ mod tests { let (pk, vk) = machine.setup(&LairMachineProgram); let mut challenger_p = machine.config().challenger(); let mut challenger_v = machine.config().challenger(); - let shard = Shard::new(&queries); - machine.debug_constraints(&pk, shard.clone()); + machine.debug_constraints(&pk, shards.clone(), &mut machine.config().challenger()); let opts = SphinxCoreOpts::default(); - let proof = machine.prove::>(&pk, shard, &mut challenger_p, opts); - machine + let prover = DefaultProver::new(machine); + let proof = prover + .prove(&pk, shards, &mut challenger_p, opts) + .expect("proof generates"); + prover + .machine() .verify(&vk, &proof, &mut challenger_v) .expect("proof verifies"); } diff --git a/src/lurk/big_num.rs b/src/lurk/big_num.rs index 14a6355c..9253b533 100644 --- a/src/lurk/big_num.rs +++ b/src/lurk/big_num.rs @@ -111,7 +111,10 @@ pub fn field_elts_to_biguint(elts: &[F]) -> BigUint { mod test { use p3_baby_bear::BabyBear as F; use p3_field::AbstractField; - use sphinx_core::{stark::StarkMachine, utils::BabyBearPoseidon2}; + use sphinx_core::{ + stark::{StarkGenericConfig, StarkMachine}, + utils::BabyBearPoseidon2, + }; use crate::{ air::debug::debug_chip_constraints_and_queries_with_sharding, @@ -177,6 +180,6 @@ mod test { let (pk, _vk) = machine.setup(&LairMachineProgram); let shard = Shard::new(&queries); - machine.debug_constraints(&pk, shard.clone()); + machine.debug_constraints(&pk, vec![shard], &mut machine.config().challenger()); } } diff --git a/src/lurk/cli/repl.rs b/src/lurk/cli/repl.rs index 76e48857..d26ad6af 100644 --- a/src/lurk/cli/repl.rs +++ b/src/lurk/cli/repl.rs @@ -12,7 +12,7 @@ use rustyline::{ Completer, Editor, Helper, Highlighter, Hinter, }; use sphinx_core::{ - stark::{LocalProver, StarkGenericConfig}, + stark::{DefaultProver, MachineProver, StarkGenericConfig}, utils::SphinxCoreOpts, }; use std::{fmt::Debug, io::Write, marker::PhantomData}; @@ -20,7 +20,7 @@ use std::{fmt::Debug, io::Write, marker::PhantomData}; use crate::{ lair::{ chipset::{Chipset, NoChip}, - execute::{DebugEntry, DebugEntryKind, QueryRecord, QueryResult, Shard}, + execute::{DebugEntry, DebugEntryKind, QueryRecord, QueryResult, ShardingConfig}, lair_chip::LairMachineProgram, toplevel::Toplevel, }, @@ -169,6 +169,7 @@ impl, C2: Chipset> Repl { let machine = new_machine(&self.toplevel); let (pk, vk) = machine.setup(&LairMachineProgram); let challenger_p = &mut machine.config().challenger(); + let prover = DefaultProver::new(machine); let must_prove = if !proof_path.exists() { true } else { @@ -177,18 +178,22 @@ impl, C2: Chipset> Repl { let machine_proof = cached_proof.into_machine_proof(); let challenger_v = &mut challenger_p.clone(); // force an overwrite if verification goes wrong - machine.verify(&vk, &machine_proof, challenger_v).is_err() + prover + .machine() + .verify(&vk, &machine_proof, challenger_v) + .is_err() } else { // force an overwrite if deserialization goes wrong true } }; if must_prove { - let challenger_v = &mut challenger_p.clone(); - let shard = Shard::new(&self.queries); let opts = SphinxCoreOpts::default(); - let machine_proof = machine.prove::>(&pk, shard, challenger_p, opts); - machine + let challenger_v = &mut challenger_p.clone(); + let sharded = ShardingConfig::default().shard(&self.queries); + let machine_proof = prover.prove(&pk, sharded, challenger_p, opts)?; + prover + .machine() .verify(&vk, &machine_proof, challenger_v) .expect("Proof verification failed"); let crypto_proof: CryptoProof = machine_proof.into(); diff --git a/src/lurk/poseidon.rs b/src/lurk/poseidon.rs index 88c3b2c5..3819b117 100644 --- a/src/lurk/poseidon.rs +++ b/src/lurk/poseidon.rs @@ -38,7 +38,8 @@ impl, const WIDTH: usize> PoseidonChipset { } } -impl, const WIDTH: usize> Chipset for PoseidonChipset +impl + 'static, const WIDTH: usize> Chipset + for PoseidonChipset where Sub1: ArraySize, { diff --git a/src/lurk/stark_machine.rs b/src/lurk/stark_machine.rs index d77a08a2..b02cf5cf 100644 --- a/src/lurk/stark_machine.rs +++ b/src/lurk/stark_machine.rs @@ -19,7 +19,7 @@ pub(crate) const NUM_PUBLIC_VALUES: usize = INPUT_SIZE + ZPTR_SIZE; /// Returns a `StarkMachine` for the Lurk toplevel, with `lurk_main` as entrypoint pub(crate) fn new_machine, C2: Chipset>( lurk_toplevel: &Toplevel, -) -> StarkMachine> { +) -> StarkMachine> { let lurk_main_idx = lurk_toplevel.func_by_name("lurk_main").index; let lurk_main_chip = FuncChip::from_index(lurk_main_idx, lurk_toplevel); StarkMachine::new( diff --git a/src/lurk/tests/lang.rs b/src/lurk/tests/lang.rs index 1406e1d5..bd0ba47d 100644 --- a/src/lurk/tests/lang.rs +++ b/src/lurk/tests/lang.rs @@ -23,6 +23,7 @@ use crate::{ use super::run_tests; +#[derive(Clone)] struct SquareGadget; impl Chipset for SquareGadget { diff --git a/src/lurk/tests/mod.rs b/src/lurk/tests/mod.rs index 947d616c..323cefc5 100644 --- a/src/lurk/tests/mod.rs +++ b/src/lurk/tests/mod.rs @@ -3,7 +3,10 @@ mod lang; use p3_baby_bear::BabyBear as F; use p3_field::AbstractField; -use sphinx_core::{stark::StarkMachine, utils::BabyBearPoseidon2}; +use sphinx_core::{ + stark::{StarkGenericConfig, StarkMachine}, + utils::BabyBearPoseidon2, +}; use crate::{ air::debug::debug_chip_constraints_and_queries_with_sharding, @@ -44,7 +47,7 @@ fn run_tests>( let lurk_main = FuncChip::from_name("lurk_main", toplevel); let result = toplevel - .execute(lurk_main.func, &input, &mut record, None) + .execute(&lurk_main.func, &input, &mut record, None) .unwrap(); assert_eq!(result.as_ref(), &expected_cloj(zstore).flatten()); @@ -67,5 +70,6 @@ fn run_tests>( record.expect_public_values().len(), ); let (pk, _) = machine.setup(&LairMachineProgram); - machine.debug_constraints(&pk, full_shard); + let mut challenger = machine.config().challenger(); + machine.debug_constraints(&pk, vec![full_shard], &mut challenger); } diff --git a/src/lurk/u64.rs b/src/lurk/u64.rs index f1e53964..ba270edd 100644 --- a/src/lurk/u64.rs +++ b/src/lurk/u64.rs @@ -225,7 +225,10 @@ impl Chipset for U64 { mod test { use p3_baby_bear::BabyBear as F; use p3_field::AbstractField; - use sphinx_core::{stark::StarkMachine, utils::BabyBearPoseidon2}; + use sphinx_core::{ + stark::{StarkGenericConfig, StarkMachine}, + utils::BabyBearPoseidon2, + }; use crate::{ air::debug::debug_chip_constraints_and_queries_with_sharding, @@ -294,7 +297,7 @@ mod test { let (pk, _vk) = machine.setup(&LairMachineProgram); let shard = Shard::new(&queries); - machine.debug_constraints(&pk, shard.clone()); + machine.debug_constraints(&pk, vec![shard], &mut machine.config().challenger()); } #[test] @@ -352,7 +355,7 @@ mod test { let (pk, _vk) = machine.setup(&LairMachineProgram); let shard = Shard::new(&queries); - machine.debug_constraints(&pk, shard.clone()); + machine.debug_constraints(&pk, vec![shard], &mut machine.config().challenger()); } #[test] @@ -414,7 +417,7 @@ mod test { let (pk, _vk) = machine.setup(&LairMachineProgram); let shard = Shard::new(&queries); - machine.debug_constraints(&pk, shard.clone()); + machine.debug_constraints(&pk, vec![shard], &mut machine.config().challenger()); } #[test] @@ -490,7 +493,7 @@ mod test { let (pk, _vk) = machine.setup(&LairMachineProgram); let shard = Shard::new(&queries); - machine.debug_constraints(&pk, shard.clone()); + machine.debug_constraints(&pk, vec![shard], &mut machine.config().challenger()); } #[test] @@ -545,7 +548,7 @@ mod test { let (pk, _vk) = machine.setup(&LairMachineProgram); let shard = Shard::new(&queries); - machine.debug_constraints(&pk, shard.clone()); + machine.debug_constraints(&pk, vec![shard], &mut machine.config().challenger()); } #[test] @@ -582,7 +585,7 @@ mod test { let (pk, _vk) = machine.setup(&LairMachineProgram); let shard = Shard::new(&queries); - machine.debug_constraints(&pk, shard.clone()); + machine.debug_constraints(&pk, vec![shard], &mut machine.config().challenger()); let mut queries = QueryRecord::new(&toplevel); let args = &[f(0), f(0), f(0), f(123), f(0), f(0), f(0), f(0)]; @@ -603,6 +606,6 @@ mod test { let (pk, _vk) = machine.setup(&LairMachineProgram); let shard = Shard::new(&queries); - machine.debug_constraints(&pk, shard.clone()); + machine.debug_constraints(&pk, vec![shard], &mut machine.config().challenger()); } } diff --git a/src/poseidon/config.rs b/src/poseidon/config.rs index 79bb1387..ceeae223 100644 --- a/src/poseidon/config.rs +++ b/src/poseidon/config.rs @@ -17,7 +17,9 @@ trait ConstantsProvided {} /// The Poseidon configuration trait storing the data needed for #[allow(non_camel_case_types, private_bounds)] -pub trait PoseidonConfig: Clone + Copy + Sync + ConstantsProvided { +pub trait PoseidonConfig: + Clone + Copy + Send + Sync + ConstantsProvided +{ type F: PrimeField; type R_P: ArraySize + Sub; type R_F: ArraySize; diff --git a/tests/fib.rs b/tests/fib.rs index 58b1aaa9..2c8adf7a 100644 --- a/tests/fib.rs +++ b/tests/fib.rs @@ -6,7 +6,7 @@ use p3_baby_bear::BabyBear; use p3_field::AbstractField; use sphinx_core::{ - stark::{LocalProver, StarkGenericConfig, StarkMachine}, + stark::{DefaultProver, MachineProver, StarkGenericConfig, StarkMachine}, utils::{BabyBearPoseidon2, SphinxCoreOpts}, }; use std::time::Instant; @@ -50,7 +50,7 @@ fn setup>( toplevel: &Toplevel, ) -> ( List, - FuncChip<'_, BabyBear, C, NoChip>, + FuncChip, QueryRecord, ) { let code = build_lurk_expr(arg); @@ -91,7 +91,10 @@ fn fib_e2e() { let mut challenger_p = machine.config().challenger(); let opts = SphinxCoreOpts::default(); let shard = Shard::new(&record); - machine.prove::>(&pk, shard, &mut challenger_p, opts); + let prover = DefaultProver::new(machine); + prover + .prove(&pk, vec![shard], &mut challenger_p, opts) + .unwrap(); let elapsed_time = start_time.elapsed().as_secs_f32(); println!("Total time for e2e-{arg} = {:.2} s", elapsed_time);