Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LessThan circuits and padding #2

Merged
merged 8 commits into from
Mar 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,8 @@ This project may contain trademarks or logos for projects, products, or services
trademarks or logos is subject to and must follow
[Microsoft's Trademark & Brand Guidelines](https://www.microsoft.com/en-us/legal/intellectualproperty/trademarks/usage/general).
Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship.
Any use of third-party trademarks or logos are subject to those third-party's policies.
Any use of third-party trademarks or logos are subject to those third-party's policies.

## Examples

Run `cargo run --example EXAMPLE_NAME` to run the corresponding example. Leave `EXAMPLE_NAME` empty for a list of available examples.
236 changes: 236 additions & 0 deletions examples/less_than.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,236 @@
use bellpepper_core::{
boolean::AllocatedBit, num::AllocatedNum, Circuit, ConstraintSystem, LinearCombination,
SynthesisError,
};
use ff::{PrimeField, PrimeFieldBits};
use pasta_curves::Fq;
use spartan2::{
errors::SpartanError,
traits::{snark::RelaxedR1CSSNARKTrait, Group},
SNARK,
};

fn num_to_bits_le_bounded<F: PrimeField + PrimeFieldBits, CS: ConstraintSystem<F>>(
cs: &mut CS,
n: AllocatedNum<F>,
num_bits: u8,
) -> Result<Vec<AllocatedBit>, SynthesisError> {
let opt_bits = match n.get_value() {
Some(v) => v
.to_le_bits()
.into_iter()
.take(num_bits as usize)
.map(Some)
.collect::<Vec<Option<bool>>>(),
None => vec![None; num_bits as usize],
};

// Add one witness per input bit in little-endian bit order
let bits_circuit = opt_bits.into_iter()
.enumerate()
// AllocateBit enforces the value to be 0 or 1 at the constraint level
.map(|(i, b)| AllocatedBit::alloc(cs.namespace(|| format!("b_{}", i)), b).unwrap())
.collect::<Vec<AllocatedBit>>();

let mut weighted_sum_lc = LinearCombination::zero();
let mut pow2 = F::ONE;

for bit in bits_circuit.iter() {
weighted_sum_lc = weighted_sum_lc + (pow2, bit.get_variable());
pow2 = pow2.double();
}

cs.enforce(
|| "bit decomposition check",
|lc| lc + &weighted_sum_lc,
|lc| lc + CS::one(),
|lc| lc + n.get_variable(),
);

Ok(bits_circuit)
}

fn get_msb_index<F: PrimeField + PrimeFieldBits>(n: F) -> u8 {
n.to_le_bits()
.into_iter()
.enumerate()
.rev()
.find(|(_, b)| *b)
.unwrap()
.0 as u8
}

// Constrains `input` < `bound`, where the LHS is a witness and the RHS is a
// constant. The bound must fit into `num_bits` bits (this is asserted in the
// circuit constructor).
// Important: it must be checked elsewhere (in an overarching circuit) that the
// input fits into `num_bits` bits - this is NOT constrained by this circuit
// in order to avoid duplications (hence "unsafe"). Cf. LessThanCircuitSafe for
// a safe version.
#[derive(Clone, Debug)]
struct LessThanCircuitUnsafe<F: PrimeField> {
bound: F, // Will be a constant in the constraits, not a variable
mmagician marked this conversation as resolved.
Show resolved Hide resolved
input: F, // Will be an input/output variable
num_bits: u8,
}

impl<F: PrimeField + PrimeFieldBits> LessThanCircuitUnsafe<F> {
fn new(bound: F, input: F, num_bits: u8) -> Self {
assert!(get_msb_index(bound) < num_bits);
Self {
bound,
input,
num_bits,
}
}
}

impl<F: PrimeField + PrimeFieldBits> Circuit<F> for LessThanCircuitUnsafe<F> {
fn synthesize<CS: ConstraintSystem<F>>(self, cs: &mut CS) -> Result<(), SynthesisError> {
assert!(F::NUM_BITS > self.num_bits as u32 + 1);

let input = AllocatedNum::alloc(cs.namespace(|| "input"), || Ok(self.input))?;

let shifted_diff = AllocatedNum::alloc(cs.namespace(|| "shifted_diff"), || {
Ok(self.input + F::from(1 << self.num_bits) - self.bound)
})?;

cs.enforce(
|| "shifted_diff_computation",
|lc| lc + input.get_variable() + (F::from(1 << self.num_bits) - self.bound, CS::one()),
|lc: LinearCombination<F>| lc + CS::one(),
|lc| lc + shifted_diff.get_variable(),
);

let shifted_diff_bits = num_to_bits_le_bounded::<F, CS>(cs, shifted_diff, self.num_bits + 1)?;

// Check that the last (i.e. most sifnificant) bit is 0
cs.enforce(
|| "bound_check",
|lc| lc + shifted_diff_bits[self.num_bits as usize].get_variable(),
|lc| lc + CS::one(),
|lc| lc + (F::ZERO, CS::one()),
);

Ok(())
}
}

// Constrains `input` < `bound`, where the LHS is a witness and the RHS is a
// constant. The bound must fit into `num_bits` bits (this is asserted in the
// circuit constructor).
// Furthermore, the input must fit into `num_bits`, which is enforced at the
// constraint level.
#[derive(Clone, Debug)]
struct LessThanCircuitSafe<F: PrimeField + PrimeFieldBits> {
bound: F,
input: F,
num_bits: u8,
}

impl<F: PrimeField + PrimeFieldBits> LessThanCircuitSafe<F> {
fn new(bound: F, input: F, num_bits: u8) -> Self {
assert!(get_msb_index(bound) < num_bits);
Self {
bound,
input,
num_bits,
}
}
}

impl<F: PrimeField + PrimeFieldBits> Circuit<F> for LessThanCircuitSafe<F> {
fn synthesize<CS: ConstraintSystem<F>>(self, cs: &mut CS) -> Result<(), SynthesisError> {
let input = AllocatedNum::alloc(cs.namespace(|| "input"), || Ok(self.input))?;

// Perform the input bit decomposition check
num_to_bits_le_bounded::<F, CS>(cs, input, self.num_bits)?;

// Entering a new namespace to prefix variables in the
// LessThanCircuitUnsafe, thus avoiding name clashes
cs.push_namespace(|| "less_than_safe");

LessThanCircuitUnsafe {
bound: self.bound,
input: self.input,
num_bits: self.num_bits,
}
.synthesize(cs)
}
}

fn verify_circuit_unsafe<G: Group, S: RelaxedR1CSSNARKTrait<G>>(
bound: G::Scalar,
input: G::Scalar,
num_bits: u8,
) -> Result<(), SpartanError> {
let circuit = LessThanCircuitUnsafe::new(bound, input, num_bits);

// produce keys
let (pk, vk) = SNARK::<G, S, LessThanCircuitUnsafe<_>>::setup(circuit.clone()).unwrap();

// produce a SNARK
let snark = SNARK::prove(&pk, circuit).unwrap();

// verify the SNARK
snark.verify(&vk, &[])
}

fn verify_circuit_safe<G: Group, S: RelaxedR1CSSNARKTrait<G>>(
bound: G::Scalar,
input: G::Scalar,
num_bits: u8,
) -> Result<(), SpartanError> {
let circuit = LessThanCircuitSafe::new(bound, input, num_bits);

// produce keys
let (pk, vk) = SNARK::<G, S, LessThanCircuitSafe<_>>::setup(circuit.clone()).unwrap();

// produce a SNARK
let snark = SNARK::prove(&pk, circuit).unwrap();

// verify the SNARK
snark.verify(&vk, &[])
}

fn main() {
type G = pasta_curves::pallas::Point;
type EE = spartan2::provider::ipa_pc::EvaluationEngine<G>;
type S = spartan2::spartan::snark::RelaxedR1CSSNARK<G, EE>;

println!("Executing unsafe circuit...");
//Typical example, ok
assert!(verify_circuit_unsafe::<G, S>(Fq::from(17), Fq::from(9), 10).is_ok());
// Typical example, err
assert!(verify_circuit_unsafe::<G, S>(Fq::from(17), Fq::from(20), 10).is_err());
// Edge case, err
assert!(verify_circuit_unsafe::<G, S>(Fq::from(4), Fq::from(4), 10).is_err());
// Edge case, ok
assert!(verify_circuit_unsafe::<G, S>(Fq::from(4), Fq::from(3), 10).is_ok());
// Minimum number of bits for the bound, ok
assert!(verify_circuit_unsafe::<G, S>(Fq::from(4), Fq::from(3), 3).is_ok());
// Insufficient number of bits for the input, but this is not detected by the
// unsafety of the circuit (compare with the last example below)
// Note that -Fq::one() is corresponds to q - 1 > bound
assert!(verify_circuit_unsafe::<G, S>(Fq::from(4), -Fq::one(), 3).is_ok());

println!("Unsafe circuit OK");

println!("Executing safe circuit...");
// Typical example, ok
assert!(verify_circuit_safe::<G, S>(Fq::from(17), Fq::from(9), 10).is_ok());
// Typical example, err
assert!(verify_circuit_safe::<G, S>(Fq::from(17), Fq::from(20), 10).is_err());
// Edge case, err
assert!(verify_circuit_safe::<G, S>(Fq::from(4), Fq::from(4), 10).is_err());
// Edge case, ok
assert!(verify_circuit_safe::<G, S>(Fq::from(4), Fq::from(3), 10).is_ok());
// Minimum number of bits for the bound, ok
assert!(verify_circuit_safe::<G, S>(Fq::from(4), Fq::from(3), 3).is_ok());
// Insufficient number of bits for the input, err (compare with the last example
// above).
// Note that -Fq::one() is corresponds to q - 1 > bound
assert!(verify_circuit_safe::<G, S>(Fq::from(4), -Fq::one(), 3).is_err());

println!("Safe circuit OK");
}
11 changes: 4 additions & 7 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ impl<G: Group, S: RelaxedR1CSSNARKTrait<G>, C: Circuit<G::Scalar>> SNARK<G, S, C
/// Produces prover and verifier keys for the direct SNARK
pub fn setup(circuit: C) -> Result<(ProverKey<G, S>, VerifierKey<G, S>), SpartanError> {
let (pk, vk) = S::setup(circuit)?;

Ok((ProverKey { pk }, VerifierKey { vk }))
}

Expand Down Expand Up @@ -108,15 +109,12 @@ mod tests {
use super::*;
use crate::provider::{bn256_grumpkin::bn256, secp_secq::secp256k1};
use bellpepper_core::{num::AllocatedNum, ConstraintSystem, SynthesisError};
use core::marker::PhantomData;
use ff::PrimeField;

#[derive(Clone, Debug, Default)]
struct CubicCircuit<F: PrimeField> {
_p: PhantomData<F>,
}
struct CubicCircuit {}

impl<F> Circuit<F> for CubicCircuit<F>
impl<F> Circuit<F> for CubicCircuit
where
F: PrimeField,
{
Expand Down Expand Up @@ -178,8 +176,7 @@ mod tests {
let circuit = CubicCircuit::default();

// produce keys
let (pk, vk) =
SNARK::<G, S, CubicCircuit<<G as Group>::Scalar>>::setup(circuit.clone()).unwrap();
let (pk, vk) = SNARK::<G, S, CubicCircuit>::setup(circuit.clone()).unwrap();
mmagician marked this conversation as resolved.
Show resolved Hide resolved

// produce a SNARK
let res = SNARK::prove(&pk, circuit);
Expand Down
28 changes: 28 additions & 0 deletions src/spartan/snark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,26 @@ impl<G: Group, EE: EvaluationEngineTrait<G>> RelaxedR1CSSNARKTrait<G> for Relaxe
) -> Result<(Self::ProverKey, Self::VerifierKey), SpartanError> {
let mut cs: ShapeCS<G> = ShapeCS::new();
let _ = circuit.synthesize(&mut cs);

// Padding the ShapeCS: constraints (rows) and variables (columns)
let num_constraints = cs.num_constraints();

(num_constraints..num_constraints.next_power_of_two()).for_each(|i| {
cs.enforce(
|| format!("padding_constraint_{i}"),
|lc| lc,
|lc| lc,
|lc| lc,
)
});
mmagician marked this conversation as resolved.
Show resolved Hide resolved

let num_vars = cs.num_aux();

(num_vars..num_vars.next_power_of_two()).for_each(|i| {
cs.alloc(|| format!("padding_var_{i}"), || Ok(G::Scalar::ZERO))
.unwrap();
});

let (S, ck) = cs.r1cs_shape();

let (pk_ee, vk_ee) = EE::setup(&ck);
Expand All @@ -121,6 +141,14 @@ impl<G: Group, EE: EvaluationEngineTrait<G>> RelaxedR1CSSNARKTrait<G> for Relaxe
let mut cs: SatisfyingAssignment<G> = SatisfyingAssignment::new();
let _ = circuit.synthesize(&mut cs);

// Padding variables
let num_vars = cs.aux_slice().len();

(num_vars..num_vars.next_power_of_two()).for_each(|i| {
cs.alloc(|| format!("padding_var_{i}"), || Ok(G::Scalar::ZERO))
.unwrap();
});

mmagician marked this conversation as resolved.
Show resolved Hide resolved
let (u, w) = cs
.r1cs_instance_and_witness(&pk.S, &pk.ck)
.map_err(|_e| SpartanError::UnSat)?;
Expand Down
Loading