diff --git a/tfhe-zk-pok/Cargo.toml b/tfhe-zk-pok/Cargo.toml index a6e7fd1fe3..3d7ce79ff1 100644 --- a/tfhe-zk-pok/Cargo.toml +++ b/tfhe-zk-pok/Cargo.toml @@ -30,3 +30,4 @@ tfhe-versionable = { version = "0.3.0", path = "../utils/tfhe-versionable" } [dev-dependencies] serde_json = "~1.0" itertools = "0.11.0" +bincode = "1.3.3" diff --git a/tfhe-zk-pok/src/backward_compatibility/mod.rs b/tfhe-zk-pok/src/backward_compatibility/mod.rs index d547a40dc6..91753ea161 100644 --- a/tfhe-zk-pok/src/backward_compatibility/mod.rs +++ b/tfhe-zk-pok/src/backward_compatibility/mod.rs @@ -1,8 +1,8 @@ use tfhe_versionable::VersionsDispatch; -use crate::curve_api::Curve; -use crate::proofs::pke::Proof as PKEv1Proof; -use crate::proofs::pke_v2::Proof as PKEv2Proof; +use crate::curve_api::{Compressible, Curve}; +use crate::proofs::pke::{CompressedProof as PKEv1CompressedProof, Proof as PKEv1Proof}; +use crate::proofs::pke_v2::{CompressedProof as PKEv2CompressedProof, Proof as PKEv2Proof}; use crate::proofs::GroupElements; use crate::serialization::{ SerializableAffine, SerializableCubicExtField, SerializableFp, SerializableFp2, @@ -34,14 +34,32 @@ pub type SerializableG1AffineVersions = SerializableAffineVersions; pub type SerializableFp12Versions = SerializableQuadExtFieldVersions; +#[derive(VersionsDispatch)] +pub enum PKEv1ProofVersions { + V0(PKEv1Proof), +} + #[derive(VersionsDispatch)] pub enum PKEv2ProofVersions { V0(PKEv2Proof), } #[derive(VersionsDispatch)] -pub enum PKEv1ProofVersions { - V0(PKEv1Proof), +pub enum PKEv1CompressedProofVersions +where + G::G1: Compressible, + G::G2: Compressible, +{ + V0(PKEv1CompressedProof), +} + +#[derive(VersionsDispatch)] +pub enum PKEv2CompressedProofVersions +where + G::G1: Compressible, + G::G2: Compressible, +{ + V0(PKEv2CompressedProof), } #[derive(VersionsDispatch)] diff --git a/tfhe-zk-pok/src/curve_api.rs b/tfhe-zk-pok/src/curve_api.rs index 7544ddc622..8bf60f37ea 100644 --- a/tfhe-zk-pok/src/curve_api.rs +++ b/tfhe-zk-pok/src/curve_api.rs @@ -3,7 +3,6 @@ use ark_ec::short_weierstrass::Affine; use ark_ec::{AdditiveGroup as Group, CurveGroup, VariableBaseMSM}; use ark_ff::{BigInt, Field, MontFp, Zero}; use ark_poly::univariate::DensePolynomial; -use ark_serialize::{CanonicalDeserialize, CanonicalSerialize}; use core::fmt; use core::ops::{Add, AddAssign, Div, Mul, Neg, Sub, SubAssign}; use serde::{Deserialize, Serialize}; @@ -108,9 +107,7 @@ pub trait CurveGroupOps: + Sync + core::fmt::Debug + serde::Serialize - + for<'de> serde::Deserialize<'de> - + CanonicalSerialize - + CanonicalDeserialize; + + for<'de> serde::Deserialize<'de>; fn projective(affine: Self::Affine) -> Self; @@ -121,6 +118,16 @@ pub trait CurveGroupOps: fn normalize(self) -> Self::Affine; } +/// Mark that an element can be compressed, by storing only the 'x' coordinates of the affine +/// representation and getting the 'y' from the curve. +pub trait Compressible: Sized { + type Compressed; + type UncompressError; + + fn compress(&self) -> Self::Compressed; + fn uncompress(compressed: Self::Compressed) -> Result; +} + pub trait PairingGroupOps: Copy + Send @@ -139,8 +146,8 @@ pub trait PairingGroupOps: pub trait Curve: Clone { type Zp: FieldOps; - type G1: CurveGroupOps + CanonicalSerialize + CanonicalDeserialize; - type G2: CurveGroupOps + CanonicalSerialize + CanonicalDeserialize; + type G1: CurveGroupOps; + type G2: CurveGroupOps; type Gt: PairingGroupOps; } diff --git a/tfhe-zk-pok/src/curve_api/bls12_381.rs b/tfhe-zk-pok/src/curve_api/bls12_381.rs index 4e18751dde..ccd8f1da18 100644 --- a/tfhe-zk-pok/src/curve_api/bls12_381.rs +++ b/tfhe-zk-pok/src/curve_api/bls12_381.rs @@ -41,19 +41,7 @@ mod g1 { use super::*; - #[derive( - Copy, - Clone, - Debug, - PartialEq, - Eq, - Serialize, - Deserialize, - CanonicalSerialize, - CanonicalDeserialize, - Hash, - Versionize, - )] + #[derive(Copy, Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Hash, Versionize)] #[serde(try_from = "SerializableG1Affine", into = "SerializableG1Affine")] #[versionize( SerializableG1AffineVersions, @@ -67,7 +55,7 @@ mod g1 { impl From for SerializableAffine { fn from(value: G1Affine) -> Self { - SerializableAffine::compressed(value.inner) + SerializableAffine::uncompressed(value.inner) } } @@ -81,6 +69,19 @@ mod g1 { } } + impl Compressible for G1Affine { + type Compressed = SerializableG1Affine; + type UncompressError = InvalidSerializedAffineError; + + fn compress(&self) -> SerializableG1Affine { + SerializableAffine::compressed(self.inner) + } + + fn uncompress(compressed: Self::Compressed) -> Result { + compressed.try_into() + } + } + impl G1Affine { pub fn multi_mul_scalar(bases: &[Self], scalars: &[Zp]) -> G1 { // SAFETY: interpreting a `repr(transparent)` pointer as its contents. @@ -96,18 +97,7 @@ mod g1 { } } - #[derive( - Copy, - Clone, - PartialEq, - Eq, - Serialize, - Deserialize, - Hash, - CanonicalSerialize, - CanonicalDeserialize, - Versionize, - )] + #[derive(Copy, Clone, PartialEq, Eq, Serialize, Deserialize, Hash, Versionize)] #[serde(try_from = "SerializableG1Affine", into = "SerializableG1Affine")] #[versionize( SerializableG1AffineVersions, @@ -121,11 +111,11 @@ mod g1 { impl From for SerializableAffine { fn from(value: G1) -> Self { - SerializableAffine::compressed(value.inner.into_affine()) + SerializableAffine::uncompressed(value.inner.into_affine()) } } - impl TryFrom> for G1 { + impl TryFrom for G1 { type Error = InvalidSerializedAffineError; fn try_from(value: SerializableAffine) -> Result { @@ -135,6 +125,19 @@ mod g1 { } } + impl Compressible for G1 { + type Compressed = SerializableG1Affine; + type UncompressError = InvalidSerializedAffineError; + + fn compress(&self) -> SerializableG1Affine { + SerializableAffine::compressed(self.inner.into_affine()) + } + + fn uncompress(compressed: Self::Compressed) -> Result { + compressed.try_into() + } + } + impl fmt::Debug for G1 { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("G1") @@ -266,19 +269,7 @@ mod g2 { use super::*; - #[derive( - Copy, - Clone, - Debug, - PartialEq, - Eq, - Serialize, - Deserialize, - CanonicalSerialize, - CanonicalDeserialize, - Hash, - Versionize, - )] + #[derive(Copy, Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Hash, Versionize)] #[serde(try_from = "SerializableG2Affine", into = "SerializableG2Affine")] #[versionize( SerializableG2AffineVersions, @@ -292,7 +283,7 @@ mod g2 { impl From for SerializableAffine { fn from(value: G2Affine) -> Self { - SerializableAffine::compressed(value.inner) + SerializableAffine::uncompressed(value.inner) } } @@ -306,6 +297,20 @@ mod g2 { } } + impl Compressible for G2Affine { + type Compressed = SerializableG2Affine; + + type UncompressError = InvalidSerializedAffineError; + + fn compress(&self) -> SerializableAffine { + SerializableAffine::compressed(self.inner) + } + + fn uncompress(compressed: Self::Compressed) -> Result { + compressed.try_into() + } + } + impl G2Affine { pub fn multi_mul_scalar(bases: &[Self], scalars: &[Zp]) -> G2 { // SAFETY: interpreting a `repr(transparent)` pointer as its contents. @@ -321,18 +326,7 @@ mod g2 { } } - #[derive( - Copy, - Clone, - PartialEq, - Eq, - Serialize, - Deserialize, - CanonicalSerialize, - CanonicalDeserialize, - Hash, - Versionize, - )] + #[derive(Copy, Clone, PartialEq, Eq, Serialize, Deserialize, Hash, Versionize)] #[serde(try_from = "SerializableG2Affine", into = "SerializableG2Affine")] #[versionize( SerializableG2AffineVersions, @@ -344,13 +338,13 @@ mod g2 { pub(crate) inner: ark_bls12_381::G2Projective, } - impl From for SerializableAffine { + impl From for SerializableG2Affine { fn from(value: G2) -> Self { - SerializableAffine::compressed(value.inner.into_affine()) + SerializableAffine::uncompressed(value.inner.into_affine()) } } - impl TryFrom> for G2 { + impl TryFrom for G2 { type Error = InvalidSerializedAffineError; fn try_from(value: SerializableAffine) -> Result { @@ -360,6 +354,20 @@ mod g2 { } } + impl Compressible for G2 { + type Compressed = SerializableG2Affine; + + type UncompressError = InvalidSerializedAffineError; + + fn compress(&self) -> SerializableAffine { + SerializableAffine::compressed(self.inner.into_affine()) + } + + fn uncompress(compressed: Self::Compressed) -> Result { + compressed.try_into() + } + } + impl fmt::Debug for G2 { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { #[allow(dead_code)] @@ -998,6 +1006,26 @@ mod tests { assert_eq!(g_hat_cur, g_hat_cur2); } + #[test] + fn test_compressed_serialization() { + let rng = &mut StdRng::seed_from_u64(0); + let alpha = Zp::rand(rng); + let g_cur = G1::GENERATOR.mul_scalar(alpha); + let g_hat_cur = G2::GENERATOR.mul_scalar(alpha); + + let g_cur2 = G1::uncompress( + serde_json::from_str(&serde_json::to_string(&g_cur.compress()).unwrap()).unwrap(), + ) + .unwrap(); + assert_eq!(g_cur, g_cur2); + + let g_hat_cur2 = G2::uncompress( + serde_json::from_str(&serde_json::to_string(&g_hat_cur.compress()).unwrap()).unwrap(), + ) + .unwrap(); + assert_eq!(g_hat_cur, g_hat_cur2); + } + #[test] fn test_hasher_and_eq() { // we need to make sure if the points are the same diff --git a/tfhe-zk-pok/src/curve_api/bls12_446.rs b/tfhe-zk-pok/src/curve_api/bls12_446.rs index 2ce7893be1..342ef68f09 100644 --- a/tfhe-zk-pok/src/curve_api/bls12_446.rs +++ b/tfhe-zk-pok/src/curve_api/bls12_446.rs @@ -41,19 +41,7 @@ mod g1 { use super::*; - #[derive( - Copy, - Clone, - Debug, - PartialEq, - Eq, - Serialize, - Deserialize, - CanonicalSerialize, - CanonicalDeserialize, - Hash, - Versionize, - )] + #[derive(Copy, Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Hash, Versionize)] #[serde(try_from = "SerializableG1Affine", into = "SerializableG1Affine")] #[versionize( SerializableG1AffineVersions, @@ -67,7 +55,7 @@ mod g1 { impl From for SerializableAffine { fn from(value: G1Affine) -> Self { - SerializableAffine::compressed(value.inner) + SerializableAffine::uncompressed(value.inner) } } @@ -81,6 +69,20 @@ mod g1 { } } + impl Compressible for G1Affine { + type Compressed = SerializableG1Affine; + + type UncompressError = InvalidSerializedAffineError; + + fn compress(&self) -> Self::Compressed { + SerializableAffine::compressed(self.inner) + } + + fn uncompress(compressed: Self::Compressed) -> Result { + compressed.try_into() + } + } + impl G1Affine { #[track_caller] pub fn multi_mul_scalar(bases: &[Self], scalars: &[Zp]) -> G1 { @@ -97,18 +99,7 @@ mod g1 { } } - #[derive( - Copy, - Clone, - PartialEq, - Eq, - Serialize, - Deserialize, - CanonicalSerialize, - CanonicalDeserialize, - Hash, - Versionize, - )] + #[derive(Copy, Clone, PartialEq, Eq, Serialize, Deserialize, Hash, Versionize)] #[serde(try_from = "SerializableG1Affine", into = "SerializableG1Affine")] #[versionize( SerializableG1AffineVersions, @@ -120,22 +111,36 @@ mod g1 { pub(crate) inner: crate::curve_446::g1::G1Projective, } - impl From for SerializableAffine { + impl From for SerializableG1Affine { fn from(value: G1) -> Self { - SerializableAffine::compressed(value.inner.into_affine()) + SerializableAffine::uncompressed(value.inner.into_affine()) } } - impl TryFrom> for G1 { + impl TryFrom for G1 { type Error = InvalidSerializedAffineError; - fn try_from(value: SerializableAffine) -> Result { + fn try_from(value: SerializableG1Affine) -> Result { Ok(Self { inner: Affine::try_from(value)?.into(), }) } } + impl Compressible for G1 { + type Compressed = SerializableG1Affine; + + type UncompressError = InvalidSerializedAffineError; + + fn compress(&self) -> Self::Compressed { + SerializableAffine::compressed(self.inner.into_affine()) + } + + fn uncompress(compressed: Self::Compressed) -> Result { + compressed.try_into() + } + } + impl fmt::Debug for G1 { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("G1") @@ -268,19 +273,7 @@ mod g2 { use super::*; use crate::serialization::InvalidSerializedAffineError; - #[derive( - Copy, - Clone, - Debug, - PartialEq, - Eq, - Serialize, - Deserialize, - CanonicalSerialize, - CanonicalDeserialize, - Hash, - Versionize, - )] + #[derive(Copy, Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Hash, Versionize)] #[serde(try_from = "SerializableG2Affine", into = "SerializableG2Affine")] #[versionize( SerializableG2AffineVersions, @@ -292,22 +285,36 @@ mod g2 { pub(crate) inner: crate::curve_446::g2::G2Affine, } - impl From for SerializableAffine { + impl From for SerializableG2Affine { fn from(value: G2Affine) -> Self { - SerializableAffine::compressed(value.inner) + SerializableAffine::uncompressed(value.inner) } } - impl TryFrom> for G2Affine { + impl TryFrom for G2Affine { type Error = InvalidSerializedAffineError; - fn try_from(value: SerializableAffine) -> Result { + fn try_from(value: SerializableG2Affine) -> Result { Ok(Self { inner: value.try_into()?, }) } } + impl Compressible for G2Affine { + type Compressed = SerializableG2Affine; + + type UncompressError = InvalidSerializedAffineError; + + fn compress(&self) -> Self::Compressed { + SerializableAffine::compressed(self.inner) + } + + fn uncompress(compressed: Self::Compressed) -> Result { + compressed.try_into() + } + } + impl G2Affine { #[track_caller] pub fn multi_mul_scalar(bases: &[Self], scalars: &[Zp]) -> G2 { @@ -414,18 +421,7 @@ mod g2 { } } - #[derive( - Copy, - Clone, - PartialEq, - Eq, - Serialize, - Deserialize, - CanonicalSerialize, - CanonicalDeserialize, - Hash, - Versionize, - )] + #[derive(Copy, Clone, PartialEq, Eq, Serialize, Deserialize, Hash, Versionize)] #[serde(try_from = "SerializableG2Affine", into = "SerializableG2Affine")] #[versionize( SerializableG2AffineVersions, @@ -437,22 +433,36 @@ mod g2 { pub(crate) inner: crate::curve_446::g2::G2Projective, } - impl From for SerializableAffine { + impl From for SerializableG2Affine { fn from(value: G2) -> Self { - SerializableAffine::compressed(value.inner.into_affine()) + SerializableAffine::uncompressed(value.inner.into_affine()) } } - impl TryFrom> for G2 { + impl TryFrom for G2 { type Error = InvalidSerializedAffineError; - fn try_from(value: SerializableAffine) -> Result { + fn try_from(value: SerializableG2Affine) -> Result { Ok(Self { inner: Affine::try_from(value)?.into(), }) } } + impl Compressible for G2 { + type Compressed = SerializableG2Affine; + + type UncompressError = InvalidSerializedAffineError; + + fn compress(&self) -> Self::Compressed { + SerializableAffine::compressed(self.inner.into_affine()) + } + + fn uncompress(compressed: Self::Compressed) -> Result { + compressed.try_into() + } + } + impl fmt::Debug for G2 { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { #[allow(dead_code)] @@ -1322,6 +1332,26 @@ mod tests { assert_eq!(g_hat_cur, g_hat_cur2); } + #[test] + fn test_compressed_serialization() { + let rng = &mut StdRng::seed_from_u64(0); + let alpha = Zp::rand(rng); + let g_cur = G1::GENERATOR.mul_scalar(alpha); + let g_hat_cur = G2::GENERATOR.mul_scalar(alpha); + + let g_cur2 = G1::uncompress( + serde_json::from_str(&serde_json::to_string(&g_cur.compress()).unwrap()).unwrap(), + ) + .unwrap(); + assert_eq!(g_cur, g_cur2); + + let g_hat_cur2 = G2::uncompress( + serde_json::from_str(&serde_json::to_string(&g_hat_cur.compress()).unwrap()).unwrap(), + ) + .unwrap(); + assert_eq!(g_hat_cur, g_hat_cur2); + } + #[test] fn test_hasher_and_eq() { // we need to make sure if the points are the same diff --git a/tfhe-zk-pok/src/proofs/mod.rs b/tfhe-zk-pok/src/proofs/mod.rs index f489f2a476..2976ab1feb 100644 --- a/tfhe-zk-pok/src/proofs/mod.rs +++ b/tfhe-zk-pok/src/proofs/mod.rs @@ -1,48 +1,17 @@ -use ark_serialize::{ - CanonicalDeserialize, CanonicalSerialize, Compress, SerializationError, Valid, Validate, +use crate::backward_compatibility::GroupElementsVersions; +use crate::curve_api::{Compressible, Curve, CurveGroupOps, FieldOps, PairingGroupOps}; +use crate::serialization::{ + InvalidSerializedGroupElementsError, SerializableG1Affine, SerializableG2Affine, + SerializableGroupElements, }; +use core::ops::{Index, IndexMut}; use rand::{Rng, RngCore}; -use std::ops::{Index, IndexMut}; - use tfhe_versionable::{Unversionize, Versionize, VersionizeOwned}; -use crate::backward_compatibility::GroupElementsVersions; -use crate::curve_api::{Curve, CurveGroupOps, FieldOps, PairingGroupOps}; - -impl Valid for OneBased { - fn check(&self) -> Result<(), SerializationError> { - self.0.check() - } -} - #[derive(Clone, Copy, Debug, serde::Serialize, serde::Deserialize)] #[repr(transparent)] pub(crate) struct OneBased(T); -impl CanonicalDeserialize for OneBased { - fn deserialize_with_mode( - reader: R, - compress: Compress, - validate: Validate, - ) -> Result { - T::deserialize_with_mode(reader, compress, validate).map(Self) - } -} - -impl CanonicalSerialize for OneBased { - fn serialize_with_mode( - &self, - writer: W, - compress: Compress, - ) -> Result<(), SerializationError> { - self.0.serialize_with_mode(writer, compress) - } - - fn serialized_size(&self, compress: Compress) -> usize { - self.0.serialized_size(compress) - } -} - // TODO: these impl could be removed by adding support for `repr(transparent)` in tfhe-versionable impl Versionize for OneBased { type Versioned<'vers> = T::Versioned<'vers> @@ -110,15 +79,7 @@ impl> IndexMut for OneBased { pub type Affine = >::Affine; -#[derive( - Clone, - Debug, - serde::Serialize, - serde::Deserialize, - CanonicalSerialize, - CanonicalDeserialize, - Versionize, -)] +#[derive(Clone, Debug, serde::Serialize, serde::Deserialize, Versionize)] #[serde(bound( deserialize = "G: Curve, G::G1: serde::Deserialize<'de>, G::G2: serde::Deserialize<'de>", serialize = "G: Curve, G::G1: serde::Serialize, G::G2: serde::Serialize" @@ -177,6 +138,34 @@ impl GroupElements { } } +impl Compressible for GroupElements +where + GroupElements: + TryFrom, + >::Affine: Compressible, + >::Affine: Compressible, +{ + type Compressed = SerializableGroupElements; + + type UncompressError = InvalidSerializedGroupElementsError; + + fn compress(&self) -> Self::Compressed { + let mut g_list = Vec::new(); + let mut g_hat_list = Vec::new(); + for idx in 0..self.message_len { + g_list.push(self.g_list[(idx * 2) + 1].compress()); + g_list.push(self.g_list[(idx * 2) + 2].compress()); + g_hat_list.push(self.g_hat_list[idx + 1].compress()) + } + + SerializableGroupElements { g_list, g_hat_list } + } + + fn uncompress(compressed: Self::Compressed) -> Result { + Self::try_from(compressed) + } +} + pub const HASH_METADATA_LEN_BYTES: usize = 256; pub mod binary; diff --git a/tfhe-zk-pok/src/proofs/pke.rs b/tfhe-zk-pok/src/proofs/pke.rs index 7a8e523fad..9463ecbab7 100644 --- a/tfhe-zk-pok/src/proofs/pke.rs +++ b/tfhe-zk-pok/src/proofs/pke.rs @@ -1,14 +1,16 @@ // TODO: refactor copy-pasted code in proof/verify -use crate::backward_compatibility::{PKEv1ProofVersions, SerializablePKEv1PublicParamsVersions}; +use crate::backward_compatibility::{ + PKEv1CompressedProofVersions, PKEv1ProofVersions, SerializablePKEv1PublicParamsVersions, +}; use crate::serialization::{ - InvalidSerializedPublicParamsError, SerializableGroupElements, SerializablePKEv1PublicParams, + try_vec_to_array, InvalidSerializedAffineError, InvalidSerializedPublicParamsError, + SerializableGroupElements, SerializablePKEv1PublicParams, }; use super::*; use core::marker::PhantomData; -use ark_serialize::{CanonicalDeserialize, CanonicalSerialize}; use rayon::prelude::*; use serde::{Deserialize, Serialize}; use std::error::Error; @@ -18,7 +20,7 @@ fn bit_iter(x: u64, nbits: u32) -> impl Iterator { (0..nbits).map(move |idx| ((x >> idx) & 1) != 0) } -#[derive(Clone, Debug, Serialize, Deserialize, CanonicalSerialize, CanonicalDeserialize)] +#[derive(Clone, Debug, Serialize, Deserialize)] #[serde( try_from = "SerializablePKEv1PublicParams", into = "SerializablePKEv1PublicParams", @@ -92,6 +94,96 @@ where } } +impl Compressible for PublicParams +where + GroupElements: Compressible< + Compressed = SerializableGroupElements, + UncompressError = InvalidSerializedGroupElementsError, + >, +{ + type Compressed = SerializablePKEv1PublicParams; + + type UncompressError = InvalidSerializedPublicParamsError; + + fn compress(&self) -> Self::Compressed { + let PublicParams { + g_lists, + big_d, + n, + d, + k, + b, + b_r, + q, + t, + msbs_zero_padding_bit_count, + hash, + hash_t, + hash_agg, + hash_lmap, + hash_z, + hash_w, + } = self; + SerializablePKEv1PublicParams { + g_lists: g_lists.compress(), + big_d: *big_d, + n: *n, + d: *d, + k: *k, + b: *b, + b_r: *b_r, + q: *q, + t: *t, + msbs_zero_padding_bit_count: *msbs_zero_padding_bit_count, + hash: hash.to_vec(), + hash_t: hash_t.to_vec(), + hash_agg: hash_agg.to_vec(), + hash_lmap: hash_lmap.to_vec(), + hash_z: hash_z.to_vec(), + hash_w: hash_w.to_vec(), + } + } + + fn uncompress(compressed: Self::Compressed) -> Result { + let SerializablePKEv1PublicParams { + g_lists, + big_d, + n, + d, + k, + b, + b_r, + q, + t, + msbs_zero_padding_bit_count, + hash, + hash_t, + hash_agg, + hash_lmap, + hash_z, + hash_w, + } = compressed; + Ok(Self { + g_lists: GroupElements::uncompress(g_lists)?, + big_d, + n, + d, + k, + b, + b_r, + q, + t, + msbs_zero_padding_bit_count, + hash: try_vec_to_array(hash)?, + hash_t: try_vec_to_array(hash_t)?, + hash_agg: try_vec_to_array(hash_agg)?, + hash_lmap: try_vec_to_array(hash_lmap)?, + hash_z: try_vec_to_array(hash_z)?, + hash_w: try_vec_to_array(hash_w)?, + }) + } +} + impl PublicParams { #[allow(clippy::too_many_arguments)] pub fn from_vec( @@ -153,6 +245,78 @@ pub struct Proof { pi_kzg: Option, } +type CompressedG2 = <::G2 as Compressible>::Compressed; +type CompressedG1 = <::G1 as Compressible>::Compressed; + +#[derive(Serialize, Deserialize, Versionize)] +#[serde(bound( + deserialize = "G: Curve, CompressedG1: serde::Deserialize<'de>, CompressedG2: serde::Deserialize<'de>", + serialize = "G: Curve, CompressedG1: serde::Serialize, CompressedG2: serde::Serialize" +))] +#[versionize(PKEv1CompressedProofVersions)] +pub struct CompressedProof +where + G::G1: Compressible, + G::G2: Compressible, +{ + c_hat: CompressedG2, + c_y: CompressedG1, + pi: CompressedG1, + c_hat_t: Option>, + c_h: Option>, + pi_kzg: Option>, +} + +impl Compressible for Proof +where + G::G1: Compressible, + G::G2: Compressible, +{ + type Compressed = CompressedProof; + + type UncompressError = InvalidSerializedAffineError; + + fn compress(&self) -> Self::Compressed { + let Proof { + c_hat, + c_y, + pi, + c_hat_t, + c_h, + pi_kzg, + } = self; + + CompressedProof { + c_hat: c_hat.compress(), + c_y: c_y.compress(), + pi: pi.compress(), + c_hat_t: c_hat_t.map(|val| val.compress()), + c_h: c_h.map(|val| val.compress()), + pi_kzg: pi_kzg.map(|val| val.compress()), + } + } + + fn uncompress(compressed: Self::Compressed) -> Result { + let CompressedProof { + c_hat, + c_y, + pi, + c_hat_t, + c_h, + pi_kzg, + } = compressed; + + Ok(Proof { + c_hat: G::G2::uncompress(c_hat)?, + c_y: G::G1::uncompress(c_y)?, + pi: G::G1::uncompress(pi)?, + c_hat_t: c_hat_t.map(G::G2::uncompress).transpose()?, + c_h: c_h.map(G::G1::uncompress).transpose()?, + pi_kzg: pi_kzg.map(G::G1::uncompress).transpose()?, + }) + } +} + #[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] pub struct PublicCommit { a: Vec, @@ -1097,7 +1261,7 @@ pub fn verify( #[cfg(test)] mod tests { use super::*; - use ark_serialize::{Compress, SerializationError, Validate}; + use bincode::ErrorKind; use rand::rngs::StdRng; use rand::{Rng, SeedableRng}; @@ -1234,15 +1398,17 @@ mod tests { type Curve = crate::curve_api::Bls12_446; - let serialize_then_deserialize = - |public_param: &PublicParams, - compress: Compress| - -> Result, SerializationError> { - let mut data = Vec::new(); - public_param.serialize_with_mode(&mut data, compress)?; - - PublicParams::deserialize_with_mode(data.as_slice(), compress, Validate::No) - }; + let serialize_then_deserialize = |public_param: &PublicParams, + compress: bool| + -> bincode::Result> { + match compress { + true => PublicParams::uncompress(bincode::deserialize(&bincode::serialize( + &public_param.clone().compress(), + )?)?) + .map_err(|e| Box::new(ErrorKind::Custom(format!("Failed to uncompress: {}", e)))), + false => bincode::deserialize(&bincode::serialize(&public_param)?), + } + }; // To check management of bigger k_max from CRS during test let crs_k = k + 1 + (rng.gen::() % (d - k)); @@ -1250,9 +1416,9 @@ mod tests { let original_public_param = crs_gen::(d, crs_k, b_i, q, t, msbs_zero_padding_bit_count, rng); let public_param_that_was_compressed = - serialize_then_deserialize(&original_public_param, Compress::No).unwrap(); + serialize_then_deserialize(&original_public_param, true).unwrap(); let public_param_that_was_not_compressed = - serialize_then_deserialize(&original_public_param, Compress::Yes).unwrap(); + serialize_then_deserialize(&original_public_param, false).unwrap(); for ( public_param, @@ -1448,21 +1614,23 @@ mod tests { let div = val.div_euclid(q); let rem = val.rem_euclid(q); let result = div as i64 + (rem > (q / 2)) as i64; - let result = result.rem_euclid(t as i64); + let result = result.rem_euclid(effective_cleartext_t as i64); m_roundtrip[i] = result; } type Curve = crate::curve_api::Bls12_446; - let serialize_then_deserialize = - |public_param: &PublicParams, - compress: Compress| - -> Result, SerializationError> { - let mut data = Vec::new(); - public_param.serialize_with_mode(&mut data, compress)?; - - PublicParams::deserialize_with_mode(data.as_slice(), compress, Validate::No) - }; + let serialize_then_deserialize = |public_param: &PublicParams, + compress: bool| + -> bincode::Result> { + match compress { + true => PublicParams::uncompress(bincode::deserialize(&bincode::serialize( + &public_param.clone().compress(), + )?)?) + .map_err(|e| Box::new(ErrorKind::Custom(format!("Failed to uncompress: {}", e)))), + false => bincode::deserialize(&bincode::serialize(&public_param)?), + } + }; // To check management of bigger k_max from CRS during test let crs_k = k + 1 + (rng.gen::() % (d - k)); @@ -1470,9 +1638,9 @@ mod tests { let original_public_param = crs_gen::(d, crs_k, b_i, q, t, msbs_zero_padding_bit_count, rng); let public_param_that_was_compressed = - serialize_then_deserialize(&original_public_param, Compress::No).unwrap(); + serialize_then_deserialize(&original_public_param, true).unwrap(); let public_param_that_was_not_compressed = - serialize_then_deserialize(&original_public_param, Compress::Yes).unwrap(); + serialize_then_deserialize(&original_public_param, false).unwrap(); for public_param in [ original_public_param, @@ -1505,4 +1673,156 @@ mod tests { } } } + + #[test] + fn test_proof_compression() { + let d = 2048; + let k = 320; + let big_b = 1048576; + let q = 0; + let t = 1024; + let msbs_zero_padding_bit_count = 1; + let effective_cleartext_t = t >> msbs_zero_padding_bit_count; + + let delta = { + let q = if q == 0 { 1i128 << 64 } else { q as i128 }; + (q / t as i128) as u64 + }; + + let rng = &mut StdRng::seed_from_u64(0); + + let polymul_rev = |a: &[i64], b: &[i64]| -> Vec { + assert_eq!(a.len(), b.len()); + let d = a.len(); + let mut c = vec![0i64; d]; + + for i in 0..d { + for j in 0..d { + if i + j < d { + c[i + j] = c[i + j].wrapping_add(a[i].wrapping_mul(b[d - j - 1])); + } else { + c[i + j - d] = c[i + j - d].wrapping_sub(a[i].wrapping_mul(b[d - j - 1])); + } + } + } + + c + }; + + let a = (0..d).map(|_| rng.gen::()).collect::>(); + let s = (0..d) + .map(|_| (rng.gen::() % 2) as i64) + .collect::>(); + let e = (0..d) + .map(|_| (rng.gen::() % (2 * big_b)) as i64 - big_b as i64) + .collect::>(); + let e1 = (0..d) + .map(|_| (rng.gen::() % (2 * big_b)) as i64 - big_b as i64) + .collect::>(); + let e2 = (0..k) + .map(|_| (rng.gen::() % (2 * big_b)) as i64 - big_b as i64) + .collect::>(); + + let r = (0..d) + .map(|_| (rng.gen::() % 2) as i64) + .collect::>(); + + let m = (0..k) + .map(|_| (rng.gen::() % effective_cleartext_t) as i64) + .collect::>(); + + let b = polymul_rev(&a, &s) + .into_iter() + .zip(e.iter()) + .map(|(x, e)| x.wrapping_add(*e)) + .collect::>(); + let c1 = polymul_rev(&a, &r) + .into_iter() + .zip(e1.iter()) + .map(|(x, e1)| x.wrapping_add(*e1)) + .collect::>(); + + let mut c2 = vec![0i64; k]; + + for i in 0..k { + let mut dot = 0i64; + for j in 0..d { + let b = if i + j < d { + b[d - j - i - 1] + } else { + b[2 * d - j - i - 1].wrapping_neg() + }; + + dot = dot.wrapping_add(r[d - j - 1].wrapping_mul(b)); + } + + c2[i] = dot + .wrapping_add(e2[i]) + .wrapping_add((delta * m[i] as u64) as i64); + } + + // One of our usecases uses 320 bits of additional metadata + const METADATA_LEN: usize = (320 / u8::BITS) as usize; + + let mut metadata = [0u8; METADATA_LEN]; + metadata.fill_with(|| rng.gen::()); + + let mut m_roundtrip = vec![0i64; k]; + for i in 0..k { + let mut dot = 0i128; + for j in 0..d { + let c = if i + j < d { + c1[d - j - i - 1] + } else { + c1[2 * d - j - i - 1].wrapping_neg() + }; + + dot += s[d - j - 1] as i128 * c as i128; + } + + let q = if q == 0 { 1i128 << 64 } else { q as i128 }; + let val = ((c2[i] as i128).wrapping_sub(dot)) * t as i128; + let div = val.div_euclid(q); + let rem = val.rem_euclid(q); + let result = div as i64 + (rem > (q / 2)) as i64; + let result = result.rem_euclid(effective_cleartext_t as i64); + m_roundtrip[i] = result; + } + + type Curve = crate::curve_api::Bls12_446; + + let crs_k = k + 1 + (rng.gen::() % (d - k)); + + let public_param = + crs_gen::(d, crs_k, big_b, q, t, msbs_zero_padding_bit_count, rng); + + let (public_commit, private_commit) = commit( + a.clone(), + b.clone(), + c1.clone(), + c2.clone(), + r.clone(), + e1.clone(), + m.clone(), + e2.clone(), + &public_param, + rng, + ); + + for load in [ComputeLoad::Proof, ComputeLoad::Verify] { + let proof = prove( + (&public_param, &public_commit), + &private_commit, + &metadata, + load, + rng, + ); + + let compressed_proof = bincode::serialize(&proof.clone().compress()).unwrap(); + let proof = + Proof::uncompress(bincode::deserialize(&compressed_proof).unwrap()).unwrap(); + + verify(&proof, (&public_param, &public_commit), &metadata).unwrap() + } + } } diff --git a/tfhe-zk-pok/src/proofs/pke_v2.rs b/tfhe-zk-pok/src/proofs/pke_v2.rs index 3ba260279f..b9312a0a9f 100644 --- a/tfhe-zk-pok/src/proofs/pke_v2.rs +++ b/tfhe-zk-pok/src/proofs/pke_v2.rs @@ -2,25 +2,26 @@ #![allow(non_snake_case)] use super::*; -use crate::backward_compatibility::{PKEv2ProofVersions, SerializablePKEv2PublicParamsVersions}; +use crate::backward_compatibility::{ + PKEv2CompressedProofVersions, PKEv2ProofVersions, SerializablePKEv2PublicParamsVersions, +}; use crate::four_squares::*; use crate::serialization::{ - InvalidSerializedPublicParamsError, SerializableGroupElements, SerializablePKEv2PublicParams, + try_vec_to_array, InvalidSerializedAffineError, InvalidSerializedPublicParamsError, + SerializableGroupElements, SerializablePKEv2PublicParams, }; use core::marker::PhantomData; use rayon::prelude::*; use serde::{Deserialize, Serialize}; use std::error::Error; -use tfhe_versionable::{ - Unversionize, UnversionizeError, Versionize, VersionizeOwned, VersionsDispatch, -}; +use tfhe_versionable::{UnversionizeError, VersionsDispatch}; fn bit_iter(x: u64, nbits: u32) -> impl Iterator { (0..nbits).map(move |idx| ((x >> idx) & 1) != 0) } /// The CRS of the zk scheme -#[derive(Clone, Debug, Serialize, Deserialize, CanonicalSerialize, CanonicalDeserialize)] +#[derive(Clone, Debug, Serialize, Deserialize)] #[serde( try_from = "SerializablePKEv2PublicParams", into = "SerializablePKEv2PublicParams", @@ -54,8 +55,7 @@ pub struct PublicParams { pub(crate) hash_chi: [u8; HASH_METADATA_LEN_BYTES], } -// Manual impl of Versionize because the proc macro has trouble handling conversion from/into types -// with generics +// Manual impl of Versionize because TryFrom + generics is currently badly handled by the proc macro impl Versionize for PublicParams where Self: Clone, @@ -101,6 +101,120 @@ where } } +impl Compressible for PublicParams +where + GroupElements: Compressible< + Compressed = SerializableGroupElements, + UncompressError = InvalidSerializedGroupElementsError, + >, +{ + type Compressed = SerializablePKEv2PublicParams; + + type UncompressError = InvalidSerializedPublicParamsError; + + fn compress(&self) -> Self::Compressed { + let PublicParams { + g_lists, + D, + n, + d, + k, + B, + B_r, + B_bound, + m_bound, + q, + t, + msbs_zero_padding_bit_count, + hash, + hash_R, + hash_t, + hash_w, + hash_agg, + hash_lmap, + hash_phi, + hash_xi, + hash_z, + hash_chi, + } = self; + SerializablePKEv2PublicParams { + g_lists: g_lists.compress(), + D: *D, + n: *n, + d: *d, + k: *k, + B: *B, + B_r: *B_r, + B_bound: *B_bound, + m_bound: *m_bound, + q: *q, + t: *t, + msbs_zero_padding_bit_count: *msbs_zero_padding_bit_count, + hash: hash.to_vec(), + hash_R: hash_R.to_vec(), + hash_t: hash_t.to_vec(), + hash_w: hash_w.to_vec(), + hash_agg: hash_agg.to_vec(), + hash_lmap: hash_lmap.to_vec(), + hash_phi: hash_phi.to_vec(), + hash_xi: hash_xi.to_vec(), + hash_z: hash_z.to_vec(), + hash_chi: hash_chi.to_vec(), + } + } + + fn uncompress(compressed: Self::Compressed) -> Result { + let SerializablePKEv2PublicParams { + g_lists, + D, + n, + d, + k, + B, + B_r, + B_bound, + m_bound, + q, + t, + msbs_zero_padding_bit_count, + hash, + hash_R, + hash_t, + hash_w, + hash_agg, + hash_lmap, + hash_phi, + hash_xi, + hash_z, + hash_chi, + } = compressed; + Ok(Self { + g_lists: GroupElements::uncompress(g_lists)?, + D, + n, + d, + k, + B, + B_r, + B_bound, + m_bound, + q, + t, + msbs_zero_padding_bit_count, + hash: try_vec_to_array(hash)?, + hash_R: try_vec_to_array(hash_R)?, + hash_t: try_vec_to_array(hash_t)?, + hash_w: try_vec_to_array(hash_w)?, + hash_agg: try_vec_to_array(hash_agg)?, + hash_lmap: try_vec_to_array(hash_lmap)?, + hash_phi: try_vec_to_array(hash_phi)?, + hash_xi: try_vec_to_array(hash_xi)?, + hash_z: try_vec_to_array(hash_z)?, + hash_chi: try_vec_to_array(hash_chi)?, + }) + } +} + impl PublicParams { #[allow(clippy::too_many_arguments)] pub fn from_vec( @@ -182,6 +296,114 @@ pub struct Proof { C_hat_w: Option, } +type CompressedG2 = <::G2 as Compressible>::Compressed; +type CompressedG1 = <::G1 as Compressible>::Compressed; + +#[derive(Serialize, Deserialize, Versionize)] +#[serde(bound( + deserialize = "G: Curve, CompressedG1: serde::Deserialize<'de>, CompressedG2: serde::Deserialize<'de>", + serialize = "G: Curve, CompressedG1: serde::Serialize, CompressedG2: serde::Serialize" +))] +#[versionize(PKEv2CompressedProofVersions)] +pub struct CompressedProof +where + G::G1: Compressible, + G::G2: Compressible, +{ + C_hat_e: CompressedG2, + C_e: CompressedG1, + C_r_tilde: CompressedG1, + C_R: CompressedG1, + C_hat_bin: CompressedG2, + C_y: CompressedG1, + C_h1: CompressedG1, + C_h2: CompressedG1, + C_hat_t: CompressedG2, + pi: CompressedG1, + pi_kzg: CompressedG1, + + C_hat_h3: Option>, + C_hat_w: Option>, +} + +impl Compressible for Proof +where + G::G1: Compressible, + G::G2: Compressible, +{ + type Compressed = CompressedProof; + + type UncompressError = InvalidSerializedAffineError; + + fn compress(&self) -> Self::Compressed { + let Proof { + C_hat_e, + C_e, + C_r_tilde, + C_R, + C_hat_bin, + C_y, + C_h1, + C_h2, + C_hat_t, + pi, + pi_kzg, + C_hat_h3, + C_hat_w, + } = self; + + CompressedProof { + C_hat_e: C_hat_e.compress(), + C_e: C_e.compress(), + C_r_tilde: C_r_tilde.compress(), + C_R: C_R.compress(), + C_hat_bin: C_hat_bin.compress(), + C_y: C_y.compress(), + C_h1: C_h1.compress(), + C_h2: C_h2.compress(), + C_hat_t: C_hat_t.compress(), + pi: pi.compress(), + pi_kzg: pi_kzg.compress(), + C_hat_h3: C_hat_h3.map(|val| val.compress()), + C_hat_w: C_hat_w.map(|val| val.compress()), + } + } + + fn uncompress(compressed: Self::Compressed) -> Result { + let CompressedProof { + C_hat_e, + C_e, + C_r_tilde, + C_R, + C_hat_bin, + C_y, + C_h1, + C_h2, + C_hat_t, + pi, + pi_kzg, + C_hat_h3, + C_hat_w, + } = compressed; + + Ok(Proof { + C_hat_e: G::G2::uncompress(C_hat_e)?, + C_e: G::G1::uncompress(C_e)?, + C_r_tilde: G::G1::uncompress(C_r_tilde)?, + C_R: G::G1::uncompress(C_R)?, + C_hat_bin: G::G2::uncompress(C_hat_bin)?, + C_y: G::G1::uncompress(C_y)?, + C_h1: G::G1::uncompress(C_h1)?, + C_h2: G::G1::uncompress(C_h2)?, + C_hat_t: G::G2::uncompress(C_hat_t)?, + pi: G::G1::uncompress(pi)?, + pi_kzg: G::G1::uncompress(pi_kzg)?, + C_hat_h3: C_hat_h3.map(G::G2::uncompress).transpose()?, + C_hat_w: C_hat_w.map(G::G2::uncompress).transpose()?, + }) + } +} + /// This is the public part of the commitment. `a` and `b` are the mask and body of the public key, /// `c1` and `c2` are the mask and body of the ciphertext. #[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] @@ -2136,7 +2358,7 @@ pub fn verify( #[cfg(test)] mod tests { use super::*; - use ark_serialize::{Compress, SerializationError, Validate}; + use bincode::ErrorKind; use rand::rngs::StdRng; use rand::{Rng, SeedableRng}; @@ -2273,15 +2495,17 @@ mod tests { type Curve = crate::curve_api::Bls12_446; - let serialize_then_deserialize = - |public_param: &PublicParams, - compress: Compress| - -> Result, SerializationError> { - let mut data = Vec::new(); - public_param.serialize_with_mode(&mut data, compress)?; - - PublicParams::deserialize_with_mode(data.as_slice(), compress, Validate::No) - }; + let serialize_then_deserialize = |public_param: &PublicParams, + compress: bool| + -> bincode::Result> { + match compress { + true => PublicParams::uncompress(bincode::deserialize(&bincode::serialize( + &public_param.clone().compress(), + )?)?) + .map_err(|e| Box::new(ErrorKind::Custom(format!("Failed to uncompress: {}", e)))), + false => bincode::deserialize(&bincode::serialize(&public_param)?), + } + }; // To check management of bigger k_max from CRS during test let crs_k = k + 1 + (rng.gen::() % (d - k)); @@ -2289,9 +2513,9 @@ mod tests { let original_public_param = crs_gen::(d, crs_k, B, q, t, msbs_zero_padding_bit_count, rng); let public_param_that_was_compressed = - serialize_then_deserialize(&original_public_param, Compress::No).unwrap(); + serialize_then_deserialize(&original_public_param, true).unwrap(); let public_param_that_was_not_compressed = - serialize_then_deserialize(&original_public_param, Compress::Yes).unwrap(); + serialize_then_deserialize(&original_public_param, false).unwrap(); for ( public_param, @@ -2375,12 +2599,14 @@ mod tests { let B = 1048576; let q = 0; let t = 1024; + let msbs_zero_padding_bit_count = 1; let effective_cleartext_t = t >> msbs_zero_padding_bit_count; let delta = { let q = if q == 0 { 1i128 << 64 } else { q as i128 }; // delta takes the encoding with the padding bit + (q / t as i128) as u64 }; @@ -2476,15 +2702,17 @@ mod tests { type Curve = crate::curve_api::Bls12_446; - let serialize_then_deserialize = - |public_param: &PublicParams, - compress: Compress| - -> Result, SerializationError> { - let mut data = Vec::new(); - public_param.serialize_with_mode(&mut data, compress)?; - - PublicParams::deserialize_with_mode(data.as_slice(), compress, Validate::No) - }; + let serialize_then_deserialize = |public_param: &PublicParams, + compress: bool| + -> bincode::Result> { + match compress { + true => PublicParams::uncompress(bincode::deserialize(&bincode::serialize( + &public_param.clone().compress(), + )?)?) + .map_err(|e| Box::new(ErrorKind::Custom(format!("Failed to uncompress: {}", e)))), + false => bincode::deserialize(&bincode::serialize(&public_param)?), + } + }; // To check management of bigger k_max from CRS during test let crs_k = k + 1 + (rng.gen::() % (d - k)); @@ -2492,9 +2720,9 @@ mod tests { let original_public_param = crs_gen::(d, crs_k, B, q, t, msbs_zero_padding_bit_count, rng); let public_param_that_was_compressed = - serialize_then_deserialize(&original_public_param, Compress::No).unwrap(); + serialize_then_deserialize(&original_public_param, true).unwrap(); let public_param_that_was_not_compressed = - serialize_then_deserialize(&original_public_param, Compress::Yes).unwrap(); + serialize_then_deserialize(&original_public_param, false).unwrap(); for public_param in [ original_public_param, @@ -2527,4 +2755,156 @@ mod tests { } } } + + #[test] + fn test_proof_compression() { + let d = 2048; + let k = 320; + let B = 1048576; + let q = 0; + let t = 1024; + + let msbs_zero_padding_bit_count = 1; + let effective_cleartext_t = t >> msbs_zero_padding_bit_count; + + let delta = { + let q = if q == 0 { 1i128 << 64 } else { q as i128 }; + (q / t as i128) as u64 + }; + + let rng = &mut StdRng::seed_from_u64(0); + + let polymul_rev = |a: &[i64], b: &[i64]| -> Vec { + assert_eq!(a.len(), b.len()); + let d = a.len(); + let mut c = vec![0i64; d]; + + for i in 0..d { + for j in 0..d { + if i + j < d { + c[i + j] = c[i + j].wrapping_add(a[i].wrapping_mul(b[d - j - 1])); + } else { + c[i + j - d] = c[i + j - d].wrapping_sub(a[i].wrapping_mul(b[d - j - 1])); + } + } + } + + c + }; + + let a = (0..d).map(|_| rng.gen::()).collect::>(); + let s = (0..d) + .map(|_| (rng.gen::() % 2) as i64) + .collect::>(); + let e = (0..d) + .map(|_| (rng.gen::() % (2 * B)) as i64 - B as i64) + .collect::>(); + let e1 = (0..d) + .map(|_| (rng.gen::() % (2 * B)) as i64 - B as i64) + .collect::>(); + let e2 = (0..k) + .map(|_| (rng.gen::() % (2 * B)) as i64 - B as i64) + .collect::>(); + + let r = (0..d) + .map(|_| (rng.gen::() % 2) as i64) + .collect::>(); + + let m = (0..k) + .map(|_| (rng.gen::() % effective_cleartext_t) as i64) + .collect::>(); + + let b = polymul_rev(&a, &s) + .into_iter() + .zip(e.iter()) + .map(|(x, e)| x.wrapping_add(*e)) + .collect::>(); + let c1 = polymul_rev(&a, &r) + .into_iter() + .zip(e1.iter()) + .map(|(x, e1)| x.wrapping_add(*e1)) + .collect::>(); + + let mut c2 = vec![0i64; k]; + + for i in 0..k { + let mut dot = 0i64; + for j in 0..d { + let b = if i + j < d { + b[d - j - i - 1] + } else { + b[2 * d - j - i - 1].wrapping_neg() + }; + + dot = dot.wrapping_add(r[d - j - 1].wrapping_mul(b)); + } + + c2[i] = dot + .wrapping_add(e2[i]) + .wrapping_add((delta * m[i] as u64) as i64); + } + + // One of our usecases uses 320 bits of additional metadata + const METADATA_LEN: usize = (320 / u8::BITS) as usize; + + let mut metadata = [0u8; METADATA_LEN]; + metadata.fill_with(|| rng.gen::()); + + let mut m_roundtrip = vec![0i64; k]; + for i in 0..k { + let mut dot = 0i128; + for j in 0..d { + let c = if i + j < d { + c1[d - j - i - 1] + } else { + c1[2 * d - j - i - 1].wrapping_neg() + }; + + dot += s[d - j - 1] as i128 * c as i128; + } + + let q = if q == 0 { 1i128 << 64 } else { q as i128 }; + let val = ((c2[i] as i128).wrapping_sub(dot)) * t as i128; + let div = val.div_euclid(q); + let rem = val.rem_euclid(q); + let result = div as i64 + (rem > (q / 2)) as i64; + let result = result.rem_euclid(effective_cleartext_t as i64); + m_roundtrip[i] = result; + } + + type Curve = crate::curve_api::Bls12_446; + + let crs_k = k + 1 + (rng.gen::() % (d - k)); + + let public_param = crs_gen::(d, crs_k, B, q, t, msbs_zero_padding_bit_count, rng); + + let (public_commit, private_commit) = commit( + a.clone(), + b.clone(), + c1.clone(), + c2.clone(), + r.clone(), + e1.clone(), + m.clone(), + e2.clone(), + &public_param, + rng, + ); + + for load in [ComputeLoad::Proof, ComputeLoad::Verify] { + let proof = prove( + (&public_param, &public_commit), + &private_commit, + &metadata, + load, + rng, + ); + + let compressed_proof = bincode::serialize(&proof.clone().compress()).unwrap(); + let proof = + Proof::uncompress(bincode::deserialize(&compressed_proof).unwrap()).unwrap(); + + verify(&proof, (&public_param, &public_commit), &metadata).unwrap() + } + } } diff --git a/tfhe-zk-pok/src/serialization.rs b/tfhe-zk-pok/src/serialization.rs index 3a14fc381c..9699bd0b2f 100644 --- a/tfhe-zk-pok/src/serialization.rs +++ b/tfhe-zk-pok/src/serialization.rs @@ -42,7 +42,9 @@ impl Error for InvalidArraySizeError {} /// Tries to convert a Vec into a constant size array, and returns an [`InvalidArraySizeError`] if /// the size does not match -fn try_vec_to_array(vec: Vec) -> Result<[T; N], InvalidArraySizeError> { +pub(crate) fn try_vec_to_array( + vec: Vec, +) -> Result<[T; N], InvalidArraySizeError> { let len = vec.len(); vec.try_into().map_err(|_| InvalidArraySizeError { @@ -75,24 +77,6 @@ impl, const N: usize> TryFrom for Fp { } } -#[derive(Debug)] -pub struct InvalidSerializedFpError { - expected_len: usize, - found_len: usize, -} - -impl Display for InvalidSerializedFpError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!( - f, - "Invalid serialized FP: found array of size {}, expected {}", - self.found_len, self.expected_len - ) - } -} - -impl Error for InvalidSerializedFpError {} - #[derive(Debug)] pub enum InvalidSerializedAffineError { InvalidFp(InvalidArraySizeError), @@ -421,16 +405,16 @@ pub struct SerializablePKEv2PublicParams { pub t: u64, pub msbs_zero_padding_bit_count: u64, // We use Vec since serde does not support fixed size arrays of 256 elements - hash: Vec, - hash_R: Vec, - hash_t: Vec, - hash_w: Vec, - hash_agg: Vec, - hash_lmap: Vec, - hash_phi: Vec, - hash_xi: Vec, - hash_z: Vec, - hash_chi: Vec, + pub(crate) hash: Vec, + pub(crate) hash_R: Vec, + pub(crate) hash_t: Vec, + pub(crate) hash_w: Vec, + pub(crate) hash_agg: Vec, + pub(crate) hash_lmap: Vec, + pub(crate) hash_phi: Vec, + pub(crate) hash_xi: Vec, + pub(crate) hash_z: Vec, + pub(crate) hash_chi: Vec, } impl From> for SerializablePKEv2PublicParams diff --git a/tfhe/src/c_api/high_level_api/zk.rs b/tfhe/src/c_api/high_level_api/zk.rs index 2c81137b94..3bccb128f3 100644 --- a/tfhe/src/c_api/high_level_api/zk.rs +++ b/tfhe/src/c_api/high_level_api/zk.rs @@ -1,7 +1,7 @@ use super::utils::*; use crate::c_api::high_level_api::config::Config; use crate::c_api::utils::get_ref_checked; -use crate::zk::{Compress, Validate}; +use crate::zk::Compressible; use std::ffi::c_int; #[repr(C)] @@ -41,16 +41,11 @@ pub unsafe extern "C" fn compact_pke_public_params_serialize( let wrapper = crate::c_api::utils::get_ref_checked(sself).unwrap(); - let compress = if compress { - Compress::Yes + let buffer = if compress { + bincode::serialize(&wrapper.0.compress()).unwrap() } else { - Compress::No + bincode::serialize(&wrapper.0).unwrap() }; - let mut buffer = vec![]; - wrapper - .0 - .serialize_with_mode(&mut buffer, compress) - .unwrap(); *result = buffer.into(); }) @@ -62,8 +57,6 @@ pub unsafe extern "C" fn compact_pke_public_params_serialize( #[no_mangle] pub unsafe extern "C" fn compact_pke_public_params_deserialize( buffer_view: crate::c_api::buffer::DynamicBufferView, - is_compressed: bool, - validate: bool, result: *mut *mut CompactPkePublicParams, ) -> ::std::os::raw::c_int { crate::c_api::utils::catch_panic(|| { @@ -71,20 +64,7 @@ pub unsafe extern "C" fn compact_pke_public_params_deserialize( *result = std::ptr::null_mut(); - let deserialized = crate::zk::CompactPkePublicParams::deserialize_with_mode( - buffer_view.as_slice(), - if is_compressed { - Compress::Yes - } else { - Compress::No - }, - if validate { - Validate::Yes - } else { - Validate::No - }, - ) - .unwrap(); + let deserialized = bincode::deserialize(buffer_view.as_slice()).unwrap(); let heap_allocated_object = Box::new(CompactPkePublicParams(deserialized)); diff --git a/tfhe/src/js_on_wasm_api/js_high_level_api/zk.rs b/tfhe/src/js_on_wasm_api/js_high_level_api/zk.rs index db896e3d93..eb7b01a42e 100644 --- a/tfhe/src/js_on_wasm_api/js_high_level_api/zk.rs +++ b/tfhe/src/js_on_wasm_api/js_high_level_api/zk.rs @@ -3,7 +3,9 @@ use wasm_bindgen::prelude::*; use crate::js_on_wasm_api::js_high_level_api::config::TfheConfig; use crate::js_on_wasm_api::js_high_level_api::{catch_panic_result, into_js_error}; use crate::js_on_wasm_api::shortint::ShortintParameters; -use tfhe_zk_pok::{Compress, Validate}; + +use crate::zk::Compressible; + #[derive(Copy, Clone, Eq, PartialEq)] #[wasm_bindgen] pub enum ZkComputeLoad { @@ -33,43 +35,23 @@ impl CompactPkePublicParams { #[wasm_bindgen] pub fn serialize(&self, compress: bool) -> Result, JsError> { catch_panic_result(|| { - let mut data = vec![]; - self.0 - .serialize_with_mode( - &mut data, - if compress { - Compress::Yes - } else { - Compress::No - }, - ) - .map_err(into_js_error)?; - Ok(data) + let data = if compress { + bincode::serialize(&self.0.compress()) + } else { + bincode::serialize(&self.0) + }; + data.map_err(into_js_error) }) } #[wasm_bindgen] - pub fn deserialize( - buffer: &[u8], - is_compressed: bool, - validate: bool, - ) -> Result { + pub fn deserialize(buffer: &[u8]) -> Result { + // If buffer is compressed it is automatically detected and uncompressed. + // TODO: handle validation catch_panic_result(|| { - crate::zk::CompactPkePublicParams::deserialize_with_mode( - buffer, - if is_compressed { - Compress::Yes - } else { - Compress::No - }, - if validate { - Validate::Yes - } else { - Validate::No - }, - ) - .map(CompactPkePublicParams) - .map_err(into_js_error) + bincode::deserialize(buffer) + .map(CompactPkePublicParams) + .map_err(into_js_error) }) } } diff --git a/tfhe/src/zk.rs b/tfhe/src/zk.rs index 0d2422ddae..5c68cb1d98 100644 --- a/tfhe/src/zk.rs +++ b/tfhe/src/zk.rs @@ -6,6 +6,7 @@ use std::collections::Bound; use std::fmt::Debug; use tfhe_zk_pok::proofs::pke::crs_gen; +pub use tfhe_zk_pok::curve_api::Compressible; pub use tfhe_zk_pok::proofs::ComputeLoad as ZkComputeLoad; pub use tfhe_zk_pok::{Compress, Validate}; type Curve = tfhe_zk_pok::curve_api::Bls12_446;