diff --git a/tfhe-zk-pok/src/curve_api.rs b/tfhe-zk-pok/src/curve_api.rs index af7ef26cc6..972f3c53c1 100644 --- a/tfhe-zk-pok/src/curve_api.rs +++ b/tfhe-zk-pok/src/curve_api.rs @@ -116,6 +116,10 @@ pub trait CurveGroupOps: fn to_le_bytes(self) -> impl AsRef<[u8]>; fn double(self) -> Self; fn normalize(self) -> Self::Affine; + fn validate_projective(&self) -> bool { + Self::validate_affine(&self.normalize()) + } + fn validate_affine(affine: &Self::Affine) -> bool; } /// Mark that an element can be compressed, by storing only the 'x' coordinates of the affine @@ -231,6 +235,10 @@ impl CurveGroupOps for bls12_381::G1 { inner: self.inner.into_affine(), } } + + fn validate_affine(affine: &Self::Affine) -> bool { + affine.validate() + } } impl CurveGroupOps for bls12_381::G2 { @@ -271,6 +279,10 @@ impl CurveGroupOps for bls12_381::G2 { inner: self.inner.into_affine(), } } + + fn validate_affine(affine: &Self::Affine) -> bool { + affine.validate() + } } impl PairingGroupOps for bls12_381::Gt { @@ -368,6 +380,10 @@ impl CurveGroupOps for bls12_446::G1 { inner: self.inner.into_affine(), } } + + fn validate_affine(affine: &Self::Affine) -> bool { + affine.validate() + } } impl CurveGroupOps for bls12_446::G2 { @@ -408,6 +424,10 @@ impl CurveGroupOps for bls12_446::G2 { inner: self.inner.into_affine(), } } + + fn validate_affine(affine: &Self::Affine) -> bool { + affine.validate() + } } impl PairingGroupOps for bls12_446::Gt { diff --git a/tfhe-zk-pok/src/curve_api/bls12_381.rs b/tfhe-zk-pok/src/curve_api/bls12_381.rs index 9a5c979bb1..be9c820607 100644 --- a/tfhe-zk-pok/src/curve_api/bls12_381.rs +++ b/tfhe-zk-pok/src/curve_api/bls12_381.rs @@ -90,6 +90,10 @@ mod g1 { .unwrap(), } } + + pub fn validate(&self) -> bool { + self.inner.is_on_curve() && self.inner.is_in_correct_subgroup_assuming_on_curve() + } } #[derive(Copy, Clone, PartialEq, Eq, Serialize, Deserialize, Hash, Versionize)] @@ -310,6 +314,10 @@ mod g2 { .unwrap(), } } + + pub fn validate(&self) -> bool { + self.inner.is_on_curve() && self.inner.is_in_correct_subgroup_assuming_on_curve() + } } #[derive(Copy, Clone, PartialEq, Eq, Serialize, Deserialize, Hash, Versionize)] diff --git a/tfhe-zk-pok/src/curve_api/bls12_446.rs b/tfhe-zk-pok/src/curve_api/bls12_446.rs index 53ea960ef0..0b6ffcac71 100644 --- a/tfhe-zk-pok/src/curve_api/bls12_446.rs +++ b/tfhe-zk-pok/src/curve_api/bls12_446.rs @@ -92,6 +92,10 @@ mod g1 { .unwrap(), } } + + pub fn validate(&self) -> bool { + self.inner.is_on_curve() && self.inner.is_in_correct_subgroup_assuming_on_curve() + } } #[derive(Copy, Clone, PartialEq, Eq, Serialize, Deserialize, Hash, Versionize)] @@ -316,6 +320,10 @@ mod g2 { } } + pub fn validate(&self) -> bool { + self.inner.is_on_curve() && self.inner.is_in_correct_subgroup_assuming_on_curve() + } + // m is an intermediate variable that's used in both the curve point addition and pairing // functions. we cache it since it requires a Zp division // https://hackmd.io/@tazAymRSQCGXTUKkbh1BAg/Sk27liTW9#Math-Formula-for-Point-Addition diff --git a/tfhe-zk-pok/src/proofs/mod.rs b/tfhe-zk-pok/src/proofs/mod.rs index 4dc7a62038..a01b112a7c 100644 --- a/tfhe-zk-pok/src/proofs/mod.rs +++ b/tfhe-zk-pok/src/proofs/mod.rs @@ -1,4 +1,5 @@ use crate::backward_compatibility::GroupElementsVersions; + use crate::curve_api::{Compressible, Curve, CurveGroupOps, FieldOps, PairingGroupOps}; use crate::serialization::{ InvalidSerializedGroupElementsError, SerializableG1Affine, SerializableG2Affine, @@ -6,6 +7,7 @@ use crate::serialization::{ }; use core::ops::{Index, IndexMut}; use rand::{Rng, RngCore}; +use rayon::iter::{IntoParallelRefIterator, ParallelIterator}; use tfhe_versionable::Versionize; #[derive(Clone, Copy, Debug, serde::Serialize, serde::Deserialize, Versionize)] @@ -108,6 +110,16 @@ impl GroupElements { message_len, } } + + /// Check if the elements are valid for their respective groups + pub fn is_valid(&self) -> bool { + let (g_list_valid, g_hat_list_valid) = rayon::join( + || self.g_list.0.par_iter().all(G::G1::validate_affine), + || self.g_hat_list.0.par_iter().all(G::G2::validate_affine), + ); + + g_list_valid && g_hat_list_valid + } } impl Compressible for GroupElements @@ -152,6 +164,8 @@ mod test { #![allow(non_snake_case)] use std::fmt::Display; + use ark_ec::{short_weierstrass, CurveConfig}; + use ark_ff::UniformRand; use bincode::ErrorKind; use rand::rngs::StdRng; use rand::Rng; @@ -359,4 +373,47 @@ mod test { PkeTestCiphertext { c1, c2 } } } + + /// Return a point with coordinates (x, y) that is randomly chosen and not on the curve + pub(super) fn point_not_on_curve( + rng: &mut StdRng, + ) -> short_weierstrass::Affine { + loop { + let fake_x = ::BaseField::rand(rng); + let fake_y = ::BaseField::rand(rng); + + let point = short_weierstrass::Affine::new_unchecked(fake_x, fake_y); + + if !point.is_on_curve() { + return point; + } + } + } + + /// Return a random point on the curve + pub(super) fn point_on_curve( + rng: &mut StdRng, + ) -> short_weierstrass::Affine { + loop { + let x = ::BaseField::rand(rng); + let is_positive = bool::rand(rng); + if let Some(point) = + short_weierstrass::Affine::get_point_from_x_unchecked(x, is_positive) + { + return point; + } + } + } + + /// Return a random point that is on the curve but not in the correct subgroup + pub(super) fn point_on_curve_wrong_subgroup( + rng: &mut StdRng, + ) -> short_weierstrass::Affine { + loop { + let point = point_on_curve(rng); + if !Config::is_in_correct_subgroup_assuming_on_curve(&point) { + return point; + } + } + } } diff --git a/tfhe-zk-pok/src/proofs/pke.rs b/tfhe-zk-pok/src/proofs/pke.rs index 4ea5456719..4590dc67ac 100644 --- a/tfhe-zk-pok/src/proofs/pke.rs +++ b/tfhe-zk-pok/src/proofs/pke.rs @@ -182,6 +182,15 @@ impl PublicParams { pub fn exclusive_max_noise(&self) -> u64 { self.b } + + /// Check if the crs can be used to generate or verify a proof + /// + /// This means checking that the points are: + /// - valid points of the curve + /// - in the correct subgroup + pub fn is_usable(&self) -> bool { + self.g_lists.is_valid() + } } #[derive(Clone, Debug, serde::Serialize, serde::Deserialize, Versionize)] @@ -197,6 +206,38 @@ pub struct Proof { pub(crate) compute_load_proof_fields: Option>, } +impl Proof { + /// Check if the proof can be used by the Verifier. + /// + /// This means checking that the points in the proof are: + /// - valid points of the curve + /// - in the correct subgroup + pub fn is_usable(&self) -> bool { + let &Proof { + c_hat, + c_y, + pi, + ref compute_load_proof_fields, + } = self; + + c_hat.validate_projective() + && c_y.validate_projective() + && pi.validate_projective() + && compute_load_proof_fields.as_ref().map_or( + true, + |&ComputeLoadProofFields { + c_hat_t, + c_h, + pi_kzg, + }| { + c_hat_t.validate_projective() + && c_h.validate_projective() + && pi_kzg.validate_projective() + }, + ) + } +} + /// These fields can be pre-computed on the prover side in the faster Verifier scheme. If that's the /// case, they should be included in the proof. #[derive(Clone, Debug, serde::Serialize, serde::Deserialize, Versionize)] @@ -1260,6 +1301,8 @@ pub fn verify( #[cfg(test)] mod tests { + use crate::curve_api::{self, bls12_446}; + use super::super::test::*; use super::*; use rand::rngs::StdRng; @@ -1312,7 +1355,7 @@ mod tests { let mut fake_metadata = [255u8; METADATA_LEN]; fake_metadata.fill_with(|| rng.gen::()); - type Curve = crate::curve_api::Bls12_446; + type Curve = curve_api::Bls12_446; // To check management of bigger k_max from CRS during test let crs_k = k + 1 + (rng.gen::() % (d - k)); @@ -1429,7 +1472,7 @@ mod tests { }; let ct = testcase.encrypt(PKEV1_TEST_PARAMS); - type Curve = crate::curve_api::Bls12_446; + type Curve = curve_api::Bls12_446; // To check management of bigger k_max from CRS during test let crs_k = k + 1 + (rng.gen::() % (d - k)); @@ -1491,7 +1534,7 @@ mod tests { let testcase = PkeTestcase::gen(rng, PKEV1_TEST_PARAMS); let ct = testcase.encrypt(PKEV1_TEST_PARAMS); - type Curve = crate::curve_api::Bls12_446; + type Curve = curve_api::Bls12_446; let crs_k = k + 1 + (rng.gen::() % (d - k)); @@ -1526,4 +1569,148 @@ mod tests { verify(&proof, (&public_param, &public_commit), &testcase.metadata).unwrap() } } + + #[test] + fn test_proof_usable() { + let PkeTestParameters { + d, + k, + B, + q, + t, + msbs_zero_padding_bit_count, + } = PKEV1_TEST_PARAMS; + + let rng = &mut StdRng::seed_from_u64(0); + + let testcase = PkeTestcase::gen(rng, PKEV1_TEST_PARAMS); + let ct = testcase.encrypt(PKEV1_TEST_PARAMS); + + type Curve = curve_api::Bls12_446; + + let crs_k = k + 1 + (rng.gen::() % (d - k)); + + let public_param = crs_gen::(d, crs_k, B, q, t, msbs_zero_padding_bit_count, rng); + + let (public_commit, private_commit) = commit( + testcase.a.clone(), + testcase.b.clone(), + ct.c1.clone(), + ct.c2.clone(), + testcase.r.clone(), + testcase.e1.clone(), + testcase.m.clone(), + testcase.e2.clone(), + &public_param, + rng, + ); + + for load in [ComputeLoad::Proof, ComputeLoad::Verify] { + let valid_proof = prove( + (&public_param, &public_commit), + &private_commit, + &testcase.metadata, + load, + rng, + ); + + let compressed_proof = bincode::serialize(&valid_proof.compress()).unwrap(); + let proof_that_was_compressed: Proof = + Proof::uncompress(bincode::deserialize(&compressed_proof).unwrap()).unwrap(); + + assert!(valid_proof.is_usable()); + assert!(proof_that_was_compressed.is_usable()); + + let not_on_curve_g1 = bls12_446::G1::projective(bls12_446::G1Affine { + inner: point_not_on_curve(rng), + }); + + let not_on_curve_g2 = bls12_446::G2::projective(bls12_446::G2Affine { + inner: point_not_on_curve(rng), + }); + + let not_in_group_g1 = bls12_446::G1::projective(bls12_446::G1Affine { + inner: point_on_curve_wrong_subgroup(rng), + }); + + let not_in_group_g2 = bls12_446::G2::projective(bls12_446::G2Affine { + inner: point_on_curve_wrong_subgroup(rng), + }); + + { + let mut proof = valid_proof.clone(); + proof.c_hat = not_on_curve_g2; + assert!(!proof.is_usable()); + proof.c_hat = not_in_group_g2; + assert!(!proof.is_usable()); + } + + { + let mut proof = valid_proof.clone(); + proof.c_y = not_on_curve_g1; + assert!(!proof.is_usable()); + proof.c_y = not_in_group_g1; + assert!(!proof.is_usable()); + } + + { + let mut proof = valid_proof.clone(); + proof.pi = not_on_curve_g1; + assert!(!proof.is_usable()); + proof.pi = not_in_group_g1; + assert!(!proof.is_usable()); + } + + if let Some(ref valid_compute_proof_fields) = valid_proof.compute_load_proof_fields { + { + let mut proof = valid_proof.clone(); + proof.compute_load_proof_fields = Some(ComputeLoadProofFields { + c_hat_t: not_on_curve_g2, + ..valid_compute_proof_fields.clone() + }); + + assert!(!proof.is_usable()); + proof.compute_load_proof_fields = Some(ComputeLoadProofFields { + c_hat_t: not_in_group_g2, + ..valid_compute_proof_fields.clone() + }); + + assert!(!proof.is_usable()); + } + + { + let mut proof = valid_proof.clone(); + proof.compute_load_proof_fields = Some(ComputeLoadProofFields { + c_h: not_on_curve_g1, + ..valid_compute_proof_fields.clone() + }); + + assert!(!proof.is_usable()); + + proof.compute_load_proof_fields = Some(ComputeLoadProofFields { + c_h: not_in_group_g1, + ..valid_compute_proof_fields.clone() + }); + + assert!(!proof.is_usable()); + } + + { + let mut proof = valid_proof.clone(); + proof.compute_load_proof_fields = Some(ComputeLoadProofFields { + pi_kzg: not_on_curve_g1, + ..valid_compute_proof_fields.clone() + }); + + assert!(!proof.is_usable()); + proof.compute_load_proof_fields = Some(ComputeLoadProofFields { + pi_kzg: not_in_group_g1, + ..valid_compute_proof_fields.clone() + }); + + assert!(!proof.is_usable()); + } + } + } + } } diff --git a/tfhe-zk-pok/src/proofs/pke_v2.rs b/tfhe-zk-pok/src/proofs/pke_v2.rs index d51a804308..f436db9f45 100644 --- a/tfhe-zk-pok/src/proofs/pke_v2.rs +++ b/tfhe-zk-pok/src/proofs/pke_v2.rs @@ -9,6 +9,7 @@ use crate::serialization::{ try_vec_to_array, InvalidSerializedAffineError, InvalidSerializedPublicParamsError, SerializableGroupElements, SerializablePKEv2PublicParams, }; + use core::marker::PhantomData; use rayon::prelude::*; use serde::{Deserialize, Serialize}; @@ -221,6 +222,15 @@ impl PublicParams { pub fn exclusive_max_noise(&self) -> u64 { self.B } + + /// Check if the crs can be used to generate or verify a proof + /// + /// This means checking that the points are: + /// - valid points of the curve + /// - in the correct subgroup + pub fn is_usable(&self) -> bool { + self.g_lists.is_valid() + } } /// This represents a proof that the given ciphertext is a valid encryptions of the input messages @@ -247,6 +257,48 @@ pub struct Proof { pub(crate) compute_load_proof_fields: Option>, } +impl Proof { + /// Check if the proof can be used by the Verifier. + /// + /// This means checking that the points in the proof are: + /// - valid points of the curve + /// - in the correct subgroup + pub fn is_usable(&self) -> bool { + let &Proof { + C_hat_e, + C_e, + C_r_tilde, + C_R, + C_hat_bin, + C_y, + C_h1, + C_h2, + C_hat_t, + pi, + pi_kzg, + ref compute_load_proof_fields, + } = self; + + C_hat_e.validate_projective() + && C_e.validate_projective() + && C_r_tilde.validate_projective() + && C_R.validate_projective() + && C_hat_bin.validate_projective() + && C_y.validate_projective() + && C_h1.validate_projective() + && C_h2.validate_projective() + && C_hat_t.validate_projective() + && pi.validate_projective() + && pi_kzg.validate_projective() + && compute_load_proof_fields.as_ref().map_or( + true, + |&ComputeLoadProofFields { C_hat_h3, C_hat_w }| { + C_hat_h3.validate_projective() && C_hat_w.validate_projective() + }, + ) + } +} + /// These fields can be pre-computed on the prover side in the faster Verifier scheme. If that's the /// case, they should be included in the proof. #[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] @@ -2368,6 +2420,8 @@ pub fn verify( #[cfg(test)] mod tests { + use crate::curve_api::{self, bls12_446}; + use super::super::test::*; use super::*; use rand::rngs::StdRng; @@ -2419,7 +2473,7 @@ mod tests { let mut fake_metadata = [255u8; METADATA_LEN]; fake_metadata.fill_with(|| rng.gen::()); - type Curve = crate::curve_api::Bls12_446; + type Curve = curve_api::Bls12_446; // To check management of bigger k_max from CRS during test let crs_k = k + 1 + (rng.gen::() % (d - k)); @@ -2536,7 +2590,7 @@ mod tests { let ct = testcase.encrypt(PKEV2_TEST_PARAMS); - type Curve = crate::curve_api::Bls12_446; + type Curve = curve_api::Bls12_446; // To check management of bigger k_max from CRS during test let crs_k = k + 1 + (rng.gen::() % (d - k)); @@ -2598,7 +2652,7 @@ mod tests { let testcase = PkeTestcase::gen(rng, PKEV2_TEST_PARAMS); let ct = testcase.encrypt(PKEV2_TEST_PARAMS); - type Curve = crate::curve_api::Bls12_446; + type Curve = curve_api::Bls12_446; let crs_k = k + 1 + (rng.gen::() % (d - k)); @@ -2633,4 +2687,196 @@ mod tests { verify(&proof, (&public_param, &public_commit), &testcase.metadata).unwrap() } } + + #[test] + fn test_proof_usable() { + let PkeTestParameters { + d, + k, + B, + q, + t, + msbs_zero_padding_bit_count, + } = PKEV2_TEST_PARAMS; + + let rng = &mut StdRng::seed_from_u64(0); + + let testcase = PkeTestcase::gen(rng, PKEV2_TEST_PARAMS); + let ct = testcase.encrypt(PKEV2_TEST_PARAMS); + + type Curve = curve_api::Bls12_446; + + let crs_k = k + 1 + (rng.gen::() % (d - k)); + + let public_param = crs_gen::(d, crs_k, B, q, t, msbs_zero_padding_bit_count, rng); + + let (public_commit, private_commit) = commit( + testcase.a.clone(), + testcase.b.clone(), + ct.c1.clone(), + ct.c2.clone(), + testcase.r.clone(), + testcase.e1.clone(), + testcase.m.clone(), + testcase.e2.clone(), + &public_param, + rng, + ); + + for load in [ComputeLoad::Proof, ComputeLoad::Verify] { + let valid_proof = prove( + (&public_param, &public_commit), + &private_commit, + &testcase.metadata, + load, + rng, + ); + + let compressed_proof = bincode::serialize(&valid_proof.compress()).unwrap(); + let proof_that_was_compressed: Proof = + Proof::uncompress(bincode::deserialize(&compressed_proof).unwrap()).unwrap(); + + assert!(valid_proof.is_usable()); + assert!(proof_that_was_compressed.is_usable()); + + let not_on_curve_g1 = bls12_446::G1::projective(bls12_446::G1Affine { + inner: point_not_on_curve(rng), + }); + + let not_on_curve_g2 = bls12_446::G2::projective(bls12_446::G2Affine { + inner: point_not_on_curve(rng), + }); + + let not_in_group_g1 = bls12_446::G1::projective(bls12_446::G1Affine { + inner: point_on_curve_wrong_subgroup(rng), + }); + + let not_in_group_g2 = bls12_446::G2::projective(bls12_446::G2Affine { + inner: point_on_curve_wrong_subgroup(rng), + }); + + { + let mut proof = valid_proof.clone(); + proof.C_hat_e = not_on_curve_g2; + assert!(!proof.is_usable()); + proof.C_hat_e = not_in_group_g2; + assert!(!proof.is_usable()); + } + + { + let mut proof = valid_proof.clone(); + proof.C_e = not_on_curve_g1; + assert!(!proof.is_usable()); + proof.C_e = not_in_group_g1; + assert!(!proof.is_usable()); + } + + { + let mut proof = valid_proof.clone(); + proof.C_r_tilde = not_on_curve_g1; + assert!(!proof.is_usable()); + proof.C_r_tilde = not_in_group_g1; + assert!(!proof.is_usable()); + } + + { + let mut proof = valid_proof.clone(); + proof.C_R = not_on_curve_g1; + assert!(!proof.is_usable()); + proof.C_R = not_in_group_g1; + assert!(!proof.is_usable()); + } + + { + let mut proof = valid_proof.clone(); + proof.C_hat_bin = not_on_curve_g2; + assert!(!proof.is_usable()); + proof.C_hat_bin = not_in_group_g2; + assert!(!proof.is_usable()); + } + + { + let mut proof = valid_proof.clone(); + proof.C_y = not_on_curve_g1; + assert!(!proof.is_usable()); + proof.C_y = not_in_group_g1; + assert!(!proof.is_usable()); + } + + { + let mut proof = valid_proof.clone(); + proof.C_h1 = not_on_curve_g1; + assert!(!proof.is_usable()); + proof.C_h1 = not_in_group_g1; + assert!(!proof.is_usable()); + } + + { + let mut proof = valid_proof.clone(); + proof.C_h2 = not_on_curve_g1; + assert!(!proof.is_usable()); + proof.C_h2 = not_in_group_g1; + assert!(!proof.is_usable()); + } + + { + let mut proof = valid_proof.clone(); + proof.C_hat_t = not_on_curve_g2; + assert!(!proof.is_usable()); + proof.C_hat_t = not_in_group_g2; + assert!(!proof.is_usable()); + } + + { + let mut proof = valid_proof.clone(); + proof.pi = not_on_curve_g1; + assert!(!proof.is_usable()); + proof.pi = not_in_group_g1; + assert!(!proof.is_usable()); + } + + { + let mut proof = valid_proof.clone(); + proof.pi_kzg = not_on_curve_g1; + assert!(!proof.is_usable()); + proof.pi_kzg = not_in_group_g1; + assert!(!proof.is_usable()); + } + + if let Some(ref valid_compute_proof_fields) = valid_proof.compute_load_proof_fields { + { + let mut proof = valid_proof.clone(); + proof.compute_load_proof_fields = Some(ComputeLoadProofFields { + C_hat_h3: not_on_curve_g2, + C_hat_w: valid_compute_proof_fields.C_hat_w, + }); + + assert!(!proof.is_usable()); + proof.compute_load_proof_fields = Some(ComputeLoadProofFields { + C_hat_h3: not_in_group_g2, + C_hat_w: valid_compute_proof_fields.C_hat_w, + }); + + assert!(!proof.is_usable()); + } + + { + let mut proof = valid_proof.clone(); + proof.compute_load_proof_fields = Some(ComputeLoadProofFields { + C_hat_h3: valid_compute_proof_fields.C_hat_h3, + C_hat_w: not_on_curve_g2, + }); + + assert!(!proof.is_usable()); + + proof.compute_load_proof_fields = Some(ComputeLoadProofFields { + C_hat_h3: valid_compute_proof_fields.C_hat_h3, + C_hat_w: not_in_group_g2, + }); + + assert!(!proof.is_usable()); + } + } + } + } }