From 339bc40e05cd0d4e0bc9628dbfa881acf151c0ab Mon Sep 17 00:00:00 2001 From: Xinding Wei Date: Thu, 17 Aug 2023 16:00:55 -0700 Subject: [PATCH] Refactor VarLenBytes Add VarLenBytesVec and FixLenBytes Fix tests --- halo2-base/Cargo.toml | 7 +- halo2-base/src/safe_types/bytes.rs | 84 +++++--- halo2-base/src/safe_types/mod.rs | 55 +++++- halo2-base/src/safe_types/tests/mod.rs | 4 +- .../src/safe_types/tests/var_byte_array.rs | 182 +++++++++++------- 5 files changed, 224 insertions(+), 108 deletions(-) diff --git a/halo2-base/Cargo.toml b/halo2-base/Cargo.toml index 93f0f21b..dc82b21d 100644 --- a/halo2-base/Cargo.toml +++ b/halo2-base/Cargo.toml @@ -14,6 +14,7 @@ rayon = "1.7" serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" log = "0.4" +getset = "0.1.2" # Use Axiom's custom halo2 monorepo for faster proving when feature = "halo2-axiom" is on halo2_proofs_axiom = { git = "https://github.com/axiom-crypto/halo2.git", package = "halo2_proofs", optional = true } @@ -45,7 +46,11 @@ mimalloc = { version = "0.1", default-features = false, optional = true } [features] default = ["halo2-axiom", "display"] asm = ["halo2_proofs_axiom?/asm"] -dev-graph = ["halo2_proofs?/dev-graph", "halo2_proofs_axiom?/dev-graph", "plotters"] +dev-graph = [ + "halo2_proofs?/dev-graph", + "halo2_proofs_axiom?/dev-graph", + "plotters", +] halo2-pse = ["halo2_proofs/circuit-params"] halo2-axiom = ["halo2_proofs_axiom"] display = [] diff --git a/halo2-base/src/safe_types/bytes.rs b/halo2-base/src/safe_types/bytes.rs index 1c23477d..20ee3d65 100644 --- a/halo2-base/src/safe_types/bytes.rs +++ b/halo2-base/src/safe_types/bytes.rs @@ -1,51 +1,91 @@ +#![allow(clippy::len_without_is_empty)] use crate::AssignedValue; use super::{SafeByte, ScalarField}; +use getset::Getters; + /// Represents a variable length byte array in circuit. /// /// Each element is guaranteed to be a byte, given by type [`SafeByte`]. /// To represent a variable length array, we must know the maximum possible length `MAX_LEN` the array could be -- this is some additional context the user must provide. /// Then we right pad the array with 0s to the maximum length (we do **not** constrain that these paddings must be 0s). -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Getters)] pub struct VarLenBytes { - /// The byte array, right padded with 0s - pub bytes: [SafeByte; MAX_LEN], + /// The byte array, right padded + #[getset(get = "pub")] + bytes: [SafeByte; MAX_LEN], /// Witness representing the actual length of the byte array. Upon construction, this is range checked to be at most `MAX_LEN` - pub var_len: AssignedValue, + #[getset(get = "pub")] + len: AssignedValue, } impl VarLenBytes { - fn new(bytes: [SafeByte; MAX_LEN], var_len: AssignedValue) -> Self { - Self { bytes, var_len } + // VarLenBytes can be only created by SafeChip. + pub(super) fn new(bytes: [SafeByte; MAX_LEN], len: AssignedValue) -> Self { + assert!( + len.value().le(&F::from_u128(MAX_LEN as u128)), + "Invalid length which exceeds MAX_LEN {}", + MAX_LEN + ); + Self { bytes, len } } + /// Returns the maximum length of the byte array. pub fn max_len(&self) -> usize { MAX_LEN } } -impl AsRef<[SafeByte]> for VarLenBytes { - fn as_ref(&self) -> &[SafeByte] { - &self.bytes - } -} - -impl AsMut<[SafeByte]> for VarLenBytes { - fn as_mut(&mut self) -> &mut [SafeByte] { - &mut self.bytes - } -} - -/// Represents a variable length byte array in circuit. +/// Represents a variable length byte array in circuit. Not encourged to use because `MAX_LEN` cannot be verified at compile time. /// /// Each element is guaranteed to be a byte, given by type [`SafeByte`]. /// To represent a variable length array, we must know the maximum possible length `MAX_LEN` the array could be -- this is some additional context the user must provide. /// Then we right pad the array with 0s to the maximum length (we do **not** constrain that these paddings must be 0s). -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Getters)] pub struct VarLenBytesVec { - /// The byte array, right padded with 0s + /// The byte array, right padded + #[getset(get = "pub")] bytes: Vec>, /// Witness representing the actual length of the byte array. Upon construction, this is range checked to be at most `MAX_LEN` - pub var_len: AssignedValue, + #[getset(get = "pub")] + len: AssignedValue, +} + +impl VarLenBytesVec { + // VarLenBytesVec can be only created by SafeChip. + pub(super) fn new(bytes: Vec>, len: AssignedValue, max_len: usize) -> Self { + assert!( + len.value().le(&F::from_u128(max_len as u128)), + "Invalid length which exceeds MAX_LEN {}", + max_len + ); + assert!(bytes.len() == max_len, "bytes is not padded correctly"); + Self { bytes, len } + } + + /// Returns the maximum length of the byte array. + pub fn max_len(&self) -> usize { + self.bytes.len() + } +} + +/// Represents a fixed length byte array in circuit. +#[derive(Debug, Clone, Getters)] +pub struct FixLenBytes { + /// The byte array + #[getset(get = "pub")] + bytes: [SafeByte; LEN], +} + +impl FixLenBytes { + // FixLenBytes can be only created by SafeChip. + pub(super) fn new(bytes: [SafeByte; LEN]) -> Self { + Self { bytes } + } + + /// Returns the length of the byte array. + pub fn len(&self) -> usize { + LEN + } } diff --git a/halo2-base/src/safe_types/mod.rs b/halo2-base/src/safe_types/mod.rs index cd6df6b0..04f78183 100644 --- a/halo2-base/src/safe_types/mod.rs +++ b/halo2-base/src/safe_types/mod.rs @@ -3,6 +3,7 @@ pub use crate::{ flex_gate::GateInstructions, range::{RangeChip, RangeInstructions}, }, + safe_types::VarLenBytes, utils::ScalarField, AssignedValue, Context, QuantumCell::{self, Constant, Existing, Witness}, @@ -16,6 +17,7 @@ mod bytes; mod primitives; pub use bytes::*; +use itertools::Itertools; pub use primitives::*; #[cfg(test)] @@ -185,21 +187,54 @@ impl<'a, F: ScalarField> SafeTypeChip<'a, F> { SafeByte(input) } - /// Converts a vector of AssignedValue(treated as little-endian) to VariableAssignedBytes. + /// Converts a slice of AssignedValue(treated as little-endian) to VarLenBytes. /// /// * ctx: Circuit [Context] to assign witnesses to. - /// * inputs: Vector of [RawAssignedValues] representing the byte array. - /// * var_len: [AssignedValue] witness representing the variable elements within the byte array from 0..=var_len. - /// * max_var_len: [usize] representing the maximum length of the byte array and the number of elements it must contain. - pub fn raw_var_bytes_to( + /// * inputs: Slice representing the byte array. + /// * len: [AssignedValue] witness representing the variable elements within the byte array from 0..=len. + /// * MAX_LEN: [usize] representing the maximum length of the byte array and the number of elements it must contain. + pub fn raw_to_var_len_bytes( + &self, + ctx: &mut Context, + inputs: [AssignedValue; MAX_LEN], + len: AssignedValue, + ) -> VarLenBytes { + self.range_chip.check_less_than_safe(ctx, len, MAX_LEN as u64); + VarLenBytes::::new(inputs.map(|input| self.assert_byte(ctx, input)), len) + } + + /// Converts a vector of AssignedValue(treated as little-endian) to VarLenBytesVec. Not encourged to use because `MAX_LEN` cannot be verified at compile time. + /// + /// * ctx: Circuit [Context] to assign witnesses to. + /// * inputs: Vector representing the byte array. + /// * len: [AssignedValue] witness representing the variable elements within the byte array from 0..=len. + /// * max_len: [usize] representing the maximum length of the byte array and the number of elements it must contain. + pub fn raw_to_var_len_bytes_vec( &self, ctx: &mut Context, inputs: RawAssignedValues, - var_len: AssignedValue, - ) -> VariableByteArray { - self.add_bytes_constraints(ctx, &inputs, BITS_PER_BYTE * MAX_VAR_LEN); - self.range_chip.check_less_than_safe(ctx, var_len, MAX_VAR_LEN as u64); - VariableByteArray::::new(inputs, var_len) + len: AssignedValue, + max_len: usize, + ) -> VarLenBytesVec { + self.range_chip.check_less_than_safe(ctx, len, max_len as u64); + VarLenBytesVec::::new( + inputs.iter().map(|input| self.assert_byte(ctx, *input)).collect_vec(), + len, + max_len, + ) + } + + /// Converts a slice of AssignedValue(treated as little-endian) to FixLenBytes. + /// + /// * ctx: Circuit [Context] to assign witnesses to. + /// * inputs: Slice representing the byte array. + /// * LEN: length of the byte array. + pub fn raw_to_fix_len_bytes( + &self, + ctx: &mut Context, + inputs: [AssignedValue; LEN], + ) -> FixLenBytes { + FixLenBytes::::new(inputs.map(|input| self.assert_byte(ctx, input))) } fn add_bytes_constraints( diff --git a/halo2-base/src/safe_types/tests/mod.rs b/halo2-base/src/safe_types/tests/mod.rs index 5f37ca99..78ee3916 100644 --- a/halo2-base/src/safe_types/tests/mod.rs +++ b/halo2-base/src/safe_types/tests/mod.rs @@ -1,2 +1,2 @@ -pub (crate) mod var_byte_array; -pub (crate) mod safe_type; \ No newline at end of file +pub(crate) mod safe_type; +pub(crate) mod var_byte_array; diff --git a/halo2-base/src/safe_types/tests/var_byte_array.rs b/halo2-base/src/safe_types/tests/var_byte_array.rs index cc03827e..f716910e 100644 --- a/halo2-base/src/safe_types/tests/var_byte_array.rs +++ b/halo2-base/src/safe_types/tests/var_byte_array.rs @@ -6,9 +6,7 @@ use crate::{ halo2_proofs::{ dev::MockProver, halo2curves::bn256::{Bn256, Fr, G1Affine}, - plonk::{ - create_proof, keygen_pk, keygen_vk, verify_proof, Circuit, ProvingKey, VerifyingKey, - }, + plonk::{create_proof, keygen_pk, keygen_vk, verify_proof}, poly::commitment::ParamsProver, poly::kzg::{ commitment::KZGCommitmentScheme, commitment::ParamsKZG, multiopen::ProverSHPLONK, @@ -20,117 +18,154 @@ use crate::{ }, safe_types::SafeTypeChip, utils::testing::base_test, + Context, }; use rand::rngs::OsRng; -use std::{env::set_var, vec}; +use std::vec; + +// =========== Utilies =============== +fn mock_circuit_test, SafeTypeChip<'_, Fr>)>(mut f: FM) { + let mut builder = GateThreadBuilder::mock(); + let range = RangeChip::default(8); + let safe = SafeTypeChip::new(&range); + let ctx = builder.main(0); + f(ctx, safe); + let mut params = builder.config(10, Some(9)); + params.lookup_bits = Some(3); + let circuit = RangeCircuitBuilder::mock(builder, params); + MockProver::run(10 as u32, &circuit, vec![]).unwrap().assert_satisfied(); +} // =========== Mock Prover =========== // Circuit Satisfied for valid inputs #[test] -fn pos_var_assigned_bytes() { +fn pos_var_len_bytes() { base_test().k(10).lookup_bits(8).run(|ctx, range| { let safe = SafeTypeChip::new(&range); let fake_bytes = ctx.assign_witnesses( vec![255u64, 255u64, 255u64, 255u64].into_iter().map(Fr::from).collect::>(), ); - let var_len = ctx.load_witness(Fr::from(3u64)); - safe.raw_var_bytes_to::<4>(ctx, fake_bytes, var_len); + let len = ctx.load_witness(Fr::from(3u64)); + safe.raw_to_var_len_bytes::<4>(ctx, fake_bytes.try_into().unwrap(), len); }); } // Checks circuit is unsatisfied for AssignedValue's are not in range 0..256 #[test] #[should_panic(expected = "circuit was not satisfied")] -fn witness_values_not_bytes() { - let mut builder = GateThreadBuilder::mock(); - let range = RangeChip::default(8); - let safe = SafeTypeChip::new(&range); - let ctx = builder.main(0); - let var_len = ctx.load_witness(Fr::from(3u64)); - let fake_bytes = ctx.assign_witnesses( - vec![500u64, 500u64, 500u64, 500u64].into_iter().map(Fr::from).collect::>(), - ); - safe.raw_var_bytes_to::<4>(ctx, fake_bytes, var_len); - builder.config(10, Some(9)); - let circuit = RangeCircuitBuilder::mock(builder); - MockProver::run(10 as u32, &circuit, vec![]).unwrap().assert_satisfied(); +fn neg_var_len_bytes_witness_values_not_bytes() { + mock_circuit_test(|ctx: &mut Context, safe: SafeTypeChip<'_, Fr>| { + let len = ctx.load_witness(Fr::from(3u64)); + let fake_bytes = ctx.assign_witnesses( + vec![500u64, 500u64, 500u64, 500u64].into_iter().map(Fr::from).collect::>(), + ); + safe.raw_to_var_len_bytes::<4>(ctx, fake_bytes.try_into().unwrap(), len); + }); } -//Checks assertion max_var_len == bytes.len() +//Checks assertion len < max_len #[test] -#[should_panic(expected = "len of bytes must equal max_var_len")] -fn bytes_len_not_equal_max_var_len() { - let mut builder = GateThreadBuilder::mock(); - let range = RangeChip::default(8); - let safe = SafeTypeChip::new(&range); - let ctx = builder.main(0); - let var_len = ctx.load_witness(Fr::from(3u64)); - let fake_bytes = ctx.assign_witnesses( - vec![500u64, 500u64, 500u64].into_iter().map(Fr::from).collect::>(), - ); - safe.raw_var_bytes_to::<4>(ctx, fake_bytes, var_len); +#[should_panic] +fn neg_var_len_bytes_len_less_than_max_len() { + mock_circuit_test(|ctx: &mut Context, safe: SafeTypeChip<'_, Fr>| { + let len = ctx.load_witness(Fr::from(5u64)); + let fake_bytes = ctx.assign_witnesses( + vec![500u64, 500u64, 500u64, 500u64].into_iter().map(Fr::from).collect::>(), + ); + safe.raw_to_var_len_bytes::<4>(ctx, fake_bytes.try_into().unwrap(), len); + }); } -//Checks assertion var_len < max_var_len +// Circuit Satisfied for valid inputs +#[test] +fn pos_var_len_bytes_vec() { + base_test().k(10).lookup_bits(8).run(|ctx, range| { + let safe = SafeTypeChip::new(&range); + let fake_bytes = ctx.assign_witnesses( + vec![255u64, 255u64, 255u64, 255u64].into_iter().map(Fr::from).collect::>(), + ); + let len = ctx.load_witness(Fr::from(3u64)); + safe.raw_to_var_len_bytes_vec(ctx, fake_bytes, len, 4); + }); +} + +// Checks circuit is unsatisfied for AssignedValue's are not in range 0..256 #[test] #[should_panic(expected = "circuit was not satisfied")] -fn neg_var_len_less_than_max_var_len() { - let mut builder = GateThreadBuilder::mock(); - let range = RangeChip::default(8); - let safe = SafeTypeChip::new(&range); - let ctx = builder.main(0); - let var_len = ctx.load_witness(Fr::from(5u64)); - let fake_bytes = ctx.assign_witnesses( - vec![500u64, 500u64, 500u64, 500u64].into_iter().map(Fr::from).collect::>(), - ); - safe.raw_var_bytes_to::<4>(ctx, fake_bytes, var_len); - builder.config(10, Some(9)); - let circuit = RangeCircuitBuilder::mock(builder); - MockProver::run(10 as u32, &circuit, vec![]).unwrap().assert_satisfied(); +fn neg_var_len_bytes_vec_witness_values_not_bytes() { + mock_circuit_test(|ctx: &mut Context, safe: SafeTypeChip<'_, Fr>| { + let len = ctx.load_witness(Fr::from(3u64)); + let fake_bytes = ctx.assign_witnesses( + vec![500u64, 500u64, 500u64, 500u64].into_iter().map(Fr::from).collect::>(), + ); + let max_len = fake_bytes.len(); + safe.raw_to_var_len_bytes_vec(ctx, fake_bytes, len, max_len); + }); +} + +//Checks assertion len != max_len +#[test] +#[should_panic] +fn neg_var_len_bytes_vec_len_less_than_max_len() { + mock_circuit_test(|ctx: &mut Context, safe: SafeTypeChip<'_, Fr>| { + let len = ctx.load_witness(Fr::from(5u64)); + let fake_bytes = ctx.assign_witnesses( + vec![500u64, 500u64, 500u64, 500u64].into_iter().map(Fr::from).collect::>(), + ); + let max_len = 5; + safe.raw_to_var_len_bytes_vec(ctx, fake_bytes, len, max_len); + }); +} + +// Circuit Satisfied for valid inputs +#[test] +fn pos_fix_len_bytes_vec() { + base_test().k(10).lookup_bits(8).run(|ctx, range| { + let safe = SafeTypeChip::new(&range); + let fake_bytes = ctx.assign_witnesses( + vec![255u64, 255u64, 255u64, 255u64].into_iter().map(Fr::from).collect::>(), + ); + safe.raw_to_fix_len_bytes::<4>(ctx, fake_bytes.try_into().unwrap()); + }); } // =========== Prover =========== #[test] fn pos_prover_satisfied() { - const KEYGEN_MAX_VAR_LEN: usize = 4; - const PROVER_MAX_VAR_LEN: usize = 4; + const KEYGEN_MAX_LEN: usize = 4; + const PROVER_MAX_LEN: usize = 4; let keygen_inputs = (vec![1u64, 2u64, 3u64, 4u64], 3); let proof_inputs = (vec![1u64, 2u64, 3u64, 4u64], 3); - prover_satisfied::(keygen_inputs, proof_inputs) - .unwrap(); + prover_satisfied::(keygen_inputs, proof_inputs).unwrap(); } #[test] -fn pos_diff_var_len_same_max_len() { - const KEYGEN_MAX_VAR_LEN: usize = 4; - const PROVER_MAX_VAR_LEN: usize = 4; +fn pos_diff_len_same_max_len() { + const KEYGEN_MAX_LEN: usize = 4; + const PROVER_MAX_LEN: usize = 4; let keygen_inputs = (vec![1u64, 2u64, 3u64, 4u64], 3); let proof_inputs = (vec![1u64, 2u64, 3u64, 4u64], 2); - prover_satisfied::(keygen_inputs, proof_inputs) - .unwrap(); + prover_satisfied::(keygen_inputs, proof_inputs).unwrap(); } #[test] -#[should_panic(expected = "called `Result::unwrap()` on an `Err` value: ConstraintSystemFailure")] -fn neg_different_proof_max_var_len() { - const KEYGEN_MAX_VAR_LEN: usize = 4; - const PROVER_MAX_VAR_LEN: usize = 3; +#[should_panic] +fn neg_different_proof_max_len() { + const KEYGEN_MAX_LEN: usize = 4; + const PROVER_MAX_LEN: usize = 3; let keygen_inputs = (vec![1u64, 2u64, 3u64, 4u64], 4); let proof_inputs = (vec![1u64, 2u64, 3u64], 3); - prover_satisfied::(keygen_inputs, proof_inputs) - .unwrap(); + prover_satisfied::(keygen_inputs, proof_inputs).unwrap(); } //test circuit -fn var_byte_array_circuit( +fn var_byte_array_circuit( k: usize, phase: bool, - (bytes, var_len): (Vec, usize), + (bytes, len): (Vec, usize), ) -> RangeCircuitBuilder { let lookup_bits = 3; - set_var("LOOKUP_BITS", lookup_bits.to_string()); - let k = 11; let mut builder = match phase { true => GateThreadBuilder::prover(), false => GateThreadBuilder::keygen(), @@ -138,31 +173,32 @@ fn var_byte_array_circuit( let range = RangeChip::::default(lookup_bits); let safe = SafeTypeChip::new(&range); let ctx = builder.main(0); - let var_len = ctx.load_witness(Fr::from(var_len as u64)); + let len = ctx.load_witness(Fr::from(len as u64)); let fake_bytes = ctx.assign_witnesses(bytes.into_iter().map(Fr::from).collect::>()); - safe.raw_var_bytes_to::(ctx, fake_bytes, var_len); - builder.config(k, Some(9)); + safe.raw_to_var_len_bytes::(ctx, fake_bytes.try_into().unwrap(), len); + let mut params = builder.config(k, Some(9)); + params.lookup_bits = Some(lookup_bits); let circuit = match phase { - true => RangeCircuitBuilder::prover(builder, vec![vec![]]), - false => RangeCircuitBuilder::keygen(builder), + true => RangeCircuitBuilder::prover(builder, params, vec![vec![]]), + false => RangeCircuitBuilder::keygen(builder, params), }; circuit } //Prover test -fn prover_satisfied( +fn prover_satisfied( keygen_inputs: (Vec, usize), proof_inputs: (Vec, usize), ) -> Result<(), Box> { let k = 11; let rng = OsRng; let params = ParamsKZG::::setup(k as u32, rng); - let keygen_circuit = var_byte_array_circuit::(k, false, keygen_inputs); + let keygen_circuit = var_byte_array_circuit::(k, false, keygen_inputs); let vk = keygen_vk(¶ms, &keygen_circuit).unwrap(); let pk = keygen_pk(¶ms, vk, &keygen_circuit).unwrap(); let mut transcript = Blake2bWrite::<_, _, Challenge255<_>>::init(vec![]); - let proof_circuit = var_byte_array_circuit::(k, true, proof_inputs); + let proof_circuit = var_byte_array_circuit::(k, true, proof_inputs); create_proof::< KZGCommitmentScheme, ProverSHPLONK<'_, Bn256>,