diff --git a/.github/workflows/aws_tfhe_fast_tests.yml b/.github/workflows/aws_tfhe_fast_tests.yml index 1a05675206..81db5e8ba0 100644 --- a/.github/workflows/aws_tfhe_fast_tests.yml +++ b/.github/workflows/aws_tfhe_fast_tests.yml @@ -248,6 +248,10 @@ jobs: run: | make test_safe_deserialization + - name: Run zk tests + run: | + make test_zk + - name: Slack Notification if: ${{ failure() }} continue-on-error: true diff --git a/Makefile b/Makefile index c2676caeab..80578d1446 100644 --- a/Makefile +++ b/Makefile @@ -754,6 +754,11 @@ test_safe_deserialization: install_rs_build_toolchain install_cargo_nextest RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_BUILD_TOOLCHAIN) test --profile $(CARGO_PROFILE) \ --features=$(TARGET_ARCH_FEATURE),boolean,shortint,integer,internal-keycache -p $(TFHE_SPEC) -- safe_deserialization:: +.PHONY: test_zk # Run the tests for the zk module of the TFHE-rs crate +test_zk: install_rs_build_toolchain install_cargo_nextest + RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_BUILD_TOOLCHAIN) test --profile $(CARGO_PROFILE) \ + --features=$(TARGET_ARCH_FEATURE),shortint,zk-pok -p $(TFHE_SPEC) -- zk:: + .PHONY: test_integer # Run all the tests for integer test_integer: install_rs_build_toolchain RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_BUILD_TOOLCHAIN) test --profile $(CARGO_PROFILE) \ diff --git a/tfhe/src/shortint/ciphertext/zk.rs b/tfhe/src/shortint/ciphertext/zk.rs index 0fda8303c6..8446742d55 100644 --- a/tfhe/src/shortint/ciphertext/zk.rs +++ b/tfhe/src/shortint/ciphertext/zk.rs @@ -228,7 +228,11 @@ impl ParameterSetConformant for ProvenCompactCiphertextList { let mut remaining_len = *total_expected_lwe_count; - for (compact_ct_list, _proof) in proved_lists { + for (compact_ct_list, proof) in proved_lists { + if !proof.is_conformant(&()) { + return false; + } + if remaining_len == 0 { return false; } diff --git a/tfhe/src/zk.rs b/tfhe/src/zk.rs index 34347c8374..e44b52ebd4 100644 --- a/tfhe/src/zk.rs +++ b/tfhe/src/zk.rs @@ -1,6 +1,8 @@ +use crate::conformance::ParameterSetConformant; use crate::core_crypto::commons::math::random::BoundedDistribution; use crate::core_crypto::prelude::*; use crate::named::Named; +use crate::shortint::parameters::CompactPublicKeyEncryptionParameters; use rand_core::RngCore; use std::cmp::Ordering; use std::collections::Bound; @@ -16,6 +18,14 @@ impl Named for CompactPkeProof { const NAME: &'static str = "zk::CompactPkeProof"; } +impl ParameterSetConformant for CompactPkeProof { + type ParameterSet = (); + + fn is_conformant(&self, _parameter_set: &Self::ParameterSet) -> bool { + self.is_usable() + } +} + pub type CompactPkePublicParams = tfhe_zk_pok::proofs::pke::PublicParams; pub type SerializableCompactPkePublicParams = tfhe_zk_pok::serialization::SerializablePKEv1PublicParams; @@ -24,6 +34,66 @@ impl Named for CompactPkePublicParams { const NAME: &'static str = "zk::CompactPkePublicParams"; } +pub struct CompactPkePublicParamsConformanceParams { + lwe_dim: LweDimension, + max_num_message: usize, + noise_bound: u64, + ciphertext_modulus: u64, + plaintext_modulus: u64, + msbs_zero_padding_bit_count: ZkMSBZeroPaddingBitCount, +} + +impl CompactPkePublicParamsConformanceParams { + pub fn new>( + value: P, + max_num_message: usize, + ) -> Result + where + E: Into, + { + let params: CompactPublicKeyEncryptionParameters = + value.try_into().map_err(|e| e.into())?; + + let mut plaintext_modulus = (params.message_modulus.0 * params.carry_modulus.0) as u64; + // Add 1 bit of modulus for the padding bit + plaintext_modulus *= 2; + + let (lwe_dim, max_num_message, noise_bound, ciphertext_modulus, plaintext_modulus) = + CompactPkeCrs::prepare_crs_parameters( + params.encryption_lwe_dimension, + max_num_message, + params.encryption_noise_distribution, + params.ciphertext_modulus, + plaintext_modulus, + )?; + + Ok(Self { + lwe_dim, + max_num_message, + noise_bound, + ciphertext_modulus, + plaintext_modulus, + // CRS created from shortint params have 1 MSB 0bit + msbs_zero_padding_bit_count: ZkMSBZeroPaddingBitCount(1), + }) + } +} + +impl ParameterSetConformant for CompactPkePublicParams { + type ParameterSet = CompactPkePublicParamsConformanceParams; + + fn is_conformant(&self, parameter_set: &Self::ParameterSet) -> bool { + self.k <= self.d + && self.d == parameter_set.lwe_dim.0 + && self.k == parameter_set.max_num_message + && self.b == parameter_set.noise_bound + && self.q == parameter_set.ciphertext_modulus + && self.t == parameter_set.plaintext_modulus + && self.msbs_zero_padding_bit_count == parameter_set.msbs_zero_padding_bit_count.0 + && self.is_usable() + } +} + // If we call `CompactPkePublicParams::compress` we end up with a // `SerializableCompactPkePublicParams` that should also impl Named to be serializable with // `safe_serialization`. Since the `CompactPkePublicParams` is transformed into a @@ -195,3 +265,44 @@ impl CompactPkeCrs { &self.public_params } } + +#[cfg(all(test, feature = "shortint"))] +mod test { + use super::*; + use crate::shortint::parameters::compact_public_key_only::p_fail_2_minus_64::ks_pbs::PARAM_PKE_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64; + use crate::shortint::{CarryModulus, MessageModulus}; + + #[test] + fn test_public_params_conformance() { + let params = PARAM_PKE_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64; + let mut bad_params = params; + bad_params.carry_modulus = CarryModulus(8); + bad_params.message_modulus = MessageModulus(8); + + let mut rng = rand::thread_rng(); + + let crs = CompactPkeCrs::new( + params.encryption_lwe_dimension, + 4, + params.encryption_noise_distribution, + params.ciphertext_modulus, + (params.message_modulus.0 * params.carry_modulus.0 * 2) as u64, + ZkMSBZeroPaddingBitCount(1), + &mut rng, + ) + .unwrap(); + + let conformance_params = CompactPkePublicParamsConformanceParams::new(params, 4).unwrap(); + + assert!(crs.public_params().is_conformant(&conformance_params)); + + let conformance_params = + CompactPkePublicParamsConformanceParams::new(bad_params, 4).unwrap(); + + assert!(!crs.public_params().is_conformant(&conformance_params)); + + let conformance_params = CompactPkePublicParamsConformanceParams::new(params, 2).unwrap(); + + assert!(!crs.public_params().is_conformant(&conformance_params)); + } +}