Skip to content

Commit

Permalink
chore: make more add/sub test use variable num_blocks
Browse files Browse the repository at this point in the history
  • Loading branch information
tmontaigu committed Oct 29, 2024
1 parent df9fd6c commit 4b2c0ae
Show file tree
Hide file tree
Showing 4 changed files with 164 additions and 147 deletions.
46 changes: 24 additions & 22 deletions tfhe/src/integer/server_key/radix_parallel/tests_signed/test_add.rs
Original file line number Diff line number Diff line change
Expand Up @@ -314,15 +314,16 @@ where
executor.setup(&cks, sks.clone());

for num_blocks in 1..MAX_NB_CTXT {
let modulus = (cks.parameters().message_modulus().0.pow(num_blocks as u32) / 2) as i64;
if modulus == 1 {
// Basically have one bit the sign bit can't really test
let half_modulus = (cks.parameters().message_modulus().0.pow(num_blocks as u32) / 2) as i64;
if half_modulus <= 1 {
// The half_modulus (i.e modulus without sign bit) is such that the set of values
// is empty
continue;
}

for _ in 0..nb_tests_smaller {
let clear_0 = rng.gen::<i64>() % modulus;
let clear_1 = rng.gen::<i64>() % modulus;
let clear_0 = rng.gen::<i64>() % half_modulus;
let clear_1 = rng.gen::<i64>() % half_modulus;

let ctxt_0 = cks.as_ref().encrypt_signed_radix(clear_0, num_blocks);
let ctxt_1 = cks.as_ref().encrypt_signed_radix(clear_1, num_blocks);
Expand All @@ -334,34 +335,34 @@ where
assert_eq!(tmp_o, result_overflowed, "Failed determinism check, \n\n\n msg0: {clear_0}, msg1: {clear_1}, \n\n\nct0: {ctxt_0:?}, \n\n\nct1: {ctxt_1:?}\n\n\n");

let (expected_result, expected_overflowed) =
signed_overflowing_add_under_modulus(clear_0, clear_1, modulus);
signed_overflowing_add_under_modulus(clear_0, clear_1, half_modulus);

let decrypted_result: i64 = cks.decrypt_signed(&ct_res);
let decrypted_overflowed = cks.decrypt_bool(&result_overflowed);
assert_eq!(
decrypted_result, expected_result,
"Invalid result for add, for ({clear_0} + {clear_1}) % {modulus} \
"Invalid result for add, for ({clear_0} + {clear_1}) % {half_modulus} \
expected {expected_result}, got {decrypted_result}"
);
assert_eq!(
decrypted_overflowed,
expected_overflowed,
"Invalid overflow flag result for overflowing_suv for ({clear_0} + {clear_1}) % {modulus} \
"Invalid overflow flag result for overflowing_add for ({clear_0} + {clear_1}) % {half_modulus} \
expected overflow flag {expected_overflowed}, got {decrypted_overflowed}"
);
assert_eq!(result_overflowed.0.degree.get(), 1);
assert_eq!(result_overflowed.0.noise_level(), NoiseLevel::NOMINAL);

for _ in 0..nb_tests_smaller {
// Add non zero scalar to have non clean ciphertexts
let clear_2 = random_non_zero_value(&mut rng, modulus);
let clear_3 = random_non_zero_value(&mut rng, modulus);
let clear_2 = random_non_zero_value(&mut rng, half_modulus);
let clear_3 = random_non_zero_value(&mut rng, half_modulus);

let ctxt_0 = sks.unchecked_scalar_add(&ctxt_0, clear_2);
let ctxt_1 = sks.unchecked_scalar_add(&ctxt_1, clear_3);

let clear_lhs = signed_add_under_modulus(clear_0, clear_2, modulus);
let clear_rhs = signed_add_under_modulus(clear_1, clear_3, modulus);
let clear_lhs = signed_add_under_modulus(clear_0, clear_2, half_modulus);
let clear_rhs = signed_add_under_modulus(clear_1, clear_3, half_modulus);

let d0: i64 = cks.decrypt_signed(&ctxt_0);
assert_eq!(d0, clear_lhs, "Failed sanity decryption check");
Expand All @@ -372,19 +373,19 @@ where
assert!(ct_res.block_carries_are_empty());

let (expected_result, expected_overflowed) =
signed_overflowing_add_under_modulus(clear_lhs, clear_rhs, modulus);
signed_overflowing_add_under_modulus(clear_lhs, clear_rhs, half_modulus);

let decrypted_result: i64 = cks.decrypt_signed(&ct_res);
let decrypted_overflowed = cks.decrypt_bool(&result_overflowed);
assert_eq!(
decrypted_result, expected_result,
"Invalid result for add, for ({clear_lhs} + {clear_rhs}) % {modulus} \
"Invalid result for add, for ({clear_lhs} + {clear_rhs}) % {half_modulus} \
expected {expected_result}, got {decrypted_result}"
);
assert_eq!(
decrypted_overflowed,
expected_overflowed,
"Invalid overflow flag result for overflowing_add, for ({clear_lhs} + {clear_rhs}) % {modulus} \
"Invalid overflow flag result for overflowing_add, for ({clear_lhs} + {clear_rhs}) % {half_modulus} \
expected overflow flag {expected_overflowed}, got {decrypted_overflowed}"
);
assert_eq!(result_overflowed.0.degree.get(), 1);
Expand Down Expand Up @@ -554,15 +555,16 @@ where
let mut clear;

for num_blocks in 1..MAX_NB_CTXT {
let modulus = (cks.parameters().message_modulus().0.pow(num_blocks as u32) / 2) as i64;
if modulus == 1 {
// Basically have one bit the sign bit can't really test
let half_modulus = (cks.parameters().message_modulus().0.pow(num_blocks as u32) / 2) as i64;
if half_modulus <= 1 {
// The half_modulus (i.e modulus without sign bit) is such that the set of values
// is empty
continue;
}

for _ in 0..nb_tests_smaller {
let clear_0 = rng.gen::<i64>() % modulus;
let clear_1 = rng.gen::<i64>() % modulus;
let clear_0 = rng.gen::<i64>() % half_modulus;
let clear_1 = rng.gen::<i64>() % half_modulus;

let ctxt_0 = cks.as_ref().encrypt_signed_radix(clear_0, num_blocks);
let ctxt_1 = cks.as_ref().encrypt_signed_radix(clear_1, num_blocks);
Expand All @@ -572,14 +574,14 @@ where
assert!(ct_res.block_carries_are_empty());
assert_eq!(ct_res, tmp_ct);

clear = signed_add_under_modulus(clear_0, clear_1, modulus);
clear = signed_add_under_modulus(clear_0, clear_1, half_modulus);

// println!("clear_0 = {}, clear_1 = {}", clear_0, clear_1);
// add multiple times to raise the degree
for _ in 0..nb_tests_smaller {
ct_res = executor.execute((&ct_res, &ctxt_0));
assert!(ct_res.block_carries_are_empty());
clear = signed_add_under_modulus(clear, clear_0, modulus);
clear = signed_add_under_modulus(clear, clear_0, half_modulus);

let dec_res: i64 = cks.decrypt_signed(&ct_res);

Expand Down
210 changes: 111 additions & 99 deletions tfhe/src/integer/server_key/radix_parallel/tests_signed/test_sub.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use crate::integer::server_key::radix_parallel::tests_signed::{
};
use crate::integer::server_key::radix_parallel::tests_unsigned::{
nb_tests_for_params, nb_tests_smaller_for_params, nb_unchecked_tests_for_params,
CpuFunctionExecutor,
CpuFunctionExecutor, MAX_NB_CTXT,
};
use crate::integer::tests::create_parametrized_test;
use crate::integer::{
Expand Down Expand Up @@ -93,111 +93,117 @@ where

let mut rng = rand::thread_rng();

// message_modulus^vec_length
let modulus = (cks.parameters().message_modulus().0.pow(NB_CTXT as u32) / 2) as i64;

for _ in 0..nb_tests_smaller {
let clear_0 = rng.gen::<i64>() % modulus;
let clear_1 = rng.gen::<i64>() % modulus;

let ctxt_0 = cks.encrypt_signed(clear_0);
let ctxt_1 = cks.encrypt_signed(clear_1);

let (ct_res, result_overflowed) = executor.execute((&ctxt_0, &ctxt_1));
let (tmp_ct, tmp_o) = executor.execute((&ctxt_0, &ctxt_1));
assert!(ct_res.block_carries_are_empty());
assert_eq!(ct_res, tmp_ct, "Failed determinism check, \n\n\n msg0: {clear_0}, msg1: {clear_1}, \n\n\nct0: {ctxt_0:?}, \n\n\nct1: {ctxt_1:?}\n\n\n");
assert_eq!(tmp_o, result_overflowed, "Failed determinism check, \n\n\n msg0: {clear_0}, msg1: {clear_1}, \n\n\nct0: {ctxt_0:?}, \n\n\nct1: {ctxt_1:?}\n\n\n");

let (expected_result, expected_overflowed) =
signed_overflowing_sub_under_modulus(clear_0, clear_1, modulus);

let decrypted_result: i64 = cks.decrypt_signed(&ct_res);
let decrypted_overflowed = cks.decrypt_bool(&result_overflowed);
assert_eq!(
decrypted_result, expected_result,
"Invalid result for sub, for ({clear_0} - {clear_1}) % {modulus} \
expected {expected_result}, got {decrypted_result}"
);
assert_eq!(
decrypted_overflowed,
expected_overflowed,
"Invalid overflow flag result for overflowing_suv for ({clear_0} - {clear_1}) % {modulus} \
expected overflow flag {expected_overflowed}, got {decrypted_overflowed}"
);
assert_eq!(result_overflowed.0.degree.get(), 1);
assert_eq!(result_overflowed.0.noise_level(), NoiseLevel::NOMINAL);
for num_blocks in 1..MAX_NB_CTXT {
let half_modulus = (cks.parameters().message_modulus().0.pow(num_blocks as u32) / 2) as i64;
if half_modulus <= 1 {
// The half_modulus (i.e modulus without sign bit) is such that the set of values
// is empty
continue;
}

for _ in 0..nb_tests_smaller {
// Add non zero scalar to have non clean ciphertexts
let clear_2 = random_non_zero_value(&mut rng, modulus);
let clear_3 = random_non_zero_value(&mut rng, modulus);

let ctxt_0 = sks.unchecked_scalar_add(&ctxt_0, clear_2);
let ctxt_1 = sks.unchecked_scalar_add(&ctxt_1, clear_3);

let clear_lhs = signed_add_under_modulus(clear_0, clear_2, modulus);
let clear_rhs = signed_add_under_modulus(clear_1, clear_3, modulus);
let clear_0 = rng.gen::<i64>() % half_modulus;
let clear_1 = rng.gen::<i64>() % half_modulus;

let d0: i64 = cks.decrypt_signed(&ctxt_0);
assert_eq!(d0, clear_lhs, "Failed sanity decryption check");
let d1: i64 = cks.decrypt_signed(&ctxt_1);
assert_eq!(d1, clear_rhs, "Failed sanity decryption check");
let ctxt_0 = cks.as_ref().encrypt_signed_radix(clear_0, num_blocks);
let ctxt_1 = cks.as_ref().encrypt_signed_radix(clear_1, num_blocks);

let (ct_res, result_overflowed) = executor.execute((&ctxt_0, &ctxt_1));
let (tmp_ct, tmp_o) = executor.execute((&ctxt_0, &ctxt_1));
assert!(ct_res.block_carries_are_empty());
assert_eq!(ct_res, tmp_ct, "Failed determinism check, \n\n\n msg0: {clear_0}, msg1: {clear_1}, \n\n\nct0: {ctxt_0:?}, \n\n\nct1: {ctxt_1:?}\n\n\n");
assert_eq!(tmp_o, result_overflowed, "Failed determinism check, \n\n\n msg0: {clear_0}, msg1: {clear_1}, \n\n\nct0: {ctxt_0:?}, \n\n\nct1: {ctxt_1:?}\n\n\n");

let (expected_result, expected_overflowed) =
signed_overflowing_sub_under_modulus(clear_lhs, clear_rhs, modulus);
signed_overflowing_sub_under_modulus(clear_0, clear_1, half_modulus);

let decrypted_result: i64 = cks.decrypt_signed(&ct_res);
let decrypted_overflowed = cks.decrypt_bool(&result_overflowed);
assert_eq!(
decrypted_result, expected_result,
"Invalid result for sub, for ({clear_lhs} - {clear_rhs}) % {modulus} \
expected {expected_result}, got {decrypted_result}"
"Invalid result for sub, for ({clear_0} - {clear_1}) % {half_modulus} \
expected {expected_result}, got {decrypted_result}"
);
assert_eq!(
decrypted_overflowed,
expected_overflowed,
"Invalid overflow flag result for overflowing_sub, for ({clear_lhs} - {clear_rhs}) % {modulus} \
expected overflow flag {expected_overflowed}, got {decrypted_overflowed}"
"Invalid overflow flag result for overflowing_sub for ({clear_0} - {clear_1}) % {half_modulus} \
expected overflow flag {expected_overflowed}, got {decrypted_overflowed}"
);
assert_eq!(result_overflowed.0.degree.get(), 1);
assert_eq!(result_overflowed.0.noise_level(), NoiseLevel::NOMINAL);

for _ in 0..nb_tests_smaller {
// Add non zero scalar to have non clean ciphertexts
let clear_2 = random_non_zero_value(&mut rng, half_modulus);
let clear_3 = random_non_zero_value(&mut rng, half_modulus);

let ctxt_0 = sks.unchecked_scalar_add(&ctxt_0, clear_2);
let ctxt_1 = sks.unchecked_scalar_add(&ctxt_1, clear_3);

let clear_lhs = signed_add_under_modulus(clear_0, clear_2, half_modulus);
let clear_rhs = signed_add_under_modulus(clear_1, clear_3, half_modulus);

let d0: i64 = cks.decrypt_signed(&ctxt_0);
assert_eq!(d0, clear_lhs, "Failed sanity decryption check");
let d1: i64 = cks.decrypt_signed(&ctxt_1);
assert_eq!(d1, clear_rhs, "Failed sanity decryption check");

let (ct_res, result_overflowed) = executor.execute((&ctxt_0, &ctxt_1));
assert!(ct_res.block_carries_are_empty());

let (expected_result, expected_overflowed) =
signed_overflowing_sub_under_modulus(clear_lhs, clear_rhs, half_modulus);

let decrypted_result: i64 = cks.decrypt_signed(&ct_res);
let decrypted_overflowed = cks.decrypt_bool(&result_overflowed);
assert_eq!(
decrypted_result, expected_result,
"Invalid result for sub, for ({clear_lhs} - {clear_rhs}) % {half_modulus} \
expected {expected_result}, got {decrypted_result}"
);
assert_eq!(
decrypted_overflowed,
expected_overflowed,
"Invalid overflow flag result for overflowing_sub, for ({clear_lhs} - {clear_rhs}) % {half_modulus} \
expected overflow flag {expected_overflowed}, got {decrypted_overflowed}"
);
assert_eq!(result_overflowed.0.degree.get(), 1);
assert_eq!(result_overflowed.0.noise_level(), NoiseLevel::NOMINAL);
}
}
}

// Test with trivial inputs, as it was bugged at some point
for _ in 0..4 {
// Reduce maximum value of random number such that at least the last block is a trivial 0
// (This is how the reproducing case was found)
let clear_0 = rng.gen::<i64>() % modulus;
let clear_1 = rng.gen::<i64>() % modulus;
// Test with trivial inputs, as it was bugged at some point
for _ in 0..4 {
// Reduce maximum value of random number such that at least the last block is a trivial
// 0 (This is how the reproducing case was found)
let clear_0 = rng.gen::<i64>() % half_modulus;
let clear_1 = rng.gen::<i64>() % half_modulus;

let a: SignedRadixCiphertext = sks.create_trivial_radix(clear_0, NB_CTXT);
let b: SignedRadixCiphertext = sks.create_trivial_radix(clear_1, NB_CTXT);
let a: SignedRadixCiphertext = sks.create_trivial_radix(clear_0, num_blocks);
let b: SignedRadixCiphertext = sks.create_trivial_radix(clear_1, num_blocks);

let (encrypted_result, encrypted_overflow) = executor.execute((&a, &b));
let (encrypted_result, encrypted_overflow) = executor.execute((&a, &b));

let (expected_result, expected_overflowed) =
signed_overflowing_sub_under_modulus(clear_0, clear_1, modulus);
let (expected_result, expected_overflowed) =
signed_overflowing_sub_under_modulus(clear_0, clear_1, half_modulus);

let decrypted_result: i64 = cks.decrypt_signed(&encrypted_result);
let decrypted_overflowed = cks.decrypt_bool(&encrypted_overflow);
assert_eq!(
decrypted_result, expected_result,
"Invalid result for sub, for ({clear_0} - {clear_1}) % {modulus} \
let decrypted_result: i64 = cks.decrypt_signed(&encrypted_result);
let decrypted_overflowed = cks.decrypt_bool(&encrypted_overflow);
assert_eq!(
decrypted_result, expected_result,
"Invalid result for sub, for ({clear_0} - {clear_1}) % {half_modulus} \
expected {expected_result}, got {decrypted_result}"
);
assert_eq!(
decrypted_overflowed,
expected_overflowed,
"Invalid overflow flag result for overflowing_sub, for ({clear_0} - {clear_1}) % {modulus} \
);
assert_eq!(
decrypted_overflowed,
expected_overflowed,
"Invalid overflow flag result for overflowing_sub, for ({clear_0} - {clear_1}) % {half_modulus} \
expected overflow flag {expected_overflowed}, got {decrypted_overflowed}"
);
assert_eq!(encrypted_overflow.0.degree.get(), 1);
assert_eq!(encrypted_overflow.0.noise_level(), NoiseLevel::ZERO);
);
assert_eq!(encrypted_overflow.0.degree.get(), 1);
assert_eq!(encrypted_overflow.0.noise_level(), NoiseLevel::ZERO);
}
}
}

Expand Down Expand Up @@ -442,37 +448,43 @@ where

let mut rng = rand::thread_rng();

// message_modulus^vec_length
let modulus = (cks.parameters().message_modulus().0.pow(NB_CTXT as u32) / 2) as i64;

executor.setup(&cks, sks);

let mut clear;

for _ in 0..nb_tests_smaller {
let clear_0 = rng.gen::<i64>() % modulus;
let clear_1 = rng.gen::<i64>() % modulus;

let ctxt_0 = cks.encrypt_signed(clear_0);
let ctxt_1 = cks.encrypt_signed(clear_1);
for num_blocks in 1..MAX_NB_CTXT {
let half_modulus = (cks.parameters().message_modulus().0.pow(num_blocks as u32) / 2) as i64;
if half_modulus <= 1 {
// The half_modulus (i.e modulus without sign bit) is such that the set of values
// is empty
continue;
}

let mut ct_res = executor.execute((&ctxt_0, &ctxt_1));
let tmp_ct = executor.execute((&ctxt_0, &ctxt_1));
assert!(ct_res.block_carries_are_empty());
assert_eq!(ct_res, tmp_ct);
for _ in 0..nb_tests_smaller {
let clear_0 = rng.gen::<i64>() % half_modulus;
let clear_1 = rng.gen::<i64>() % half_modulus;

clear = signed_sub_under_modulus(clear_0, clear_1, modulus);
let ctxt_0 = cks.as_ref().encrypt_signed_radix(clear_0, num_blocks);
let ctxt_1 = cks.as_ref().encrypt_signed_radix(clear_1, num_blocks);

// sub multiple times to raise the degree
for _ in 0..nb_tests_smaller {
ct_res = executor.execute((&ct_res, &ctxt_0));
let mut ct_res = executor.execute((&ctxt_0, &ctxt_1));
let tmp_ct = executor.execute((&ctxt_0, &ctxt_1));
assert!(ct_res.block_carries_are_empty());
clear = signed_sub_under_modulus(clear, clear_0, modulus);
assert_eq!(ct_res, tmp_ct);

clear = signed_sub_under_modulus(clear_0, clear_1, half_modulus);

// sub multiple times to raise the degree
for _ in 0..nb_tests_smaller {
ct_res = executor.execute((&ct_res, &ctxt_0));
assert!(ct_res.block_carries_are_empty());
clear = signed_sub_under_modulus(clear, clear_0, half_modulus);

let dec_res: i64 = cks.decrypt_signed(&ct_res);
let dec_res: i64 = cks.decrypt_signed(&ct_res);

// println!("clear = {}, dec_res = {}", clear, dec_res);
assert_eq!(clear, dec_res);
// println!("clear = {}, dec_res = {}", clear, dec_res);
assert_eq!(clear, dec_res);
}
}
}
}
Expand Down
Loading

0 comments on commit 4b2c0ae

Please sign in to comment.