From a46c994897cc81b60828813e1ac3e775e278e244 Mon Sep 17 00:00:00 2001 From: Andrew Milson Date: Sun, 15 Sep 2024 19:10:52 -1000 Subject: [PATCH] Add arithmetic op counts to InfoEvaluator --- .../prover/src/constraint_framework/info.rs | 338 +++++++++++++++++- .../prover/src/constraint_framework/logup.rs | 5 +- crates/prover/src/core/circle.rs | 2 +- crates/prover/src/core/lookups/utils.rs | 14 +- crates/prover/src/examples/blake/mod.rs | 3 +- .../src/examples/blake/round/constraints.rs | 6 +- .../examples/blake/scheduler/constraints.rs | 12 +- crates/prover/src/examples/plonk/mod.rs | 4 +- crates/prover/src/examples/poseidon/mod.rs | 2 +- 9 files changed, 350 insertions(+), 36 deletions(-) diff --git a/crates/prover/src/constraint_framework/info.rs b/crates/prover/src/constraint_framework/info.rs index 05da93f6f..d2f42091c 100644 --- a/crates/prover/src/constraint_framework/info.rs +++ b/crates/prover/src/constraint_framework/info.rs @@ -1,18 +1,24 @@ -use std::ops::Mul; +use std::array; +use std::cell::{RefCell, RefMut}; +use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub}; +use std::rc::Rc; -use num_traits::One; +use num_traits::{One, Zero}; use super::EvalAtRow; use crate::core::fields::m31::BaseField; use crate::core::fields::qm31::SecureField; +use crate::core::fields::FieldExpOps; use crate::core::pcs::TreeVec; /// Collects information about the constraints. -/// This includes mask offsets and columns at each interaction, and the number of constraints. +/// This includes mask offsets and columns at each interaction, the number of constraints and number +/// of arithmetic operations. #[derive(Default)] pub struct InfoEvaluator { pub mask_offsets: TreeVec>>, pub n_constraints: usize, + pub arithmetic_counts: ArithmeticCounts, } impl InfoEvaluator { pub fn new() -> Self { @@ -20,8 +26,8 @@ impl InfoEvaluator { } } impl EvalAtRow for InfoEvaluator { - type F = BaseField; - type EF = SecureField; + type F = BaseFieldCounter; + type EF = SecureFieldCounter; fn next_interaction_mask( &mut self, interaction: usize, @@ -33,16 +39,330 @@ impl EvalAtRow for InfoEvaluator { self.mask_offsets.resize(interaction + 1, vec![]); } self.mask_offsets[interaction].push(offsets.into_iter().collect()); - [BaseField::one(); N] + array::from_fn(|_| BaseFieldCounter::one()) } - fn add_constraint(&mut self, _constraint: G) + fn add_constraint(&mut self, constraint: G) where Self::EF: Mul, { + let lin_combination = SecureFieldCounter::one() + SecureFieldCounter::one() * constraint; + self.arithmetic_counts.merge(lin_combination.drain()); self.n_constraints += 1; } - fn combine_ef(_values: [Self::F; 4]) -> Self::EF { - SecureField::one() + fn combine_ef(values: [Self::F; 4]) -> Self::EF { + let mut res = SecureFieldCounter::zero(); + values.map(|v| res.merge(v)); + res + } +} + +/// Stores a count of field operations. +#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)] +pub struct ArithmeticCounts { + /// Number of `ExtensionField * ExtensionField` operations. + pub n_ef_mul_ef: usize, + /// Number of `ExtensionField * BaseField` operations. + pub n_ef_mul_f: usize, + /// Number of `ExtensionField + ExtensionField` operations. + pub n_ef_add_ef: usize, + /// Number of `ExtensionField + BaseField` operations. + pub n_ef_add_f: usize, + /// Number of `BaseField * BaseField` operations. + pub n_f_mul_f: usize, + /// Number of `BaseField + BaseField` operations. + pub n_f_add_f: usize, +} + +impl ArithmeticCounts { + fn merge(&mut self, other: ArithmeticCounts) { + self.n_ef_mul_ef += other.n_ef_mul_ef; + self.n_ef_mul_f += other.n_ef_mul_f; + self.n_ef_add_f += other.n_ef_add_f; + self.n_ef_add_ef += other.n_ef_add_ef; + self.n_f_mul_f += other.n_f_mul_f; + self.n_f_add_f += other.n_f_add_f; + } +} + +#[derive(Debug, Default, Clone)] +pub struct ArithmeticCounter(Rc>); + +pub type BaseFieldCounter = ArithmeticCounter; + +pub type SecureFieldCounter = ArithmeticCounter; + +impl ArithmeticCounter { + fn merge( + &mut self, + other: ArithmeticCounter, + ) { + // Skip if they come from the same source. + if Rc::ptr_eq(&self.0, &other.0) { + return; + } + + self.counts().merge(other.drain()); + } + + fn drain(self) -> ArithmeticCounts { + self.0.take() + } + + fn counts(&mut self) -> RefMut<'_, ArithmeticCounts> { + self.0.borrow_mut() + } +} + +impl Zero for ArithmeticCounter { + fn zero() -> Self { + Self::default() + } + + fn is_zero(&self) -> bool { + // TODO(andrew): Consider removing Zero from EvalAtRow::F, EvalAtRow::EF since is_zero + // doesn't make sense. Creating zero elements does though. + panic!() + } +} + +impl One for ArithmeticCounter { + fn one() -> Self { + Self::default() + } +} + +impl Add for ArithmeticCounter { + type Output = Self; + + fn add(mut self, rhs: Self) -> Self { + self.merge(rhs); + { + let mut counts = self.counts(); + match IS_EXT_FIELD { + true => counts.n_ef_add_ef += 1, + false => counts.n_f_add_f += 1, + } + } + self + } +} + +impl Sub for ArithmeticCounter { + type Output = Self; + + #[allow(clippy::suspicious_arithmetic_impl)] + fn sub(self, rhs: Self) -> Self { + // Treat as addition. + self + rhs + } +} + +impl Add for SecureFieldCounter { + type Output = Self; + + fn add(mut self, rhs: BaseFieldCounter) -> Self { + self.merge(rhs); + self.counts().n_ef_add_f += 1; + self + } +} + +impl Mul for ArithmeticCounter { + type Output = Self; + + fn mul(mut self, rhs: Self) -> Self { + self.merge(rhs); + { + let mut counts = self.counts(); + match IS_EXT_FIELD { + true => counts.n_ef_mul_ef += 1, + false => counts.n_f_mul_f += 1, + } + } + self + } +} + +impl Mul for SecureFieldCounter { + type Output = SecureFieldCounter; + + #[allow(clippy::suspicious_arithmetic_impl)] + fn mul(mut self, rhs: BaseFieldCounter) -> Self { + self.merge(rhs); + self.counts().n_ef_mul_f += 1; + self + } +} + +impl MulAssign for ArithmeticCounter { + fn mul_assign(&mut self, rhs: Self) { + *self = self.clone() * rhs + } +} + +impl AddAssign for ArithmeticCounter { + fn add_assign(&mut self, rhs: Self) { + *self = self.clone() + rhs + } +} + +impl AddAssign for BaseFieldCounter { + fn add_assign(&mut self, _rhs: BaseField) { + *self = self.clone() + BaseFieldCounter::zero() + } +} + +impl Mul for BaseFieldCounter { + type Output = Self; + + #[allow(clippy::suspicious_arithmetic_impl)] + fn mul(self, _rhs: BaseField) -> Self { + self * BaseFieldCounter::zero() + } +} + +impl Mul for SecureFieldCounter { + type Output = Self; + + #[allow(clippy::suspicious_arithmetic_impl)] + fn mul(self, _rhs: SecureField) -> Self { + self * SecureFieldCounter::zero() + } +} + +impl Add for BaseFieldCounter { + type Output = SecureFieldCounter; + + fn add(self, _rhs: SecureField) -> SecureFieldCounter { + SecureFieldCounter::zero() + self + } +} + +impl Add for SecureFieldCounter { + type Output = Self; + + fn add(self, _rhs: SecureField) -> Self { + self + SecureFieldCounter::zero() + } +} + +impl Sub for SecureFieldCounter { + type Output = Self; + + #[allow(clippy::suspicious_arithmetic_impl)] + fn sub(self, rhs: SecureField) -> Self { + // Tread subtraction as addition + self + rhs + } +} + +impl Mul for BaseFieldCounter { + type Output = SecureFieldCounter; + + fn mul(self, _rhs: SecureField) -> SecureFieldCounter { + SecureFieldCounter::zero() * self + } +} + +impl From for BaseFieldCounter { + fn from(_value: BaseField) -> Self { + Self::one() + } +} + +impl From for SecureFieldCounter { + fn from(_value: SecureField) -> Self { + Self::one() + } +} + +impl From for SecureFieldCounter { + fn from(value: BaseFieldCounter) -> Self { + Self(value.0) + } +} + +impl Neg for ArithmeticCounter { + type Output = Self; + + fn neg(self) -> Self { + // Treat as addition. + self + ArithmeticCounter::::zero() + } +} + +impl FieldExpOps for ArithmeticCounter { + fn inverse(&self) -> Self { + todo!() + } +} + +#[cfg(test)] +mod tests { + use num_traits::{One, Zero}; + + use super::SecureFieldCounter; + use crate::constraint_framework::info::{ArithmeticCounts, BaseFieldCounter}; + use crate::core::fields::m31::BaseField; + use crate::core::fields::qm31::SecureField; + + #[test] + fn test_arithmetic_counter() { + const N_EF_MUL_EF: usize = 1; + const N_EF_MUL_F: usize = 2; + const N_EF_MUL_ASSIGN_EF: usize = 1; + const N_EF_MUL_SECURE_FIELD: usize = 3; + const N_EF_ADD_EF: usize = 4; + const N_EF_ADD_ASSIGN_EF: usize = 4; + const N_EF_ADD_F: usize = 5; + const N_EF_NEG: usize = 6; + const N_EF_SUB_EF: usize = 7; + const N_F_MUL_F: usize = 8; + const N_F_MUL_ASSIGN_F: usize = 8; + const N_F_MUL_BASE_FIELD: usize = 9; + const N_F_ADD_F: usize = 10; + const N_F_ADD_ASSIGN_F: usize = 4; + const N_F_ADD_ASSIGN_BASE_FIELD: usize = 4; + const N_F_NEG: usize = 11; + const N_F_SUB_F: usize = 12; + let mut ef = SecureFieldCounter::zero(); + let mut f = BaseFieldCounter::zero(); + + (0..N_EF_MUL_EF).for_each(|_| ef = ef.clone() * ef.clone()); + (0..N_EF_MUL_F).for_each(|_| ef = ef.clone() * f.clone()); + (0..N_EF_MUL_SECURE_FIELD).for_each(|_| ef = ef.clone() * SecureField::one()); + (0..N_EF_MUL_ASSIGN_EF).for_each(|_| ef *= ef.clone()); + (0..N_EF_ADD_EF).for_each(|_| ef = ef.clone() + ef.clone()); + (0..N_EF_ADD_ASSIGN_EF).for_each(|_| ef += ef.clone()); + (0..N_EF_ADD_F).for_each(|_| ef = ef.clone() + f.clone()); + (0..N_EF_NEG).for_each(|_| ef = -ef.clone()); + (0..N_EF_SUB_EF).for_each(|_| ef = ef.clone() - ef.clone()); + (0..N_F_MUL_F).for_each(|_| f = f.clone() * f.clone()); + (0..N_F_MUL_ASSIGN_F).for_each(|_| f *= f.clone()); + (0..N_F_MUL_BASE_FIELD).for_each(|_| f = f.clone() * BaseField::one()); + (0..N_F_ADD_F).for_each(|_| f = f.clone() + f.clone()); + (0..N_F_ADD_ASSIGN_F).for_each(|_| f += f.clone()); + (0..N_F_ADD_ASSIGN_BASE_FIELD).for_each(|_| f += BaseField::one()); + (0..N_F_NEG).for_each(|_| f = -f.clone()); + (0..N_F_SUB_F).for_each(|_| f = f.clone() - f.clone()); + let mut res = f.drain(); + res.merge(ef.drain()); + + assert_eq!( + res, + ArithmeticCounts { + n_ef_mul_ef: N_EF_MUL_EF + N_EF_MUL_SECURE_FIELD + N_EF_MUL_ASSIGN_EF, + n_ef_mul_f: N_EF_MUL_F, + n_ef_add_ef: N_EF_ADD_EF + N_EF_NEG + N_EF_SUB_EF + N_EF_ADD_ASSIGN_EF, + n_ef_add_f: N_EF_ADD_F, + n_f_mul_f: N_F_MUL_F + N_F_MUL_BASE_FIELD + N_F_MUL_ASSIGN_F, + n_f_add_f: N_F_ADD_F + + N_F_NEG + + N_F_SUB_F + + N_F_ADD_ASSIGN_BASE_FIELD + + N_F_ADD_ASSIGN_F, + } + ); } } diff --git a/crates/prover/src/constraint_framework/logup.rs b/crates/prover/src/constraint_framework/logup.rs index 387a8600f..7fec3a737 100644 --- a/crates/prover/src/constraint_framework/logup.rs +++ b/crates/prover/src/constraint_framework/logup.rs @@ -308,11 +308,10 @@ mod tests { #[test] #[should_panic] fn test_logup_not_finalized_panic() { - let mut logup = - LogupAtRow::::new(1, SecureField::one(), None, BaseField::one()); + let mut logup = LogupAtRow::::new(1, SecureField::one(), None, One::one()); logup.write_frac( &mut InfoEvaluator::default(), - Fraction::new(SecureField::one(), SecureField::one()), + Fraction::new(SecureField::one().into(), SecureField::one().into()), ); } diff --git a/crates/prover/src/core/circle.rs b/crates/prover/src/core/circle.rs index 0fbdd64b4..8cfe48ab8 100644 --- a/crates/prover/src/core/circle.rs +++ b/crates/prover/src/core/circle.rs @@ -104,7 +104,7 @@ impl + FieldExpOps + Sub + Neg } } - pub fn into_ef>(&self) -> CirclePoint { + pub fn into_ef>(self) -> CirclePoint { CirclePoint { x: self.x.clone().into(), y: self.y.clone().into(), diff --git a/crates/prover/src/core/lookups/utils.rs b/crates/prover/src/core/lookups/utils.rs index 035579e5f..ed67477f7 100644 --- a/crates/prover/src/core/lookups/utils.rs +++ b/crates/prover/src/core/lookups/utils.rs @@ -195,14 +195,12 @@ where } /// Projective fraction. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Copy)] pub struct Fraction { pub numerator: N, pub denominator: D, } -impl Copy for Fraction {} - impl Fraction { pub fn new(numerator: N, denominator: D) -> Self { Self { @@ -212,17 +210,15 @@ impl Fraction { } } -impl< - N: Clone, - D: Add + Add + Mul + Mul + Clone, - > Add for Fraction +impl + Add + Mul + Mul + Clone> Add + for Fraction { type Output = Fraction; fn add(self, rhs: Self) -> Fraction { Fraction { - numerator: rhs.denominator.clone() * self.numerator.clone() - + self.denominator.clone() * rhs.numerator.clone(), + numerator: rhs.denominator.clone() * self.numerator + + self.denominator.clone() * rhs.numerator, denominator: self.denominator * rhs.denominator, } } diff --git a/crates/prover/src/examples/blake/mod.rs b/crates/prover/src/examples/blake/mod.rs index 6bc504e61..3e4ad042e 100644 --- a/crates/prover/src/examples/blake/mod.rs +++ b/crates/prover/src/examples/blake/mod.rs @@ -112,8 +112,7 @@ where + Sub + Mul, { - #[allow(clippy::wrong_self_convention)] - fn to_felts(self) -> [F; 2] { + fn into_felts(self) -> [F; 2] { [self.l, self.h] } } diff --git a/crates/prover/src/examples/blake/round/constraints.rs b/crates/prover/src/examples/blake/round/constraints.rs index b5e415d14..0a12fea79 100644 --- a/crates/prover/src/examples/blake/round/constraints.rs +++ b/crates/prover/src/examples/blake/round/constraints.rs @@ -71,9 +71,9 @@ impl<'a, E: EvalAtRow> BlakeRoundEval<'a, E> { -E::EF::one(), self.round_lookup_elements.combine( &chain![ - input_v.iter().cloned().flat_map(Fu32::to_felts), - v.iter().cloned().flat_map(Fu32::to_felts), - m.iter().cloned().flat_map(Fu32::to_felts) + input_v.iter().cloned().flat_map(Fu32::into_felts), + v.iter().cloned().flat_map(Fu32::into_felts), + m.iter().cloned().flat_map(Fu32::into_felts) ] .collect_vec(), ), diff --git a/crates/prover/src/examples/blake/scheduler/constraints.rs b/crates/prover/src/examples/blake/scheduler/constraints.rs index ea0ad2c34..ddc557e90 100644 --- a/crates/prover/src/examples/blake/scheduler/constraints.rs +++ b/crates/prover/src/examples/blake/scheduler/constraints.rs @@ -28,9 +28,9 @@ pub fn eval_blake_scheduler_constraints( let round_messages = SIGMA[idx].map(|k| messages[k as usize].clone()); round_lookup_elements.combine::( &chain![ - input_state.iter().cloned().flat_map(Fu32::to_felts), - output_state.iter().cloned().flat_map(Fu32::to_felts), - round_messages.iter().cloned().flat_map(Fu32::to_felts) + input_state.iter().cloned().flat_map(Fu32::into_felts), + output_state.iter().cloned().flat_map(Fu32::into_felts), + round_messages.iter().cloned().flat_map(Fu32::into_felts) ] .collect_vec(), ) @@ -48,9 +48,9 @@ pub fn eval_blake_scheduler_constraints( E::EF::zero(), blake_lookup_elements.combine( &chain![ - input_state.iter().cloned().flat_map(Fu32::to_felts), - output_state.iter().cloned().flat_map(Fu32::to_felts), - messages.iter().cloned().flat_map(Fu32::to_felts) + input_state.iter().cloned().flat_map(Fu32::into_felts), + output_state.iter().cloned().flat_map(Fu32::into_felts), + messages.iter().cloned().flat_map(Fu32::into_felts) ] .collect_vec(), ), diff --git a/crates/prover/src/examples/plonk/mod.rs b/crates/prover/src/examples/plonk/mod.rs index db6502fc9..2832d0cca 100644 --- a/crates/prover/src/examples/plonk/mod.rs +++ b/crates/prover/src/examples/plonk/mod.rs @@ -17,7 +17,7 @@ use crate::core::backend::Column; use crate::core::channel::Blake2sChannel; use crate::core::fields::m31::BaseField; use crate::core::fields::qm31::SecureField; -use crate::core::lookups::utils::Fraction; +use crate::core::lookups::utils::{Fraction, Reciprocal}; use crate::core::pcs::{CommitmentSchemeProver, PcsConfig, TreeSubspan}; use crate::core::poly::circle::{CanonicCoset, CircleEvaluation, PolyOps}; use crate::core::poly::BitReversedOrder; @@ -73,7 +73,7 @@ impl FrameworkEval for PlonkEval { logup.write_frac( &mut eval, - Fraction::new(denom_a.clone() + denom_b.clone(), denom_a * denom_b), + Reciprocal::new(denom_a) + Reciprocal::new(denom_b), ); logup.write_frac( &mut eval, diff --git a/crates/prover/src/examples/poseidon/mod.rs b/crates/prover/src/examples/poseidon/mod.rs index 0cbc3ebbd..269214ca1 100644 --- a/crates/prover/src/examples/poseidon/mod.rs +++ b/crates/prover/src/examples/poseidon/mod.rs @@ -27,7 +27,7 @@ use crate::core::prover::{prove, StarkProof}; use crate::core::vcs::blake2_merkle::{Blake2sMerkleChannel, Blake2sMerkleHasher}; use crate::core::ColumnVec; -const N_LOG_INSTANCES_PER_ROW: usize = 3; +const N_LOG_INSTANCES_PER_ROW: usize = 0; const N_INSTANCES_PER_ROW: usize = 1 << N_LOG_INSTANCES_PER_ROW; const N_STATE: usize = 16; const N_PARTIAL_ROUNDS: usize = 14;