diff --git a/README.md b/README.md index 2986250..66b57a7 100644 --- a/README.md +++ b/README.md @@ -31,4 +31,8 @@ This project may contain trademarks or logos for projects, products, or services trademarks or logos is subject to and must follow [Microsoft's Trademark & Brand Guidelines](https://www.microsoft.com/en-us/legal/intellectualproperty/trademarks/usage/general). Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship. -Any use of third-party trademarks or logos are subject to those third-party's policies. \ No newline at end of file +Any use of third-party trademarks or logos are subject to those third-party's policies. + +## Examples + +Run `cargo run --example EXAMPLE_NAME` to run the corresponding example. Leave `EXAMPLE_NAME` empty for a list of available examples. diff --git a/examples/less_than.rs b/examples/less_than.rs new file mode 100644 index 0000000..a3bac59 --- /dev/null +++ b/examples/less_than.rs @@ -0,0 +1,236 @@ +use bellpepper_core::{ + boolean::AllocatedBit, num::AllocatedNum, Circuit, ConstraintSystem, LinearCombination, + SynthesisError, +}; +use ff::{PrimeField, PrimeFieldBits}; +use pasta_curves::Fq; +use spartan2::{ + errors::SpartanError, + traits::{snark::RelaxedR1CSSNARKTrait, Group}, + SNARK, +}; + +fn num_to_bits_le_bounded>( + cs: &mut CS, + n: AllocatedNum, + num_bits: u8, +) -> Result, SynthesisError> { + let opt_bits = match n.get_value() { + Some(v) => v + .to_le_bits() + .into_iter() + .take(num_bits as usize) + .map(Some) + .collect::>>(), + None => vec![None; num_bits as usize], + }; + + // Add one witness per input bit in little-endian bit order + let bits_circuit = opt_bits.into_iter() + .enumerate() + // AllocateBit enforces the value to be 0 or 1 at the constraint level + .map(|(i, b)| AllocatedBit::alloc(cs.namespace(|| format!("b_{}", i)), b).unwrap()) + .collect::>(); + + let mut weighted_sum_lc = LinearCombination::zero(); + let mut pow2 = F::ONE; + + for bit in bits_circuit.iter() { + weighted_sum_lc = weighted_sum_lc + (pow2, bit.get_variable()); + pow2 = pow2.double(); + } + + cs.enforce( + || "bit decomposition check", + |lc| lc + &weighted_sum_lc, + |lc| lc + CS::one(), + |lc| lc + n.get_variable(), + ); + + Ok(bits_circuit) +} + +fn get_msb_index(n: F) -> u8 { + n.to_le_bits() + .into_iter() + .enumerate() + .rev() + .find(|(_, b)| *b) + .unwrap() + .0 as u8 +} + +// Constrains `input` < `bound`, where the LHS is a witness and the RHS is a +// constant. The bound must fit into `num_bits` bits (this is asserted in the +// circuit constructor). +// Important: it must be checked elsewhere (in an overarching circuit) that the +// input fits into `num_bits` bits - this is NOT constrained by this circuit +// in order to avoid duplications (hence "unsafe"). Cf. LessThanCircuitSafe for +// a safe version. +#[derive(Clone, Debug)] +struct LessThanCircuitUnsafe { + bound: F, // Will be a constant in the constraits, not a variable + input: F, // Will be an input/output variable + num_bits: u8, +} + +impl LessThanCircuitUnsafe { + fn new(bound: F, input: F, num_bits: u8) -> Self { + assert!(get_msb_index(bound) < num_bits); + Self { + bound, + input, + num_bits, + } + } +} + +impl Circuit for LessThanCircuitUnsafe { + fn synthesize>(self, cs: &mut CS) -> Result<(), SynthesisError> { + assert!(F::NUM_BITS > self.num_bits as u32 + 1); + + let input = AllocatedNum::alloc(cs.namespace(|| "input"), || Ok(self.input))?; + + let shifted_diff = AllocatedNum::alloc(cs.namespace(|| "shifted_diff"), || { + Ok(self.input + F::from(1 << self.num_bits) - self.bound) + })?; + + cs.enforce( + || "shifted_diff_computation", + |lc| lc + input.get_variable() + (F::from(1 << self.num_bits) - self.bound, CS::one()), + |lc: LinearCombination| lc + CS::one(), + |lc| lc + shifted_diff.get_variable(), + ); + + let shifted_diff_bits = num_to_bits_le_bounded::(cs, shifted_diff, self.num_bits + 1)?; + + // Check that the last (i.e. most sifnificant) bit is 0 + cs.enforce( + || "bound_check", + |lc| lc + shifted_diff_bits[self.num_bits as usize].get_variable(), + |lc| lc + CS::one(), + |lc| lc + (F::ZERO, CS::one()), + ); + + Ok(()) + } +} + +// Constrains `input` < `bound`, where the LHS is a witness and the RHS is a +// constant. The bound must fit into `num_bits` bits (this is asserted in the +// circuit constructor). +// Furthermore, the input must fit into `num_bits`, which is enforced at the +// constraint level. +#[derive(Clone, Debug)] +struct LessThanCircuitSafe { + bound: F, + input: F, + num_bits: u8, +} + +impl LessThanCircuitSafe { + fn new(bound: F, input: F, num_bits: u8) -> Self { + assert!(get_msb_index(bound) < num_bits); + Self { + bound, + input, + num_bits, + } + } +} + +impl Circuit for LessThanCircuitSafe { + fn synthesize>(self, cs: &mut CS) -> Result<(), SynthesisError> { + let input = AllocatedNum::alloc(cs.namespace(|| "input"), || Ok(self.input))?; + + // Perform the input bit decomposition check + num_to_bits_le_bounded::(cs, input, self.num_bits)?; + + // Entering a new namespace to prefix variables in the + // LessThanCircuitUnsafe, thus avoiding name clashes + cs.push_namespace(|| "less_than_safe"); + + LessThanCircuitUnsafe { + bound: self.bound, + input: self.input, + num_bits: self.num_bits, + } + .synthesize(cs) + } +} + +fn verify_circuit_unsafe>( + bound: G::Scalar, + input: G::Scalar, + num_bits: u8, +) -> Result<(), SpartanError> { + let circuit = LessThanCircuitUnsafe::new(bound, input, num_bits); + + // produce keys + let (pk, vk) = SNARK::>::setup(circuit.clone()).unwrap(); + + // produce a SNARK + let snark = SNARK::prove(&pk, circuit).unwrap(); + + // verify the SNARK + snark.verify(&vk, &[]) +} + +fn verify_circuit_safe>( + bound: G::Scalar, + input: G::Scalar, + num_bits: u8, +) -> Result<(), SpartanError> { + let circuit = LessThanCircuitSafe::new(bound, input, num_bits); + + // produce keys + let (pk, vk) = SNARK::>::setup(circuit.clone()).unwrap(); + + // produce a SNARK + let snark = SNARK::prove(&pk, circuit).unwrap(); + + // verify the SNARK + snark.verify(&vk, &[]) +} + +fn main() { + type G = pasta_curves::pallas::Point; + type EE = spartan2::provider::ipa_pc::EvaluationEngine; + type S = spartan2::spartan::snark::RelaxedR1CSSNARK; + + println!("Executing unsafe circuit..."); + //Typical example, ok + assert!(verify_circuit_unsafe::(Fq::from(17), Fq::from(9), 10).is_ok()); + // Typical example, err + assert!(verify_circuit_unsafe::(Fq::from(17), Fq::from(20), 10).is_err()); + // Edge case, err + assert!(verify_circuit_unsafe::(Fq::from(4), Fq::from(4), 10).is_err()); + // Edge case, ok + assert!(verify_circuit_unsafe::(Fq::from(4), Fq::from(3), 10).is_ok()); + // Minimum number of bits for the bound, ok + assert!(verify_circuit_unsafe::(Fq::from(4), Fq::from(3), 3).is_ok()); + // Insufficient number of bits for the input, but this is not detected by the + // unsafety of the circuit (compare with the last example below) + // Note that -Fq::one() is corresponds to q - 1 > bound + assert!(verify_circuit_unsafe::(Fq::from(4), -Fq::one(), 3).is_ok()); + + println!("Unsafe circuit OK"); + + println!("Executing safe circuit..."); + // Typical example, ok + assert!(verify_circuit_safe::(Fq::from(17), Fq::from(9), 10).is_ok()); + // Typical example, err + assert!(verify_circuit_safe::(Fq::from(17), Fq::from(20), 10).is_err()); + // Edge case, err + assert!(verify_circuit_safe::(Fq::from(4), Fq::from(4), 10).is_err()); + // Edge case, ok + assert!(verify_circuit_safe::(Fq::from(4), Fq::from(3), 10).is_ok()); + // Minimum number of bits for the bound, ok + assert!(verify_circuit_safe::(Fq::from(4), Fq::from(3), 3).is_ok()); + // Insufficient number of bits for the input, err (compare with the last example + // above). + // Note that -Fq::one() is corresponds to q - 1 > bound + assert!(verify_circuit_safe::(Fq::from(4), -Fq::one(), 3).is_err()); + + println!("Safe circuit OK"); +} diff --git a/src/lib.rs b/src/lib.rs index 5a60502..b55f0f0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -76,6 +76,7 @@ impl, C: Circuit> SNARK Result<(ProverKey, VerifierKey), SpartanError> { let (pk, vk) = S::setup(circuit)?; + Ok((ProverKey { pk }, VerifierKey { vk })) } @@ -108,15 +109,12 @@ mod tests { use super::*; use crate::provider::{bn256_grumpkin::bn256, secp_secq::secp256k1}; use bellpepper_core::{num::AllocatedNum, ConstraintSystem, SynthesisError}; - use core::marker::PhantomData; use ff::PrimeField; #[derive(Clone, Debug, Default)] - struct CubicCircuit { - _p: PhantomData, - } + struct CubicCircuit {} - impl Circuit for CubicCircuit + impl Circuit for CubicCircuit where F: PrimeField, { @@ -178,8 +176,7 @@ mod tests { let circuit = CubicCircuit::default(); // produce keys - let (pk, vk) = - SNARK::::Scalar>>::setup(circuit.clone()).unwrap(); + let (pk, vk) = SNARK::::setup(circuit.clone()).unwrap(); // produce a SNARK let res = SNARK::prove(&pk, circuit); diff --git a/src/spartan/snark.rs b/src/spartan/snark.rs index a1ad696..8520650 100644 --- a/src/spartan/snark.rs +++ b/src/spartan/snark.rs @@ -101,6 +101,26 @@ impl> RelaxedR1CSSNARKTrait for Relaxe ) -> Result<(Self::ProverKey, Self::VerifierKey), SpartanError> { let mut cs: ShapeCS = ShapeCS::new(); let _ = circuit.synthesize(&mut cs); + + // Padding the ShapeCS: constraints (rows) and variables (columns) + let num_constraints = cs.num_constraints(); + + (num_constraints..num_constraints.next_power_of_two()).for_each(|i| { + cs.enforce( + || format!("padding_constraint_{i}"), + |lc| lc, + |lc| lc, + |lc| lc, + ) + }); + + let num_vars = cs.num_aux(); + + (num_vars..num_vars.next_power_of_two()).for_each(|i| { + cs.alloc(|| format!("padding_var_{i}"), || Ok(G::Scalar::ZERO)) + .unwrap(); + }); + let (S, ck) = cs.r1cs_shape(); let (pk_ee, vk_ee) = EE::setup(&ck); @@ -121,6 +141,14 @@ impl> RelaxedR1CSSNARKTrait for Relaxe let mut cs: SatisfyingAssignment = SatisfyingAssignment::new(); let _ = circuit.synthesize(&mut cs); + // Padding variables + let num_vars = cs.aux_slice().len(); + + (num_vars..num_vars.next_power_of_two()).for_each(|i| { + cs.alloc(|| format!("padding_var_{i}"), || Ok(G::Scalar::ZERO)) + .unwrap(); + }); + let (u, w) = cs .r1cs_instance_and_witness(&pk.S, &pk.ck) .map_err(|_e| SpartanError::UnSat)?;