Skip to content

Commit

Permalink
chore(integer): add extensive_trivial tests for sub
Browse files Browse the repository at this point in the history
  • Loading branch information
tmontaigu committed Oct 16, 2024
1 parent f4e74b2 commit d09492d
Show file tree
Hide file tree
Showing 7 changed files with 275 additions and 23 deletions.
6 changes: 6 additions & 0 deletions tfhe/src/core_crypto/commons/numeric/unsigned.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ pub trait UnsignedInteger:
#[must_use]
fn overflowing_add(self, rhs: Self) -> (Self, bool);
#[must_use]
fn overflowing_sub(self, rhs: Self) -> (Self, bool);
#[must_use]
fn is_power_of_two(self) -> bool;
#[must_use]
fn next_power_of_two(self) -> Self;
Expand Down Expand Up @@ -221,6 +223,10 @@ macro_rules! implement {
self.overflowing_add(rhs)
}
#[inline]
fn overflowing_sub(self, rhs: Self) -> (Self, bool) {
self.overflowing_sub(rhs)
}
#[inline]
fn is_power_of_two(self) -> bool {
self.is_power_of_two()
}
Expand Down
18 changes: 9 additions & 9 deletions tfhe/src/integer/server_key/radix_parallel/tests_signed/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -719,19 +719,19 @@ pub(crate) fn signed_neg_under_modulus(lhs: i64, modulus: i64) -> i64 {
// This is to 'simulate' i8, i16, ixy using i64 integers
//
// lhs and rhs must be in [-modulus..modulus[
pub(crate) fn signed_sub_under_modulus(lhs: i64, rhs: i64, modulus: i64) -> i64 {
pub(crate) fn signed_sub_under_modulus<T: SignedInteger>(lhs: T, rhs: T, modulus: T) -> T {
signed_overflowing_sub_under_modulus(lhs, rhs, modulus).0
}

pub(crate) fn signed_overflowing_sub_under_modulus(
lhs: i64,
rhs: i64,
modulus: i64,
) -> (i64, bool) {
pub(crate) fn signed_overflowing_sub_under_modulus<T: SignedInteger>(
lhs: T,
rhs: T,
modulus: T,
) -> (T, bool) {
// Technically we should be able to call overflowing_add_under_modulus(lhs, -rhs, ...)
// but due to -rhs being a 'special case' when rhs == -modulus, we have to
// so the impl here
assert!(modulus > 0);
assert!(modulus > T::ZERO);
assert!((-modulus..modulus).contains(&lhs));

// The code below requires rhs and lhs to be in range -modulus..modulus
Expand All @@ -741,14 +741,14 @@ pub(crate) fn signed_overflowing_sub_under_modulus(
(lhs - rhs, false)
} else {
// 2*modulus to get all the bits
(lhs - (rhs % (2 * modulus)), true)
(lhs - (rhs % (T::TWO * modulus)), true)
};

if res < -modulus {
// rem_euclid(modulus) would also work
res = modulus + (res - -modulus);
overflowed = true;
} else if res > modulus - 1 {
} else if res > modulus - T::ONE {
res = -modulus + (res - modulus);
overflowed = true;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -451,10 +451,7 @@ where

let message_modulus = cks.parameters().message_modulus();
let block_num_bits = message_modulus.0.ilog2();
// Contrary to regular add, we do bit_size every block num_bits,
// otherwise the bit_size actually encrypted is not exactly the same
// leading to false test overflow results.
for bit_size in (2..=64u32).step_by(block_num_bits as usize) {
for bit_size in 2..=64u32 {
let num_blocks = bit_size.div_ceil(block_num_bits);
let modulus = (cks.parameters().message_modulus().0 as i128).pow(num_blocks) / 2;

Expand Down
124 changes: 124 additions & 0 deletions tfhe/src/integer/server_key/radix_parallel/tests_signed/test_sub.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ use std::sync::Arc;
create_parametrized_test!(integer_signed_unchecked_sub);
create_parametrized_test!(integer_signed_unchecked_overflowing_sub);
create_parametrized_test!(integer_signed_default_sub);
create_parametrized_test!(integer_extensive_trivial_signed_default_sub);
create_parametrized_test!(integer_signed_default_overflowing_sub);
create_parametrized_test!(integer_extensive_trivial_signed_default_overflowing_sub);

fn integer_signed_unchecked_sub<P>(param: P)
where
Expand All @@ -47,13 +49,30 @@ where
signed_default_sub_test(param, executor);
}

fn integer_extensive_trivial_signed_default_sub<P>(param: P)
where
P: Into<PBSParameters>,
{
let executor = CpuFunctionExecutor::new(&ServerKey::sub_parallelized);
extensive_trivial_signed_default_sub_test(param, executor);
}

fn integer_extensive_trivial_signed_default_overflowing_sub<P>(param: P)
where
P: Into<PBSParameters>,
{
let executor = CpuFunctionExecutor::new(&ServerKey::signed_overflowing_sub_parallelized);
extensive_trivial_signed_default_overflowing_sub_test(param, executor);
}

fn integer_signed_default_overflowing_sub<P>(param: P)
where
P: Into<PBSParameters>,
{
let executor = CpuFunctionExecutor::new(&ServerKey::signed_overflowing_sub_parallelized);
signed_default_overflowing_sub_test(param, executor);
}

pub(crate) fn signed_default_overflowing_sub_test<P, T>(param: P, mut executor: T)
where
P: Into<PBSParameters>,
Expand Down Expand Up @@ -299,6 +318,63 @@ where
}
}

/// Although this uses the executor pattern and could be plugged in other backends,
/// It is not recommended to do so unless the backend is extremely fast on trivial ciphertexts
/// or extremely extremely fast in general, or if its plugged just as a one time thing.
pub(crate) fn extensive_trivial_signed_default_overflowing_sub_test<P, T>(
param: P,
mut overflowing_sub_executor: T,
) where
P: Into<PBSParameters>,
T: for<'a> FunctionExecutor<
(&'a SignedRadixCiphertext, &'a SignedRadixCiphertext),
(SignedRadixCiphertext, BooleanBlock),
>,
{
let param = param.into();
let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
let cks = RadixClientKey::from((cks, NB_CTXT));

sks.set_deterministic_pbs_execution(true);
let sks = Arc::new(sks);

let mut rng = rand::thread_rng();

overflowing_sub_executor.setup(&cks, sks.clone());

let message_modulus = cks.parameters().message_modulus();
let block_num_bits = message_modulus.0.ilog2();
for bit_size in 2..=64u32 {
let num_blocks = bit_size.div_ceil(block_num_bits);
let modulus = (cks.parameters().message_modulus().0 as i128).pow(num_blocks) / 2;

for _ in 0..50 {
let clear_0 = rng.gen::<i128>() % modulus;
let clear_1 = rng.gen::<i128>() % modulus;

let ctxt_0 = sks.create_trivial_radix(clear_0, num_blocks as usize);
let ctxt_1 = sks.create_trivial_radix(clear_1, num_blocks as usize);

let (ct_res, ct_overflow) = overflowing_sub_executor.execute((&ctxt_0, &ctxt_1));
let dec_res: i128 = cks.decrypt_signed(&ct_res);
let dec_overflow = cks.decrypt_bool(&ct_overflow);

let (expected_clear, expected_overflow) =
signed_overflowing_sub_under_modulus(clear_0, clear_1, modulus);
assert_eq!(
expected_clear, dec_res,
"Invalid result for {clear_0} + {clear_1}, expected: {expected_clear}, got: {dec_res}\n\
num_blocks={num_blocks}, modulus={modulus}"
);
assert_eq!(
expected_overflow, dec_overflow,
"Invalid overflow result for {clear_0} + {clear_1}, expected: {expected_overflow}, got: {dec_overflow}\n\
num_blocks={num_blocks}, modulus={modulus}"
);
}
}
}

pub(crate) fn signed_unchecked_sub_test<P, T>(param: P, mut executor: T)
where
P: Into<PBSParameters>,
Expand Down Expand Up @@ -400,3 +476,51 @@ where
}
}
}

/// Although this uses the executor pattern and could be plugged in other backends,
/// It is not recommended to do so unless the backend is extremely fast on trivial ciphertexts
/// or extremely extremely fast in general, or if its plugged just as a one time thing.
pub(crate) fn extensive_trivial_signed_default_sub_test<P, T>(param: P, mut sub_executor: T)
where
P: Into<PBSParameters>,
T: for<'a> FunctionExecutor<
(&'a SignedRadixCiphertext, &'a SignedRadixCiphertext),
SignedRadixCiphertext,
>,
{
let param = param.into();
let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
let cks = RadixClientKey::from((cks, NB_CTXT));

sks.set_deterministic_pbs_execution(true);
let sks = Arc::new(sks);

let mut rng = rand::thread_rng();

sub_executor.setup(&cks, sks.clone());

let message_modulus = cks.parameters().message_modulus();
let block_num_bits = message_modulus.0.ilog2();
for bit_size in 2..=64u32 {
let num_blocks = bit_size.div_ceil(block_num_bits);
let modulus = (cks.parameters().message_modulus().0 as i128).pow(num_blocks) / 2;

for _ in 0..50 {
let clear_0 = rng.gen::<i128>() % modulus;
let clear_1 = rng.gen::<i128>() % modulus;

let ctxt_0 = sks.create_trivial_radix(clear_0, num_blocks as usize);
let ctxt_1 = sks.create_trivial_radix(clear_1, num_blocks as usize);

let ct_res = sub_executor.execute((&ctxt_0, &ctxt_1));
let dec_res: i128 = cks.decrypt_signed(&ct_res);

let expected_clear = signed_sub_under_modulus(clear_0, clear_1, modulus);
assert_eq!(
expected_clear, dec_res,
"Invalid result for {clear_0} - {clear_1}, expected: {expected_clear}, got: {dec_res}\n\
num_blocks={num_blocks}, modulus={modulus}"
);
}
}
}
10 changes: 5 additions & 5 deletions tfhe/src/integer/server_key/radix_parallel/tests_unsigned/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -150,11 +150,11 @@ pub(crate) fn rotate_right_helper(value: u64, n: u32, actual_bit_size: u32) -> u
(rotated & mask) | ((rotated & shifted_mask) >> (u64::BITS - actual_bit_size))
}

pub(crate) fn overflowing_sub_under_modulus(lhs: u64, rhs: u64, modulus: u64) -> (u64, bool) {
assert!(
!(modulus.is_power_of_two() && (modulus - 1).overflowing_mul(2).1),
"If modulus is not a power of two, then must not overflow u64"
);
pub(crate) fn overflowing_sub_under_modulus<T: UnsignedInteger>(
lhs: T,
rhs: T,
modulus: T,
) -> (T, bool) {
let (result, overflowed) = lhs.overflowing_sub(rhs);
(result % modulus, overflowed)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -522,10 +522,8 @@ where

let message_modulus = cks.parameters().message_modulus();
let block_num_bits = message_modulus.0.ilog2();
// Contrary to regular add, we do bit_size every block num_bits,
// otherwise the bit_size actually encrypted is not exactly the same
// leading to false test overflow results.
for bit_size in (1..=64u32).step_by(block_num_bits as usize) {

for bit_size in 1..=64u32 {
let num_blocks = bit_size.div_ceil(block_num_bits);
let modulus = unsigned_modulus_u128(cks.parameters().message_modulus(), num_blocks);

Expand Down
Loading

0 comments on commit d09492d

Please sign in to comment.