Skip to content

Commit

Permalink
chore(integer): brings the CPU and GPU comopression tests into line.
Browse files Browse the repository at this point in the history
- also implements Debug, Eq, PartialEq to CompressedCiphertextList
  • Loading branch information
pdroalves committed Oct 7, 2024
1 parent 256378f commit 478d126
Show file tree
Hide file tree
Showing 6 changed files with 239 additions and 92 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ __host__ void host_pack(cudaStream_t stream, uint32_t gpu_index,

// Last GLWE
auto last_body_count = num_lwes % compression_params.polynomial_size;
in_len =
auto last_in_len =
compression_params.glwe_dimension * compression_params.polynomial_size +
last_body_count;
number_bits_to_pack = in_len * log_modulus;
Expand All @@ -75,10 +75,6 @@ __host__ void host_pack(cudaStream_t stream, uint32_t gpu_index,

dim3 grid(num_blocks);
dim3 threads(num_threads);
cuda_memset_async(array_out, 0,
num_glwes * (compression_params.glwe_dimension + 1) *
compression_params.polynomial_size * sizeof(Torus),
stream, gpu_index);
pack<Torus><<<grid, threads, 0, stream>>>(array_out, array_in, log_modulus,
num_coeffs, in_len, out_len);
check_cuda_error(cudaGetLastError());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ use crate::core_crypto::prelude::*;
/// );
/// }
/// ```
#[derive(Clone, serde::Serialize, serde::Deserialize, Versionize)]
#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize, serde::Deserialize, Versionize)]
#[versionize(CompressedModulusSwitchedGlweCiphertextVersions)]
pub struct CompressedModulusSwitchedGlweCiphertext<Scalar: UnsignedInteger> {
pub(crate) packed_integers: PackedIntegers<Scalar>,
Expand Down
2 changes: 1 addition & 1 deletion tfhe/src/core_crypto/entities/packed_integers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use crate::conformance::ParameterSetConformant;
use crate::core_crypto::backward_compatibility::entities::packed_integers::PackedIntegersVersions;
use crate::core_crypto::prelude::*;

#[derive(Clone, serde::Serialize, serde::Deserialize, Versionize)]
#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize, serde::Deserialize, Versionize)]
#[versionize(PackedIntegersVersions)]
pub struct PackedIntegers<Scalar: UnsignedInteger> {
pub(crate) packed_coeffs: Vec<Scalar>,
Expand Down
202 changes: 172 additions & 30 deletions tfhe/src/integer/ciphertext/compressed_ciphertext_list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ impl CompressedCiphertextListBuilder {
}
}

#[derive(Clone, Serialize, Deserialize, Versionize)]
#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize, Versionize)]
#[versionize(CompressedCiphertextListVersions)]
pub struct CompressedCiphertextList {
pub(crate) packed_list: ShortintCompressedCiphertextList,
Expand Down Expand Up @@ -153,46 +153,188 @@ impl CompressedCiphertextList {
#[cfg(test)]
mod tests {
use super::*;
use crate::integer::ClientKey;
use crate::integer::{gen_keys_radix, ClientKey};
use crate::shortint::parameters::list_compression::COMP_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64;
use crate::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64;
use itertools::Itertools;
use rand::Rng;

const NB_TESTS: usize = 10;
const NB_OPERATOR_TESTS: usize = 10;
#[test]
fn test_heterogeneous_ciphertext_compression_ci_run_filter() {
fn test_ciphertext_compression() {
const NUM_BLOCKS: usize = 32;

let cks = ClientKey::new(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64);

let (_, radix_sks) =
gen_keys_radix(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64, NUM_BLOCKS);

let private_compression_key =
cks.new_compression_private_key(COMP_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64);

let (compression_key, decompression_key) =
cks.new_compression_decompression_keys(&private_compression_key);

let ct1 = cks.encrypt_radix(3_u32, 16);

let ct2 = cks.encrypt_signed_radix(-2, 16);

let ct3 = cks.encrypt_bool(true);

let compressed = CompressedCiphertextListBuilder::new()
.push(ct1)
.push(ct2)
.push(ct3)
.build(&compression_key);

let decompressed1 = compressed.get(0, &decompression_key).unwrap().unwrap();

let decrypted: u32 = cks.decrypt_radix(&decompressed1);

assert_eq!(decrypted, 3_u32);

let decompressed2 = compressed.get(1, &decompression_key).unwrap().unwrap();

let decrypted2: i32 = cks.decrypt_signed_radix(&decompressed2);

assert_eq!(decrypted2, -2);

let decompressed3 = compressed.get(2, &decompression_key).unwrap().unwrap();

assert!(cks.decrypt_bool(&decompressed3));
const MAX_NB_MESSAGES: usize = 2 * COMP_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64
.packing_ks_polynomial_size
.0
/ NUM_BLOCKS;

let mut rng = rand::thread_rng();

let message_modulus: u128 = cks.parameters().message_modulus().0 as u128;

for _ in 0..NB_TESTS {
// Unsigned
let modulus = message_modulus.pow(NUM_BLOCKS as u32);
for _ in 0..NB_OPERATOR_TESTS {
let nb_messages = 1 + (rng.gen::<u64>() % MAX_NB_MESSAGES as u64);
let messages = (0..nb_messages)
.map(|_| rng.gen::<u128>() % modulus)
.collect::<Vec<_>>();

let cts = messages
.iter()
.map(|message| cks.encrypt_radix(*message, NUM_BLOCKS))
.collect_vec();

let mut builder = CompressedCiphertextListBuilder::new();

for ct in cts {
let and_ct = radix_sks.bitand_parallelized(&ct, &ct);
builder.push(and_ct);
}

let compressed = builder.build(&compression_key);

for (i, message) in messages.iter().enumerate() {
let decompressed = compressed.get(i, &decompression_key).unwrap().unwrap();
let decrypted: u128 = cks.decrypt_radix(&decompressed);
assert_eq!(decrypted, *message);
}
}

// Signed
let modulus = message_modulus.pow((NUM_BLOCKS - 1) as u32) as i128;
for _ in 0..NB_OPERATOR_TESTS {
let nb_messages = 1 + (rng.gen::<u64>() % MAX_NB_MESSAGES as u64);
let messages = (0..nb_messages)
.map(|_| rng.gen::<i128>() % modulus)
.collect::<Vec<_>>();

let cts = messages
.iter()
.map(|message| cks.encrypt_signed_radix(*message, NUM_BLOCKS))
.collect_vec();

let mut builder = CompressedCiphertextListBuilder::new();

for ct in cts {
builder.push(ct);
}

let compressed = builder.build(&compression_key);

for (i, message) in messages.iter().enumerate() {
let decompressed = compressed.get(i, &decompression_key).unwrap().unwrap();
let decrypted: i128 = cks.decrypt_signed_radix(&decompressed);
assert_eq!(decrypted, *message);
}
}

// Boolean
for _ in 0..NB_OPERATOR_TESTS {
let nb_messages = 1 + (rng.gen::<u64>() % MAX_NB_MESSAGES as u64);
let messages = (0..nb_messages)
.map(|_| rng.gen::<i64>() % 2 != 0)
.collect::<Vec<_>>();

let cts = messages
.iter()
.map(|message| cks.encrypt_bool(*message))
.collect_vec();

let mut builder = CompressedCiphertextListBuilder::new();

for ct in cts {
builder.push(ct);
}

let cuda_compressed = builder.build(&compression_key);

for (i, message) in messages.iter().enumerate() {
let decompressed = cuda_compressed.get(i, &decompression_key).unwrap().unwrap();
let decrypted = cks.decrypt_bool(&decompressed);
assert_eq!(decrypted, *message);
}
}

// Hybrid
enum MessageType {
Unsigned(u128),
Signed(i128),
Boolean(bool),
}
for _ in 0..NB_OPERATOR_TESTS {
let mut builder = CompressedCiphertextListBuilder::new();

let nb_messages = 1 + (rng.gen::<u64>() % MAX_NB_MESSAGES as u64);
let mut messages = vec![];
for _ in 0..nb_messages {
let case_selector = rng.gen_range(0..3);
match case_selector {
0 => {
// Unsigned
let modulus = message_modulus.pow(NUM_BLOCKS as u32);
let message = rng.gen::<u128>() % modulus;
let ct = cks.encrypt_radix(message, NUM_BLOCKS);
builder.push(ct);
messages.push(MessageType::Unsigned(message));
}
1 => {
// Signed
let modulus = message_modulus.pow((NUM_BLOCKS - 1) as u32) as i128;
let message = rng.gen::<i128>() % modulus;
let ct = cks.encrypt_signed_radix(message, NUM_BLOCKS);
builder.push(ct);
messages.push(MessageType::Signed(message));
}
_ => {
// Boolean
let message = rng.gen::<i64>() % 2 != 0;
let ct = cks.encrypt_bool(message);
builder.push(ct);
messages.push(MessageType::Boolean(message));
}
}
}

let compressed = builder.build(&compression_key);

for (i, val) in messages.iter().enumerate() {
match val {
MessageType::Unsigned(message) => {
let decompressed =
compressed.get(i, &decompression_key).unwrap().unwrap();
let decrypted: u128 = cks.decrypt_radix(&decompressed);
assert_eq!(decrypted, *message);
}
MessageType::Signed(message) => {
let decompressed =
compressed.get(i, &decompression_key).unwrap().unwrap();
let decrypted: i128 = cks.decrypt_signed_radix(&decompressed);
assert_eq!(decrypted, *message);
}
MessageType::Boolean(message) => {
let decompressed =
compressed.get(i, &decompression_key).unwrap().unwrap();
let decrypted = cks.decrypt_bool(&decompressed);
assert_eq!(decrypted, *message);
}
}
}
}
}
}
}
Loading

0 comments on commit 478d126

Please sign in to comment.