diff --git a/backends/tfhe-cuda-backend/cuda/src/integer/compression/compression.cuh b/backends/tfhe-cuda-backend/cuda/src/integer/compression/compression.cuh index 16d4a119fa..4ce3af34d5 100644 --- a/backends/tfhe-cuda-backend/cuda/src/integer/compression/compression.cuh +++ b/backends/tfhe-cuda-backend/cuda/src/integer/compression/compression.cuh @@ -62,7 +62,7 @@ __host__ void host_pack(cudaStream_t stream, uint32_t gpu_index, // Last GLWE auto last_body_count = num_lwes % compression_params.polynomial_size; - in_len = + auto last_in_len = compression_params.glwe_dimension * compression_params.polynomial_size + last_body_count; number_bits_to_pack = in_len * log_modulus; @@ -75,10 +75,6 @@ __host__ void host_pack(cudaStream_t stream, uint32_t gpu_index, dim3 grid(num_blocks); dim3 threads(num_threads); - cuda_memset_async(array_out, 0, - num_glwes * (compression_params.glwe_dimension + 1) * - compression_params.polynomial_size * sizeof(Torus), - stream, gpu_index); pack<<>>(array_out, array_in, log_modulus, num_coeffs, in_len, out_len); check_cuda_error(cudaGetLastError()); diff --git a/tfhe/src/core_crypto/entities/compressed_modulus_switched_glwe_ciphertext.rs b/tfhe/src/core_crypto/entities/compressed_modulus_switched_glwe_ciphertext.rs index 963af08045..db2e7269c4 100644 --- a/tfhe/src/core_crypto/entities/compressed_modulus_switched_glwe_ciphertext.rs +++ b/tfhe/src/core_crypto/entities/compressed_modulus_switched_glwe_ciphertext.rs @@ -77,7 +77,7 @@ use crate::core_crypto::prelude::*; /// ); /// } /// ``` -#[derive(Clone, serde::Serialize, serde::Deserialize, Versionize)] +#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize, serde::Deserialize, Versionize)] #[versionize(CompressedModulusSwitchedGlweCiphertextVersions)] pub struct CompressedModulusSwitchedGlweCiphertext { pub(crate) packed_integers: PackedIntegers, diff --git a/tfhe/src/core_crypto/entities/packed_integers.rs b/tfhe/src/core_crypto/entities/packed_integers.rs index 9f5e3ba2f5..0df76d79f3 100644 --- a/tfhe/src/core_crypto/entities/packed_integers.rs +++ b/tfhe/src/core_crypto/entities/packed_integers.rs @@ -4,7 +4,7 @@ use crate::conformance::ParameterSetConformant; use crate::core_crypto::backward_compatibility::entities::packed_integers::PackedIntegersVersions; use crate::core_crypto::prelude::*; -#[derive(Clone, serde::Serialize, serde::Deserialize, Versionize)] +#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize, serde::Deserialize, Versionize)] #[versionize(PackedIntegersVersions)] pub struct PackedIntegers { pub(crate) packed_coeffs: Vec, diff --git a/tfhe/src/integer/ciphertext/compressed_ciphertext_list.rs b/tfhe/src/integer/ciphertext/compressed_ciphertext_list.rs index 68811ef639..76bf0caf6d 100644 --- a/tfhe/src/integer/ciphertext/compressed_ciphertext_list.rs +++ b/tfhe/src/integer/ciphertext/compressed_ciphertext_list.rs @@ -95,7 +95,7 @@ impl CompressedCiphertextListBuilder { } } -#[derive(Clone, Serialize, Deserialize, Versionize)] +#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize, Versionize)] #[versionize(CompressedCiphertextListVersions)] pub struct CompressedCiphertextList { pub(crate) packed_list: ShortintCompressedCiphertextList, @@ -153,46 +153,188 @@ impl CompressedCiphertextList { #[cfg(test)] mod tests { use super::*; - use crate::integer::ClientKey; + use crate::integer::{gen_keys_radix, ClientKey}; use crate::shortint::parameters::list_compression::COMP_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64; use crate::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64; + use itertools::Itertools; + use rand::Rng; + const NB_TESTS: usize = 10; + const NB_OPERATOR_TESTS: usize = 10; #[test] - fn test_heterogeneous_ciphertext_compression_ci_run_filter() { + fn test_ciphertext_compression() { + const NUM_BLOCKS: usize = 32; + let cks = ClientKey::new(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64); + let (_, radix_sks) = + gen_keys_radix(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64, NUM_BLOCKS); + let private_compression_key = cks.new_compression_private_key(COMP_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64); let (compression_key, decompression_key) = cks.new_compression_decompression_keys(&private_compression_key); - let ct1 = cks.encrypt_radix(3_u32, 16); - - let ct2 = cks.encrypt_signed_radix(-2, 16); - - let ct3 = cks.encrypt_bool(true); - - let compressed = CompressedCiphertextListBuilder::new() - .push(ct1) - .push(ct2) - .push(ct3) - .build(&compression_key); - - let decompressed1 = compressed.get(0, &decompression_key).unwrap().unwrap(); - - let decrypted: u32 = cks.decrypt_radix(&decompressed1); - - assert_eq!(decrypted, 3_u32); - - let decompressed2 = compressed.get(1, &decompression_key).unwrap().unwrap(); - - let decrypted2: i32 = cks.decrypt_signed_radix(&decompressed2); - - assert_eq!(decrypted2, -2); - - let decompressed3 = compressed.get(2, &decompression_key).unwrap().unwrap(); - - assert!(cks.decrypt_bool(&decompressed3)); + const MAX_NB_MESSAGES: usize = 2 * COMP_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64 + .packing_ks_polynomial_size + .0 + / NUM_BLOCKS; + + let mut rng = rand::thread_rng(); + + let message_modulus: u128 = cks.parameters().message_modulus().0 as u128; + + for _ in 0..NB_TESTS { + // Unsigned + let modulus = message_modulus.pow(NUM_BLOCKS as u32); + for _ in 0..NB_OPERATOR_TESTS { + let nb_messages = 1 + (rng.gen::() % MAX_NB_MESSAGES as u64); + let messages = (0..nb_messages) + .map(|_| rng.gen::() % modulus) + .collect::>(); + + let cts = messages + .iter() + .map(|message| cks.encrypt_radix(*message, NUM_BLOCKS)) + .collect_vec(); + + let mut builder = CompressedCiphertextListBuilder::new(); + + for ct in cts { + let and_ct = radix_sks.bitand_parallelized(&ct, &ct); + builder.push(and_ct); + } + + let compressed = builder.build(&compression_key); + + for (i, message) in messages.iter().enumerate() { + let decompressed = compressed.get(i, &decompression_key).unwrap().unwrap(); + let decrypted: u128 = cks.decrypt_radix(&decompressed); + assert_eq!(decrypted, *message); + } + } + + // Signed + let modulus = message_modulus.pow((NUM_BLOCKS - 1) as u32) as i128; + for _ in 0..NB_OPERATOR_TESTS { + let nb_messages = 1 + (rng.gen::() % MAX_NB_MESSAGES as u64); + let messages = (0..nb_messages) + .map(|_| rng.gen::() % modulus) + .collect::>(); + + let cts = messages + .iter() + .map(|message| cks.encrypt_signed_radix(*message, NUM_BLOCKS)) + .collect_vec(); + + let mut builder = CompressedCiphertextListBuilder::new(); + + for ct in cts { + builder.push(ct); + } + + let compressed = builder.build(&compression_key); + + for (i, message) in messages.iter().enumerate() { + let decompressed = compressed.get(i, &decompression_key).unwrap().unwrap(); + let decrypted: i128 = cks.decrypt_signed_radix(&decompressed); + assert_eq!(decrypted, *message); + } + } + + // Boolean + for _ in 0..NB_OPERATOR_TESTS { + let nb_messages = 1 + (rng.gen::() % MAX_NB_MESSAGES as u64); + let messages = (0..nb_messages) + .map(|_| rng.gen::() % 2 != 0) + .collect::>(); + + let cts = messages + .iter() + .map(|message| cks.encrypt_bool(*message)) + .collect_vec(); + + let mut builder = CompressedCiphertextListBuilder::new(); + + for ct in cts { + builder.push(ct); + } + + let cuda_compressed = builder.build(&compression_key); + + for (i, message) in messages.iter().enumerate() { + let decompressed = cuda_compressed.get(i, &decompression_key).unwrap().unwrap(); + let decrypted = cks.decrypt_bool(&decompressed); + assert_eq!(decrypted, *message); + } + } + + // Hybrid + enum MessageType { + Unsigned(u128), + Signed(i128), + Boolean(bool), + } + for _ in 0..NB_OPERATOR_TESTS { + let mut builder = CompressedCiphertextListBuilder::new(); + + let nb_messages = 1 + (rng.gen::() % MAX_NB_MESSAGES as u64); + let mut messages = vec![]; + for _ in 0..nb_messages { + let case_selector = rng.gen_range(0..3); + match case_selector { + 0 => { + // Unsigned + let modulus = message_modulus.pow(NUM_BLOCKS as u32); + let message = rng.gen::() % modulus; + let ct = cks.encrypt_radix(message, NUM_BLOCKS); + builder.push(ct); + messages.push(MessageType::Unsigned(message)); + } + 1 => { + // Signed + let modulus = message_modulus.pow((NUM_BLOCKS - 1) as u32) as i128; + let message = rng.gen::() % modulus; + let ct = cks.encrypt_signed_radix(message, NUM_BLOCKS); + builder.push(ct); + messages.push(MessageType::Signed(message)); + } + _ => { + // Boolean + let message = rng.gen::() % 2 != 0; + let ct = cks.encrypt_bool(message); + builder.push(ct); + messages.push(MessageType::Boolean(message)); + } + } + } + + let compressed = builder.build(&compression_key); + + for (i, val) in messages.iter().enumerate() { + match val { + MessageType::Unsigned(message) => { + let decompressed = + compressed.get(i, &decompression_key).unwrap().unwrap(); + let decrypted: u128 = cks.decrypt_radix(&decompressed); + assert_eq!(decrypted, *message); + } + MessageType::Signed(message) => { + let decompressed = + compressed.get(i, &decompression_key).unwrap().unwrap(); + let decrypted: i128 = cks.decrypt_signed_radix(&decompressed); + assert_eq!(decrypted, *message); + } + MessageType::Boolean(message) => { + let decompressed = + compressed.get(i, &decompression_key).unwrap().unwrap(); + let decrypted = cks.decrypt_bool(&decompressed); + assert_eq!(decrypted, *message); + } + } + } + } + } } } diff --git a/tfhe/src/integer/gpu/ciphertext/compressed_ciphertext_list.rs b/tfhe/src/integer/gpu/ciphertext/compressed_ciphertext_list.rs index eb289218f8..130d511742 100644 --- a/tfhe/src/integer/gpu/ciphertext/compressed_ciphertext_list.rs +++ b/tfhe/src/integer/gpu/ciphertext/compressed_ciphertext_list.rs @@ -261,17 +261,17 @@ impl CudaCompressedCiphertextList { } impl CompressedCiphertextList { - /// ```rust - /// use tfhe::core_crypto::gpu::CudaStreams; - /// use tfhe::integer::ciphertext::CompressedCiphertextListBuilder; - /// use tfhe::integer::ClientKey; - /// use tfhe::integer::gpu::ciphertext::{CudaSignedRadixCiphertext, CudaUnsignedRadixCiphertext}; - /// use tfhe::integer::gpu::ciphertext::boolean_value::CudaBooleanBlock; - /// use tfhe::integer::gpu::gen_keys_radix_gpu; - /// use tfhe::shortint::parameters::list_compression::COMP_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64; - /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64; + ///```rust + /// use tfhe::core_crypto::gpu::CudaStreams; + /// use tfhe::integer::ciphertext::CompressedCiphertextListBuilder; + /// use tfhe::integer::ClientKey; + /// use tfhe::integer::gpu::ciphertext::{CudaSignedRadixCiphertext, CudaUnsignedRadixCiphertext}; + /// use tfhe::integer::gpu::ciphertext::boolean_value::CudaBooleanBlock; + /// use tfhe::integer::gpu::gen_keys_radix_gpu; + /// use tfhe::shortint::parameters::list_compression::COMP_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64; /// - /// let cks = ClientKey::new(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64); + /// let cks = ClientKey::new(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64); /// /// let private_compression_key = /// cks.new_compression_private_key(COMP_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64); @@ -287,46 +287,49 @@ impl CompressedCiphertextList { /// let (compressed_compression_key, compressed_decompression_key) = /// radix_cks.new_compressed_compression_decompression_keys(&private_compression_key); /// - /// let cuda_decompression_key = - /// compressed_decompression_key.decompress_to_cuda( - /// radix_cks.parameters().glwe_dimension(), - /// radix_cks.parameters().polynomial_size(), - /// radix_cks.parameters().message_modulus(), - /// radix_cks.parameters().carry_modulus(), - /// radix_cks.parameters().ciphertext_modulus(), - /// &streams); + /// let cuda_decompression_key = compressed_decompression_key.decompress_to_cuda( + /// radix_cks.parameters().glwe_dimension(), + /// radix_cks.parameters().polynomial_size(), + /// radix_cks.parameters().message_modulus(), + /// radix_cks.parameters().carry_modulus(), + /// radix_cks.parameters().ciphertext_modulus(), + /// &streams + /// ); /// /// let compression_key = compressed_compression_key.decompress(); /// - /// let ct1 = radix_cks.encrypt(3_u32); - /// let ct2 = radix_cks.encrypt_signed(-2); - /// let ct3 = radix_cks.encrypt_bool(true); + /// let ct1 = radix_cks.encrypt(3_u32); + /// let ct2 = radix_cks.encrypt_signed(-2); + /// let ct3 = radix_cks.encrypt_bool(true); /// - /// let compressed = CompressedCiphertextListBuilder::new() - /// .push(ct1) - /// .push(ct2) - /// .push(ct3) - /// .build(&compression_key); + /// let compressed = CompressedCiphertextListBuilder::new() + /// .push(ct1) + /// .push(ct2) + /// .push(ct3) + /// .build(&compression_key); + /// + /// let cuda_compressed = compressed.to_cuda_compressed_ciphertext_list(&streams); + /// let recovered_cuda_compressed = cuda_compressed.to_compressed_ciphertext_list(&streams); /// - /// let cuda_compressed = compressed.to_cuda_compressed_ciphertext_list(&streams); + /// assert_eq!(recovered_cuda_compressed, compressed); /// - /// let d_decompressed1: CudaUnsignedRadixCiphertext = - /// cuda_compressed.get(0, &cuda_decompression_key, &streams).unwrap().unwrap(); - /// let decompressed1 = d_decompressed1.to_radix_ciphertext(&streams); - /// let decrypted: u32 = radix_cks.decrypt(&decompressed1); - /// assert_eq!(decrypted, 3_u32); + /// let d_decompressed1: CudaUnsignedRadixCiphertext = + /// cuda_compressed.get(0, &cuda_decompression_key, &streams).unwrap().unwrap(); + /// let decompressed1 = d_decompressed1.to_radix_ciphertext(&streams); + /// let decrypted: u32 = radix_cks.decrypt(&decompressed1); + /// assert_eq!(decrypted, 3_u32); /// - /// let d_decompressed2: CudaSignedRadixCiphertext = - /// cuda_compressed.get(1, &cuda_decompression_key, &streams).unwrap().unwrap(); - /// let decompressed2 = d_decompressed2.to_signed_radix_ciphertext(&streams); - /// let decrypted: i32 = radix_cks.decrypt_signed(&decompressed2); - /// assert_eq!(decrypted, -2); + /// let d_decompressed2: CudaSignedRadixCiphertext = + /// cuda_compressed.get(1, &cuda_decompression_key, &streams).unwrap().unwrap(); + /// let decompressed2 = d_decompressed2.to_signed_radix_ciphertext(&streams); + /// let decrypted: i32 = radix_cks.decrypt_signed(&decompressed2); + /// assert_eq!(decrypted, -2); /// - /// let d_decompressed3: CudaBooleanBlock = - /// cuda_compressed.get(2, &cuda_decompression_key, &streams).unwrap().unwrap(); - /// let decompressed3 = d_decompressed3.to_boolean_block(&streams); - /// let decrypted = radix_cks.decrypt_bool(&decompressed3); - /// assert!(decrypted); + /// let d_decompressed3: CudaBooleanBlock = + /// cuda_compressed.get(2, &cuda_decompression_key, &streams).unwrap().unwrap(); + /// let decompressed3 = d_decompressed3.to_boolean_block(&streams); + /// let decrypted = radix_cks.decrypt_bool(&decompressed3); + /// assert!(decrypted); /// ``` pub fn to_cuda_compressed_ciphertext_list( &self, @@ -523,6 +526,7 @@ mod tests { #[test] fn test_gpu_ciphertext_compression() { + const NUM_BLOCKS: usize = 32; let cks = ClientKey::new(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64); let private_compression_key = @@ -530,24 +534,28 @@ mod tests { let streams = CudaStreams::new_multi_gpu(); - let num_blocks = 32; - let (radix_cks, _) = gen_keys_radix_gpu( + let (radix_cks, radix_sks) = gen_keys_radix_gpu( PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64, - num_blocks, + NUM_BLOCKS, &streams, ); let (cuda_compression_key, cuda_decompression_key) = radix_cks.new_cuda_compression_decompression_keys(&private_compression_key, &streams); + const MAX_NB_MESSAGES: usize = 2 * COMP_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64 + .packing_ks_polynomial_size + .0 + / NUM_BLOCKS; + let mut rng = rand::thread_rng(); let message_modulus: u128 = cks.parameters().message_modulus().0 as u128; for _ in 0..NB_TESTS { // Unsigned - let modulus = message_modulus.pow(num_blocks as u32); + let modulus = message_modulus.pow(NUM_BLOCKS as u32); for _ in 0..NB_OPERATOR_TESTS { - let nb_messages = 1 + (rng.gen::() % 6); + let nb_messages = 1 + (rng.gen::() % MAX_NB_MESSAGES as u64); let messages = (0..nb_messages) .map(|_| rng.gen::() % modulus) .collect::>(); @@ -563,7 +571,8 @@ mod tests { let mut builder = CudaCompressedCiphertextListBuilder::new(); for d_ct in d_cts { - builder.push(d_ct, &streams); + let d_and_ct = radix_sks.bitand(&d_ct, &d_ct, &streams); + builder.push(d_and_ct, &streams); } let cuda_compressed = builder.build(&cuda_compression_key, &streams); @@ -580,9 +589,9 @@ mod tests { } // Signed - let modulus = message_modulus.pow((num_blocks - 1) as u32) as i128; + let modulus = message_modulus.pow((NUM_BLOCKS - 1) as u32) as i128; for _ in 0..NB_OPERATOR_TESTS { - let nb_messages = 1 + (rng.gen::() % 6); + let nb_messages = 1 + (rng.gen::() % MAX_NB_MESSAGES as u64); let messages = (0..nb_messages) .map(|_| rng.gen::() % modulus) .collect::>(); @@ -616,7 +625,7 @@ mod tests { // Boolean for _ in 0..NB_OPERATOR_TESTS { - let nb_messages = 1 + (rng.gen::() % 6); + let nb_messages = 1 + (rng.gen::() % MAX_NB_MESSAGES as u64); let messages = (0..nb_messages) .map(|_| rng.gen::() % 2 != 0) .collect::>(); @@ -657,14 +666,14 @@ mod tests { for _ in 0..NB_OPERATOR_TESTS { let mut builder = CudaCompressedCiphertextListBuilder::new(); - let nb_messages = 1 + (rng.gen::() % 6); + let nb_messages = 1 + (rng.gen::() % MAX_NB_MESSAGES as u64); let mut messages = vec![]; for _ in 0..nb_messages { let case_selector = rng.gen_range(0..3); match case_selector { 0 => { // Unsigned - let modulus = message_modulus.pow(num_blocks as u32); + let modulus = message_modulus.pow(NUM_BLOCKS as u32); let message = rng.gen::() % modulus; let ct = radix_cks.encrypt(message); let d_ct = @@ -674,7 +683,7 @@ mod tests { } 1 => { // Signed - let modulus = message_modulus.pow((num_blocks - 1) as u32) as i128; + let modulus = message_modulus.pow((NUM_BLOCKS - 1) as u32) as i128; let message = rng.gen::() % modulus; let ct = radix_cks.encrypt_signed(message); let d_ct = CudaSignedRadixCiphertext::from_signed_radix_ciphertext( diff --git a/tfhe/src/shortint/ciphertext/compressed_ciphertext_list.rs b/tfhe/src/shortint/ciphertext/compressed_ciphertext_list.rs index 304df183e7..7d43593351 100644 --- a/tfhe/src/shortint/ciphertext/compressed_ciphertext_list.rs +++ b/tfhe/src/shortint/ciphertext/compressed_ciphertext_list.rs @@ -7,7 +7,7 @@ use crate::shortint::backward_compatibility::ciphertext::CompressedCiphertextLis use crate::shortint::parameters::CompressedCiphertextConformanceParams; use crate::shortint::{CarryModulus, MessageModulus}; -#[derive(Clone, serde::Serialize, serde::Deserialize, Versionize)] +#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize, serde::Deserialize, Versionize)] #[versionize(CompressedCiphertextListVersions)] pub struct CompressedCiphertextList { pub modulus_switched_glwe_ciphertext_list: Vec>,