diff --git a/Makefile b/Makefile index 34a64837d6..204f51b278 100644 --- a/Makefile +++ b/Makefile @@ -21,7 +21,7 @@ BENCH_OP_FLAVOR?=DEFAULT NODE_VERSION=22.6 FORWARD_COMPAT?=OFF BACKWARD_COMPAT_DATA_URL=https://github.com/zama-ai/tfhe-backward-compat-data.git -BACKWARD_COMPAT_DATA_BRANCH?=v0.1 +BACKWARD_COMPAT_DATA_BRANCH?=v0.2 BACKWARD_COMPAT_DATA_PROJECT=tfhe-backward-compat-data BACKWARD_COMPAT_DATA_DIR=$(BACKWARD_COMPAT_DATA_PROJECT) TFHE_SPEC:=tfhe @@ -391,7 +391,7 @@ clippy_cuda_backend: install_rs_check_toolchain .PHONY: tfhe_lints # Run custom tfhe-rs lints tfhe_lints: install_tfhe_lints cd tfhe && RUSTFLAGS="$(RUSTFLAGS)" cargo tfhe-lints \ - --features=$(TARGET_ARCH_FEATURE),boolean,shortint,integer -- -D warnings + --features=$(TARGET_ARCH_FEATURE),boolean,shortint,integer,zk-pok -- -D warnings .PHONY: build_core # Build core_crypto without experimental features build_core: install_rs_build_toolchain install_rs_check_toolchain @@ -810,7 +810,7 @@ test_versionable: install_rs_build_toolchain test_backward_compatibility_ci: install_rs_build_toolchain TFHE_BACKWARD_COMPAT_DATA_DIR="$(BACKWARD_COMPAT_DATA_DIR)" RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_BUILD_TOOLCHAIN) test --profile $(CARGO_PROFILE) \ --config "patch.'$(BACKWARD_COMPAT_DATA_URL)'.$(BACKWARD_COMPAT_DATA_PROJECT).path=\"tfhe/$(BACKWARD_COMPAT_DATA_DIR)\"" \ - --features=$(TARGET_ARCH_FEATURE),shortint,integer -p $(TFHE_SPEC) test_backward_compatibility -- --nocapture + --features=$(TARGET_ARCH_FEATURE),shortint,integer,zk-pok -p $(TFHE_SPEC) test_backward_compatibility -- --nocapture .PHONY: test_backward_compatibility # Same as test_backward_compatibility_ci but tries to clone the data repo first if needed test_backward_compatibility: tfhe/$(BACKWARD_COMPAT_DATA_DIR) test_backward_compatibility_ci diff --git a/scripts/clone_backward_compat_data.sh b/scripts/clone_backward_compat_data.sh index 35e458f1fe..2917bab92f 100755 --- a/scripts/clone_backward_compat_data.sh +++ b/scripts/clone_backward_compat_data.sh @@ -14,7 +14,7 @@ if ! git lfs env 2>/dev/null >/dev/null; then fi if [ -d $3 ]; then - cd $3 && git fetch --depth 1 && git reset --hard origin/$2 && git clean -dfx + cd $3 && git remote set-branches origin '*' && git fetch --depth 1 && git reset --hard origin/$2 && git clean -dfx else git clone $1 -b $2 --depth 1 $3 diff --git a/tfhe-zk-pok/Cargo.toml b/tfhe-zk-pok/Cargo.toml index 6933143a7e..d8eb6728e7 100644 --- a/tfhe-zk-pok/Cargo.toml +++ b/tfhe-zk-pok/Cargo.toml @@ -18,14 +18,15 @@ ark-ff = { package = "tfhe-ark-ff", version = "0.4.3", features = ["parallel"] } ark-poly = { package = "tfhe-ark-poly", version = "0.4.2", features = [ "parallel", ] } -ark-serialize = { version = "0.4.2" } rand = "0.8.5" rayon = "1.8.0" sha3 = "0.10.8" serde = { version = "~1.0", features = ["derive"] } zeroize = "1.7.0" num-bigint = "0.4.5" +tfhe-versionable = { version = "0.3.0", path = "../utils/tfhe-versionable" } [dev-dependencies] serde_json = "~1.0" itertools = "0.11.0" +bincode = "1.3.3" diff --git a/tfhe-zk-pok/src/backward_compatibility/mod.rs b/tfhe-zk-pok/src/backward_compatibility/mod.rs new file mode 100644 index 0000000000..91753ea161 --- /dev/null +++ b/tfhe-zk-pok/src/backward_compatibility/mod.rs @@ -0,0 +1,85 @@ +use tfhe_versionable::VersionsDispatch; + +use crate::curve_api::{Compressible, Curve}; +use crate::proofs::pke::{CompressedProof as PKEv1CompressedProof, Proof as PKEv1Proof}; +use crate::proofs::pke_v2::{CompressedProof as PKEv2CompressedProof, Proof as PKEv2Proof}; +use crate::proofs::GroupElements; +use crate::serialization::{ + SerializableAffine, SerializableCubicExtField, SerializableFp, SerializableFp2, + SerializableFp6, SerializableGroupElements, SerializablePKEv1PublicParams, + SerializablePKEv2PublicParams, SerializableQuadExtField, +}; + +#[derive(VersionsDispatch)] +pub enum SerializableAffineVersions { + V0(SerializableAffine), +} + +#[derive(VersionsDispatch)] +pub enum SerializableFpVersions { + V0(SerializableFp), +} + +#[derive(VersionsDispatch)] +pub enum SerializableQuadExtFieldVersions { + V0(SerializableQuadExtField), +} + +#[derive(VersionsDispatch)] +pub enum SerializableCubicExtFieldVersions { + V0(SerializableCubicExtField), +} + +pub type SerializableG1AffineVersions = SerializableAffineVersions; +pub type SerializableG2AffineVersions = SerializableAffineVersions; +pub type SerializableFp12Versions = SerializableQuadExtFieldVersions; + +#[derive(VersionsDispatch)] +pub enum PKEv1ProofVersions { + V0(PKEv1Proof), +} + +#[derive(VersionsDispatch)] +pub enum PKEv2ProofVersions { + V0(PKEv2Proof), +} + +#[derive(VersionsDispatch)] +pub enum PKEv1CompressedProofVersions +where + G::G1: Compressible, + G::G2: Compressible, +{ + V0(PKEv1CompressedProof), +} + +#[derive(VersionsDispatch)] +pub enum PKEv2CompressedProofVersions +where + G::G1: Compressible, + G::G2: Compressible, +{ + V0(PKEv2CompressedProof), +} + +#[derive(VersionsDispatch)] +#[allow(dead_code)] +pub(crate) enum GroupElementsVersions { + V0(GroupElements), +} + +#[derive(VersionsDispatch)] +#[allow(dead_code)] +pub(crate) enum SerializableGroupElementsVersions { + V0(SerializableGroupElements), +} + +#[derive(VersionsDispatch)] +pub enum SerializablePKEv2PublicParamsVersions { + V0(SerializablePKEv2PublicParams), +} + +#[derive(VersionsDispatch)] +pub enum SerializablePKEv1PublicParamsVersions { + V0(SerializablePKEv1PublicParams), +} diff --git a/tfhe-zk-pok/src/curve_446/mod.rs b/tfhe-zk-pok/src/curve_446/mod.rs index ef6446ed04..0cd5f9c2db 100644 --- a/tfhe-zk-pok/src/curve_446/mod.rs +++ b/tfhe-zk-pok/src/curve_446/mod.rs @@ -221,8 +221,6 @@ impl Fp12Config for Fq12Config { } pub type Bls12_446 = Bls12; -use g1::G1Affine; -use g2::G2Affine; pub struct Config; @@ -239,236 +237,17 @@ impl Bls12Config for Config { } pub mod util { - use ark_ec::short_weierstrass::Affine; - use ark_ec::AffineRepr; - use ark_ff::{BigInteger448, PrimeField}; - use ark_serialize::SerializationError; - - use super::g1::Config as G1Config; - use super::g2::Config as G2Config; - use super::{Fq, Fq2, G1Affine, G2Affine}; - pub const G1_SERIALIZED_SIZE: usize = 57; pub const G2_SERIALIZED_SIZE: usize = 114; - - pub struct EncodingFlags { - pub is_compressed: bool, - pub is_infinity: bool, - pub is_lexographically_largest: bool, - } - - impl EncodingFlags { - pub fn get_flags(bytes: &[u8]) -> Self { - let compression_flag_set = (bytes[0] >> 7) & 1; - let infinity_flag_set = (bytes[0] >> 6) & 1; - let sort_flag_set = (bytes[0] >> 5) & 1; - - Self { - is_compressed: compression_flag_set == 1, - is_infinity: infinity_flag_set == 1, - is_lexographically_largest: sort_flag_set == 1, - } - } - pub fn encode_flags(&self, bytes: &mut [u8]) { - if self.is_compressed { - bytes[0] |= 1 << 7; - } - - if self.is_infinity { - bytes[0] |= 1 << 6; - } - - if self.is_compressed && !self.is_infinity && self.is_lexographically_largest { - bytes[0] |= 1 << 5; - } - } - } - - pub(crate) fn deserialize_fq(bytes: [u8; 56]) -> Option { - let mut tmp = BigInteger448::new([0, 0, 0, 0, 0, 0, 0]); - - // Note: The following unwraps are if the compiler cannot convert - // the byte slice into [u8;8], we know this is infallible since we - // are providing the indices at compile time and bytes has a fixed size - tmp.0[6] = u64::from_be_bytes(<[u8; 8]>::try_from(&bytes[0..8]).unwrap()); - tmp.0[5] = u64::from_be_bytes(<[u8; 8]>::try_from(&bytes[8..16]).unwrap()); - tmp.0[4] = u64::from_be_bytes(<[u8; 8]>::try_from(&bytes[16..24]).unwrap()); - tmp.0[3] = u64::from_be_bytes(<[u8; 8]>::try_from(&bytes[24..32]).unwrap()); - tmp.0[2] = u64::from_be_bytes(<[u8; 8]>::try_from(&bytes[32..40]).unwrap()); - tmp.0[1] = u64::from_be_bytes(<[u8; 8]>::try_from(&bytes[40..48]).unwrap()); - tmp.0[0] = u64::from_be_bytes(<[u8; 8]>::try_from(&bytes[48..56]).unwrap()); - - Fq::from_bigint(tmp) - } - - pub(crate) fn serialize_fq(field: Fq) -> [u8; 56] { - let mut result = [0u8; 56]; - - let rep = field.into_bigint(); - - result[0..8].copy_from_slice(&rep.0[6].to_be_bytes()); - result[8..16].copy_from_slice(&rep.0[5].to_be_bytes()); - result[16..24].copy_from_slice(&rep.0[4].to_be_bytes()); - result[24..32].copy_from_slice(&rep.0[3].to_be_bytes()); - result[32..40].copy_from_slice(&rep.0[2].to_be_bytes()); - result[40..48].copy_from_slice(&rep.0[1].to_be_bytes()); - result[48..56].copy_from_slice(&rep.0[0].to_be_bytes()); - - result - } - - pub(crate) fn read_fq_with_offset( - bytes: &[u8], - offset: usize, - ) -> Result { - let mut tmp = [0; G1_SERIALIZED_SIZE - 1]; - // read `G1_SERIALIZED_SIZE` bytes - tmp.copy_from_slice( - &bytes[offset * G1_SERIALIZED_SIZE + 1..G1_SERIALIZED_SIZE * (offset + 1)], - ); - - deserialize_fq(tmp).ok_or(SerializationError::InvalidData) - } - - pub(crate) fn read_g1_compressed( - mut reader: R, - ) -> Result, ark_serialize::SerializationError> { - let mut bytes = [0u8; G1_SERIALIZED_SIZE]; - reader - .read_exact(&mut bytes) - .ok() - .ok_or(SerializationError::InvalidData)?; - - // Obtain the three flags from the start of the byte sequence - let flags = EncodingFlags::get_flags(&bytes[..]); - - // we expect to be deserializing a compressed point - if !flags.is_compressed { - return Err(SerializationError::UnexpectedFlags); - } - - if flags.is_infinity { - return Ok(G1Affine::zero()); - } - - // Attempt to obtain the x-coordinate - let x = read_fq_with_offset(&bytes, 0)?; - - let p = G1Affine::get_point_from_x_unchecked(x, flags.is_lexographically_largest) - .ok_or(SerializationError::InvalidData)?; - - Ok(p) - } - - pub(crate) fn read_g1_uncompressed( - mut reader: R, - ) -> Result, ark_serialize::SerializationError> { - let mut bytes = [0u8; 2 * G1_SERIALIZED_SIZE]; - reader - .read_exact(&mut bytes) - .map_err(|_| SerializationError::InvalidData)?; - - // Obtain the three flags from the start of the byte sequence - let flags = EncodingFlags::get_flags(&bytes[..]); - - // we expect to be deserializing an uncompressed point - if flags.is_compressed { - return Err(SerializationError::UnexpectedFlags); - } - - if flags.is_infinity { - return Ok(G1Affine::zero()); - } - - // Attempt to obtain the x-coordinate - let x = read_fq_with_offset(&bytes, 0)?; - // Attempt to obtain the y-coordinate - let y = read_fq_with_offset(&bytes, 1)?; - - let p = G1Affine::new_unchecked(x, y); - - Ok(p) - } - - pub(crate) fn read_g2_compressed( - mut reader: R, - ) -> Result, ark_serialize::SerializationError> { - let mut bytes = [0u8; G2_SERIALIZED_SIZE]; - reader - .read_exact(&mut bytes) - .map_err(|_| SerializationError::InvalidData)?; - - // Obtain the three flags from the start of the byte sequence - let flags = EncodingFlags::get_flags(&bytes); - - // we expect to be deserializing a compressed point - if !flags.is_compressed { - return Err(SerializationError::UnexpectedFlags); - } - - if flags.is_infinity { - return Ok(G2Affine::zero()); - } - - // Attempt to obtain the x-coordinate - let xc1 = read_fq_with_offset(&bytes, 0)?; - let xc0 = read_fq_with_offset(&bytes, 1)?; - - let x = Fq2::new(xc0, xc1); - - let p = G2Affine::get_point_from_x_unchecked(x, flags.is_lexographically_largest) - .ok_or(SerializationError::InvalidData)?; - - Ok(p) - } - - pub(crate) fn read_g2_uncompressed( - mut reader: R, - ) -> Result, ark_serialize::SerializationError> { - let mut bytes = [0u8; 2 * G2_SERIALIZED_SIZE]; - reader - .read_exact(&mut bytes) - .map_err(|_| SerializationError::InvalidData)?; - - // Obtain the three flags from the start of the byte sequence - let flags = EncodingFlags::get_flags(&bytes); - - // we expect to be deserializing an uncompressed point - if flags.is_compressed { - return Err(SerializationError::UnexpectedFlags); - } - - if flags.is_infinity { - return Ok(G2Affine::zero()); - } - - // Attempt to obtain the x-coordinate - let xc1 = read_fq_with_offset(&bytes, 0)?; - let xc0 = read_fq_with_offset(&bytes, 1)?; - let x = Fq2::new(xc0, xc1); - - // Attempt to obtain the y-coordinate - let yc1 = read_fq_with_offset(&bytes, 2)?; - let yc0 = read_fq_with_offset(&bytes, 3)?; - let y = Fq2::new(yc0, yc1); - - let p = G2Affine::new_unchecked(x, y); - - Ok(p) - } } pub mod g1 { - use super::util::{ - read_g1_compressed, read_g1_uncompressed, serialize_fq, EncodingFlags, G1_SERIALIZED_SIZE, - }; use super::{Fq, Fr}; use ark_ec::bls12::Bls12Config; use ark_ec::models::CurveConfig; use ark_ec::short_weierstrass::{Affine, SWCurveConfig}; use ark_ec::{bls12, AdditiveGroup, AffineRepr, PrimeGroup}; use ark_ff::{MontFp, One, PrimeField, Zero}; - use ark_serialize::{Compress, SerializationError}; use core::ops::Neg; #[derive(Clone, Default, PartialEq, Eq)] @@ -533,68 +312,6 @@ pub mod g1 { let h_eff = one_minus_x().into_bigint(); Config::mul_affine(p, h_eff.as_ref()).into() } - - fn deserialize_with_mode( - mut reader: R, - compress: ark_serialize::Compress, - validate: ark_serialize::Validate, - ) -> Result, ark_serialize::SerializationError> { - let p = if compress == ark_serialize::Compress::Yes { - read_g1_compressed(&mut reader)? - } else { - read_g1_uncompressed(&mut reader)? - }; - - if validate == ark_serialize::Validate::Yes - && !p.is_in_correct_subgroup_assuming_on_curve() - { - return Err(SerializationError::InvalidData); - } - Ok(p) - } - - fn serialize_with_mode( - item: &Affine, - mut writer: W, - compress: ark_serialize::Compress, - ) -> Result<(), SerializationError> { - let encoding = EncodingFlags { - is_compressed: compress == ark_serialize::Compress::Yes, - is_infinity: item.is_zero(), - is_lexographically_largest: item.y > -item.y, - }; - let mut p = *item; - if encoding.is_infinity { - p = G1Affine::zero(); - } - // need to access the field struct `x` directly, otherwise we get None from xy() - // method - let x_bytes = serialize_fq(p.x); - if encoding.is_compressed { - let mut bytes = [0u8; G1_SERIALIZED_SIZE]; - bytes[1..].copy_from_slice(&x_bytes); - - encoding.encode_flags(&mut bytes); - writer.write_all(&bytes)?; - } else { - let mut bytes = [0u8; 2 * G1_SERIALIZED_SIZE]; - bytes[1..G1_SERIALIZED_SIZE].copy_from_slice(&x_bytes[..]); - bytes[1 + G1_SERIALIZED_SIZE..].copy_from_slice(&serialize_fq(p.y)[..]); - - encoding.encode_flags(&mut bytes); - writer.write_all(&bytes)?; - }; - - Ok(()) - } - - fn serialized_size(compress: Compress) -> usize { - if compress == Compress::Yes { - G1_SERIALIZED_SIZE - } else { - G1_SERIALIZED_SIZE * 2 - } - } } fn one_minus_x() -> Fr { @@ -624,15 +341,11 @@ pub mod g1 { } pub mod g2 { - use super::util::{ - read_g2_compressed, read_g2_uncompressed, serialize_fq, EncodingFlags, G2_SERIALIZED_SIZE, - }; use super::*; + use ark_ec::bls12; use ark_ec::models::CurveConfig; - use ark_ec::short_weierstrass::{Affine, SWCurveConfig}; - use ark_ec::{bls12, AffineRepr}; + use ark_ec::short_weierstrass::SWCurveConfig; use ark_ff::MontFp; - use ark_serialize::{Compress, SerializationError}; pub type G2Affine = bls12::G2Affine; pub type G2Projective = bls12::G2Projective; @@ -681,76 +394,6 @@ pub mod g2 { fn mul_by_a(_: Self::BaseField) -> Self::BaseField { Self::BaseField::zero() } - - fn deserialize_with_mode( - mut reader: R, - compress: ark_serialize::Compress, - validate: ark_serialize::Validate, - ) -> Result, ark_serialize::SerializationError> { - let p = if compress == ark_serialize::Compress::Yes { - read_g2_compressed(&mut reader)? - } else { - read_g2_uncompressed(&mut reader)? - }; - - if validate == ark_serialize::Validate::Yes - && !p.is_in_correct_subgroup_assuming_on_curve() - { - return Err(SerializationError::InvalidData); - } - Ok(p) - } - - fn serialize_with_mode( - item: &Affine, - mut writer: W, - compress: ark_serialize::Compress, - ) -> Result<(), SerializationError> { - let encoding = EncodingFlags { - is_compressed: compress == ark_serialize::Compress::Yes, - is_infinity: item.is_zero(), - is_lexographically_largest: item.y > -item.y, - }; - let mut p = *item; - if encoding.is_infinity { - p = G2Affine::zero(); - } - - let mut x_bytes = [0u8; G2_SERIALIZED_SIZE]; - let c1_bytes = serialize_fq(p.x.c1); - let c0_bytes = serialize_fq(p.x.c0); - x_bytes[1..56 + 1].copy_from_slice(&c1_bytes[..]); - x_bytes[56 + 2..114].copy_from_slice(&c0_bytes[..]); - if encoding.is_compressed { - let mut bytes: [u8; G2_SERIALIZED_SIZE] = x_bytes; - - encoding.encode_flags(&mut bytes); - writer.write_all(&bytes)?; - } else { - let mut bytes = [0u8; 2 * G2_SERIALIZED_SIZE]; - - let mut y_bytes = [0u8; G2_SERIALIZED_SIZE]; - let c1_bytes = serialize_fq(p.y.c1); - let c0_bytes = serialize_fq(p.y.c0); - y_bytes[1..56 + 1].copy_from_slice(&c1_bytes[..]); - y_bytes[56 + 2..114].copy_from_slice(&c0_bytes[..]); - bytes[0..G2_SERIALIZED_SIZE].copy_from_slice(&x_bytes); - bytes[G2_SERIALIZED_SIZE..].copy_from_slice(&y_bytes); - - encoding.encode_flags(&mut bytes); - writer.write_all(&bytes)?; - }; - - Ok(()) - } - - fn serialized_size(compress: ark_serialize::Compress) -> usize { - if compress == Compress::Yes { - G2_SERIALIZED_SIZE - } else { - 2 * G2_SERIALIZED_SIZE - } - } } pub const G2_GENERATOR_X: Fq2 = Fq2::new(G2_GENERATOR_X_C0, G2_GENERATOR_X_C1); diff --git a/tfhe-zk-pok/src/curve_api.rs b/tfhe-zk-pok/src/curve_api.rs index 75f3b3ccfc..8bf60f37ea 100644 --- a/tfhe-zk-pok/src/curve_api.rs +++ b/tfhe-zk-pok/src/curve_api.rs @@ -1,29 +1,14 @@ +use ark_ec::pairing::PairingOutput; +use ark_ec::short_weierstrass::Affine; use ark_ec::{AdditiveGroup as Group, CurveGroup, VariableBaseMSM}; use ark_ff::{BigInt, Field, MontFp, Zero}; use ark_poly::univariate::DensePolynomial; -use ark_serialize::{CanonicalDeserialize, CanonicalSerialize, Compress, Validate}; use core::fmt; use core::ops::{Add, AddAssign, Div, Mul, Neg, Sub, SubAssign}; use serde::{Deserialize, Serialize}; +use tfhe_versionable::NotVersioned; -fn ark_se(a: &A, s: S) -> Result -where - S: serde::Serializer, -{ - let mut bytes = vec![]; - a.serialize_with_mode(&mut bytes, Compress::Yes) - .map_err(serde::ser::Error::custom)?; - s.serialize_bytes(&bytes) -} - -fn ark_de<'de, D, A: CanonicalDeserialize>(data: D) -> Result -where - D: serde::de::Deserializer<'de>, -{ - let s: Vec = serde::de::Deserialize::deserialize(data)?; - let a = A::deserialize_with_mode(s.as_slice(), Compress::Yes, Validate::Yes); - a.map_err(serde::de::Error::custom) -} +use crate::serialization::{SerializableAffine, SerializableFp, SerializableFp12, SerializableFp2}; struct MontIntDisplay<'a, T>(&'a T); @@ -62,7 +47,7 @@ pub trait FieldOps: fn from_u128(n: u128) -> Self; fn from_u64(n: u64) -> Self; fn from_i64(n: i64) -> Self; - fn to_bytes(self) -> impl AsRef<[u8]>; + fn to_le_bytes(self) -> impl AsRef<[u8]>; fn rand(rng: &mut dyn rand::RngCore) -> Self; fn hash(values: &mut [Self], data: &[&[u8]]); fn hash_128bit(values: &mut [Self], data: &[&[u8]]); @@ -122,19 +107,27 @@ pub trait CurveGroupOps: + Sync + core::fmt::Debug + serde::Serialize - + for<'de> serde::Deserialize<'de> - + CanonicalSerialize - + CanonicalDeserialize; + + for<'de> serde::Deserialize<'de>; fn projective(affine: Self::Affine) -> Self; fn mul_scalar(self, scalar: Zp) -> Self; fn multi_mul_scalar(bases: &[Self::Affine], scalars: &[Zp]) -> Self; - fn to_bytes(self) -> impl AsRef<[u8]>; + fn to_le_bytes(self) -> impl AsRef<[u8]>; fn double(self) -> Self; fn normalize(self) -> Self::Affine; } +/// Mark that an element can be compressed, by storing only the 'x' coordinates of the affine +/// representation and getting the 'y' from the curve. +pub trait Compressible: Sized { + type Compressed; + type UncompressError; + + fn compress(&self) -> Self::Compressed; + fn uncompress(compressed: Self::Compressed) -> Result; +} + pub trait PairingGroupOps: Copy + Send @@ -151,10 +144,10 @@ pub trait PairingGroupOps: fn pairing(x: G1, y: G2) -> Self; } -pub trait Curve { +pub trait Curve: Clone { type Zp: FieldOps; - type G1: CurveGroupOps + CanonicalSerialize + CanonicalDeserialize; - type G2: CurveGroupOps + CanonicalSerialize + CanonicalDeserialize; + type G1: CurveGroupOps; + type G2: CurveGroupOps; type Gt: PairingGroupOps; } @@ -171,8 +164,8 @@ impl FieldOps for bls12_381::Zp { fn from_i64(n: i64) -> Self { Self::from_i64(n) } - fn to_bytes(self) -> impl AsRef<[u8]> { - self.to_bytes() + fn to_le_bytes(self) -> impl AsRef<[u8]> { + self.to_le_bytes() } fn rand(rng: &mut dyn rand::RngCore) -> Self { Self::rand(rng) @@ -222,8 +215,8 @@ impl CurveGroupOps for bls12_381::G1 { Self::Affine::multi_mul_scalar(bases, scalars) } - fn to_bytes(self) -> impl AsRef<[u8]> { - self.to_bytes() + fn to_le_bytes(self) -> impl AsRef<[u8]> { + self.to_le_bytes() } fn double(self) -> Self { @@ -262,8 +255,8 @@ impl CurveGroupOps for bls12_381::G2 { Self::Affine::multi_mul_scalar(bases, scalars) } - fn to_bytes(self) -> impl AsRef<[u8]> { - self.to_bytes() + fn to_le_bytes(self) -> impl AsRef<[u8]> { + self.to_le_bytes() } fn double(self) -> Self { @@ -303,8 +296,8 @@ impl FieldOps for bls12_446::Zp { fn from_i64(n: i64) -> Self { Self::from_i64(n) } - fn to_bytes(self) -> impl AsRef<[u8]> { - self.to_bytes() + fn to_le_bytes(self) -> impl AsRef<[u8]> { + self.to_le_bytes() } fn rand(rng: &mut dyn rand::RngCore) -> Self { Self::rand(rng) @@ -359,8 +352,8 @@ impl CurveGroupOps for bls12_446::G1 { } } - fn to_bytes(self) -> impl AsRef<[u8]> { - self.to_bytes() + fn to_le_bytes(self) -> impl AsRef<[u8]> { + self.to_le_bytes() } fn double(self) -> Self { @@ -399,8 +392,8 @@ impl CurveGroupOps for bls12_446::G2 { Self::Affine::multi_mul_scalar(bases, scalars) } - fn to_bytes(self) -> impl AsRef<[u8]> { - self.to_bytes() + fn to_le_bytes(self) -> impl AsRef<[u8]> { + self.to_le_bytes() } fn double(self) -> Self { @@ -427,9 +420,11 @@ impl PairingGroupOps for bls12_446: } } -#[derive(Debug, Copy, Clone, serde::Serialize, serde::Deserialize)] +// These are just ZSTs that are not actually produced and are only used for their +// associated types. So it's ok to derive "NotVersioned" for them. +#[derive(Debug, Copy, Clone, serde::Serialize, serde::Deserialize, NotVersioned)] pub struct Bls12_381; -#[derive(Debug, Copy, Clone, serde::Serialize, serde::Deserialize)] +#[derive(Debug, Copy, Clone, serde::Serialize, serde::Deserialize, NotVersioned)] pub struct Bls12_446; impl Curve for Bls12_381 { diff --git a/tfhe-zk-pok/src/curve_api/bls12_381.rs b/tfhe-zk-pok/src/curve_api/bls12_381.rs index db579165b1..ccd8f1da18 100644 --- a/tfhe-zk-pok/src/curve_api/bls12_381.rs +++ b/tfhe-zk-pok/src/curve_api/bls12_381.rs @@ -25,7 +25,7 @@ fn mul_zp + Group>(x: T, scalar: Zp) -> T { y } -fn bigint_to_bytes(x: [u64; 6]) -> [u8; 6 * 8] { +fn bigint_to_le_bytes(x: [u64; 6]) -> [u8; 6 * 8] { let mut buf = [0u8; 6 * 8]; for (i, &xi) in x.iter().enumerate() { buf[i * 8..][..8].copy_from_slice(&xi.to_le_bytes()); @@ -34,26 +34,54 @@ fn bigint_to_bytes(x: [u64; 6]) -> [u8; 6 * 8] { } mod g1 { + use tfhe_versionable::Versionize; + + use crate::backward_compatibility::SerializableG1AffineVersions; + use crate::serialization::{InvalidSerializedAffineError, SerializableG1Affine}; + use super::*; - #[derive( - Copy, - Clone, - Debug, - PartialEq, - Eq, - Serialize, - Deserialize, - Hash, - CanonicalSerialize, - CanonicalDeserialize, + #[derive(Copy, Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Hash, Versionize)] + #[serde(try_from = "SerializableG1Affine", into = "SerializableG1Affine")] + #[versionize( + SerializableG1AffineVersions, + try_from = "SerializableG1Affine", + into = "SerializableG1Affine" )] #[repr(transparent)] pub struct G1Affine { - #[serde(serialize_with = "ark_se", deserialize_with = "ark_de")] pub(crate) inner: ark_bls12_381::g1::G1Affine, } + impl From for SerializableAffine { + fn from(value: G1Affine) -> Self { + SerializableAffine::uncompressed(value.inner) + } + } + + impl TryFrom> for G1Affine { + type Error = InvalidSerializedAffineError; + + fn try_from(value: SerializableAffine) -> Result { + Ok(Self { + inner: value.try_into()?, + }) + } + } + + impl Compressible for G1Affine { + type Compressed = SerializableG1Affine; + type UncompressError = InvalidSerializedAffineError; + + fn compress(&self) -> SerializableG1Affine { + SerializableAffine::compressed(self.inner) + } + + fn uncompress(compressed: Self::Compressed) -> Result { + compressed.try_into() + } + } + impl G1Affine { pub fn multi_mul_scalar(bases: &[Self], scalars: &[Zp]) -> G1 { // SAFETY: interpreting a `repr(transparent)` pointer as its contents. @@ -69,23 +97,47 @@ mod g1 { } } - #[derive( - Copy, - Clone, - PartialEq, - Eq, - Serialize, - Deserialize, - Hash, - CanonicalSerialize, - CanonicalDeserialize, + #[derive(Copy, Clone, PartialEq, Eq, Serialize, Deserialize, Hash, Versionize)] + #[serde(try_from = "SerializableG1Affine", into = "SerializableG1Affine")] + #[versionize( + SerializableG1AffineVersions, + try_from = "SerializableG1Affine", + into = "SerializableG1Affine" )] #[repr(transparent)] pub struct G1 { - #[serde(serialize_with = "ark_se", deserialize_with = "ark_de")] pub(crate) inner: ark_bls12_381::G1Projective, } + impl From for SerializableAffine { + fn from(value: G1) -> Self { + SerializableAffine::uncompressed(value.inner.into_affine()) + } + } + + impl TryFrom for G1 { + type Error = InvalidSerializedAffineError; + + fn try_from(value: SerializableAffine) -> Result { + Ok(Self { + inner: Affine::try_from(value)?.into(), + }) + } + } + + impl Compressible for G1 { + type Compressed = SerializableG1Affine; + type UncompressError = InvalidSerializedAffineError; + + fn compress(&self) -> SerializableG1Affine { + SerializableAffine::compressed(self.inner.into_affine()) + } + + fn uncompress(compressed: Self::Compressed) -> Result { + compressed.try_into() + } + } + impl fmt::Debug for G1 { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("G1") @@ -114,7 +166,7 @@ mod g1 { }, }; - // Size in number of bytes when the [to_bytes] + // Size in number of bytes when the [to_le_bytes] // function is called. // This is not the size after serialization! pub const BYTE_SIZE: usize = 2 * 6 * 8 + 1; @@ -140,10 +192,10 @@ mod g1 { .sum::() } - pub fn to_bytes(self) -> [u8; Self::BYTE_SIZE] { + pub fn to_le_bytes(self) -> [u8; Self::BYTE_SIZE] { let g = self.inner.into_affine(); - let x = bigint_to_bytes(g.x.0 .0); - let y = bigint_to_bytes(g.y.0 .0); + let x = bigint_to_le_bytes(g.x.0 .0); + let y = bigint_to_le_bytes(g.y.0 .0); let mut buf = [0u8; 2 * 6 * 8 + 1]; buf[..6 * 8].copy_from_slice(&x); buf[6 * 8..][..6 * 8].copy_from_slice(&y); @@ -210,26 +262,55 @@ mod g1 { } mod g2 { + use tfhe_versionable::Versionize; + + use crate::backward_compatibility::SerializableG2AffineVersions; + use crate::serialization::{InvalidSerializedAffineError, SerializableG2Affine}; + use super::*; - #[derive( - Copy, - Clone, - Debug, - PartialEq, - Eq, - Serialize, - Deserialize, - Hash, - CanonicalSerialize, - CanonicalDeserialize, + #[derive(Copy, Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Hash, Versionize)] + #[serde(try_from = "SerializableG2Affine", into = "SerializableG2Affine")] + #[versionize( + SerializableG2AffineVersions, + try_from = "SerializableG2Affine", + into = "SerializableG2Affine" )] #[repr(transparent)] pub struct G2Affine { - #[serde(serialize_with = "ark_se", deserialize_with = "ark_de")] pub(crate) inner: ark_bls12_381::g2::G2Affine, } + impl From for SerializableAffine { + fn from(value: G2Affine) -> Self { + SerializableAffine::uncompressed(value.inner) + } + } + + impl TryFrom> for G2Affine { + type Error = InvalidSerializedAffineError; + + fn try_from(value: SerializableAffine) -> Result { + Ok(Self { + inner: value.try_into()?, + }) + } + } + + impl Compressible for G2Affine { + type Compressed = SerializableG2Affine; + + type UncompressError = InvalidSerializedAffineError; + + fn compress(&self) -> SerializableAffine { + SerializableAffine::compressed(self.inner) + } + + fn uncompress(compressed: Self::Compressed) -> Result { + compressed.try_into() + } + } + impl G2Affine { pub fn multi_mul_scalar(bases: &[Self], scalars: &[Zp]) -> G2 { // SAFETY: interpreting a `repr(transparent)` pointer as its contents. @@ -245,23 +326,48 @@ mod g2 { } } - #[derive( - Copy, - Clone, - PartialEq, - Eq, - Serialize, - Deserialize, - Hash, - CanonicalSerialize, - CanonicalDeserialize, + #[derive(Copy, Clone, PartialEq, Eq, Serialize, Deserialize, Hash, Versionize)] + #[serde(try_from = "SerializableG2Affine", into = "SerializableG2Affine")] + #[versionize( + SerializableG2AffineVersions, + try_from = "SerializableG2Affine", + into = "SerializableG2Affine" )] #[repr(transparent)] pub struct G2 { - #[serde(serialize_with = "ark_se", deserialize_with = "ark_de")] pub(crate) inner: ark_bls12_381::G2Projective, } + impl From for SerializableG2Affine { + fn from(value: G2) -> Self { + SerializableAffine::uncompressed(value.inner.into_affine()) + } + } + + impl TryFrom for G2 { + type Error = InvalidSerializedAffineError; + + fn try_from(value: SerializableAffine) -> Result { + Ok(Self { + inner: Affine::try_from(value)?.into(), + }) + } + } + + impl Compressible for G2 { + type Compressed = SerializableG2Affine; + + type UncompressError = InvalidSerializedAffineError; + + fn compress(&self) -> SerializableAffine { + SerializableAffine::compressed(self.inner.into_affine()) + } + + fn uncompress(compressed: Self::Compressed) -> Result { + compressed.try_into() + } + } + impl fmt::Debug for G2 { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { #[allow(dead_code)] @@ -333,7 +439,7 @@ mod g2 { }, }; - // Size in number of bytes when the [to_bytes] + // Size in number of bytes when the [to_le_bytes] // function is called. // This is not the size after serialization! pub const BYTE_SIZE: usize = 4 * 6 * 8 + 1; @@ -359,12 +465,12 @@ mod g2 { .sum::() } - pub fn to_bytes(self) -> [u8; Self::BYTE_SIZE] { + pub fn to_le_bytes(self) -> [u8; Self::BYTE_SIZE] { let g = self.inner.into_affine(); - let xc0 = bigint_to_bytes(g.x.c0.0 .0); - let xc1 = bigint_to_bytes(g.x.c1.0 .0); - let yc0 = bigint_to_bytes(g.y.c0.0 .0); - let yc1 = bigint_to_bytes(g.y.c1.0 .0); + let xc0 = bigint_to_le_bytes(g.x.c0.0 .0); + let xc1 = bigint_to_le_bytes(g.x.c1.0 .0); + let yc0 = bigint_to_le_bytes(g.y.c0.0 .0); + let yc1 = bigint_to_le_bytes(g.y.c1.0 .0); let mut buf = [0u8; 4 * 6 * 8 + 1]; buf[..6 * 8].copy_from_slice(&xc0); buf[6 * 8..][..6 * 8].copy_from_slice(&xc1); @@ -433,16 +539,41 @@ mod g2 { } mod gt { + use crate::backward_compatibility::SerializableFp12Versions; + use crate::serialization::InvalidArraySizeError; + use super::*; use ark_ec::pairing::Pairing; - - #[derive(Copy, Clone, PartialEq, Eq, Serialize, Deserialize, Hash)] + use tfhe_versionable::Versionize; + + #[derive(Copy, Clone, PartialEq, Eq, Serialize, Deserialize, Versionize, Hash)] + #[serde(try_from = "SerializableFp12", into = "SerializableFp12")] + #[versionize( + SerializableFp12Versions, + try_from = "SerializableFp12", + into = "SerializableFp12" + )] #[repr(transparent)] pub struct Gt { - #[serde(serialize_with = "ark_se", deserialize_with = "ark_de")] inner: ark_ec::pairing::PairingOutput, } + impl From for SerializableFp12 { + fn from(value: Gt) -> Self { + value.inner.0.into() + } + } + + impl TryFrom for Gt { + type Error = InvalidArraySizeError; + + fn try_from(value: SerializableFp12) -> Result { + Ok(Self { + inner: PairingOutput(value.try_into()?), + }) + } + } + impl fmt::Debug for Gt { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { #[allow(dead_code)] @@ -566,8 +697,12 @@ mod gt { } mod zp { + use crate::backward_compatibility::SerializableFpVersions; + use crate::serialization::InvalidArraySizeError; + use super::*; use ark_ff::Fp; + use tfhe_versionable::Versionize; use zeroize::Zeroize; fn redc(n: [u64; 4], nprime: u64, mut t: [u64; 6]) -> [u64; 4] { @@ -604,13 +739,33 @@ mod zp { t } - #[derive(Copy, Clone, PartialEq, Eq, Serialize, Deserialize, Hash, Zeroize)] + #[derive(Copy, Clone, PartialEq, Eq, Serialize, Deserialize, Versionize, Hash, Zeroize)] + #[serde(try_from = "SerializableFp", into = "SerializableFp")] + #[versionize( + SerializableFpVersions, + try_from = "SerializableFp", + into = "SerializableFp" + )] #[repr(transparent)] pub struct Zp { - #[serde(serialize_with = "ark_se", deserialize_with = "ark_de")] pub(crate) inner: ark_bls12_381::Fr, } + impl From for SerializableFp { + fn from(value: Zp) -> Self { + value.inner.into() + } + } + impl TryFrom for Zp { + type Error = InvalidArraySizeError; + + fn try_from(value: SerializableFp) -> Result { + Ok(Self { + inner: value.try_into()?, + }) + } + } + impl fmt::Debug for Zp { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_tuple("Zp") @@ -649,7 +804,7 @@ mod zp { } } - pub fn to_bytes(self) -> [u8; 4 * 8] { + pub fn to_le_bytes(self) -> [u8; 4 * 8] { let buf = [ self.inner.0 .0[0].to_le_bytes(), self.inner.0 .0[1].to_le_bytes(), @@ -851,6 +1006,26 @@ mod tests { assert_eq!(g_hat_cur, g_hat_cur2); } + #[test] + fn test_compressed_serialization() { + let rng = &mut StdRng::seed_from_u64(0); + let alpha = Zp::rand(rng); + let g_cur = G1::GENERATOR.mul_scalar(alpha); + let g_hat_cur = G2::GENERATOR.mul_scalar(alpha); + + let g_cur2 = G1::uncompress( + serde_json::from_str(&serde_json::to_string(&g_cur.compress()).unwrap()).unwrap(), + ) + .unwrap(); + assert_eq!(g_cur, g_cur2); + + let g_hat_cur2 = G2::uncompress( + serde_json::from_str(&serde_json::to_string(&g_hat_cur.compress()).unwrap()).unwrap(), + ) + .unwrap(); + assert_eq!(g_hat_cur, g_hat_cur2); + } + #[test] fn test_hasher_and_eq() { // we need to make sure if the points are the same diff --git a/tfhe-zk-pok/src/curve_api/bls12_446.rs b/tfhe-zk-pok/src/curve_api/bls12_446.rs index a19bcd31de..342ef68f09 100644 --- a/tfhe-zk-pok/src/curve_api/bls12_446.rs +++ b/tfhe-zk-pok/src/curve_api/bls12_446.rs @@ -25,7 +25,7 @@ fn mul_zp + Group>(x: T, scalar: Zp) -> T { y } -fn bigint_to_bytes(x: [u64; 7]) -> [u8; 7 * 8] { +fn bigint_to_le_bytes(x: [u64; 7]) -> [u8; 7 * 8] { let mut buf = [0u8; 7 * 8]; for (i, &xi) in x.iter().enumerate() { buf[i * 8..][..8].copy_from_slice(&xi.to_le_bytes()); @@ -34,26 +34,55 @@ fn bigint_to_bytes(x: [u64; 7]) -> [u8; 7 * 8] { } mod g1 { + use tfhe_versionable::Versionize; + + use crate::backward_compatibility::SerializableG1AffineVersions; + use crate::serialization::{InvalidSerializedAffineError, SerializableG1Affine}; + use super::*; - #[derive( - Copy, - Clone, - Debug, - PartialEq, - Eq, - Serialize, - Deserialize, - Hash, - CanonicalSerialize, - CanonicalDeserialize, + #[derive(Copy, Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Hash, Versionize)] + #[serde(try_from = "SerializableG1Affine", into = "SerializableG1Affine")] + #[versionize( + SerializableG1AffineVersions, + try_from = "SerializableG1Affine", + into = "SerializableG1Affine" )] #[repr(transparent)] pub struct G1Affine { - #[serde(serialize_with = "ark_se", deserialize_with = "ark_de")] pub(crate) inner: crate::curve_446::g1::G1Affine, } + impl From for SerializableAffine { + fn from(value: G1Affine) -> Self { + SerializableAffine::uncompressed(value.inner) + } + } + + impl TryFrom> for G1Affine { + type Error = InvalidSerializedAffineError; + + fn try_from(value: SerializableAffine) -> Result { + Ok(Self { + inner: value.try_into()?, + }) + } + } + + impl Compressible for G1Affine { + type Compressed = SerializableG1Affine; + + type UncompressError = InvalidSerializedAffineError; + + fn compress(&self) -> Self::Compressed { + SerializableAffine::compressed(self.inner) + } + + fn uncompress(compressed: Self::Compressed) -> Result { + compressed.try_into() + } + } + impl G1Affine { #[track_caller] pub fn multi_mul_scalar(bases: &[Self], scalars: &[Zp]) -> G1 { @@ -70,23 +99,48 @@ mod g1 { } } - #[derive( - Copy, - Clone, - PartialEq, - Eq, - Serialize, - Deserialize, - Hash, - CanonicalSerialize, - CanonicalDeserialize, + #[derive(Copy, Clone, PartialEq, Eq, Serialize, Deserialize, Hash, Versionize)] + #[serde(try_from = "SerializableG1Affine", into = "SerializableG1Affine")] + #[versionize( + SerializableG1AffineVersions, + try_from = "SerializableG1Affine", + into = "SerializableG1Affine" )] #[repr(transparent)] pub struct G1 { - #[serde(serialize_with = "ark_se", deserialize_with = "ark_de")] pub(crate) inner: crate::curve_446::g1::G1Projective, } + impl From for SerializableG1Affine { + fn from(value: G1) -> Self { + SerializableAffine::uncompressed(value.inner.into_affine()) + } + } + + impl TryFrom for G1 { + type Error = InvalidSerializedAffineError; + + fn try_from(value: SerializableG1Affine) -> Result { + Ok(Self { + inner: Affine::try_from(value)?.into(), + }) + } + } + + impl Compressible for G1 { + type Compressed = SerializableG1Affine; + + type UncompressError = InvalidSerializedAffineError; + + fn compress(&self) -> Self::Compressed { + SerializableAffine::compressed(self.inner.into_affine()) + } + + fn uncompress(compressed: Self::Compressed) -> Result { + compressed.try_into() + } + } + impl fmt::Debug for G1 { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("G1") @@ -114,7 +168,7 @@ mod g1 { }, }; - // Size in number of bytes when the [to_bytes] + // Size in number of bytes when the [to_le_bytes] // function is called. // This is not the size after serialization! pub const BYTE_SIZE: usize = 2 * 7 * 8 + 1; @@ -141,10 +195,10 @@ mod g1 { } } - pub fn to_bytes(self) -> [u8; Self::BYTE_SIZE] { + pub fn to_le_bytes(self) -> [u8; Self::BYTE_SIZE] { let g = self.inner.into_affine(); - let x = bigint_to_bytes(g.x.0 .0); - let y = bigint_to_bytes(g.y.0 .0); + let x = bigint_to_le_bytes(g.x.0 .0); + let y = bigint_to_le_bytes(g.y.0 .0); let mut buf = [0u8; 2 * 7 * 8 + 1]; buf[..7 * 8].copy_from_slice(&x); buf[7 * 8..][..7 * 8].copy_from_slice(&y); @@ -211,26 +265,56 @@ mod g1 { } mod g2 { - use super::*; + use tfhe_versionable::Versionize; - #[derive( - Copy, - Clone, - Debug, - PartialEq, - Eq, - Serialize, - Deserialize, - Hash, - CanonicalSerialize, - CanonicalDeserialize, + use crate::backward_compatibility::SerializableG2AffineVersions; + use crate::serialization::SerializableG2Affine; + + use super::*; + use crate::serialization::InvalidSerializedAffineError; + + #[derive(Copy, Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Hash, Versionize)] + #[serde(try_from = "SerializableG2Affine", into = "SerializableG2Affine")] + #[versionize( + SerializableG2AffineVersions, + try_from = "SerializableG2Affine", + into = "SerializableG2Affine" )] #[repr(transparent)] pub struct G2Affine { - #[serde(serialize_with = "ark_se", deserialize_with = "ark_de")] pub(crate) inner: crate::curve_446::g2::G2Affine, } + impl From for SerializableG2Affine { + fn from(value: G2Affine) -> Self { + SerializableAffine::uncompressed(value.inner) + } + } + + impl TryFrom for G2Affine { + type Error = InvalidSerializedAffineError; + + fn try_from(value: SerializableG2Affine) -> Result { + Ok(Self { + inner: value.try_into()?, + }) + } + } + + impl Compressible for G2Affine { + type Compressed = SerializableG2Affine; + + type UncompressError = InvalidSerializedAffineError; + + fn compress(&self) -> Self::Compressed { + SerializableAffine::compressed(self.inner) + } + + fn uncompress(compressed: Self::Compressed) -> Result { + compressed.try_into() + } + } + impl G2Affine { #[track_caller] pub fn multi_mul_scalar(bases: &[Self], scalars: &[Zp]) -> G2 { @@ -337,23 +421,48 @@ mod g2 { } } - #[derive( - Copy, - Clone, - PartialEq, - Eq, - Serialize, - Deserialize, - Hash, - CanonicalSerialize, - CanonicalDeserialize, + #[derive(Copy, Clone, PartialEq, Eq, Serialize, Deserialize, Hash, Versionize)] + #[serde(try_from = "SerializableG2Affine", into = "SerializableG2Affine")] + #[versionize( + SerializableG2AffineVersions, + try_from = "SerializableG2Affine", + into = "SerializableG2Affine" )] #[repr(transparent)] pub struct G2 { - #[serde(serialize_with = "ark_se", deserialize_with = "ark_de")] pub(crate) inner: crate::curve_446::g2::G2Projective, } + impl From for SerializableG2Affine { + fn from(value: G2) -> Self { + SerializableAffine::uncompressed(value.inner.into_affine()) + } + } + + impl TryFrom for G2 { + type Error = InvalidSerializedAffineError; + + fn try_from(value: SerializableG2Affine) -> Result { + Ok(Self { + inner: Affine::try_from(value)?.into(), + }) + } + } + + impl Compressible for G2 { + type Compressed = SerializableG2Affine; + + type UncompressError = InvalidSerializedAffineError; + + fn compress(&self) -> Self::Compressed { + SerializableAffine::compressed(self.inner.into_affine()) + } + + fn uncompress(compressed: Self::Compressed) -> Result { + compressed.try_into() + } + } + impl fmt::Debug for G2 { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { #[allow(dead_code)] @@ -424,7 +533,7 @@ mod g2 { }, }; - // Size in number of bytes when the [to_bytes] + // Size in number of bytes when the [to_le_bytes] // function is called. // This is not the size after serialization! pub const BYTE_SIZE: usize = 4 * 7 * 8 + 1; @@ -450,12 +559,12 @@ mod g2 { .sum::() } - pub fn to_bytes(self) -> [u8; Self::BYTE_SIZE] { + pub fn to_le_bytes(self) -> [u8; Self::BYTE_SIZE] { let g = self.inner.into_affine(); - let xc0 = bigint_to_bytes(g.x.c0.0 .0); - let xc1 = bigint_to_bytes(g.x.c1.0 .0); - let yc0 = bigint_to_bytes(g.y.c0.0 .0); - let yc1 = bigint_to_bytes(g.y.c1.0 .0); + let xc0 = bigint_to_le_bytes(g.x.c0.0 .0); + let xc1 = bigint_to_le_bytes(g.x.c1.0 .0); + let yc0 = bigint_to_le_bytes(g.y.c0.0 .0); + let yc1 = bigint_to_le_bytes(g.y.c1.0 .0); let mut buf = [0u8; 4 * 7 * 8 + 1]; buf[..7 * 8].copy_from_slice(&xc0); buf[7 * 8..][..7 * 8].copy_from_slice(&xc1); @@ -524,11 +633,14 @@ mod g2 { } mod gt { + use crate::backward_compatibility::SerializableFp12Versions; use crate::curve_446::{Fq, Fq12, Fq2}; + use crate::serialization::InvalidSerializedAffineError; use super::*; use ark_ec::pairing::{MillerLoopOutput, Pairing}; use ark_ff::{CubicExtField, QuadExtField}; + use tfhe_versionable::Versionize; type Bls = crate::curve_446::Bls12_446; @@ -698,13 +810,34 @@ mod gt { } } - #[derive(Copy, Clone, PartialEq, Eq, Serialize, Deserialize, Hash)] + #[derive(Copy, Clone, PartialEq, Eq, Serialize, Deserialize, Versionize, Hash)] + #[serde(try_from = "SerializableFp12", into = "SerializableFp12")] + #[versionize( + SerializableFp12Versions, + try_from = "SerializableFp12", + into = "SerializableFp12" + )] #[repr(transparent)] pub struct Gt { - #[serde(serialize_with = "ark_se", deserialize_with = "ark_de")] pub(crate) inner: ark_ec::pairing::PairingOutput, } + impl From for SerializableFp12 { + fn from(value: Gt) -> Self { + value.inner.0.into() + } + } + + impl TryFrom for Gt { + type Error = InvalidSerializedAffineError; + + fn try_from(value: SerializableFp12) -> Result { + Ok(Self { + inner: PairingOutput(value.try_into()?), + }) + } + } + impl fmt::Debug for Gt { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { #[allow(dead_code)] @@ -826,8 +959,12 @@ mod gt { } mod zp { + use crate::backward_compatibility::SerializableFpVersions; + use super::*; + use crate::serialization::InvalidArraySizeError; use ark_ff::Fp; + use tfhe_versionable::Versionize; use zeroize::Zeroize; fn redc(n: [u64; 5], nprime: u64, mut t: [u64; 7]) -> [u64; 5] { @@ -864,13 +1001,33 @@ mod zp { t } - #[derive(Copy, Clone, PartialEq, Eq, Serialize, Deserialize, Hash, Zeroize)] + #[derive(Copy, Clone, PartialEq, Eq, Serialize, Deserialize, Versionize, Hash, Zeroize)] + #[serde(try_from = "SerializableFp", into = "SerializableFp")] + #[versionize( + SerializableFpVersions, + try_from = "SerializableFp", + into = "SerializableFp" + )] #[repr(transparent)] pub struct Zp { - #[serde(serialize_with = "ark_se", deserialize_with = "ark_de")] pub inner: crate::curve_446::Fr, } + impl From for SerializableFp { + fn from(value: Zp) -> Self { + value.inner.into() + } + } + impl TryFrom for Zp { + type Error = InvalidArraySizeError; + + fn try_from(value: SerializableFp) -> Result { + Ok(Self { + inner: value.try_into()?, + }) + } + } + impl fmt::Debug for Zp { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_tuple("Zp") @@ -909,7 +1066,7 @@ mod zp { } } - pub fn to_bytes(self) -> [u8; 5 * 8] { + pub fn to_le_bytes(self) -> [u8; 5 * 8] { let buf = [ self.inner.0 .0[0].to_le_bytes(), self.inner.0 .0[1].to_le_bytes(), @@ -1175,6 +1332,26 @@ mod tests { assert_eq!(g_hat_cur, g_hat_cur2); } + #[test] + fn test_compressed_serialization() { + let rng = &mut StdRng::seed_from_u64(0); + let alpha = Zp::rand(rng); + let g_cur = G1::GENERATOR.mul_scalar(alpha); + let g_hat_cur = G2::GENERATOR.mul_scalar(alpha); + + let g_cur2 = G1::uncompress( + serde_json::from_str(&serde_json::to_string(&g_cur.compress()).unwrap()).unwrap(), + ) + .unwrap(); + assert_eq!(g_cur, g_cur2); + + let g_hat_cur2 = G2::uncompress( + serde_json::from_str(&serde_json::to_string(&g_hat_cur.compress()).unwrap()).unwrap(), + ) + .unwrap(); + assert_eq!(g_hat_cur, g_hat_cur2); + } + #[test] fn test_hasher_and_eq() { // we need to make sure if the points are the same diff --git a/tfhe-zk-pok/src/lib.rs b/tfhe-zk-pok/src/lib.rs index 350888e6ba..5b1c45bc87 100644 --- a/tfhe-zk-pok/src/lib.rs +++ b/tfhe-zk-pok/src/lib.rs @@ -1,7 +1,7 @@ -pub use ark_serialize::{CanonicalDeserialize, CanonicalSerialize, Compress, Validate}; - pub mod curve_446; pub mod curve_api; pub mod proofs; +pub mod serialization; +pub mod backward_compatibility; mod four_squares; diff --git a/tfhe-zk-pok/src/proofs/binary.rs b/tfhe-zk-pok/src/proofs/binary.rs index 9ba8c3a867..5fc684b6ae 100644 --- a/tfhe-zk-pok/src/proofs/binary.rs +++ b/tfhe-zk-pok/src/proofs/binary.rs @@ -102,7 +102,7 @@ pub fn prove( let g_list = &public.0.g_lists.g_list; let mut y = OneBased(vec![G::Zp::ZERO; n]); - G::Zp::hash(&mut y.0, &[&public.0.hash, c_hat.to_bytes().as_ref()]); + G::Zp::hash(&mut y.0, &[&public.0.hash, c_hat.to_le_bytes().as_ref()]); let mut c_y = g.mul_scalar(gamma_y); for j in 1..n + 1 { @@ -110,7 +110,7 @@ pub fn prove( } let y_bytes = &*(1..n + 1) - .flat_map(|i| y[i].to_bytes().as_ref().to_vec()) + .flat_map(|i| y[i].to_le_bytes().as_ref().to_vec()) .collect::>(); let mut t = OneBased(vec![G::Zp::ZERO; n]); G::Zp::hash( @@ -118,8 +118,8 @@ pub fn prove( &[ &public.0.hash_t, y_bytes, - c_hat.to_bytes().as_ref(), - c_y.to_bytes().as_ref(), + c_hat.to_le_bytes().as_ref(), + c_y.to_le_bytes().as_ref(), ], ); @@ -128,8 +128,8 @@ pub fn prove( &mut delta, &[ &public.0.hash_agg, - c_hat.to_bytes().as_ref(), - c_y.to_bytes().as_ref(), + c_hat.to_le_bytes().as_ref(), + c_y.to_le_bytes().as_ref(), ], ); let [delta_eq, delta_y] = delta; @@ -191,10 +191,10 @@ pub fn verify( let c_y = proof.c_y; let mut y = OneBased(vec![G::Zp::ZERO; n]); - G::Zp::hash(&mut y.0, &[&public.0.hash, c_hat.to_bytes().as_ref()]); + G::Zp::hash(&mut y.0, &[&public.0.hash, c_hat.to_le_bytes().as_ref()]); let y_bytes = &*(1..n + 1) - .flat_map(|i| y[i].to_bytes().as_ref().to_vec()) + .flat_map(|i| y[i].to_le_bytes().as_ref().to_vec()) .collect::>(); let mut t = OneBased(vec![G::Zp::ZERO; n]); G::Zp::hash( @@ -202,8 +202,8 @@ pub fn verify( &[ &public.0.hash_t, y_bytes, - c_hat.to_bytes().as_ref(), - c_y.to_bytes().as_ref(), + c_hat.to_le_bytes().as_ref(), + c_y.to_le_bytes().as_ref(), ], ); @@ -212,8 +212,8 @@ pub fn verify( &mut delta, &[ &public.0.hash_agg, - c_hat.to_bytes().as_ref(), - c_y.to_bytes().as_ref(), + c_hat.to_le_bytes().as_ref(), + c_y.to_le_bytes().as_ref(), ], ); let [delta_eq, delta_y] = delta; diff --git a/tfhe-zk-pok/src/proofs/index.rs b/tfhe-zk-pok/src/proofs/index.rs index 29061b09b5..8f77575722 100644 --- a/tfhe-zk-pok/src/proofs/index.rs +++ b/tfhe-zk-pok/src/proofs/index.rs @@ -1,3 +1,5 @@ +use rand::RngCore; + use super::*; #[derive(Clone, Debug)] diff --git a/tfhe-zk-pok/src/proofs/mod.rs b/tfhe-zk-pok/src/proofs/mod.rs index d713c2fb32..2976ab1feb 100644 --- a/tfhe-zk-pok/src/proofs/mod.rs +++ b/tfhe-zk-pok/src/proofs/mod.rs @@ -1,44 +1,47 @@ -use crate::curve_api::{Curve, CurveGroupOps, FieldOps, PairingGroupOps}; -use ark_serialize::{ - CanonicalDeserialize, CanonicalSerialize, Compress, SerializationError, Valid, Validate, +use crate::backward_compatibility::GroupElementsVersions; +use crate::curve_api::{Compressible, Curve, CurveGroupOps, FieldOps, PairingGroupOps}; +use crate::serialization::{ + InvalidSerializedGroupElementsError, SerializableG1Affine, SerializableG2Affine, + SerializableGroupElements, }; use core::ops::{Index, IndexMut}; use rand::{Rng, RngCore}; +use tfhe_versionable::{Unversionize, Versionize, VersionizeOwned}; #[derive(Clone, Copy, Debug, serde::Serialize, serde::Deserialize)] #[repr(transparent)] -struct OneBased(T); +pub(crate) struct OneBased(T); -impl Valid for OneBased { - fn check(&self) -> Result<(), SerializationError> { - self.0.check() - } -} +// TODO: these impl could be removed by adding support for `repr(transparent)` in tfhe-versionable +impl Versionize for OneBased { + type Versioned<'vers> = T::Versioned<'vers> + where + T: 'vers, + ; -impl CanonicalDeserialize for OneBased { - fn deserialize_with_mode( - reader: R, - compress: Compress, - validate: Validate, - ) -> Result { - T::deserialize_with_mode(reader, compress, validate).map(Self) + fn versionize(&self) -> Self::Versioned<'_> { + self.0.versionize() } } -impl CanonicalSerialize for OneBased { - fn serialize_with_mode( - &self, - writer: W, - compress: Compress, - ) -> Result<(), SerializationError> { - self.0.serialize_with_mode(writer, compress) +impl VersionizeOwned for OneBased { + type VersionedOwned = T::VersionedOwned; + + fn versionize_owned(self) -> Self::VersionedOwned { + self.0.versionize_owned() } +} - fn serialized_size(&self, compress: Compress) -> usize { - self.0.serialized_size(compress) +impl Unversionize for OneBased { + fn unversionize( + versioned: Self::VersionedOwned, + ) -> Result { + Ok(Self(T::unversionize(versioned)?)) } } +/// The proving scheme is available in 2 versions, one that puts more load on the prover and one +/// that puts more load on the verifier #[derive(Copy, Clone, Debug, PartialEq, Eq)] pub enum ComputeLoad { Proof, @@ -76,17 +79,16 @@ impl> IndexMut for OneBased { pub type Affine = >::Affine; -#[derive( - Clone, Debug, serde::Serialize, serde::Deserialize, CanonicalSerialize, CanonicalDeserialize, -)] +#[derive(Clone, Debug, serde::Serialize, serde::Deserialize, Versionize)] #[serde(bound( deserialize = "G: Curve, G::G1: serde::Deserialize<'de>, G::G2: serde::Deserialize<'de>", serialize = "G: Curve, G::G1: serde::Serialize, G::G2: serde::Serialize" ))] -struct GroupElements { - g_list: OneBased>>, - g_hat_list: OneBased>>, - message_len: usize, +#[versionize(GroupElementsVersions)] +pub(crate) struct GroupElements { + pub(crate) g_list: OneBased>>, + pub(crate) g_hat_list: OneBased>>, + pub(crate) message_len: usize, } impl GroupElements { @@ -136,6 +138,34 @@ impl GroupElements { } } +impl Compressible for GroupElements +where + GroupElements: + TryFrom, + >::Affine: Compressible, + >::Affine: Compressible, +{ + type Compressed = SerializableGroupElements; + + type UncompressError = InvalidSerializedGroupElementsError; + + fn compress(&self) -> Self::Compressed { + let mut g_list = Vec::new(); + let mut g_hat_list = Vec::new(); + for idx in 0..self.message_len { + g_list.push(self.g_list[(idx * 2) + 1].compress()); + g_list.push(self.g_list[(idx * 2) + 2].compress()); + g_hat_list.push(self.g_hat_list[idx + 1].compress()) + } + + SerializableGroupElements { g_list, g_hat_list } + } + + fn uncompress(compressed: Self::Compressed) -> Result { + Self::try_from(compressed) + } +} + pub const HASH_METADATA_LEN_BYTES: usize = 256; pub mod binary; diff --git a/tfhe-zk-pok/src/proofs/pke.rs b/tfhe-zk-pok/src/proofs/pke.rs index 7cb6de3906..9463ecbab7 100644 --- a/tfhe-zk-pok/src/proofs/pke.rs +++ b/tfhe-zk-pok/src/proofs/pke.rs @@ -1,17 +1,37 @@ // TODO: refactor copy-pasted code in proof/verify +use crate::backward_compatibility::{ + PKEv1CompressedProofVersions, PKEv1ProofVersions, SerializablePKEv1PublicParamsVersions, +}; +use crate::serialization::{ + try_vec_to_array, InvalidSerializedAffineError, InvalidSerializedPublicParamsError, + SerializableGroupElements, SerializablePKEv1PublicParams, +}; + use super::*; use core::marker::PhantomData; + use rayon::prelude::*; +use serde::{Deserialize, Serialize}; +use std::error::Error; +use tfhe_versionable::{UnversionizeError, VersionsDispatch}; fn bit_iter(x: u64, nbits: u32) -> impl Iterator { (0..nbits).map(move |idx| ((x >> idx) & 1) != 0) } -#[derive(Clone, Debug, CanonicalSerialize, CanonicalDeserialize)] +#[derive(Clone, Debug, Serialize, Deserialize)] +#[serde( + try_from = "SerializablePKEv1PublicParams", + into = "SerializablePKEv1PublicParams", + bound( + deserialize = "PublicParams: TryFrom", + serialize = "PublicParams: Into" + ) +)] pub struct PublicParams { - g_lists: GroupElements, - big_d: usize, + pub(crate) g_lists: GroupElements, + pub(crate) big_d: usize, pub n: usize, pub d: usize, pub k: usize, @@ -20,12 +40,148 @@ pub struct PublicParams { pub q: u64, pub t: u64, pub msbs_zero_padding_bit_count: u64, - hash: [u8; HASH_METADATA_LEN_BYTES], - hash_t: [u8; HASH_METADATA_LEN_BYTES], - hash_agg: [u8; HASH_METADATA_LEN_BYTES], - hash_lmap: [u8; HASH_METADATA_LEN_BYTES], - hash_z: [u8; HASH_METADATA_LEN_BYTES], - hash_w: [u8; HASH_METADATA_LEN_BYTES], + pub(crate) hash: [u8; HASH_METADATA_LEN_BYTES], + pub(crate) hash_t: [u8; HASH_METADATA_LEN_BYTES], + pub(crate) hash_agg: [u8; HASH_METADATA_LEN_BYTES], + pub(crate) hash_lmap: [u8; HASH_METADATA_LEN_BYTES], + pub(crate) hash_z: [u8; HASH_METADATA_LEN_BYTES], + pub(crate) hash_w: [u8; HASH_METADATA_LEN_BYTES], +} + +// Manual impl of Versionize because TryFrom + generics is currently badly handled by the proc macro +impl Versionize for PublicParams +where + Self: Clone, + SerializablePKEv1PublicParamsVersions: VersionsDispatch, + GroupElements: Into, +{ + type Versioned<'vers> = + >::Owned where G:'vers; + fn versionize(&self) -> Self::Versioned<'_> { + VersionizeOwned::versionize_owned(SerializablePKEv1PublicParams::from(self.to_owned())) + } +} + +impl VersionizeOwned for PublicParams +where + Self: Clone, + SerializablePKEv1PublicParamsVersions: VersionsDispatch, + GroupElements: Into, +{ + type VersionedOwned = >::Owned; + fn versionize_owned(self) -> Self::VersionedOwned { + VersionizeOwned::versionize_owned(SerializablePKEv1PublicParams::from(self.to_owned())) + } +} + +impl Unversionize for PublicParams +where + Self: Clone, + SerializablePKEv1PublicParamsVersions: VersionsDispatch, + GroupElements: Into, + Self: TryFrom, + E: Error + Send + Sync + 'static, +{ + fn unversionize(versioned: Self::VersionedOwned) -> Result { + SerializablePKEv1PublicParams::unversionize(versioned).and_then(|value| { + TryInto::::try_into(value) + .map_err(|e| UnversionizeError::conversion("SerializablePublicParams", e)) + }) + } +} + +impl Compressible for PublicParams +where + GroupElements: Compressible< + Compressed = SerializableGroupElements, + UncompressError = InvalidSerializedGroupElementsError, + >, +{ + type Compressed = SerializablePKEv1PublicParams; + + type UncompressError = InvalidSerializedPublicParamsError; + + fn compress(&self) -> Self::Compressed { + let PublicParams { + g_lists, + big_d, + n, + d, + k, + b, + b_r, + q, + t, + msbs_zero_padding_bit_count, + hash, + hash_t, + hash_agg, + hash_lmap, + hash_z, + hash_w, + } = self; + SerializablePKEv1PublicParams { + g_lists: g_lists.compress(), + big_d: *big_d, + n: *n, + d: *d, + k: *k, + b: *b, + b_r: *b_r, + q: *q, + t: *t, + msbs_zero_padding_bit_count: *msbs_zero_padding_bit_count, + hash: hash.to_vec(), + hash_t: hash_t.to_vec(), + hash_agg: hash_agg.to_vec(), + hash_lmap: hash_lmap.to_vec(), + hash_z: hash_z.to_vec(), + hash_w: hash_w.to_vec(), + } + } + + fn uncompress(compressed: Self::Compressed) -> Result { + let SerializablePKEv1PublicParams { + g_lists, + big_d, + n, + d, + k, + b, + b_r, + q, + t, + msbs_zero_padding_bit_count, + hash, + hash_t, + hash_agg, + hash_lmap, + hash_z, + hash_w, + } = compressed; + Ok(Self { + g_lists: GroupElements::uncompress(g_lists)?, + big_d, + n, + d, + k, + b, + b_r, + q, + t, + msbs_zero_padding_bit_count, + hash: try_vec_to_array(hash)?, + hash_t: try_vec_to_array(hash_t)?, + hash_agg: try_vec_to_array(hash_agg)?, + hash_lmap: try_vec_to_array(hash_lmap)?, + hash_z: try_vec_to_array(hash_z)?, + hash_w: try_vec_to_array(hash_w)?, + }) + } } impl PublicParams { @@ -74,11 +230,12 @@ impl PublicParams { } } -#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] +#[derive(Clone, Debug, serde::Serialize, serde::Deserialize, Versionize)] #[serde(bound( deserialize = "G: Curve, G::G1: serde::Deserialize<'de>, G::G2: serde::Deserialize<'de>", serialize = "G: Curve, G::G1: serde::Serialize, G::G2: serde::Serialize" ))] +#[versionize(PKEv1ProofVersions)] pub struct Proof { c_hat: G::G2, c_y: G::G1, @@ -88,6 +245,78 @@ pub struct Proof { pi_kzg: Option, } +type CompressedG2 = <::G2 as Compressible>::Compressed; +type CompressedG1 = <::G1 as Compressible>::Compressed; + +#[derive(Serialize, Deserialize, Versionize)] +#[serde(bound( + deserialize = "G: Curve, CompressedG1: serde::Deserialize<'de>, CompressedG2: serde::Deserialize<'de>", + serialize = "G: Curve, CompressedG1: serde::Serialize, CompressedG2: serde::Serialize" +))] +#[versionize(PKEv1CompressedProofVersions)] +pub struct CompressedProof +where + G::G1: Compressible, + G::G2: Compressible, +{ + c_hat: CompressedG2, + c_y: CompressedG1, + pi: CompressedG1, + c_hat_t: Option>, + c_h: Option>, + pi_kzg: Option>, +} + +impl Compressible for Proof +where + G::G1: Compressible, + G::G2: Compressible, +{ + type Compressed = CompressedProof; + + type UncompressError = InvalidSerializedAffineError; + + fn compress(&self) -> Self::Compressed { + let Proof { + c_hat, + c_y, + pi, + c_hat_t, + c_h, + pi_kzg, + } = self; + + CompressedProof { + c_hat: c_hat.compress(), + c_y: c_y.compress(), + pi: pi.compress(), + c_hat_t: c_hat_t.map(|val| val.compress()), + c_h: c_h.map(|val| val.compress()), + pi_kzg: pi_kzg.map(|val| val.compress()), + } + } + + fn uncompress(compressed: Self::Compressed) -> Result { + let CompressedProof { + c_hat, + c_y, + pi, + c_hat_t, + c_h, + pi_kzg, + } = compressed; + + Ok(Proof { + c_hat: G::G2::uncompress(c_hat)?, + c_y: G::G1::uncompress(c_y)?, + pi: G::G1::uncompress(pi)?, + c_hat_t: c_hat_t.map(G::G2::uncompress).transpose()?, + c_h: c_h.map(G::G1::uncompress).transpose()?, + pi_kzg: pi_kzg.map(G::G1::uncompress).transpose()?, + }) + } +} + #[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] pub struct PublicCommit { a: Vec, @@ -375,7 +604,7 @@ pub fn prove( let mut y = vec![G::Zp::ZERO; n]; G::Zp::hash( &mut y, - &[hash, metadata, x_bytes, c_hat.to_bytes().as_ref()], + &[hash, metadata, x_bytes, c_hat.to_le_bytes().as_ref()], ); let y = OneBased(y); @@ -391,8 +620,8 @@ pub fn prove( hash_lmap, metadata, x_bytes, - c_hat.to_bytes().as_ref(), - c_y.to_bytes().as_ref(), + c_hat.to_le_bytes().as_ref(), + c_y.to_le_bytes().as_ref(), ], ); @@ -423,11 +652,11 @@ pub fn prove( hash_t, metadata, &(1..n + 1) - .flat_map(|i| y[i].to_bytes().as_ref().to_vec()) + .flat_map(|i| y[i].to_le_bytes().as_ref().to_vec()) .collect::>(), x_bytes, - c_hat.to_bytes().as_ref(), - c_y.to_bytes().as_ref(), + c_hat.to_le_bytes().as_ref(), + c_y.to_le_bytes().as_ref(), ], ); let t = OneBased(t); @@ -439,8 +668,8 @@ pub fn prove( hash_agg, metadata, x_bytes, - c_hat.to_bytes().as_ref(), - c_y.to_bytes().as_ref(), + c_hat.to_le_bytes().as_ref(), + c_y.to_le_bytes().as_ref(), ], ); let [delta_eq, delta_y] = delta; @@ -518,20 +747,20 @@ pub fn prove( hash_z, metadata, x_bytes, - c_hat.to_bytes().as_ref(), - c_y.to_bytes().as_ref(), - pi.to_bytes().as_ref(), - c_h.to_bytes().as_ref(), - c_hat_t.to_bytes().as_ref(), + c_hat.to_le_bytes().as_ref(), + c_y.to_le_bytes().as_ref(), + pi.to_le_bytes().as_ref(), + c_h.to_le_bytes().as_ref(), + c_hat_t.to_le_bytes().as_ref(), &y.0.iter() - .flat_map(|x| x.to_bytes().as_ref().to_vec()) + .flat_map(|x| x.to_le_bytes().as_ref().to_vec()) .collect::>(), &t.0.iter() - .flat_map(|x| x.to_bytes().as_ref().to_vec()) + .flat_map(|x| x.to_le_bytes().as_ref().to_vec()) .collect::>(), &delta .iter() - .flat_map(|x| x.to_bytes().as_ref().to_vec()) + .flat_map(|x| x.to_le_bytes().as_ref().to_vec()) .collect::>(), ], ); @@ -559,24 +788,24 @@ pub fn prove( hash_w, metadata, x_bytes, - c_hat.to_bytes().as_ref(), - c_y.to_bytes().as_ref(), - pi.to_bytes().as_ref(), - c_h.to_bytes().as_ref(), - c_hat_t.to_bytes().as_ref(), + c_hat.to_le_bytes().as_ref(), + c_y.to_le_bytes().as_ref(), + pi.to_le_bytes().as_ref(), + c_h.to_le_bytes().as_ref(), + c_hat_t.to_le_bytes().as_ref(), &y.0.iter() - .flat_map(|x| x.to_bytes().as_ref().to_vec()) + .flat_map(|x| x.to_le_bytes().as_ref().to_vec()) .collect::>(), &t.0.iter() - .flat_map(|x| x.to_bytes().as_ref().to_vec()) + .flat_map(|x| x.to_le_bytes().as_ref().to_vec()) .collect::>(), &delta .iter() - .flat_map(|x| x.to_bytes().as_ref().to_vec()) + .flat_map(|x| x.to_le_bytes().as_ref().to_vec()) .collect::>(), - z.to_bytes().as_ref(), - p_h.to_bytes().as_ref(), - p_t.to_bytes().as_ref(), + z.to_le_bytes().as_ref(), + p_h.to_le_bytes().as_ref(), + p_t.to_le_bytes().as_ref(), ], ); @@ -821,7 +1050,7 @@ pub fn verify( let mut y = vec![G::Zp::ZERO; n]; G::Zp::hash( &mut y, - &[hash, metadata, x_bytes, c_hat.to_bytes().as_ref()], + &[hash, metadata, x_bytes, c_hat.to_le_bytes().as_ref()], ); let y = OneBased(y); @@ -832,8 +1061,8 @@ pub fn verify( hash_lmap, metadata, x_bytes, - c_hat.to_bytes().as_ref(), - c_y.to_bytes().as_ref(), + c_hat.to_le_bytes().as_ref(), + c_y.to_le_bytes().as_ref(), ], ); let theta0 = &theta[..d + k]; @@ -869,11 +1098,11 @@ pub fn verify( hash_t, metadata, &(1..n + 1) - .flat_map(|i| y[i].to_bytes().as_ref().to_vec()) + .flat_map(|i| y[i].to_le_bytes().as_ref().to_vec()) .collect::>(), x_bytes, - c_hat.to_bytes().as_ref(), - c_y.to_bytes().as_ref(), + c_hat.to_le_bytes().as_ref(), + c_y.to_le_bytes().as_ref(), ], ); let t = OneBased(t); @@ -885,8 +1114,8 @@ pub fn verify( hash_agg, metadata, x_bytes, - c_hat.to_bytes().as_ref(), - c_y.to_bytes().as_ref(), + c_hat.to_le_bytes().as_ref(), + c_y.to_le_bytes().as_ref(), ], ); let [delta_eq, delta_y] = delta; @@ -900,20 +1129,20 @@ pub fn verify( hash_z, metadata, x_bytes, - c_hat.to_bytes().as_ref(), - c_y.to_bytes().as_ref(), - pi.to_bytes().as_ref(), - c_h.to_bytes().as_ref(), - c_hat_t.to_bytes().as_ref(), + c_hat.to_le_bytes().as_ref(), + c_y.to_le_bytes().as_ref(), + pi.to_le_bytes().as_ref(), + c_h.to_le_bytes().as_ref(), + c_hat_t.to_le_bytes().as_ref(), &y.0.iter() - .flat_map(|x| x.to_bytes().as_ref().to_vec()) + .flat_map(|x| x.to_le_bytes().as_ref().to_vec()) .collect::>(), &t.0.iter() - .flat_map(|x| x.to_bytes().as_ref().to_vec()) + .flat_map(|x| x.to_le_bytes().as_ref().to_vec()) .collect::>(), &delta .iter() - .flat_map(|x| x.to_bytes().as_ref().to_vec()) + .flat_map(|x| x.to_le_bytes().as_ref().to_vec()) .collect::>(), ], ); @@ -953,24 +1182,24 @@ pub fn verify( hash_w, metadata, x_bytes, - c_hat.to_bytes().as_ref(), - c_y.to_bytes().as_ref(), - pi.to_bytes().as_ref(), - c_h.to_bytes().as_ref(), - c_hat_t.to_bytes().as_ref(), + c_hat.to_le_bytes().as_ref(), + c_y.to_le_bytes().as_ref(), + pi.to_le_bytes().as_ref(), + c_h.to_le_bytes().as_ref(), + c_hat_t.to_le_bytes().as_ref(), &y.0.iter() - .flat_map(|x| x.to_bytes().as_ref().to_vec()) + .flat_map(|x| x.to_le_bytes().as_ref().to_vec()) .collect::>(), &t.0.iter() - .flat_map(|x| x.to_bytes().as_ref().to_vec()) + .flat_map(|x| x.to_le_bytes().as_ref().to_vec()) .collect::>(), &delta .iter() - .flat_map(|x| x.to_bytes().as_ref().to_vec()) + .flat_map(|x| x.to_le_bytes().as_ref().to_vec()) .collect::>(), - z.to_bytes().as_ref(), - p_h.to_bytes().as_ref(), - p_t.to_bytes().as_ref(), + z.to_le_bytes().as_ref(), + p_h.to_le_bytes().as_ref(), + p_t.to_le_bytes().as_ref(), ], ); @@ -1032,6 +1261,7 @@ pub fn verify( #[cfg(test)] mod tests { use super::*; + use bincode::ErrorKind; use rand::rngs::StdRng; use rand::{Rng, SeedableRng}; @@ -1168,15 +1398,17 @@ mod tests { type Curve = crate::curve_api::Bls12_446; - let serialize_then_deserialize = - |public_param: &PublicParams, - compress: Compress| - -> Result, SerializationError> { - let mut data = Vec::new(); - public_param.serialize_with_mode(&mut data, compress)?; - - PublicParams::deserialize_with_mode(data.as_slice(), compress, Validate::No) - }; + let serialize_then_deserialize = |public_param: &PublicParams, + compress: bool| + -> bincode::Result> { + match compress { + true => PublicParams::uncompress(bincode::deserialize(&bincode::serialize( + &public_param.clone().compress(), + )?)?) + .map_err(|e| Box::new(ErrorKind::Custom(format!("Failed to uncompress: {}", e)))), + false => bincode::deserialize(&bincode::serialize(&public_param)?), + } + }; // To check management of bigger k_max from CRS during test let crs_k = k + 1 + (rng.gen::() % (d - k)); @@ -1184,9 +1416,9 @@ mod tests { let original_public_param = crs_gen::(d, crs_k, b_i, q, t, msbs_zero_padding_bit_count, rng); let public_param_that_was_compressed = - serialize_then_deserialize(&original_public_param, Compress::No).unwrap(); + serialize_then_deserialize(&original_public_param, true).unwrap(); let public_param_that_was_not_compressed = - serialize_then_deserialize(&original_public_param, Compress::Yes).unwrap(); + serialize_then_deserialize(&original_public_param, false).unwrap(); for ( public_param, @@ -1382,21 +1614,23 @@ mod tests { let div = val.div_euclid(q); let rem = val.rem_euclid(q); let result = div as i64 + (rem > (q / 2)) as i64; - let result = result.rem_euclid(t as i64); + let result = result.rem_euclid(effective_cleartext_t as i64); m_roundtrip[i] = result; } type Curve = crate::curve_api::Bls12_446; - let serialize_then_deserialize = - |public_param: &PublicParams, - compress: Compress| - -> Result, SerializationError> { - let mut data = Vec::new(); - public_param.serialize_with_mode(&mut data, compress)?; - - PublicParams::deserialize_with_mode(data.as_slice(), compress, Validate::No) - }; + let serialize_then_deserialize = |public_param: &PublicParams, + compress: bool| + -> bincode::Result> { + match compress { + true => PublicParams::uncompress(bincode::deserialize(&bincode::serialize( + &public_param.clone().compress(), + )?)?) + .map_err(|e| Box::new(ErrorKind::Custom(format!("Failed to uncompress: {}", e)))), + false => bincode::deserialize(&bincode::serialize(&public_param)?), + } + }; // To check management of bigger k_max from CRS during test let crs_k = k + 1 + (rng.gen::() % (d - k)); @@ -1404,9 +1638,9 @@ mod tests { let original_public_param = crs_gen::(d, crs_k, b_i, q, t, msbs_zero_padding_bit_count, rng); let public_param_that_was_compressed = - serialize_then_deserialize(&original_public_param, Compress::No).unwrap(); + serialize_then_deserialize(&original_public_param, true).unwrap(); let public_param_that_was_not_compressed = - serialize_then_deserialize(&original_public_param, Compress::Yes).unwrap(); + serialize_then_deserialize(&original_public_param, false).unwrap(); for public_param in [ original_public_param, @@ -1439,4 +1673,156 @@ mod tests { } } } + + #[test] + fn test_proof_compression() { + let d = 2048; + let k = 320; + let big_b = 1048576; + let q = 0; + let t = 1024; + let msbs_zero_padding_bit_count = 1; + let effective_cleartext_t = t >> msbs_zero_padding_bit_count; + + let delta = { + let q = if q == 0 { 1i128 << 64 } else { q as i128 }; + (q / t as i128) as u64 + }; + + let rng = &mut StdRng::seed_from_u64(0); + + let polymul_rev = |a: &[i64], b: &[i64]| -> Vec { + assert_eq!(a.len(), b.len()); + let d = a.len(); + let mut c = vec![0i64; d]; + + for i in 0..d { + for j in 0..d { + if i + j < d { + c[i + j] = c[i + j].wrapping_add(a[i].wrapping_mul(b[d - j - 1])); + } else { + c[i + j - d] = c[i + j - d].wrapping_sub(a[i].wrapping_mul(b[d - j - 1])); + } + } + } + + c + }; + + let a = (0..d).map(|_| rng.gen::()).collect::>(); + let s = (0..d) + .map(|_| (rng.gen::() % 2) as i64) + .collect::>(); + let e = (0..d) + .map(|_| (rng.gen::() % (2 * big_b)) as i64 - big_b as i64) + .collect::>(); + let e1 = (0..d) + .map(|_| (rng.gen::() % (2 * big_b)) as i64 - big_b as i64) + .collect::>(); + let e2 = (0..k) + .map(|_| (rng.gen::() % (2 * big_b)) as i64 - big_b as i64) + .collect::>(); + + let r = (0..d) + .map(|_| (rng.gen::() % 2) as i64) + .collect::>(); + + let m = (0..k) + .map(|_| (rng.gen::() % effective_cleartext_t) as i64) + .collect::>(); + + let b = polymul_rev(&a, &s) + .into_iter() + .zip(e.iter()) + .map(|(x, e)| x.wrapping_add(*e)) + .collect::>(); + let c1 = polymul_rev(&a, &r) + .into_iter() + .zip(e1.iter()) + .map(|(x, e1)| x.wrapping_add(*e1)) + .collect::>(); + + let mut c2 = vec![0i64; k]; + + for i in 0..k { + let mut dot = 0i64; + for j in 0..d { + let b = if i + j < d { + b[d - j - i - 1] + } else { + b[2 * d - j - i - 1].wrapping_neg() + }; + + dot = dot.wrapping_add(r[d - j - 1].wrapping_mul(b)); + } + + c2[i] = dot + .wrapping_add(e2[i]) + .wrapping_add((delta * m[i] as u64) as i64); + } + + // One of our usecases uses 320 bits of additional metadata + const METADATA_LEN: usize = (320 / u8::BITS) as usize; + + let mut metadata = [0u8; METADATA_LEN]; + metadata.fill_with(|| rng.gen::()); + + let mut m_roundtrip = vec![0i64; k]; + for i in 0..k { + let mut dot = 0i128; + for j in 0..d { + let c = if i + j < d { + c1[d - j - i - 1] + } else { + c1[2 * d - j - i - 1].wrapping_neg() + }; + + dot += s[d - j - 1] as i128 * c as i128; + } + + let q = if q == 0 { 1i128 << 64 } else { q as i128 }; + let val = ((c2[i] as i128).wrapping_sub(dot)) * t as i128; + let div = val.div_euclid(q); + let rem = val.rem_euclid(q); + let result = div as i64 + (rem > (q / 2)) as i64; + let result = result.rem_euclid(effective_cleartext_t as i64); + m_roundtrip[i] = result; + } + + type Curve = crate::curve_api::Bls12_446; + + let crs_k = k + 1 + (rng.gen::() % (d - k)); + + let public_param = + crs_gen::(d, crs_k, big_b, q, t, msbs_zero_padding_bit_count, rng); + + let (public_commit, private_commit) = commit( + a.clone(), + b.clone(), + c1.clone(), + c2.clone(), + r.clone(), + e1.clone(), + m.clone(), + e2.clone(), + &public_param, + rng, + ); + + for load in [ComputeLoad::Proof, ComputeLoad::Verify] { + let proof = prove( + (&public_param, &public_commit), + &private_commit, + &metadata, + load, + rng, + ); + + let compressed_proof = bincode::serialize(&proof.clone().compress()).unwrap(); + let proof = + Proof::uncompress(bincode::deserialize(&compressed_proof).unwrap()).unwrap(); + + verify(&proof, (&public_param, &public_commit), &metadata).unwrap() + } + } } diff --git a/tfhe-zk-pok/src/proofs/pke_v2.rs b/tfhe-zk-pok/src/proofs/pke_v2.rs index cc0b5544aa..b9312a0a9f 100644 --- a/tfhe-zk-pok/src/proofs/pke_v2.rs +++ b/tfhe-zk-pok/src/proofs/pke_v2.rs @@ -2,19 +2,37 @@ #![allow(non_snake_case)] use super::*; +use crate::backward_compatibility::{ + PKEv2CompressedProofVersions, PKEv2ProofVersions, SerializablePKEv2PublicParamsVersions, +}; use crate::four_squares::*; +use crate::serialization::{ + try_vec_to_array, InvalidSerializedAffineError, InvalidSerializedPublicParamsError, + SerializableGroupElements, SerializablePKEv2PublicParams, +}; use core::marker::PhantomData; use rayon::prelude::*; +use serde::{Deserialize, Serialize}; +use std::error::Error; +use tfhe_versionable::{UnversionizeError, VersionsDispatch}; fn bit_iter(x: u64, nbits: u32) -> impl Iterator { (0..nbits).map(move |idx| ((x >> idx) & 1) != 0) } /// The CRS of the zk scheme -#[derive(Clone, Debug, CanonicalSerialize, CanonicalDeserialize)] +#[derive(Clone, Debug, Serialize, Deserialize)] +#[serde( + try_from = "SerializablePKEv2PublicParams", + into = "SerializablePKEv2PublicParams", + bound( + deserialize = "PublicParams: TryFrom", + serialize = "PublicParams: Into" + ) +)] pub struct PublicParams { - g_lists: GroupElements, - D: usize, + pub(crate) g_lists: GroupElements, + pub(crate) D: usize, pub n: usize, pub d: usize, pub k: usize, @@ -25,16 +43,176 @@ pub struct PublicParams { pub q: u64, pub t: u64, pub msbs_zero_padding_bit_count: u64, - hash: [u8; HASH_METADATA_LEN_BYTES], - hash_R: [u8; HASH_METADATA_LEN_BYTES], - hash_t: [u8; HASH_METADATA_LEN_BYTES], - hash_w: [u8; HASH_METADATA_LEN_BYTES], - hash_agg: [u8; HASH_METADATA_LEN_BYTES], - hash_lmap: [u8; HASH_METADATA_LEN_BYTES], - hash_phi: [u8; HASH_METADATA_LEN_BYTES], - hash_xi: [u8; HASH_METADATA_LEN_BYTES], - hash_z: [u8; HASH_METADATA_LEN_BYTES], - hash_chi: [u8; HASH_METADATA_LEN_BYTES], + pub(crate) hash: [u8; HASH_METADATA_LEN_BYTES], + pub(crate) hash_R: [u8; HASH_METADATA_LEN_BYTES], + pub(crate) hash_t: [u8; HASH_METADATA_LEN_BYTES], + pub(crate) hash_w: [u8; HASH_METADATA_LEN_BYTES], + pub(crate) hash_agg: [u8; HASH_METADATA_LEN_BYTES], + pub(crate) hash_lmap: [u8; HASH_METADATA_LEN_BYTES], + pub(crate) hash_phi: [u8; HASH_METADATA_LEN_BYTES], + pub(crate) hash_xi: [u8; HASH_METADATA_LEN_BYTES], + pub(crate) hash_z: [u8; HASH_METADATA_LEN_BYTES], + pub(crate) hash_chi: [u8; HASH_METADATA_LEN_BYTES], +} + +// Manual impl of Versionize because TryFrom + generics is currently badly handled by the proc macro +impl Versionize for PublicParams +where + Self: Clone, + SerializablePKEv2PublicParamsVersions: VersionsDispatch, + GroupElements: Into, +{ + type Versioned<'vers> = + >::Owned where G:'vers; + fn versionize(&self) -> Self::Versioned<'_> { + VersionizeOwned::versionize_owned(SerializablePKEv2PublicParams::from(self.to_owned())) + } +} + +impl VersionizeOwned for PublicParams +where + Self: Clone, + SerializablePKEv2PublicParamsVersions: VersionsDispatch, + GroupElements: Into, +{ + type VersionedOwned = >::Owned; + fn versionize_owned(self) -> Self::VersionedOwned { + VersionizeOwned::versionize_owned(SerializablePKEv2PublicParams::from(self.to_owned())) + } +} + +impl Unversionize for PublicParams +where + Self: Clone, + SerializablePKEv2PublicParamsVersions: VersionsDispatch, + GroupElements: Into, + Self: TryFrom, + E: Error + Send + Sync + 'static, +{ + fn unversionize(versioned: Self::VersionedOwned) -> Result { + SerializablePKEv2PublicParams::unversionize(versioned).and_then(|value| { + TryInto::::try_into(value) + .map_err(|e| UnversionizeError::conversion("SerializablePublicParams", e)) + }) + } +} + +impl Compressible for PublicParams +where + GroupElements: Compressible< + Compressed = SerializableGroupElements, + UncompressError = InvalidSerializedGroupElementsError, + >, +{ + type Compressed = SerializablePKEv2PublicParams; + + type UncompressError = InvalidSerializedPublicParamsError; + + fn compress(&self) -> Self::Compressed { + let PublicParams { + g_lists, + D, + n, + d, + k, + B, + B_r, + B_bound, + m_bound, + q, + t, + msbs_zero_padding_bit_count, + hash, + hash_R, + hash_t, + hash_w, + hash_agg, + hash_lmap, + hash_phi, + hash_xi, + hash_z, + hash_chi, + } = self; + SerializablePKEv2PublicParams { + g_lists: g_lists.compress(), + D: *D, + n: *n, + d: *d, + k: *k, + B: *B, + B_r: *B_r, + B_bound: *B_bound, + m_bound: *m_bound, + q: *q, + t: *t, + msbs_zero_padding_bit_count: *msbs_zero_padding_bit_count, + hash: hash.to_vec(), + hash_R: hash_R.to_vec(), + hash_t: hash_t.to_vec(), + hash_w: hash_w.to_vec(), + hash_agg: hash_agg.to_vec(), + hash_lmap: hash_lmap.to_vec(), + hash_phi: hash_phi.to_vec(), + hash_xi: hash_xi.to_vec(), + hash_z: hash_z.to_vec(), + hash_chi: hash_chi.to_vec(), + } + } + + fn uncompress(compressed: Self::Compressed) -> Result { + let SerializablePKEv2PublicParams { + g_lists, + D, + n, + d, + k, + B, + B_r, + B_bound, + m_bound, + q, + t, + msbs_zero_padding_bit_count, + hash, + hash_R, + hash_t, + hash_w, + hash_agg, + hash_lmap, + hash_phi, + hash_xi, + hash_z, + hash_chi, + } = compressed; + Ok(Self { + g_lists: GroupElements::uncompress(g_lists)?, + D, + n, + d, + k, + B, + B_r, + B_bound, + m_bound, + q, + t, + msbs_zero_padding_bit_count, + hash: try_vec_to_array(hash)?, + hash_R: try_vec_to_array(hash_R)?, + hash_t: try_vec_to_array(hash_t)?, + hash_w: try_vec_to_array(hash_w)?, + hash_agg: try_vec_to_array(hash_agg)?, + hash_lmap: try_vec_to_array(hash_lmap)?, + hash_phi: try_vec_to_array(hash_phi)?, + hash_xi: try_vec_to_array(hash_xi)?, + hash_z: try_vec_to_array(hash_z)?, + hash_chi: try_vec_to_array(hash_chi)?, + }) + } } impl PublicParams { @@ -95,11 +273,12 @@ impl PublicParams { /// This represents a proof that the given ciphertext is a valid encryptions of the input messages /// with the provided public key. -#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] +#[derive(Clone, Debug, serde::Serialize, serde::Deserialize, Versionize)] #[serde(bound( deserialize = "G: Curve, G::G1: serde::Deserialize<'de>, G::G2: serde::Deserialize<'de>", serialize = "G: Curve, G::G1: serde::Serialize, G::G2: serde::Serialize" ))] +#[versionize(PKEv2ProofVersions)] pub struct Proof { C_hat_e: G::G2, C_e: G::G1, @@ -117,6 +296,114 @@ pub struct Proof { C_hat_w: Option, } +type CompressedG2 = <::G2 as Compressible>::Compressed; +type CompressedG1 = <::G1 as Compressible>::Compressed; + +#[derive(Serialize, Deserialize, Versionize)] +#[serde(bound( + deserialize = "G: Curve, CompressedG1: serde::Deserialize<'de>, CompressedG2: serde::Deserialize<'de>", + serialize = "G: Curve, CompressedG1: serde::Serialize, CompressedG2: serde::Serialize" +))] +#[versionize(PKEv2CompressedProofVersions)] +pub struct CompressedProof +where + G::G1: Compressible, + G::G2: Compressible, +{ + C_hat_e: CompressedG2, + C_e: CompressedG1, + C_r_tilde: CompressedG1, + C_R: CompressedG1, + C_hat_bin: CompressedG2, + C_y: CompressedG1, + C_h1: CompressedG1, + C_h2: CompressedG1, + C_hat_t: CompressedG2, + pi: CompressedG1, + pi_kzg: CompressedG1, + + C_hat_h3: Option>, + C_hat_w: Option>, +} + +impl Compressible for Proof +where + G::G1: Compressible, + G::G2: Compressible, +{ + type Compressed = CompressedProof; + + type UncompressError = InvalidSerializedAffineError; + + fn compress(&self) -> Self::Compressed { + let Proof { + C_hat_e, + C_e, + C_r_tilde, + C_R, + C_hat_bin, + C_y, + C_h1, + C_h2, + C_hat_t, + pi, + pi_kzg, + C_hat_h3, + C_hat_w, + } = self; + + CompressedProof { + C_hat_e: C_hat_e.compress(), + C_e: C_e.compress(), + C_r_tilde: C_r_tilde.compress(), + C_R: C_R.compress(), + C_hat_bin: C_hat_bin.compress(), + C_y: C_y.compress(), + C_h1: C_h1.compress(), + C_h2: C_h2.compress(), + C_hat_t: C_hat_t.compress(), + pi: pi.compress(), + pi_kzg: pi_kzg.compress(), + C_hat_h3: C_hat_h3.map(|val| val.compress()), + C_hat_w: C_hat_w.map(|val| val.compress()), + } + } + + fn uncompress(compressed: Self::Compressed) -> Result { + let CompressedProof { + C_hat_e, + C_e, + C_r_tilde, + C_R, + C_hat_bin, + C_y, + C_h1, + C_h2, + C_hat_t, + pi, + pi_kzg, + C_hat_h3, + C_hat_w, + } = compressed; + + Ok(Proof { + C_hat_e: G::G2::uncompress(C_hat_e)?, + C_e: G::G1::uncompress(C_e)?, + C_r_tilde: G::G1::uncompress(C_r_tilde)?, + C_R: G::G1::uncompress(C_R)?, + C_hat_bin: G::G2::uncompress(C_hat_bin)?, + C_y: G::G1::uncompress(C_y)?, + C_h1: G::G1::uncompress(C_h1)?, + C_h2: G::G1::uncompress(C_h2)?, + C_hat_t: G::G2::uncompress(C_hat_t)?, + pi: G::G1::uncompress(pi)?, + pi_kzg: G::G1::uncompress(pi_kzg)?, + C_hat_h3: C_hat_h3.map(G::G2::uncompress).transpose()?, + C_hat_w: C_hat_w.map(G::G2::uncompress).transpose()?, + }) + } +} + /// This is the public part of the commitment. `a` and `b` are the mask and body of the public key, /// `c1` and `c2` are the mask and body of the ciphertext. #[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] @@ -561,9 +848,9 @@ pub fn prove( hash_R, metadata, x_bytes, - C_hat_e.to_bytes().as_ref(), - C_e.to_bytes().as_ref(), - C_r_tilde.to_bytes().as_ref(), + C_hat_e.to_le_bytes().as_ref(), + C_e.to_le_bytes().as_ref(), + C_r_tilde.to_le_bytes().as_ref(), ] { hasher.update(data); } @@ -631,15 +918,15 @@ pub fn prove( metadata, x_bytes, R_bytes, - C_hat_e.to_bytes().as_ref(), - C_e.to_bytes().as_ref(), - C_R.to_bytes().as_ref(), - C_r_tilde.to_bytes().as_ref(), + C_hat_e.to_le_bytes().as_ref(), + C_e.to_le_bytes().as_ref(), + C_R.to_le_bytes().as_ref(), + C_r_tilde.to_le_bytes().as_ref(), ], ); let phi_bytes = &*phi .iter() - .flat_map(|x| x.to_bytes().as_ref().to_vec()) + .flat_map(|x| x.to_le_bytes().as_ref().to_vec()) .collect::>(); let m = m_bound; @@ -670,19 +957,19 @@ pub fn prove( hash_xi, metadata, x_bytes, - C_hat_e.to_bytes().as_ref(), - C_e.to_bytes().as_ref(), + C_hat_e.to_le_bytes().as_ref(), + C_e.to_le_bytes().as_ref(), R_bytes, phi_bytes, - C_R.to_bytes().as_ref(), - C_hat_bin.to_bytes().as_ref(), - C_r_tilde.to_bytes().as_ref(), + C_R.to_le_bytes().as_ref(), + C_hat_bin.to_le_bytes().as_ref(), + C_r_tilde.to_le_bytes().as_ref(), ], ); let xi_bytes = &*xi .iter() - .flat_map(|x| x.to_bytes().as_ref().to_vec()) + .flat_map(|x| x.to_le_bytes().as_ref().to_vec()) .collect::>(); let mut y = vec![G::Zp::ZERO; D + 128 * m]; @@ -695,16 +982,16 @@ pub fn prove( R_bytes, phi_bytes, xi_bytes, - C_hat_e.to_bytes().as_ref(), - C_e.to_bytes().as_ref(), - C_R.to_bytes().as_ref(), - C_hat_bin.to_bytes().as_ref(), - C_r_tilde.to_bytes().as_ref(), + C_hat_e.to_le_bytes().as_ref(), + C_e.to_le_bytes().as_ref(), + C_R.to_le_bytes().as_ref(), + C_hat_bin.to_le_bytes().as_ref(), + C_r_tilde.to_le_bytes().as_ref(), ], ); let y_bytes = &*y .iter() - .flat_map(|x| x.to_bytes().as_ref().to_vec()) + .flat_map(|x| x.to_le_bytes().as_ref().to_vec()) .collect::>(); assert_eq!(y.len(), w_bin.len()); @@ -727,18 +1014,18 @@ pub fn prove( y_bytes, phi_bytes, xi_bytes, - C_hat_e.to_bytes().as_ref(), - C_e.to_bytes().as_ref(), + C_hat_e.to_le_bytes().as_ref(), + C_e.to_le_bytes().as_ref(), R_bytes, - C_R.to_bytes().as_ref(), - C_hat_bin.to_bytes().as_ref(), - C_r_tilde.to_bytes().as_ref(), - C_y.to_bytes().as_ref(), + C_R.to_le_bytes().as_ref(), + C_hat_bin.to_le_bytes().as_ref(), + C_r_tilde.to_le_bytes().as_ref(), + C_y.to_le_bytes().as_ref(), ], ); let t_bytes = &*t .iter() - .flat_map(|x| x.to_bytes().as_ref().to_vec()) + .flat_map(|x| x.to_le_bytes().as_ref().to_vec()) .collect::>(); let mut theta = vec![G::Zp::ZERO; d + k]; @@ -752,18 +1039,18 @@ pub fn prove( t_bytes, phi_bytes, xi_bytes, - C_hat_e.to_bytes().as_ref(), - C_e.to_bytes().as_ref(), + C_hat_e.to_le_bytes().as_ref(), + C_e.to_le_bytes().as_ref(), R_bytes, - C_R.to_bytes().as_ref(), - C_hat_bin.to_bytes().as_ref(), - C_r_tilde.to_bytes().as_ref(), - C_y.to_bytes().as_ref(), + C_R.to_le_bytes().as_ref(), + C_hat_bin.to_le_bytes().as_ref(), + C_r_tilde.to_le_bytes().as_ref(), + C_y.to_le_bytes().as_ref(), ], ); let theta_bytes = &*theta .iter() - .flat_map(|x| x.to_bytes().as_ref().to_vec()) + .flat_map(|x| x.to_le_bytes().as_ref().to_vec()) .collect::>(); let mut a_theta = vec![G::Zp::ZERO; D]; @@ -788,18 +1075,18 @@ pub fn prove( phi_bytes, xi_bytes, theta_bytes, - C_hat_e.to_bytes().as_ref(), - C_e.to_bytes().as_ref(), + C_hat_e.to_le_bytes().as_ref(), + C_e.to_le_bytes().as_ref(), R_bytes, - C_R.to_bytes().as_ref(), - C_hat_bin.to_bytes().as_ref(), - C_r_tilde.to_bytes().as_ref(), - C_y.to_bytes().as_ref(), + C_R.to_le_bytes().as_ref(), + C_hat_bin.to_le_bytes().as_ref(), + C_r_tilde.to_le_bytes().as_ref(), + C_y.to_le_bytes().as_ref(), ], ); let w_bytes = &*w .iter() - .flat_map(|x| x.to_bytes().as_ref().to_vec()) + .flat_map(|x| x.to_le_bytes().as_ref().to_vec()) .collect::>(); let mut delta = [G::Zp::ZERO; 7]; @@ -815,19 +1102,19 @@ pub fn prove( xi_bytes, theta_bytes, w_bytes, - C_hat_e.to_bytes().as_ref(), - C_e.to_bytes().as_ref(), + C_hat_e.to_le_bytes().as_ref(), + C_e.to_le_bytes().as_ref(), R_bytes, - C_R.to_bytes().as_ref(), - C_hat_bin.to_bytes().as_ref(), - C_r_tilde.to_bytes().as_ref(), - C_y.to_bytes().as_ref(), + C_R.to_le_bytes().as_ref(), + C_hat_bin.to_le_bytes().as_ref(), + C_r_tilde.to_le_bytes().as_ref(), + C_y.to_le_bytes().as_ref(), ], ); let [delta_r, delta_dec, delta_eq, delta_y, delta_theta, delta_e, delta_l] = delta; let delta_bytes = &*delta .iter() - .flat_map(|x| x.to_bytes().as_ref().to_vec()) + .flat_map(|x| x.to_le_bytes().as_ref().to_vec()) .collect::>(); let mut poly_0_lhs = vec![G::Zp::ZERO; 1 + n]; @@ -1170,8 +1457,8 @@ pub fn prove( ComputeLoad::Verify => (None, None), }; - let C_hat_h3_bytes = C_hat_h3.map(G::G2::to_bytes); - let C_hat_w_bytes = C_hat_w.map(G::G2::to_bytes); + let C_hat_h3_bytes = C_hat_h3.map(G::G2::to_le_bytes); + let C_hat_w_bytes = C_hat_w.map(G::G2::to_le_bytes); let C_hat_h3_bytes = C_hat_h3_bytes.as_ref().map(|x| x.as_ref()).unwrap_or(&[]); let C_hat_w_bytes = C_hat_w_bytes.as_ref().map(|x| x.as_ref()).unwrap_or(&[]); @@ -1190,16 +1477,16 @@ pub fn prove( x_bytes, theta_bytes, delta_bytes, - C_hat_e.to_bytes().as_ref(), - C_e.to_bytes().as_ref(), + C_hat_e.to_le_bytes().as_ref(), + C_e.to_le_bytes().as_ref(), R_bytes, - C_R.to_bytes().as_ref(), - C_hat_bin.to_bytes().as_ref(), - C_r_tilde.to_bytes().as_ref(), - C_y.to_bytes().as_ref(), - C_h1.to_bytes().as_ref(), - C_h2.to_bytes().as_ref(), - C_hat_t.to_bytes().as_ref(), + C_R.to_le_bytes().as_ref(), + C_hat_bin.to_le_bytes().as_ref(), + C_r_tilde.to_le_bytes().as_ref(), + C_y.to_le_bytes().as_ref(), + C_h1.to_le_bytes().as_ref(), + C_h2.to_le_bytes().as_ref(), + C_hat_t.to_le_bytes().as_ref(), C_hat_h3_bytes, C_hat_w_bytes, ], @@ -1323,22 +1610,22 @@ pub fn prove( xi_bytes, theta_bytes, delta_bytes, - C_hat_e.to_bytes().as_ref(), - C_e.to_bytes().as_ref(), + C_hat_e.to_le_bytes().as_ref(), + C_e.to_le_bytes().as_ref(), R_bytes, - C_R.to_bytes().as_ref(), - C_hat_bin.to_bytes().as_ref(), - C_r_tilde.to_bytes().as_ref(), - C_y.to_bytes().as_ref(), - C_h1.to_bytes().as_ref(), - C_h2.to_bytes().as_ref(), - C_hat_t.to_bytes().as_ref(), + C_R.to_le_bytes().as_ref(), + C_hat_bin.to_le_bytes().as_ref(), + C_r_tilde.to_le_bytes().as_ref(), + C_y.to_le_bytes().as_ref(), + C_h1.to_le_bytes().as_ref(), + C_h2.to_le_bytes().as_ref(), + C_hat_t.to_le_bytes().as_ref(), C_hat_h3_bytes, C_hat_w_bytes, - z.to_bytes().as_ref(), - p_h1.to_bytes().as_ref(), - p_h2.to_bytes().as_ref(), - p_t.to_bytes().as_ref(), + z.to_le_bytes().as_ref(), + p_h1.to_le_bytes().as_ref(), + p_h2.to_le_bytes().as_ref(), + p_t.to_le_bytes().as_ref(), ], ); @@ -1558,8 +1845,8 @@ pub fn verify( return Err(()); } - let C_hat_h3_bytes = C_hat_h3.map(G::G2::to_bytes); - let C_hat_w_bytes = C_hat_w.map(G::G2::to_bytes); + let C_hat_h3_bytes = C_hat_h3.map(G::G2::to_le_bytes); + let C_hat_w_bytes = C_hat_w.map(G::G2::to_le_bytes); let C_hat_h3_bytes = C_hat_h3_bytes.as_ref().map(|x| x.as_ref()).unwrap_or(&[]); let C_hat_w_bytes = C_hat_w_bytes.as_ref().map(|x| x.as_ref()).unwrap_or(&[]); @@ -1596,9 +1883,9 @@ pub fn verify( hash_R, metadata, x_bytes, - C_hat_e.to_bytes().as_ref(), - C_e.to_bytes().as_ref(), - C_r_tilde.to_bytes().as_ref(), + C_hat_e.to_le_bytes().as_ref(), + C_e.to_le_bytes().as_ref(), + C_r_tilde.to_le_bytes().as_ref(), ] { hasher.update(data); } @@ -1634,15 +1921,15 @@ pub fn verify( metadata, x_bytes, R_bytes, - C_hat_e.to_bytes().as_ref(), - C_e.to_bytes().as_ref(), - C_R.to_bytes().as_ref(), - C_r_tilde.to_bytes().as_ref(), + C_hat_e.to_le_bytes().as_ref(), + C_e.to_le_bytes().as_ref(), + C_R.to_le_bytes().as_ref(), + C_r_tilde.to_le_bytes().as_ref(), ], ); let phi_bytes = &*phi .iter() - .flat_map(|x| x.to_bytes().as_ref().to_vec()) + .flat_map(|x| x.to_le_bytes().as_ref().to_vec()) .collect::>(); let mut xi = vec![G::Zp::ZERO; 128]; @@ -1652,18 +1939,18 @@ pub fn verify( hash_xi, metadata, x_bytes, - C_hat_e.to_bytes().as_ref(), - C_e.to_bytes().as_ref(), + C_hat_e.to_le_bytes().as_ref(), + C_e.to_le_bytes().as_ref(), R_bytes, phi_bytes, - C_R.to_bytes().as_ref(), - C_hat_bin.to_bytes().as_ref(), - C_r_tilde.to_bytes().as_ref(), + C_R.to_le_bytes().as_ref(), + C_hat_bin.to_le_bytes().as_ref(), + C_r_tilde.to_le_bytes().as_ref(), ], ); let xi_bytes = &*xi .iter() - .flat_map(|x| x.to_bytes().as_ref().to_vec()) + .flat_map(|x| x.to_le_bytes().as_ref().to_vec()) .collect::>(); let mut y = vec![G::Zp::ZERO; D + 128 * m]; @@ -1676,16 +1963,16 @@ pub fn verify( R_bytes, phi_bytes, xi_bytes, - C_hat_e.to_bytes().as_ref(), - C_e.to_bytes().as_ref(), - C_R.to_bytes().as_ref(), - C_hat_bin.to_bytes().as_ref(), - C_r_tilde.to_bytes().as_ref(), + C_hat_e.to_le_bytes().as_ref(), + C_e.to_le_bytes().as_ref(), + C_R.to_le_bytes().as_ref(), + C_hat_bin.to_le_bytes().as_ref(), + C_r_tilde.to_le_bytes().as_ref(), ], ); let y_bytes = &*y .iter() - .flat_map(|x| x.to_bytes().as_ref().to_vec()) + .flat_map(|x| x.to_le_bytes().as_ref().to_vec()) .collect::>(); let mut t = vec![G::Zp::ZERO; n]; @@ -1698,18 +1985,18 @@ pub fn verify( y_bytes, phi_bytes, xi_bytes, - C_hat_e.to_bytes().as_ref(), - C_e.to_bytes().as_ref(), + C_hat_e.to_le_bytes().as_ref(), + C_e.to_le_bytes().as_ref(), R_bytes, - C_R.to_bytes().as_ref(), - C_hat_bin.to_bytes().as_ref(), - C_r_tilde.to_bytes().as_ref(), - C_y.to_bytes().as_ref(), + C_R.to_le_bytes().as_ref(), + C_hat_bin.to_le_bytes().as_ref(), + C_r_tilde.to_le_bytes().as_ref(), + C_y.to_le_bytes().as_ref(), ], ); let t_bytes = &*t .iter() - .flat_map(|x| x.to_bytes().as_ref().to_vec()) + .flat_map(|x| x.to_le_bytes().as_ref().to_vec()) .collect::>(); let mut theta = vec![G::Zp::ZERO; d + k]; @@ -1723,18 +2010,18 @@ pub fn verify( t_bytes, phi_bytes, xi_bytes, - C_hat_e.to_bytes().as_ref(), - C_e.to_bytes().as_ref(), + C_hat_e.to_le_bytes().as_ref(), + C_e.to_le_bytes().as_ref(), R_bytes, - C_R.to_bytes().as_ref(), - C_hat_bin.to_bytes().as_ref(), - C_r_tilde.to_bytes().as_ref(), - C_y.to_bytes().as_ref(), + C_R.to_le_bytes().as_ref(), + C_hat_bin.to_le_bytes().as_ref(), + C_r_tilde.to_le_bytes().as_ref(), + C_y.to_le_bytes().as_ref(), ], ); let theta_bytes = &*theta .iter() - .flat_map(|x| x.to_bytes().as_ref().to_vec()) + .flat_map(|x| x.to_le_bytes().as_ref().to_vec()) .collect::>(); let mut w = vec![G::Zp::ZERO; n]; @@ -1749,18 +2036,18 @@ pub fn verify( phi_bytes, xi_bytes, theta_bytes, - C_hat_e.to_bytes().as_ref(), - C_e.to_bytes().as_ref(), + C_hat_e.to_le_bytes().as_ref(), + C_e.to_le_bytes().as_ref(), R_bytes, - C_R.to_bytes().as_ref(), - C_hat_bin.to_bytes().as_ref(), - C_r_tilde.to_bytes().as_ref(), - C_y.to_bytes().as_ref(), + C_R.to_le_bytes().as_ref(), + C_hat_bin.to_le_bytes().as_ref(), + C_r_tilde.to_le_bytes().as_ref(), + C_y.to_le_bytes().as_ref(), ], ); let w_bytes = &*w .iter() - .flat_map(|x| x.to_bytes().as_ref().to_vec()) + .flat_map(|x| x.to_le_bytes().as_ref().to_vec()) .collect::>(); let mut a_theta = vec![G::Zp::ZERO; D]; @@ -1786,19 +2073,19 @@ pub fn verify( xi_bytes, theta_bytes, w_bytes, - C_hat_e.to_bytes().as_ref(), - C_e.to_bytes().as_ref(), + C_hat_e.to_le_bytes().as_ref(), + C_e.to_le_bytes().as_ref(), R_bytes, - C_R.to_bytes().as_ref(), - C_hat_bin.to_bytes().as_ref(), - C_r_tilde.to_bytes().as_ref(), - C_y.to_bytes().as_ref(), + C_R.to_le_bytes().as_ref(), + C_hat_bin.to_le_bytes().as_ref(), + C_r_tilde.to_le_bytes().as_ref(), + C_y.to_le_bytes().as_ref(), ], ); let [delta_r, delta_dec, delta_eq, delta_y, delta_theta, delta_e, delta_l] = delta; let delta_bytes = &*delta .iter() - .flat_map(|x| x.to_bytes().as_ref().to_vec()) + .flat_map(|x| x.to_le_bytes().as_ref().to_vec()) .collect::>(); let g = G::G1::GENERATOR; @@ -1880,16 +2167,16 @@ pub fn verify( x_bytes, theta_bytes, delta_bytes, - C_hat_e.to_bytes().as_ref(), - C_e.to_bytes().as_ref(), + C_hat_e.to_le_bytes().as_ref(), + C_e.to_le_bytes().as_ref(), R_bytes, - C_R.to_bytes().as_ref(), - C_hat_bin.to_bytes().as_ref(), - C_r_tilde.to_bytes().as_ref(), - C_y.to_bytes().as_ref(), - C_h1.to_bytes().as_ref(), - C_h2.to_bytes().as_ref(), - C_hat_t.to_bytes().as_ref(), + C_R.to_le_bytes().as_ref(), + C_hat_bin.to_le_bytes().as_ref(), + C_r_tilde.to_le_bytes().as_ref(), + C_y.to_le_bytes().as_ref(), + C_h1.to_le_bytes().as_ref(), + C_h2.to_le_bytes().as_ref(), + C_hat_t.to_le_bytes().as_ref(), C_hat_h3_bytes, C_hat_w_bytes, ], @@ -2019,22 +2306,22 @@ pub fn verify( xi_bytes, theta_bytes, delta_bytes, - C_hat_e.to_bytes().as_ref(), - C_e.to_bytes().as_ref(), + C_hat_e.to_le_bytes().as_ref(), + C_e.to_le_bytes().as_ref(), R_bytes, - C_R.to_bytes().as_ref(), - C_hat_bin.to_bytes().as_ref(), - C_r_tilde.to_bytes().as_ref(), - C_y.to_bytes().as_ref(), - C_h1.to_bytes().as_ref(), - C_h2.to_bytes().as_ref(), - C_hat_t.to_bytes().as_ref(), + C_R.to_le_bytes().as_ref(), + C_hat_bin.to_le_bytes().as_ref(), + C_r_tilde.to_le_bytes().as_ref(), + C_y.to_le_bytes().as_ref(), + C_h1.to_le_bytes().as_ref(), + C_h2.to_le_bytes().as_ref(), + C_hat_t.to_le_bytes().as_ref(), C_hat_h3_bytes, C_hat_w_bytes, - z.to_bytes().as_ref(), - p_h1.to_bytes().as_ref(), - p_h2.to_bytes().as_ref(), - p_t.to_bytes().as_ref(), + z.to_le_bytes().as_ref(), + p_h1.to_le_bytes().as_ref(), + p_h2.to_le_bytes().as_ref(), + p_t.to_le_bytes().as_ref(), ], ); let chi2 = chi * chi; @@ -2071,6 +2358,7 @@ pub fn verify( #[cfg(test)] mod tests { use super::*; + use bincode::ErrorKind; use rand::rngs::StdRng; use rand::{Rng, SeedableRng}; @@ -2207,15 +2495,17 @@ mod tests { type Curve = crate::curve_api::Bls12_446; - let serialize_then_deserialize = - |public_param: &PublicParams, - compress: Compress| - -> Result, SerializationError> { - let mut data = Vec::new(); - public_param.serialize_with_mode(&mut data, compress)?; - - PublicParams::deserialize_with_mode(data.as_slice(), compress, Validate::No) - }; + let serialize_then_deserialize = |public_param: &PublicParams, + compress: bool| + -> bincode::Result> { + match compress { + true => PublicParams::uncompress(bincode::deserialize(&bincode::serialize( + &public_param.clone().compress(), + )?)?) + .map_err(|e| Box::new(ErrorKind::Custom(format!("Failed to uncompress: {}", e)))), + false => bincode::deserialize(&bincode::serialize(&public_param)?), + } + }; // To check management of bigger k_max from CRS during test let crs_k = k + 1 + (rng.gen::() % (d - k)); @@ -2223,9 +2513,9 @@ mod tests { let original_public_param = crs_gen::(d, crs_k, B, q, t, msbs_zero_padding_bit_count, rng); let public_param_that_was_compressed = - serialize_then_deserialize(&original_public_param, Compress::No).unwrap(); + serialize_then_deserialize(&original_public_param, true).unwrap(); let public_param_that_was_not_compressed = - serialize_then_deserialize(&original_public_param, Compress::Yes).unwrap(); + serialize_then_deserialize(&original_public_param, false).unwrap(); for ( public_param, @@ -2309,12 +2599,14 @@ mod tests { let B = 1048576; let q = 0; let t = 1024; + let msbs_zero_padding_bit_count = 1; let effective_cleartext_t = t >> msbs_zero_padding_bit_count; let delta = { let q = if q == 0 { 1i128 << 64 } else { q as i128 }; // delta takes the encoding with the padding bit + (q / t as i128) as u64 }; @@ -2410,15 +2702,17 @@ mod tests { type Curve = crate::curve_api::Bls12_446; - let serialize_then_deserialize = - |public_param: &PublicParams, - compress: Compress| - -> Result, SerializationError> { - let mut data = Vec::new(); - public_param.serialize_with_mode(&mut data, compress)?; - - PublicParams::deserialize_with_mode(data.as_slice(), compress, Validate::No) - }; + let serialize_then_deserialize = |public_param: &PublicParams, + compress: bool| + -> bincode::Result> { + match compress { + true => PublicParams::uncompress(bincode::deserialize(&bincode::serialize( + &public_param.clone().compress(), + )?)?) + .map_err(|e| Box::new(ErrorKind::Custom(format!("Failed to uncompress: {}", e)))), + false => bincode::deserialize(&bincode::serialize(&public_param)?), + } + }; // To check management of bigger k_max from CRS during test let crs_k = k + 1 + (rng.gen::() % (d - k)); @@ -2426,9 +2720,9 @@ mod tests { let original_public_param = crs_gen::(d, crs_k, B, q, t, msbs_zero_padding_bit_count, rng); let public_param_that_was_compressed = - serialize_then_deserialize(&original_public_param, Compress::No).unwrap(); + serialize_then_deserialize(&original_public_param, true).unwrap(); let public_param_that_was_not_compressed = - serialize_then_deserialize(&original_public_param, Compress::Yes).unwrap(); + serialize_then_deserialize(&original_public_param, false).unwrap(); for public_param in [ original_public_param, @@ -2461,4 +2755,156 @@ mod tests { } } } + + #[test] + fn test_proof_compression() { + let d = 2048; + let k = 320; + let B = 1048576; + let q = 0; + let t = 1024; + + let msbs_zero_padding_bit_count = 1; + let effective_cleartext_t = t >> msbs_zero_padding_bit_count; + + let delta = { + let q = if q == 0 { 1i128 << 64 } else { q as i128 }; + (q / t as i128) as u64 + }; + + let rng = &mut StdRng::seed_from_u64(0); + + let polymul_rev = |a: &[i64], b: &[i64]| -> Vec { + assert_eq!(a.len(), b.len()); + let d = a.len(); + let mut c = vec![0i64; d]; + + for i in 0..d { + for j in 0..d { + if i + j < d { + c[i + j] = c[i + j].wrapping_add(a[i].wrapping_mul(b[d - j - 1])); + } else { + c[i + j - d] = c[i + j - d].wrapping_sub(a[i].wrapping_mul(b[d - j - 1])); + } + } + } + + c + }; + + let a = (0..d).map(|_| rng.gen::()).collect::>(); + let s = (0..d) + .map(|_| (rng.gen::() % 2) as i64) + .collect::>(); + let e = (0..d) + .map(|_| (rng.gen::() % (2 * B)) as i64 - B as i64) + .collect::>(); + let e1 = (0..d) + .map(|_| (rng.gen::() % (2 * B)) as i64 - B as i64) + .collect::>(); + let e2 = (0..k) + .map(|_| (rng.gen::() % (2 * B)) as i64 - B as i64) + .collect::>(); + + let r = (0..d) + .map(|_| (rng.gen::() % 2) as i64) + .collect::>(); + + let m = (0..k) + .map(|_| (rng.gen::() % effective_cleartext_t) as i64) + .collect::>(); + + let b = polymul_rev(&a, &s) + .into_iter() + .zip(e.iter()) + .map(|(x, e)| x.wrapping_add(*e)) + .collect::>(); + let c1 = polymul_rev(&a, &r) + .into_iter() + .zip(e1.iter()) + .map(|(x, e1)| x.wrapping_add(*e1)) + .collect::>(); + + let mut c2 = vec![0i64; k]; + + for i in 0..k { + let mut dot = 0i64; + for j in 0..d { + let b = if i + j < d { + b[d - j - i - 1] + } else { + b[2 * d - j - i - 1].wrapping_neg() + }; + + dot = dot.wrapping_add(r[d - j - 1].wrapping_mul(b)); + } + + c2[i] = dot + .wrapping_add(e2[i]) + .wrapping_add((delta * m[i] as u64) as i64); + } + + // One of our usecases uses 320 bits of additional metadata + const METADATA_LEN: usize = (320 / u8::BITS) as usize; + + let mut metadata = [0u8; METADATA_LEN]; + metadata.fill_with(|| rng.gen::()); + + let mut m_roundtrip = vec![0i64; k]; + for i in 0..k { + let mut dot = 0i128; + for j in 0..d { + let c = if i + j < d { + c1[d - j - i - 1] + } else { + c1[2 * d - j - i - 1].wrapping_neg() + }; + + dot += s[d - j - 1] as i128 * c as i128; + } + + let q = if q == 0 { 1i128 << 64 } else { q as i128 }; + let val = ((c2[i] as i128).wrapping_sub(dot)) * t as i128; + let div = val.div_euclid(q); + let rem = val.rem_euclid(q); + let result = div as i64 + (rem > (q / 2)) as i64; + let result = result.rem_euclid(effective_cleartext_t as i64); + m_roundtrip[i] = result; + } + + type Curve = crate::curve_api::Bls12_446; + + let crs_k = k + 1 + (rng.gen::() % (d - k)); + + let public_param = crs_gen::(d, crs_k, B, q, t, msbs_zero_padding_bit_count, rng); + + let (public_commit, private_commit) = commit( + a.clone(), + b.clone(), + c1.clone(), + c2.clone(), + r.clone(), + e1.clone(), + m.clone(), + e2.clone(), + &public_param, + rng, + ); + + for load in [ComputeLoad::Proof, ComputeLoad::Verify] { + let proof = prove( + (&public_param, &public_commit), + &private_commit, + &metadata, + load, + rng, + ); + + let compressed_proof = bincode::serialize(&proof.clone().compress()).unwrap(); + let proof = + Proof::uncompress(bincode::deserialize(&compressed_proof).unwrap()).unwrap(); + + verify(&proof, (&public_param, &public_commit), &metadata).unwrap() + } + } } diff --git a/tfhe-zk-pok/src/proofs/range.rs b/tfhe-zk-pok/src/proofs/range.rs index 565cf41a42..d6fd4d43af 100644 --- a/tfhe-zk-pok/src/proofs/range.rs +++ b/tfhe-zk-pok/src/proofs/range.rs @@ -145,7 +145,11 @@ pub fn prove( let mut y = vec![G::Zp::ZERO; n]; G::Zp::hash( &mut y, - &[hash, v_hat.to_bytes().as_ref(), c_hat.to_bytes().as_ref()], + &[ + hash, + v_hat.to_le_bytes().as_ref(), + c_hat.to_le_bytes().as_ref(), + ], ); let y = OneBased(y); let mut c_y = g.mul_scalar(gamma_y); @@ -154,7 +158,7 @@ pub fn prove( } let y_bytes = &*(1..n + 1) - .flat_map(|i| y[i].to_bytes().as_ref().to_vec()) + .flat_map(|i| y[i].to_le_bytes().as_ref().to_vec()) .collect::>(); let mut t = vec![G::Zp::ZERO; n]; @@ -163,9 +167,9 @@ pub fn prove( &[ hash_t, y_bytes, - v_hat.to_bytes().as_ref(), - c_hat.to_bytes().as_ref(), - c_y.to_bytes().as_ref(), + v_hat.to_le_bytes().as_ref(), + c_hat.to_le_bytes().as_ref(), + c_y.to_le_bytes().as_ref(), ], ); let t = OneBased(t); @@ -222,9 +226,9 @@ pub fn prove( &[ hash_s, &i.to_le_bytes(), - v_hat.to_bytes().as_ref(), - c_hat.to_bytes().as_ref(), - c_y.to_bytes().as_ref(), + v_hat.to_le_bytes().as_ref(), + c_hat.to_le_bytes().as_ref(), + c_y.to_le_bytes().as_ref(), ], ); } @@ -244,9 +248,9 @@ pub fn prove( &mut delta, &[ hash_agg, - v_hat.to_bytes().as_ref(), - c_hat.to_bytes().as_ref(), - c_y.to_bytes().as_ref(), + v_hat.to_le_bytes().as_ref(), + c_hat.to_le_bytes().as_ref(), + c_y.to_le_bytes().as_ref(), ], ); let [delta_x, delta_eq, delta_y, delta_v] = delta; @@ -289,12 +293,16 @@ pub fn verify( let mut y = vec![G::Zp::ZERO; n]; G::Zp::hash( &mut y, - &[hash, v_hat.to_bytes().as_ref(), c_hat.to_bytes().as_ref()], + &[ + hash, + v_hat.to_le_bytes().as_ref(), + c_hat.to_le_bytes().as_ref(), + ], ); let y = OneBased(y); let y_bytes = &*(1..n + 1) - .flat_map(|i| y[i].to_bytes().as_ref().to_vec()) + .flat_map(|i| y[i].to_le_bytes().as_ref().to_vec()) .collect::>(); let mut t = vec![G::Zp::ZERO; n]; @@ -303,9 +311,9 @@ pub fn verify( &[ hash_t, y_bytes, - v_hat.to_bytes().as_ref(), - c_hat.to_bytes().as_ref(), - c_y.to_bytes().as_ref(), + v_hat.to_le_bytes().as_ref(), + c_hat.to_le_bytes().as_ref(), + c_y.to_le_bytes().as_ref(), ], ); let t = OneBased(t); @@ -315,9 +323,9 @@ pub fn verify( &mut delta, &[ hash_agg, - v_hat.to_bytes().as_ref(), - c_hat.to_bytes().as_ref(), - c_y.to_bytes().as_ref(), + v_hat.to_le_bytes().as_ref(), + c_hat.to_le_bytes().as_ref(), + c_y.to_le_bytes().as_ref(), ], ); let [delta_x, delta_eq, delta_y, delta_v] = delta; @@ -329,9 +337,9 @@ pub fn verify( &[ hash_s, &i.to_le_bytes(), - v_hat.to_bytes().as_ref(), - c_hat.to_bytes().as_ref(), - c_y.to_bytes().as_ref(), + v_hat.to_le_bytes().as_ref(), + c_hat.to_le_bytes().as_ref(), + c_y.to_le_bytes().as_ref(), ], ); } diff --git a/tfhe-zk-pok/src/proofs/rlwe.rs b/tfhe-zk-pok/src/proofs/rlwe.rs index d81274c9f2..8ca9b90462 100644 --- a/tfhe-zk-pok/src/proofs/rlwe.rs +++ b/tfhe-zk-pok/src/proofs/rlwe.rs @@ -329,7 +329,7 @@ pub fn prove( .collect::>(); let mut y = vec![G::Zp::ZERO; n]; - G::Zp::hash(&mut y, &[hash, x_bytes, c_hat.to_bytes().as_ref()]); + G::Zp::hash(&mut y, &[hash, x_bytes, c_hat.to_le_bytes().as_ref()]); let y = OneBased(y); let scalars = (n + 1 - big_d..n + 1) @@ -343,11 +343,11 @@ pub fn prove( &[ hash_t, &(1..n + 1) - .flat_map(|i| y[i].to_bytes().as_ref().to_vec()) + .flat_map(|i| y[i].to_le_bytes().as_ref().to_vec()) .collect::>(), x_bytes, - c_hat.to_bytes().as_ref(), - c_y.to_bytes().as_ref(), + c_hat.to_le_bytes().as_ref(), + c_y.to_le_bytes().as_ref(), ], ); let t = OneBased(t); @@ -358,8 +358,8 @@ pub fn prove( &[ hash_lmap, x_bytes, - c_hat.to_bytes().as_ref(), - c_y.to_bytes().as_ref(), + c_hat.to_le_bytes().as_ref(), + c_y.to_le_bytes().as_ref(), ], ); let theta = (0..big_n * d + 1).map(|k| theta_bar[k]).collect::>(); @@ -442,8 +442,8 @@ pub fn prove( &[ hash_agg, x_bytes, - c_hat.to_bytes().as_ref(), - c_y.to_bytes().as_ref(), + c_hat.to_le_bytes().as_ref(), + c_y.to_le_bytes().as_ref(), ], ); let [delta_eq, delta_y] = delta; @@ -510,20 +510,20 @@ pub fn prove( &[ hash_z, x_bytes, - c_hat.to_bytes().as_ref(), - c_y.to_bytes().as_ref(), - pi.to_bytes().as_ref(), - c_h.to_bytes().as_ref(), - c_hat_t.to_bytes().as_ref(), + c_hat.to_le_bytes().as_ref(), + c_y.to_le_bytes().as_ref(), + pi.to_le_bytes().as_ref(), + c_h.to_le_bytes().as_ref(), + c_hat_t.to_le_bytes().as_ref(), &y.0.iter() - .flat_map(|x| x.to_bytes().as_ref().to_vec()) + .flat_map(|x| x.to_le_bytes().as_ref().to_vec()) .collect::>(), &t.0.iter() - .flat_map(|x| x.to_bytes().as_ref().to_vec()) + .flat_map(|x| x.to_le_bytes().as_ref().to_vec()) .collect::>(), &delta .iter() - .flat_map(|x| x.to_bytes().as_ref().to_vec()) + .flat_map(|x| x.to_le_bytes().as_ref().to_vec()) .collect::>(), ], ); @@ -550,24 +550,24 @@ pub fn prove( &[ hash_w, x_bytes, - c_hat.to_bytes().as_ref(), - c_y.to_bytes().as_ref(), - pi.to_bytes().as_ref(), - c_h.to_bytes().as_ref(), - c_hat_t.to_bytes().as_ref(), + c_hat.to_le_bytes().as_ref(), + c_y.to_le_bytes().as_ref(), + pi.to_le_bytes().as_ref(), + c_h.to_le_bytes().as_ref(), + c_hat_t.to_le_bytes().as_ref(), &y.0.iter() - .flat_map(|x| x.to_bytes().as_ref().to_vec()) + .flat_map(|x| x.to_le_bytes().as_ref().to_vec()) .collect::>(), &t.0.iter() - .flat_map(|x| x.to_bytes().as_ref().to_vec()) + .flat_map(|x| x.to_le_bytes().as_ref().to_vec()) .collect::>(), &delta .iter() - .flat_map(|x| x.to_bytes().as_ref().to_vec()) + .flat_map(|x| x.to_le_bytes().as_ref().to_vec()) .collect::>(), - z.to_bytes().as_ref(), - p_h.to_bytes().as_ref(), - p_t.to_bytes().as_ref(), + z.to_le_bytes().as_ref(), + p_h.to_le_bytes().as_ref(), + p_t.to_le_bytes().as_ref(), ], ); @@ -676,14 +676,14 @@ pub fn verify( &[ hash_agg, x_bytes, - c_hat.to_bytes().as_ref(), - c_y.to_bytes().as_ref(), + c_hat.to_le_bytes().as_ref(), + c_y.to_le_bytes().as_ref(), ], ); let [delta_eq, delta_y] = delta; let mut y = vec![G::Zp::ZERO; n]; - G::Zp::hash(&mut y, &[hash, x_bytes, c_hat.to_bytes().as_ref()]); + G::Zp::hash(&mut y, &[hash, x_bytes, c_hat.to_le_bytes().as_ref()]); let y = OneBased(y); let mut t = vec![G::Zp::ZERO; n]; @@ -692,11 +692,11 @@ pub fn verify( &[ hash_t, &(1..n + 1) - .flat_map(|i| y[i].to_bytes().as_ref().to_vec()) + .flat_map(|i| y[i].to_le_bytes().as_ref().to_vec()) .collect::>(), x_bytes, - c_hat.to_bytes().as_ref(), - c_y.to_bytes().as_ref(), + c_hat.to_le_bytes().as_ref(), + c_y.to_le_bytes().as_ref(), ], ); let t = OneBased(t); @@ -707,8 +707,8 @@ pub fn verify( &[ hash_lmap, x_bytes, - c_hat.to_bytes().as_ref(), - c_y.to_bytes().as_ref(), + c_hat.to_le_bytes().as_ref(), + c_y.to_le_bytes().as_ref(), ], ); let theta = (0..big_n * d + 1).map(|k| theta_bar[k]).collect::>(); @@ -792,20 +792,20 @@ pub fn verify( &[ hash_z, x_bytes, - c_hat.to_bytes().as_ref(), - c_y.to_bytes().as_ref(), - pi.to_bytes().as_ref(), - c_h.to_bytes().as_ref(), - c_hat_t.to_bytes().as_ref(), + c_hat.to_le_bytes().as_ref(), + c_y.to_le_bytes().as_ref(), + pi.to_le_bytes().as_ref(), + c_h.to_le_bytes().as_ref(), + c_hat_t.to_le_bytes().as_ref(), &y.0.iter() - .flat_map(|x| x.to_bytes().as_ref().to_vec()) + .flat_map(|x| x.to_le_bytes().as_ref().to_vec()) .collect::>(), &t.0.iter() - .flat_map(|x| x.to_bytes().as_ref().to_vec()) + .flat_map(|x| x.to_le_bytes().as_ref().to_vec()) .collect::>(), &delta .iter() - .flat_map(|x| x.to_bytes().as_ref().to_vec()) + .flat_map(|x| x.to_le_bytes().as_ref().to_vec()) .collect::>(), ], ); @@ -844,24 +844,24 @@ pub fn verify( &[ hash_w, x_bytes, - c_hat.to_bytes().as_ref(), - c_y.to_bytes().as_ref(), - pi.to_bytes().as_ref(), - c_h.to_bytes().as_ref(), - c_hat_t.to_bytes().as_ref(), + c_hat.to_le_bytes().as_ref(), + c_y.to_le_bytes().as_ref(), + pi.to_le_bytes().as_ref(), + c_h.to_le_bytes().as_ref(), + c_hat_t.to_le_bytes().as_ref(), &y.0.iter() - .flat_map(|x| x.to_bytes().as_ref().to_vec()) + .flat_map(|x| x.to_le_bytes().as_ref().to_vec()) .collect::>(), &t.0.iter() - .flat_map(|x| x.to_bytes().as_ref().to_vec()) + .flat_map(|x| x.to_le_bytes().as_ref().to_vec()) .collect::>(), &delta .iter() - .flat_map(|x| x.to_bytes().as_ref().to_vec()) + .flat_map(|x| x.to_le_bytes().as_ref().to_vec()) .collect::>(), - z.to_bytes().as_ref(), - p_h.to_bytes().as_ref(), - p_t.to_bytes().as_ref(), + z.to_le_bytes().as_ref(), + p_h.to_le_bytes().as_ref(), + p_t.to_le_bytes().as_ref(), ], ); diff --git a/tfhe-zk-pok/src/serialization.rs b/tfhe-zk-pok/src/serialization.rs new file mode 100644 index 0000000000..9699bd0b2f --- /dev/null +++ b/tfhe-zk-pok/src/serialization.rs @@ -0,0 +1,646 @@ +#![allow(non_snake_case)] + +use std::error::Error; +use std::fmt::Display; +use std::marker::PhantomData; + +use crate::backward_compatibility::{ + SerializableAffineVersions, SerializableCubicExtFieldVersions, SerializableFpVersions, + SerializableGroupElementsVersions, SerializablePKEv1PublicParamsVersions, + SerializablePKEv2PublicParamsVersions, SerializableQuadExtFieldVersions, +}; +use ark_ec::short_weierstrass::{Affine, SWCurveConfig}; +use ark_ec::AffineRepr; +use ark_ff::{BigInt, Field, Fp, Fp2, Fp6, Fp6Config, FpConfig, QuadExtConfig, QuadExtField}; +use serde::{Deserialize, Serialize}; +use tfhe_versionable::Versionize; + +use crate::curve_api::{Curve, CurveGroupOps}; +use crate::proofs::pke::PublicParams as PKEv1PublicParams; +use crate::proofs::pke_v2::PublicParams as PKEv2PublicParams; +use crate::proofs::GroupElements; + +/// Error returned when a conversion from a vec to a fixed size array failed because the vec size is +/// incorrect +#[derive(Debug)] +pub struct InvalidArraySizeError { + expected_len: usize, + found_len: usize, +} + +impl Display for InvalidArraySizeError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "Invalid serialized array: found array of size {}, expected {}", + self.found_len, self.expected_len + ) + } +} + +impl Error for InvalidArraySizeError {} + +/// Tries to convert a Vec into a constant size array, and returns an [`InvalidArraySizeError`] if +/// the size does not match +pub(crate) fn try_vec_to_array( + vec: Vec, +) -> Result<[T; N], InvalidArraySizeError> { + let len = vec.len(); + + vec.try_into().map_err(|_| InvalidArraySizeError { + expected_len: len, + found_len: N, + }) +} + +/// Serialization equivalent of the [`Fp`] struct, where the bigint is split into +/// multiple u64. +#[derive(Serialize, Deserialize, Versionize)] +#[versionize(SerializableFpVersions)] +pub struct SerializableFp { + val: Vec, // Use a Vec since serde does not support fixed size arrays with a generic +} + +impl, const N: usize> From> for SerializableFp { + fn from(value: Fp) -> Self { + Self { + val: value.0 .0.to_vec(), + } + } +} + +impl, const N: usize> TryFrom for Fp { + type Error = InvalidArraySizeError; + + fn try_from(value: SerializableFp) -> Result { + Ok(Fp(BigInt(try_vec_to_array(value.val)?), PhantomData)) + } +} + +#[derive(Debug)] +pub enum InvalidSerializedAffineError { + InvalidFp(InvalidArraySizeError), + InvalidCompressedXCoordinate, +} + +impl Display for InvalidSerializedAffineError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + InvalidSerializedAffineError::InvalidFp(fp_error) => { + write!(f, "Invalid fp element in affine: {}", fp_error) + } + InvalidSerializedAffineError::InvalidCompressedXCoordinate => { + write!( + f, + "Cannot uncompress affine: X coordinate does not belong to the curve" + ) + } + } + } +} + +impl Error for InvalidSerializedAffineError { + fn source(&self) -> Option<&(dyn Error + 'static)> { + match self { + InvalidSerializedAffineError::InvalidFp(fp_error) => Some(fp_error), + InvalidSerializedAffineError::InvalidCompressedXCoordinate => None, + } + } +} + +impl From for InvalidSerializedAffineError { + fn from(value: InvalidArraySizeError) -> Self { + Self::InvalidFp(value) + } +} + +/// Serialization equivalent to the [`Affine`], which support an optional compression mode +/// where only the `x` coordinate is stored, and the `y` is computed on load. +#[derive(Serialize, Deserialize, Versionize)] +#[versionize(SerializableAffineVersions)] +pub enum SerializableAffine { + Infinity, + Compressed { x: F, take_largest_y: bool }, + Uncompressed { x: F, y: F }, +} + +impl SerializableAffine { + #[allow(unused)] + pub fn uncompressed + Field, C: SWCurveConfig>( + value: Affine, + ) -> Self { + if value.is_zero() { + Self::Infinity + } else { + Self::Uncompressed { + x: value.x.into(), + y: value.y.into(), + } + } + } + + pub fn compressed + Field, C: SWCurveConfig>( + value: Affine, + ) -> Self { + if value.is_zero() { + Self::Infinity + } else { + let take_largest_y = value.y > -value.y; + Self::Compressed { + x: value.x.into(), + take_largest_y, + } + } + } +} + +impl TryFrom> for Affine +where + F: TryInto, +{ + type Error = InvalidSerializedAffineError; + + fn try_from(value: SerializableAffine) -> Result { + match value { + SerializableAffine::Infinity => Ok(Self::zero()), + SerializableAffine::Compressed { x, take_largest_y } => { + Self::get_point_from_x_unchecked(x.try_into()?, take_largest_y) + .ok_or(InvalidSerializedAffineError::InvalidCompressedXCoordinate) + } + SerializableAffine::Uncompressed { x, y } => { + Ok(Self::new_unchecked(x.try_into()?, y.try_into()?)) + } + } + } +} + +pub(crate) type SerializableG1Affine = SerializableAffine; + +#[derive(Serialize, Deserialize, Versionize)] +#[versionize(SerializableQuadExtFieldVersions)] +pub struct SerializableQuadExtField { + c0: F, + c1: F, +} + +pub(crate) type SerializableFp2 = SerializableQuadExtField; +pub type SerializableG2Affine = SerializableAffine; + +impl From> for SerializableQuadExtField +where + F: From, +{ + fn from(value: QuadExtField

) -> Self { + Self { + c0: value.c0.into(), + c1: value.c1.into(), + } + } +} + +impl TryFrom> for QuadExtField

+where + F: TryInto, +{ + type Error = InvalidArraySizeError; + + fn try_from(value: SerializableQuadExtField) -> Result { + Ok(QuadExtField { + c0: value.c0.try_into()?, + c1: value.c1.try_into()?, + }) + } +} + +#[derive(Serialize, Deserialize, Versionize)] +#[versionize(SerializableCubicExtFieldVersions)] +pub struct SerializableCubicExtField { + c0: F, + c1: F, + c2: F, +} + +pub(crate) type SerializableFp6 = SerializableCubicExtField; + +impl From> for SerializableCubicExtField +where + F: From>, +{ + fn from(value: Fp6) -> Self { + Self { + c0: value.c0.into(), + c1: value.c1.into(), + c2: value.c2.into(), + } + } +} + +impl TryFrom> for Fp6 +where + F: TryInto, Error = InvalidArraySizeError>, +{ + type Error = InvalidArraySizeError; + + fn try_from(value: SerializableCubicExtField) -> Result { + Ok(Fp6 { + c0: value.c0.try_into()?, + c1: value.c1.try_into()?, + c2: value.c2.try_into()?, + }) + } +} + +pub(crate) type SerializableFp12 = SerializableQuadExtField; + +#[derive(Debug)] +pub enum InvalidSerializedGroupElementsError { + InvalidAffine(InvalidSerializedAffineError), + InvalidGlistDimension(InvalidArraySizeError), +} + +impl Display for InvalidSerializedGroupElementsError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + InvalidSerializedGroupElementsError::InvalidAffine(affine_error) => { + write!(f, "Invalid Affine in GroupElement: {}", affine_error) + } + InvalidSerializedGroupElementsError::InvalidGlistDimension(arr_error) => { + write!(f, "invalid number of elements in g_list: {}", arr_error) + } + } + } +} + +impl Error for InvalidSerializedGroupElementsError { + fn source(&self) -> Option<&(dyn Error + 'static)> { + match self { + InvalidSerializedGroupElementsError::InvalidAffine(affine_error) => Some(affine_error), + InvalidSerializedGroupElementsError::InvalidGlistDimension(arr_error) => { + Some(arr_error) + } + } + } +} + +impl From for InvalidSerializedGroupElementsError { + fn from(value: InvalidSerializedAffineError) -> Self { + Self::InvalidAffine(value) + } +} + +#[derive(Serialize, Deserialize, Versionize)] +#[versionize(SerializableGroupElementsVersions)] +pub(crate) struct SerializableGroupElements { + pub(crate) g_list: Vec, + pub(crate) g_hat_list: Vec, +} + +impl From> for SerializableGroupElements +where + >::Affine: Into, + >::Affine: Into, +{ + fn from(value: GroupElements) -> Self { + let mut g_list = Vec::new(); + let mut g_hat_list = Vec::new(); + for idx in 0..value.message_len { + g_list.push(value.g_list[(idx * 2) + 1].into()); + g_list.push(value.g_list[(idx * 2) + 2].into()); + g_hat_list.push(value.g_hat_list[idx + 1].into()) + } + + Self { g_list, g_hat_list } + } +} + +impl TryFrom for GroupElements +where + >::Affine: + TryFrom, + >::Affine: + TryFrom, +{ + type Error = InvalidSerializedGroupElementsError; + + fn try_from(value: SerializableGroupElements) -> Result { + if value.g_list.len() != value.g_hat_list.len() * 2 { + return Err(InvalidSerializedGroupElementsError::InvalidGlistDimension( + InvalidArraySizeError { + expected_len: value.g_hat_list.len() * 2, + found_len: value.g_list.len(), + }, + )); + } + + let g_list = value + .g_list + .into_iter() + .map(>::Affine::try_from) + .collect::>()?; + let g_hat_list = value + .g_hat_list + .into_iter() + .map(>::Affine::try_from) + .collect::>()?; + + Ok(Self::from_vec(g_list, g_hat_list)) + } +} + +#[derive(Debug)] +pub enum InvalidSerializedPublicParamsError { + InvalidGroupElements(InvalidSerializedGroupElementsError), + InvalidHashDimension(InvalidArraySizeError), +} + +impl Display for InvalidSerializedPublicParamsError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + InvalidSerializedPublicParamsError::InvalidGroupElements(group_error) => { + write!(f, "Invalid PublicParams: {}", group_error) + } + InvalidSerializedPublicParamsError::InvalidHashDimension(arr_error) => { + write!(f, "invalid size of hash: {}", arr_error) + } + } + } +} + +impl Error for InvalidSerializedPublicParamsError { + fn source(&self) -> Option<&(dyn Error + 'static)> { + match self { + InvalidSerializedPublicParamsError::InvalidGroupElements(group_error) => { + Some(group_error) + } + InvalidSerializedPublicParamsError::InvalidHashDimension(arr_error) => Some(arr_error), + } + } +} + +impl From for InvalidSerializedPublicParamsError { + fn from(value: InvalidSerializedGroupElementsError) -> Self { + Self::InvalidGroupElements(value) + } +} + +impl From for InvalidSerializedPublicParamsError { + fn from(value: InvalidArraySizeError) -> Self { + Self::InvalidHashDimension(value) + } +} + +#[derive(serde::Serialize, serde::Deserialize, Versionize)] +#[versionize(SerializablePKEv2PublicParamsVersions)] +pub struct SerializablePKEv2PublicParams { + pub(crate) g_lists: SerializableGroupElements, + pub(crate) D: usize, + pub n: usize, + pub d: usize, + pub k: usize, + pub B: u64, + pub B_r: u64, + pub B_bound: u64, + pub m_bound: usize, + pub q: u64, + pub t: u64, + pub msbs_zero_padding_bit_count: u64, + // We use Vec since serde does not support fixed size arrays of 256 elements + pub(crate) hash: Vec, + pub(crate) hash_R: Vec, + pub(crate) hash_t: Vec, + pub(crate) hash_w: Vec, + pub(crate) hash_agg: Vec, + pub(crate) hash_lmap: Vec, + pub(crate) hash_phi: Vec, + pub(crate) hash_xi: Vec, + pub(crate) hash_z: Vec, + pub(crate) hash_chi: Vec, +} + +impl From> for SerializablePKEv2PublicParams +where + GroupElements: Into, +{ + fn from(value: PKEv2PublicParams) -> Self { + let PKEv2PublicParams { + g_lists, + D, + n, + d, + k, + B, + B_r, + B_bound, + m_bound, + q, + t, + msbs_zero_padding_bit_count, + hash, + hash_R, + hash_t, + hash_w, + hash_agg, + hash_lmap, + hash_phi, + hash_xi, + hash_z, + hash_chi, + } = value; + Self { + g_lists: g_lists.into(), + D, + n, + d, + k, + B, + B_r, + B_bound, + m_bound, + q, + t, + msbs_zero_padding_bit_count, + hash: hash.to_vec(), + hash_R: hash_R.to_vec(), + hash_t: hash_t.to_vec(), + hash_w: hash_w.to_vec(), + hash_agg: hash_agg.to_vec(), + hash_lmap: hash_lmap.to_vec(), + hash_phi: hash_phi.to_vec(), + hash_xi: hash_xi.to_vec(), + hash_z: hash_z.to_vec(), + hash_chi: hash_chi.to_vec(), + } + } +} + +impl TryFrom for PKEv2PublicParams +where + GroupElements: + TryFrom, +{ + type Error = InvalidSerializedPublicParamsError; + + fn try_from(value: SerializablePKEv2PublicParams) -> Result { + let SerializablePKEv2PublicParams { + g_lists, + D, + n, + d, + k, + B, + B_r, + B_bound, + m_bound, + q, + t, + msbs_zero_padding_bit_count, + hash, + hash_R, + hash_t, + hash_w, + hash_agg, + hash_lmap, + hash_phi, + hash_xi, + hash_z, + hash_chi, + } = value; + Ok(Self { + g_lists: g_lists.try_into()?, + D, + n, + d, + k, + B, + B_r, + B_bound, + m_bound, + q, + t, + msbs_zero_padding_bit_count, + hash: try_vec_to_array(hash)?, + hash_R: try_vec_to_array(hash_R)?, + hash_t: try_vec_to_array(hash_t)?, + hash_w: try_vec_to_array(hash_w)?, + hash_agg: try_vec_to_array(hash_agg)?, + hash_lmap: try_vec_to_array(hash_lmap)?, + hash_phi: try_vec_to_array(hash_phi)?, + hash_xi: try_vec_to_array(hash_xi)?, + hash_z: try_vec_to_array(hash_z)?, + hash_chi: try_vec_to_array(hash_chi)?, + }) + } +} + +#[derive(serde::Serialize, serde::Deserialize, Versionize)] +#[versionize(SerializablePKEv1PublicParamsVersions)] +pub struct SerializablePKEv1PublicParams { + pub(crate) g_lists: SerializableGroupElements, + pub(crate) big_d: usize, + pub n: usize, + pub d: usize, + pub k: usize, + pub b: u64, + pub b_r: u64, + pub q: u64, + pub t: u64, + pub msbs_zero_padding_bit_count: u64, + // We use Vec since serde does not support fixed size arrays of 256 elements + pub(crate) hash: Vec, + pub(crate) hash_t: Vec, + pub(crate) hash_agg: Vec, + pub(crate) hash_lmap: Vec, + pub(crate) hash_z: Vec, + pub(crate) hash_w: Vec, +} + +impl From> for SerializablePKEv1PublicParams +where + GroupElements: Into, +{ + fn from(value: PKEv1PublicParams) -> Self { + let PKEv1PublicParams { + g_lists, + big_d, + n, + d, + k, + b, + b_r, + q, + t, + msbs_zero_padding_bit_count, + hash, + hash_t, + hash_agg, + hash_lmap, + hash_z, + hash_w, + } = value; + Self { + g_lists: g_lists.into(), + big_d, + n, + d, + k, + b, + b_r, + q, + t, + msbs_zero_padding_bit_count, + hash: hash.to_vec(), + hash_t: hash_t.to_vec(), + hash_agg: hash_agg.to_vec(), + hash_lmap: hash_lmap.to_vec(), + hash_z: hash_z.to_vec(), + hash_w: hash_w.to_vec(), + } + } +} + +impl TryFrom for PKEv1PublicParams +where + GroupElements: + TryFrom, +{ + type Error = InvalidSerializedPublicParamsError; + + fn try_from(value: SerializablePKEv1PublicParams) -> Result { + let SerializablePKEv1PublicParams { + g_lists, + big_d, + n, + d, + k, + b, + b_r, + q, + t, + msbs_zero_padding_bit_count, + hash, + hash_t, + hash_agg, + hash_lmap, + hash_z, + hash_w, + } = value; + Ok(Self { + g_lists: g_lists.try_into()?, + big_d, + n, + d, + k, + b, + b_r, + q, + t, + msbs_zero_padding_bit_count, + hash: try_vec_to_array(hash)?, + hash_t: try_vec_to_array(hash_t)?, + hash_agg: try_vec_to_array(hash_agg)?, + hash_lmap: try_vec_to_array(hash_lmap)?, + hash_z: try_vec_to_array(hash_z)?, + hash_w: try_vec_to_array(hash_w)?, + }) + } +} diff --git a/tfhe/Cargo.toml b/tfhe/Cargo.toml index b60e3f4bc3..c64a8f3406 100644 --- a/tfhe/Cargo.toml +++ b/tfhe/Cargo.toml @@ -46,7 +46,7 @@ hex = "0.4.3" # End regex-engine deps # Used for backward compatibility test metadata ron = "0.8" -tfhe-backward-compat-data = { git = "https://github.com/zama-ai/tfhe-backward-compat-data.git", branch = "v0.1", default-features = false, features = [ +tfhe-backward-compat-data = { git = "https://github.com/zama-ai/tfhe-backward-compat-data.git", branch = "v0.2", default-features = false, features = [ "load", ] } diff --git a/tfhe/benches/integer/zk_pke.rs b/tfhe/benches/integer/zk_pke.rs index bcfd9cb936..03b13b5681 100644 --- a/tfhe/benches/integer/zk_pke.rs +++ b/tfhe/benches/integer/zk_pke.rs @@ -6,7 +6,6 @@ use rand::prelude::*; use std::fs::{File, OpenOptions}; use std::io::Write; use std::path::Path; -use tfhe::core_crypto::prelude::*; use tfhe::integer::key_switching_key::KeySwitchingKey; use tfhe::integer::parameters::{ IntegerCompactCiphertextListCastingMode, IntegerCompactCiphertextListUnpackingMode, @@ -157,10 +156,7 @@ fn pke_zk_verify(c: &mut Criterion, results_file: &Path) { let shortint_params: PBSParameters = param_fhe.into(); - let mut crs_data = vec![]; - public_params - .serialize_with_mode(&mut crs_data, Compress::No) - .unwrap(); + let crs_data = bincode::serialize(&public_params).unwrap(); println!("CRS size: {}", crs_data.len()); diff --git a/tfhe/docs/guides/zk-pok.md b/tfhe/docs/guides/zk-pok.md index 09d30d6706..1256425ed8 100644 --- a/tfhe/docs/guides/zk-pok.md +++ b/tfhe/docs/guides/zk-pok.md @@ -12,7 +12,7 @@ Using this feature is straightforward: during encryption, the client generates t ```rust use rand::prelude::*; -use tfhe::prelude::FheDecrypt; +use tfhe::prelude::*; use tfhe::set_server_key; use tfhe::zk::{CompactPkeCrs, ZkComputeLoad}; @@ -45,9 +45,8 @@ pub fn main() -> Result<(), Box> { // Verify the ciphertexts let expander = proven_compact_list.verify_and_expand(public_zk_params, &public_key, &metadata)?; - - let a: tfhe::FheUint64 = expander.get(0).unwrap()?; - let b: tfhe::FheUint64 = expander.get(1).unwrap()?; + let a: tfhe::FheUint64 = expander.get(0)?.unwrap(); + let b: tfhe::FheUint64 = expander.get(1)?.unwrap(); a + b }; @@ -80,7 +79,7 @@ This works essentially in the same way as before. Additionally, you need to indi ```rust use rand::prelude::*; -use tfhe::prelude::FheDecrypt; +use tfhe::prelude::*; use tfhe::set_server_key; use tfhe::zk::{CompactPkeCrs, ZkComputeLoad}; @@ -119,9 +118,10 @@ pub fn main() -> Result<(), Box> { set_server_key(server_key); // Verify the ciphertexts - let expander = proven_compact_list.verify_and_expand(public_zk_params, &public_key, &metadata)?; - let a: tfhe::FheUint64 = expander.get(0).unwrap()?; - let b: tfhe::FheUint64 = expander.get(1).unwrap()?; + let expander = + proven_compact_list.verify_and_expand(public_zk_params, &public_key, &metadata)?; + let a: tfhe::FheUint64 = expander.get(0)?.unwrap(); + let b: tfhe::FheUint64 = expander.get(1)?.unwrap(); a + b }; diff --git a/tfhe/src/c_api/high_level_api/compact_list.rs b/tfhe/src/c_api/high_level_api/compact_list.rs index 4473e4e996..7e0f0ffeb5 100644 --- a/tfhe/src/c_api/high_level_api/compact_list.rs +++ b/tfhe/src/c_api/high_level_api/compact_list.rs @@ -15,6 +15,7 @@ use crate::c_api::high_level_api::utils::{ #[cfg(feature = "zk-pok")] use crate::c_api::high_level_api::zk::{CompactPkePublicParams, ZkComputeLoad}; use crate::c_api::utils::{catch_panic, get_mut_checked, get_ref_checked}; +use crate::prelude::CiphertextList; use std::ffi::c_int; pub struct CompactCiphertextListBuilder(crate::high_level_api::CompactCiphertextListBuilder); diff --git a/tfhe/src/c_api/high_level_api/compressed_ciphertext_list.rs b/tfhe/src/c_api/high_level_api/compressed_ciphertext_list.rs index 1b04e08dda..03224feee6 100644 --- a/tfhe/src/c_api/high_level_api/compressed_ciphertext_list.rs +++ b/tfhe/src/c_api/high_level_api/compressed_ciphertext_list.rs @@ -8,6 +8,7 @@ use crate::c_api::high_level_api::utils::{ impl_destroy_on_type, impl_serialize_deserialize_on_type, }; use crate::c_api::utils::{catch_panic, get_mut_checked, get_ref_checked}; +use crate::prelude::CiphertextList; use std::ffi::c_int; pub struct CompressedCiphertextListBuilder(crate::high_level_api::CompressedCiphertextListBuilder); diff --git a/tfhe/src/c_api/high_level_api/zk.rs b/tfhe/src/c_api/high_level_api/zk.rs index dfa9a20440..3bccb128f3 100644 --- a/tfhe/src/c_api/high_level_api/zk.rs +++ b/tfhe/src/c_api/high_level_api/zk.rs @@ -1,7 +1,7 @@ use super::utils::*; use crate::c_api::high_level_api::config::Config; use crate::c_api::utils::get_ref_checked; -use crate::zk::{CanonicalDeserialize, CanonicalSerialize, Compress, Validate}; +use crate::zk::Compressible; use std::ffi::c_int; #[repr(C)] @@ -41,16 +41,11 @@ pub unsafe extern "C" fn compact_pke_public_params_serialize( let wrapper = crate::c_api::utils::get_ref_checked(sself).unwrap(); - let compress = if compress { - Compress::Yes + let buffer = if compress { + bincode::serialize(&wrapper.0.compress()).unwrap() } else { - Compress::No + bincode::serialize(&wrapper.0).unwrap() }; - let mut buffer = vec![]; - wrapper - .0 - .serialize_with_mode(&mut buffer, compress) - .unwrap(); *result = buffer.into(); }) @@ -62,8 +57,6 @@ pub unsafe extern "C" fn compact_pke_public_params_serialize( #[no_mangle] pub unsafe extern "C" fn compact_pke_public_params_deserialize( buffer_view: crate::c_api::buffer::DynamicBufferView, - is_compressed: bool, - validate: bool, result: *mut *mut CompactPkePublicParams, ) -> ::std::os::raw::c_int { crate::c_api::utils::catch_panic(|| { @@ -71,20 +64,7 @@ pub unsafe extern "C" fn compact_pke_public_params_deserialize( *result = std::ptr::null_mut(); - let deserialized = crate::zk::CompactPkePublicParams::deserialize_with_mode( - buffer_view.as_slice(), - if is_compressed { - Compress::Yes - } else { - Compress::No - }, - if validate { - Validate::Yes - } else { - Validate::No - }, - ) - .unwrap(); + let deserialized = bincode::deserialize(buffer_view.as_slice()).unwrap(); let heap_allocated_object = Box::new(CompactPkePublicParams(deserialized)); diff --git a/tfhe/src/high_level_api/backward_compatibility/booleans.rs b/tfhe/src/high_level_api/backward_compatibility/booleans.rs index 2605b89659..8ef15a6622 100644 --- a/tfhe/src/high_level_api/backward_compatibility/booleans.rs +++ b/tfhe/src/high_level_api/backward_compatibility/booleans.rs @@ -7,6 +7,7 @@ use crate::high_level_api::booleans::{ InnerBoolean, InnerBooleanVersionOwned, InnerCompressedFheBool, }; use crate::integer::ciphertext::{CompactCiphertextList, DataKind}; +use crate::prelude::CiphertextList; use crate::{ CompactCiphertextList as HlCompactCiphertextList, CompressedFheBool, Error, FheBool, Tag, }; @@ -111,7 +112,7 @@ impl CompactFheBool { let block = list .inner .get::(0) - .ok_or_else(|| Error::new("Failed to expand compact list".to_string()))??; + .map(|b| b.ok_or_else(|| Error::new("Failed to expand compact list".to_string())))??; let mut ciphertext = FheBool::new(block, Tag::default()); ciphertext.ciphertext.move_to_device_of_server_key_if_set(); @@ -148,7 +149,9 @@ impl CompactFheBoolList { let block = list .inner .get::(idx) - .ok_or_else(|| Error::new("Failed to expand compact list".to_string()))??; + .map(|list| { + list.ok_or_else(|| Error::new("Failed to expand compact list".to_string())) + })??; let mut ciphertext = FheBool::new(block, Tag::default()); ciphertext.ciphertext.move_to_device_of_server_key_if_set(); diff --git a/tfhe/src/high_level_api/backward_compatibility/compact_list.rs b/tfhe/src/high_level_api/backward_compatibility/compact_list.rs index b6f3bdf4f4..cc38b5a16a 100644 --- a/tfhe/src/high_level_api/backward_compatibility/compact_list.rs +++ b/tfhe/src/high_level_api/backward_compatibility/compact_list.rs @@ -17,8 +17,17 @@ impl Upgrade for CompactCiphertextListV0 { } } +#[cfg(feature = "zk-pok")] +use crate::ProvenCompactCiphertextList; + #[derive(VersionsDispatch)] pub enum CompactCiphertextListVersions { V0(CompactCiphertextListV0), V1(CompactCiphertextList), } + +#[cfg(feature = "zk-pok")] +#[derive(VersionsDispatch)] +pub enum ProvenCompactCiphertextListVersions { + V0(ProvenCompactCiphertextList), +} diff --git a/tfhe/src/high_level_api/backward_compatibility/integers.rs b/tfhe/src/high_level_api/backward_compatibility/integers.rs index de05f2b0a7..fca57b4ab4 100644 --- a/tfhe/src/high_level_api/backward_compatibility/integers.rs +++ b/tfhe/src/high_level_api/backward_compatibility/integers.rs @@ -16,6 +16,7 @@ use crate::integer::ciphertext::{ CompressedRadixCiphertext as IntegerCompressedRadixCiphertext, CompressedSignedRadixCiphertext as IntegerCompressedSignedRadixCiphertext, DataKind, }; +use crate::prelude::CiphertextList; use crate::shortint::ciphertext::CompressedModulusSwitchedCiphertext; use crate::shortint::{Ciphertext, ServerKey}; use crate::{CompactCiphertextList as HlCompactCiphertextList, Error, Tag}; @@ -277,7 +278,9 @@ where let ct = list .inner .get::(0) - .ok_or_else(|| Error::new("Failed to expand compact list".to_string()))??; + .map(|list| { + list.ok_or_else(|| Error::new("Failed to expand compact list".to_string())) + })??; Ok(FheInt::new(ct, Tag::default())) } } @@ -316,7 +319,9 @@ where let ct = list .inner .get::(idx) - .ok_or_else(|| Error::new("Failed to expand compact list".to_string()))??; + .map(|list| { + list.ok_or_else(|| Error::new("Failed to expand compact list".to_string())) + })??; Ok(FheInt::new(ct, Tag::default())) }) .collect::, _>>() @@ -353,7 +358,9 @@ where let ct = list .inner .get::(0) - .ok_or_else(|| Error::new("Failed to expand compact list".to_string()))??; + .map(|ct| { + ct.ok_or_else(|| Error::new("Failed to expand compact list".to_string())) + })??; Ok(FheUint::new(ct, Tag::default())) } } @@ -391,7 +398,9 @@ where let ct = list .inner .get::(idx) - .ok_or_else(|| Error::new("Failed to expand compact list".to_string()))??; + .map(|ct| { + ct.ok_or_else(|| Error::new("Failed to expand compact list".to_string())) + })??; Ok(FheUint::new(ct, Tag::default())) }) .collect::, _>>() diff --git a/tfhe/src/high_level_api/compact_list.rs b/tfhe/src/high_level_api/compact_list.rs index 836888cd8a..723cd7c29a 100644 --- a/tfhe/src/high_level_api/compact_list.rs +++ b/tfhe/src/high_level_api/compact_list.rs @@ -1,6 +1,8 @@ use tfhe_versionable::Versionize; use crate::backward_compatibility::compact_list::CompactCiphertextListVersions; +#[cfg(feature = "zk-pok")] +use crate::backward_compatibility::compact_list::ProvenCompactCiphertextListVersions; use crate::conformance::ParameterSetConformant; use crate::core_crypto::commons::math::random::{Deserialize, Serialize}; use crate::core_crypto::prelude::Numeric; @@ -14,6 +16,7 @@ use crate::integer::parameters::{ IntegerCompactCiphertextListUnpackingMode, }; use crate::named::Named; +use crate::prelude::CiphertextList; use crate::shortint::MessageModulus; #[cfg(feature = "zk-pok")] pub use zk::ProvenCompactCiphertextList; @@ -194,7 +197,8 @@ impl ParameterSetConformant for CompactCiphertextList { mod zk { use super::*; - #[derive(Clone, Serialize, Deserialize)] + #[derive(Clone, Serialize, Deserialize, Versionize)] + #[versionize(ProvenCompactCiphertextListVersions)] pub struct ProvenCompactCiphertextList { pub(crate) inner: crate::integer::ciphertext::ProvenCompactCiphertextList, pub(crate) tag: Tag, @@ -366,27 +370,27 @@ pub struct CompactCiphertextListExpander { tag: Tag, } -impl CompactCiphertextListExpander { - pub fn len(&self) -> usize { +impl CiphertextList for CompactCiphertextListExpander { + fn len(&self) -> usize { self.inner.len() } - pub fn is_empty(&self) -> bool { + fn is_empty(&self) -> bool { self.len() == 0 } - pub fn get_kind_of(&self, index: usize) -> Option { + fn get_kind_of(&self, index: usize) -> Option { self.inner.get_kind_of(index).and_then(|data_kind| { crate::FheTypes::from_data_kind(data_kind, self.inner.message_modulus()) }) } - pub fn get(&self, index: usize) -> Option> + fn get(&self, index: usize) -> crate::Result> where T: Expandable + Tagged, { let mut expanded = self.inner.get::(index); - if let Some(Ok(inner)) = &mut expanded { + if let Ok(Some(inner)) = &mut expanded { inner.tag_mut().set_data(self.tag.data()); } expanded @@ -540,15 +544,15 @@ mod tests { let e: u8 = e.decrypt(&ck); assert_eq!(e, 3); - assert!(expander.get::(5).is_none()); + assert!(expander.get::(5).unwrap().is_none()); } { // Incorrect type - assert!(expander.get::(0).unwrap().is_err()); + assert!(expander.get::(0).is_err()); // Correct type but wrong number of bits - assert!(expander.get::(0).unwrap().is_err()); + assert!(expander.get::(0).is_err()); } } @@ -602,15 +606,15 @@ mod tests { let e: u8 = e.decrypt(&ck); assert_eq!(e, 3); - assert!(expander.get::(5).is_none()); + assert!(expander.get::(5).unwrap().is_none()); } { // Incorrect type - assert!(expander.get::(0).unwrap().is_err()); + assert!(expander.get::(0).is_err()); // Correct type but wrong number of bits - assert!(expander.get::(0).unwrap().is_err()); + assert!(expander.get::(0).is_err()); } } @@ -665,15 +669,15 @@ mod tests { let d: u8 = d.decrypt(&ck); assert_eq!(d, 3); - assert!(expander.get::(4).is_none()); + assert!(expander.get::(4).unwrap().is_none()); } { // Incorrect type - assert!(expander.get::(0).unwrap().is_err()); + assert!(expander.get::(0).is_err()); // Correct type but wrong number of bits - assert!(expander.get::(0).unwrap().is_err()); + assert!(expander.get::(0).is_err()); } let unverified_expander = compact_list.expand_without_verification().unwrap(); @@ -693,7 +697,7 @@ mod tests { let d: u8 = d.decrypt(&ck); assert_eq!(d, 3); - assert!(unverified_expander.get::(4).is_none()); + assert!(unverified_expander.get::(4).unwrap().is_none()); } } @@ -754,15 +758,15 @@ mod tests { let d: u8 = d.decrypt(&ck); assert_eq!(d, 3); - assert!(expander.get::(4).is_none()); + assert!(expander.get::(4).unwrap().is_none()); } { // Incorrect type - assert!(expander.get::(0).unwrap().is_err()); + assert!(expander.get::(0).is_err()); // Correct type but wrong number of bits - assert!(expander.get::(0).unwrap().is_err()); + assert!(expander.get::(0).is_err()); } let unverified_expander = compact_list.expand_without_verification().unwrap(); @@ -782,7 +786,7 @@ mod tests { let d: u8 = d.decrypt(&ck); assert_eq!(d, 3); - assert!(unverified_expander.get::(4).is_none()); + assert!(unverified_expander.get::(4).unwrap().is_none()); } } } diff --git a/tfhe/src/high_level_api/compressed_ciphertext_list.rs b/tfhe/src/high_level_api/compressed_ciphertext_list.rs index 83e2153447..2bb06ff956 100644 --- a/tfhe/src/high_level_api/compressed_ciphertext_list.rs +++ b/tfhe/src/high_level_api/compressed_ciphertext_list.rs @@ -12,7 +12,7 @@ use crate::integer::gpu::ciphertext::compressed_ciphertext_list::{ CudaCompressible, CudaExpandable, }; use crate::named::Named; -use crate::prelude::Tagged; +use crate::prelude::{CiphertextList, Tagged}; use crate::shortint::Ciphertext; use crate::{FheBool, FheInt, FheUint, Tag}; @@ -233,8 +233,8 @@ impl Tagged for CompressedCiphertextList { } } -impl CompressedCiphertextList { - pub fn len(&self) -> usize { +impl CiphertextList for CompressedCiphertextList { + fn len(&self) -> usize { match &self.inner { InnerCompressedCiphertextList::Cpu(inner) => inner.len(), #[cfg(feature = "gpu")] @@ -242,7 +242,7 @@ impl CompressedCiphertextList { } } - pub fn is_empty(&self) -> bool { + fn is_empty(&self) -> bool { match &self.inner { InnerCompressedCiphertextList::Cpu(inner) => inner.len() == 0, #[cfg(feature = "gpu")] @@ -250,7 +250,7 @@ impl CompressedCiphertextList { } } - pub fn get_kind_of(&self, index: usize) -> Option { + fn get_kind_of(&self, index: usize) -> Option { match &self.inner { InnerCompressedCiphertextList::Cpu(inner) => Some(match inner.get_kind_of(index)? { DataKind::Unsigned(n) => { @@ -342,7 +342,7 @@ impl CompressedCiphertextList { } } - pub fn get(&self, index: usize) -> crate::Result> + fn get(&self, index: usize) -> crate::Result> where T: HlExpandable + Tagged, { @@ -394,7 +394,9 @@ impl CompressedCiphertextList { } } } +} +impl CompressedCiphertextList { pub fn into_raw_parts(self) -> (crate::integer::ciphertext::CompressedCiphertextList, Tag) { let Self { inner, tag } = self; match inner { diff --git a/tfhe/src/high_level_api/prelude.rs b/tfhe/src/high_level_api/prelude.rs index 5baf64ad9f..128cea0a97 100644 --- a/tfhe/src/high_level_api/prelude.rs +++ b/tfhe/src/high_level_api/prelude.rs @@ -7,10 +7,10 @@ //! use tfhe::prelude::*; //! ``` pub use crate::high_level_api::traits::{ - BitSlice, DivRem, FheBootstrap, FheDecrypt, FheEncrypt, FheEq, FheKeyswitch, FheMax, FheMin, - FheNumberConstant, FheOrd, FheTrivialEncrypt, FheTryEncrypt, FheTryTrivialEncrypt, IfThenElse, - OverflowingAdd, OverflowingMul, OverflowingSub, RotateLeft, RotateLeftAssign, RotateRight, - RotateRightAssign, Tagged, + BitSlice, CiphertextList, DivRem, FheBootstrap, FheDecrypt, FheEncrypt, FheEq, FheKeyswitch, + FheMax, FheMin, FheNumberConstant, FheOrd, FheTrivialEncrypt, FheTryEncrypt, + FheTryTrivialEncrypt, IfThenElse, OverflowingAdd, OverflowingMul, OverflowingSub, RotateLeft, + RotateLeftAssign, RotateRight, RotateRightAssign, Tagged, }; pub use crate::conformance::ParameterSetConformant; diff --git a/tfhe/src/high_level_api/traits.rs b/tfhe/src/high_level_api/traits.rs index 850ec3c952..d5b7f26116 100644 --- a/tfhe/src/high_level_api/traits.rs +++ b/tfhe/src/high_level_api/traits.rs @@ -4,6 +4,8 @@ use crate::error::InvalidRangeError; use crate::high_level_api::ClientKey; use crate::{FheBool, Tag}; +use super::compressed_ciphertext_list::HlExpandable; + /// Trait used to have a generic way of creating a value of a FHE type /// from a native value. /// @@ -199,3 +201,12 @@ pub trait Tagged { fn tag_mut(&mut self) -> &mut Tag; } + +pub trait CiphertextList { + fn len(&self) -> usize; + fn is_empty(&self) -> bool; + fn get_kind_of(&self, index: usize) -> Option; + fn get(&self, index: usize) -> crate::Result> + where + T: HlExpandable + Tagged; +} diff --git a/tfhe/src/integer/backward_compatibility/ciphertext/mod.rs b/tfhe/src/integer/backward_compatibility/ciphertext/mod.rs index f268099091..1a2459e83a 100644 --- a/tfhe/src/integer/backward_compatibility/ciphertext/mod.rs +++ b/tfhe/src/integer/backward_compatibility/ciphertext/mod.rs @@ -9,6 +9,8 @@ use crate::integer::ciphertext::{ CompressedModulusSwitchedSignedRadixCiphertext, DataKind, }; use crate::integer::BooleanBlock; +#[cfg(feature = "zk-pok")] +use crate::integer::ProvenCompactCiphertextList; use crate::shortint::ciphertext::CompressedModulusSwitchedCiphertext; #[derive(VersionsDispatch)] @@ -53,6 +55,12 @@ pub enum CompactCiphertextListVersions { V1(CompactCiphertextList), } +#[cfg(feature = "zk-pok")] +#[derive(VersionsDispatch)] +pub enum ProvenCompactCiphertextListVersions { + V0(ProvenCompactCiphertextList), +} + #[derive(VersionsDispatch)] pub enum DataKindVersions { V0(DataKind), diff --git a/tfhe/src/integer/ciphertext/compact_list.rs b/tfhe/src/integer/ciphertext/compact_list.rs index f286b6a547..33d730d544 100644 --- a/tfhe/src/integer/ciphertext/compact_list.rs +++ b/tfhe/src/integer/ciphertext/compact_list.rs @@ -2,6 +2,8 @@ use super::{DataKind, Expandable}; use crate::conformance::{ListSizeConstraint, ParameterSetConformant}; use crate::core_crypto::prelude::Numeric; use crate::integer::backward_compatibility::ciphertext::CompactCiphertextListVersions; +#[cfg(feature = "zk-pok")] +use crate::integer::backward_compatibility::ciphertext::ProvenCompactCiphertextListVersions; use crate::integer::block_decomposition::DecomposableInto; use crate::integer::encryption::{create_clear_radix_block_iterator, KnowsMessageModulus}; use crate::integer::parameters::CompactCiphertextListConformanceParams; @@ -285,12 +287,13 @@ impl CompactCiphertextListExpander { .map(|block| (block, current_info)) } - pub fn get(&self, index: usize) -> Option> + pub fn get(&self, index: usize) -> crate::Result> where T: Expandable, { self.blocks_of(index) .map(|(blocks, kind)| T::from_expanded_blocks(blocks.to_owned(), kind)) + .transpose() } pub(crate) fn message_modulus(&self) -> MessageModulus { @@ -554,7 +557,8 @@ impl CompactCiphertextList { } #[cfg(feature = "zk-pok")] -#[derive(Clone, Serialize, Deserialize)] +#[derive(Clone, Serialize, Deserialize, Versionize)] +#[versionize(ProvenCompactCiphertextListVersions)] pub struct ProvenCompactCiphertextList { pub(crate) ct_list: crate::shortint::ciphertext::ProvenCompactCiphertextList, // Integers stored can have a heterogeneous number of blocks and signedness diff --git a/tfhe/src/js_on_wasm_api/js_high_level_api/integers.rs b/tfhe/src/js_on_wasm_api/js_high_level_api/integers.rs index bfec4879a2..6168e9e736 100644 --- a/tfhe/src/js_on_wasm_api/js_high_level_api/integers.rs +++ b/tfhe/src/js_on_wasm_api/js_high_level_api/integers.rs @@ -1051,12 +1051,13 @@ macro_rules! define_expander_get_method { #[wasm_bindgen] pub fn [] (&mut self, index: usize) -> Result<[], JsError> { catch_panic_result(|| { - self.0.get::]>(index) - .map_or_else( - || Err(JsError::new(&format!("Index {index} is out of bounds"))), - |a| a.map_err(into_js_error), - ) - .map([]) + self.0.get::]>(index) + .map_err(into_js_error) + .map(|val| + val.map_or_else( + || Err(JsError::new(&format!("Index {index} is out of bounds"))), + |val| Ok([](val)) + ))? }) } )* @@ -1077,11 +1078,12 @@ macro_rules! define_expander_get_method { pub fn [] (&mut self, index: usize) -> Result<[], JsError> { catch_panic_result(|| { self.0.get::]>(index) - .map_or_else( - || Err(JsError::new(&format!("Index {index} is out of bounds"))), - |a| a.map_err(into_js_error), - ) - .map([]) + .map_err(into_js_error) + .map(|val| + val.map_or_else( + || Err(JsError::new(&format!("Index {index} is out of bounds"))), + |val| Ok([](val)) + ))? }) } )* @@ -1103,11 +1105,13 @@ impl CompactCiphertextListExpander { catch_panic_result(|| { self.0 .get::(index) - .map_or_else( - || Err(JsError::new(&format!("Index {index} is out of bounds"))), - |a| a.map_err(into_js_error), - ) - .map(FheBool) + .map_err(into_js_error) + .map(|val| { + val.map_or_else( + || Err(JsError::new(&format!("Index {index} is out of bounds"))), + |val| Ok(FheBool(val)), + ) + })? }) } diff --git a/tfhe/src/js_on_wasm_api/js_high_level_api/zk.rs b/tfhe/src/js_on_wasm_api/js_high_level_api/zk.rs index c18f8950dc..eb7b01a42e 100644 --- a/tfhe/src/js_on_wasm_api/js_high_level_api/zk.rs +++ b/tfhe/src/js_on_wasm_api/js_high_level_api/zk.rs @@ -3,7 +3,9 @@ use wasm_bindgen::prelude::*; use crate::js_on_wasm_api::js_high_level_api::config::TfheConfig; use crate::js_on_wasm_api::js_high_level_api::{catch_panic_result, into_js_error}; use crate::js_on_wasm_api::shortint::ShortintParameters; -use tfhe_zk_pok::{CanonicalDeserialize, CanonicalSerialize, Compress, Validate}; + +use crate::zk::Compressible; + #[derive(Copy, Clone, Eq, PartialEq)] #[wasm_bindgen] pub enum ZkComputeLoad { @@ -33,43 +35,23 @@ impl CompactPkePublicParams { #[wasm_bindgen] pub fn serialize(&self, compress: bool) -> Result, JsError> { catch_panic_result(|| { - let mut data = vec![]; - self.0 - .serialize_with_mode( - &mut data, - if compress { - Compress::Yes - } else { - Compress::No - }, - ) - .map_err(into_js_error)?; - Ok(data) + let data = if compress { + bincode::serialize(&self.0.compress()) + } else { + bincode::serialize(&self.0) + }; + data.map_err(into_js_error) }) } #[wasm_bindgen] - pub fn deserialize( - buffer: &[u8], - is_compressed: bool, - validate: bool, - ) -> Result { + pub fn deserialize(buffer: &[u8]) -> Result { + // If buffer is compressed it is automatically detected and uncompressed. + // TODO: handle validation catch_panic_result(|| { - crate::zk::CompactPkePublicParams::deserialize_with_mode( - buffer, - if is_compressed { - Compress::Yes - } else { - Compress::No - }, - if validate { - Validate::Yes - } else { - Validate::No - }, - ) - .map(CompactPkePublicParams) - .map_err(into_js_error) + bincode::deserialize(buffer) + .map(CompactPkePublicParams) + .map_err(into_js_error) }) } } diff --git a/tfhe/src/shortint/backward_compatibility/ciphertext/mod.rs b/tfhe/src/shortint/backward_compatibility/ciphertext/mod.rs index 947e33cfb4..93d416c760 100644 --- a/tfhe/src/shortint/backward_compatibility/ciphertext/mod.rs +++ b/tfhe/src/shortint/backward_compatibility/ciphertext/mod.rs @@ -64,6 +64,12 @@ pub enum CompactCiphertextListVersions { V1(CompactCiphertextList), } +#[cfg(feature = "zk-pok")] +#[derive(VersionsDispatch)] +pub enum ProvenCompactCiphertextListVersions { + V0(ProvenCompactCiphertextList), +} + #[derive(VersionsDispatch)] pub enum CompressedCiphertextVersions { V0(CompressedCiphertext), diff --git a/tfhe/src/shortint/ciphertext/zk.rs b/tfhe/src/shortint/ciphertext/zk.rs index 2b3e6cd234..ed993eafb9 100644 --- a/tfhe/src/shortint/ciphertext/zk.rs +++ b/tfhe/src/shortint/ciphertext/zk.rs @@ -1,4 +1,5 @@ use crate::core_crypto::algorithms::verify_lwe_compact_ciphertext_list; +use crate::shortint::backward_compatibility::ciphertext::ProvenCompactCiphertextListVersions; use crate::shortint::ciphertext::CompactCiphertextList; use crate::shortint::parameters::{ CompactPublicKeyEncryptionParameters, MessageModulus, ShortintCompactCiphertextListCastingMode, @@ -10,6 +11,7 @@ use crate::zk::{ }; use rayon::prelude::*; use serde::{Deserialize, Serialize}; +use tfhe_versionable::Versionize; impl CompactPkeCrs { /// Construct the CRS that corresponds to the given parameters @@ -49,7 +51,8 @@ impl CompactPkeCrs { /// A List of CompactCiphertext with their zero-knowledge proofs /// /// The proofs can only be generated during the encryption with a [CompactPublicKey] -#[derive(Clone, Serialize, Deserialize)] +#[derive(Clone, Serialize, Deserialize, Versionize)] +#[versionize(ProvenCompactCiphertextListVersions)] pub struct ProvenCompactCiphertextList { pub(crate) proved_lists: Vec<(CompactCiphertextList, CompactPkeProof)>, } diff --git a/tfhe/src/zk.rs b/tfhe/src/zk.rs index 865926b078..673dcfe81a 100644 --- a/tfhe/src/zk.rs +++ b/tfhe/src/zk.rs @@ -6,8 +6,8 @@ use std::collections::Bound; use std::fmt::Debug; use tfhe_zk_pok::proofs::pke::crs_gen; +pub use tfhe_zk_pok::curve_api::Compressible; pub use tfhe_zk_pok::proofs::ComputeLoad as ZkComputeLoad; -pub use tfhe_zk_pok::{CanonicalDeserialize, CanonicalSerialize, Compress, Validate}; type Curve = tfhe_zk_pok::curve_api::Bls12_446; pub type CompactPkeProof = tfhe_zk_pok::proofs::pke::Proof; pub type CompactPkePublicParams = tfhe_zk_pok::proofs::pke::PublicParams; diff --git a/tfhe/tests/backward_compatibility/high_level_api.rs b/tfhe/tests/backward_compatibility/high_level_api.rs index 86385d32ee..bf274b04e8 100644 --- a/tfhe/tests/backward_compatibility/high_level_api.rs +++ b/tfhe/tests/backward_compatibility/high_level_api.rs @@ -6,13 +6,17 @@ use tfhe::backward_compatibility::integers::{ CompactFheInt8, CompactFheInt8List, CompactFheUint8, CompactFheUint8List, }; -use tfhe::prelude::{FheDecrypt, FheEncrypt}; +use tfhe::prelude::{CiphertextList, FheDecrypt, FheEncrypt}; use tfhe::shortint::PBSParameters; +#[cfg(feature = "zk-pok")] +use tfhe::zk::CompactPkePublicParams; use tfhe::{ set_server_key, ClientKey, CompactCiphertextList, CompressedCiphertextList, CompressedCompactPublicKey, CompressedFheBool, CompressedFheInt8, CompressedFheUint8, CompressedPublicKey, CompressedServerKey, FheBool, FheInt8, FheUint8, }; +#[cfg(feature = "zk-pok")] +use tfhe::{CompactPublicKey, ProvenCompactCiphertextList}; use tfhe_backward_compat_data::load::{ load_versioned_auxiliary, DataFormat, TestFailure, TestResult, TestSuccess, }; @@ -20,7 +24,7 @@ use tfhe_backward_compat_data::{ DataKind, HlBoolCiphertextListTest, HlBoolCiphertextTest, HlCiphertextListTest, HlCiphertextTest, HlClientKeyTest, HlHeterogeneousCiphertextListTest, HlPublicKeyTest, HlServerKeyTest, HlSignedCiphertextListTest, HlSignedCiphertextTest, TestMetadata, - TestParameterSet, TestType, Testcase, + TestParameterSet, TestType, Testcase, ZkPkePublicParamsTest, }; use tfhe_versionable::Unversionize; @@ -259,6 +263,21 @@ pub fn test_hl_bool_ciphertext_list( } } +/// Test Zk Public params +pub fn test_zk_params( + dir: &Path, + test: &ZkPkePublicParamsTest, + format: DataFormat, +) -> Result { + #[cfg(feature = "zk-pok")] + let _loaded_params: CompactPkePublicParams = load_and_unversionize(dir, test, format)?; + + #[cfg(not(feature = "zk-pok"))] + let _ = dir; + + Ok(test.success(format)) +} + /// Test HL ciphertext list: loads the ciphertext list and compare the decrypted values to the ones /// in the metadata. pub fn test_hl_heterogeneous_ciphertext_list( @@ -276,14 +295,40 @@ pub fn test_hl_heterogeneous_ciphertext_list( set_server_key(server_key); if test.compressed { - test_hl_heterogeneous_ciphertext_list_compressed( - load_and_unversionize(dir, test, format)?, - &key, - test, - ) + let list: CompressedCiphertextList = load_and_unversionize(dir, test, format)?; + test_hl_heterogeneous_ciphertext_list_elements(list, &key, test) + } else if let Some(zk_info) = &test.proof_info { + #[cfg(feature = "zk-pok")] + { + let crs_file = dir.join(&*zk_info.params_filename); + let crs = CompactPkePublicParams::unversionize( + load_versioned_auxiliary(crs_file).map_err(|e| test.failure(e, format))?, + ) + .map_err(|e| test.failure(e, format))?; + + let pubkey_file = dir.join(&*zk_info.public_key_filename); + let pubkey = CompactPublicKey::unversionize( + load_versioned_auxiliary(pubkey_file).map_err(|e| test.failure(e, format))?, + ) + .map_err(|e| test.failure(e, format))?; + + let list: ProvenCompactCiphertextList = load_and_unversionize(dir, test, format)?; + test_hl_heterogeneous_ciphertext_list_elements( + list.verify_and_expand(&crs, &pubkey, zk_info.metadata.as_bytes()) + .map_err(|msg| test.failure(msg, format))?, + &key, + test, + ) + } + #[cfg(not(feature = "zk-pok"))] + { + let _ = zk_info; + Ok(()) + } } else { - test_hl_heterogeneous_ciphertext_list_compact( - load_and_unversionize(dir, test, format)?, + let list: CompactCiphertextList = load_and_unversionize(dir, test, format)?; + test_hl_heterogeneous_ciphertext_list_elements( + list.expand().map_err(|msg| test.failure(msg, format))?, &key, test, ) @@ -292,62 +337,15 @@ pub fn test_hl_heterogeneous_ciphertext_list( .map_err(|msg| test.failure(msg, format)) } -pub fn test_hl_heterogeneous_ciphertext_list_compact( - list: CompactCiphertextList, - key: &ClientKey, - test: &HlHeterogeneousCiphertextListTest, -) -> Result<(), String> { - let ct_list = list.expand().unwrap(); - - for idx in 0..(ct_list.len()) { - match test.data_kinds[idx] { - DataKind::Bool => { - let ct: FheBool = ct_list.get(idx).unwrap().unwrap(); - let clear = ct.decrypt(key); - if clear != (test.clear_values[idx] != 0) { - return Err(format!( - "Invalid decrypted cleartext:\n Expected :\n{:?}\nGot:\n{:?}", - clear, test.clear_values[idx] - )); - } - } - DataKind::Signed => { - let ct: FheInt8 = ct_list.get(idx).unwrap().unwrap(); - let clear: i8 = ct.decrypt(key); - if clear != test.clear_values[idx] as i8 { - return Err(format!( - "Invalid decrypted cleartext:\n Expected :\n{:?}\nGot:\n{:?}", - clear, - (test.clear_values[idx] as i8) - )); - } - } - DataKind::Unsigned => { - let ct: FheUint8 = ct_list.get(idx).unwrap().unwrap(); - let clear: u8 = ct.decrypt(key); - if clear != test.clear_values[idx] as u8 { - return Err(format!( - "Invalid decrypted cleartext:\n Expected :\n{:?}\nGot:\n{:?}", - clear, test.clear_values[idx] - )); - } - } - }; - } - Ok(()) -} - -pub fn test_hl_heterogeneous_ciphertext_list_compressed( - list: CompressedCiphertextList, +pub fn test_hl_heterogeneous_ciphertext_list_elements( + list: CtList, key: &ClientKey, test: &HlHeterogeneousCiphertextListTest, ) -> Result<(), String> { - let ct_list = list; - - for idx in 0..(ct_list.len()) { + for idx in 0..(list.len()) { match test.data_kinds[idx] { DataKind::Bool => { - let ct: FheBool = ct_list.get(idx).unwrap().unwrap(); + let ct: FheBool = list.get(idx).unwrap().unwrap(); let clear = ct.decrypt(key); if clear != (test.clear_values[idx] != 0) { return Err(format!( @@ -357,7 +355,7 @@ pub fn test_hl_heterogeneous_ciphertext_list_compressed( } } DataKind::Signed => { - let ct: FheInt8 = ct_list.get(idx).unwrap().unwrap(); + let ct: FheInt8 = list.get(idx).unwrap().unwrap(); let clear: i8 = ct.decrypt(key); if clear != test.clear_values[idx] as i8 { return Err(format!( @@ -368,7 +366,7 @@ pub fn test_hl_heterogeneous_ciphertext_list_compressed( } } DataKind::Unsigned => { - let ct: FheUint8 = ct_list.get(idx).unwrap().unwrap(); + let ct: FheUint8 = list.get(idx).unwrap().unwrap(); let clear: u8 = ct.decrypt(key); if clear != test.clear_values[idx] as u8 { return Err(format!( @@ -543,6 +541,9 @@ impl TestedModule for Hl { TestMetadata::HlServerKey(test) => { test_hl_serverkey(test_dir.as_ref(), test, format).into() } + TestMetadata::ZkPkePublicParams(test) => { + test_zk_params(test_dir.as_ref(), test, format).into() + } _ => { println!("WARNING: missing test: {:?}", testcase.metadata); TestResult::Skipped(testcase.skip()) diff --git a/tfhe/tests/backward_compatibility/shortint.rs b/tfhe/tests/backward_compatibility/shortint.rs index 08b57e9874..c1ae1d3aab 100644 --- a/tfhe/tests/backward_compatibility/shortint.rs +++ b/tfhe/tests/backward_compatibility/shortint.rs @@ -1,11 +1,12 @@ use std::path::Path; +use tfhe::core_crypto::prelude::TUniform; use tfhe_backward_compat_data::load::{ load_versioned_auxiliary, DataFormat, TestFailure, TestResult, TestSuccess, }; use tfhe_backward_compat_data::{ - ShortintCiphertextTest, ShortintClientKeyTest, TestMetadata, TestParameterSet, TestType, - Testcase, + ShortintCiphertextTest, ShortintClientKeyTest, TestDistribution, TestMetadata, + TestParameterSet, TestType, Testcase, }; use tfhe::shortint::parameters::{ @@ -27,12 +28,8 @@ pub fn load_params(test_params: &TestParameterSet) -> ClassicPBSParameters { lwe_dimension: LweDimension(test_params.lwe_dimension), glwe_dimension: GlweDimension(test_params.glwe_dimension), polynomial_size: PolynomialSize(test_params.polynomial_size), - lwe_noise_distribution: DynamicDistribution::new_gaussian_from_std_dev(StandardDev( - test_params.lwe_noise_gaussian_stddev, - )), - glwe_noise_distribution: DynamicDistribution::new_gaussian_from_std_dev(StandardDev( - test_params.glwe_noise_gaussian_stddev, - )), + lwe_noise_distribution: convert_distribution(&test_params.lwe_noise_distribution), + glwe_noise_distribution: convert_distribution(&test_params.glwe_noise_distribution), pbs_base_log: DecompositionBaseLog(test_params.pbs_base_log), pbs_level: DecompositionLevelCount(test_params.pbs_level), ks_base_log: DecompositionBaseLog(test_params.ks_base_log), @@ -52,6 +49,17 @@ pub fn load_params(test_params: &TestParameterSet) -> ClassicPBSParameters { } } +fn convert_distribution(value: &TestDistribution) -> DynamicDistribution { + match value { + TestDistribution::Gaussian { stddev } => { + DynamicDistribution::new_gaussian_from_std_dev(StandardDev(*stddev)) + } + TestDistribution::TUniform { bound_log2 } => { + DynamicDistribution::TUniform(TUniform::new(*bound_log2)) + } + } +} + fn load_shortint_params(test_params: &TestParameterSet) -> ShortintParameterSet { ShortintParameterSet::new_pbs_param_set(PBSParameters::PBS(load_params(test_params))) }