diff --git a/tfhe/src/integer/bigint/algorithms.rs b/tfhe/src/integer/bigint/algorithms.rs index fd2b74c52c..56934b2f1b 100644 --- a/tfhe/src/integer/bigint/algorithms.rs +++ b/tfhe/src/integer/bigint/algorithms.rs @@ -143,20 +143,20 @@ pub(crate) fn bitxor_assign(lhs: &mut [u64], rhs: &[u64]) { } #[inline(always)] -pub(crate) fn add_with_carry(l: T, r: T, c: bool) -> (T, bool) { +pub(crate) fn wrapping_add_with_carry(l: T, r: T, c: bool) -> (T, bool) { let (lr, o0) = l.overflowing_add(r); let (lrc, o1) = lr.overflowing_add(T::cast_from(c)); (lrc, o0 | o1) } -pub(crate) fn add_assign_words(lhs: &mut [T], rhs: &[T]) { +pub(crate) fn wrapping_add_assign_words(lhs: &mut [T], rhs: &[T]) { let iter = lhs .iter_mut() .zip(rhs.iter().copied().chain(std::iter::repeat(T::ZERO))); let mut carry = false; for (lhs_block, rhs_block) in iter { - let (result, out_carry) = add_with_carry(*lhs_block, rhs_block, carry); + let (result, out_carry) = wrapping_add_with_carry(*lhs_block, rhs_block, carry); *lhs_block = result; carry = out_carry; } @@ -188,7 +188,7 @@ pub(crate) fn schoolbook_mul_assign(lhs: &mut [u64], rhs: &[u64]) { let mut result = terms.pop().unwrap(); for term in terms { - add_assign_words(&mut result, &term); + wrapping_add_assign_words(&mut result, &term); } for (lhs_block, result_block) in lhs.iter_mut().zip(result) { diff --git a/tfhe/src/integer/bigint/static_signed.rs b/tfhe/src/integer/bigint/static_signed.rs index 60afbeab7d..e38ee2fd74 100644 --- a/tfhe/src/integer/bigint/static_signed.rs +++ b/tfhe/src/integer/bigint/static_signed.rs @@ -115,7 +115,7 @@ impl std::ops::Add for StaticSignedBigInt { impl std::ops::AddAssign for StaticSignedBigInt { fn add_assign(&mut self, rhs: Self) { - super::algorithms::add_assign_words(self.0.as_mut_slice(), rhs.0.as_slice()); + super::algorithms::wrapping_add_assign_words(self.0.as_mut_slice(), rhs.0.as_slice()); } } diff --git a/tfhe/src/integer/bigint/static_unsigned.rs b/tfhe/src/integer/bigint/static_unsigned.rs index 80000401dd..d9b9652333 100644 --- a/tfhe/src/integer/bigint/static_unsigned.rs +++ b/tfhe/src/integer/bigint/static_unsigned.rs @@ -80,6 +80,16 @@ impl StaticUnsignedBigInt { pub fn ceil_ilog2(self) -> u32 { self.ilog2() + u32::from(!self.is_power_of_two()) } + + pub fn wrapping_sub(mut self, other: Self) -> Self { + let mut negated = !other; + super::algorithms::wrapping_add_assign_words( + negated.0.as_mut_slice(), + Self::from(1u64).0.as_slice(), + ); + super::algorithms::wrapping_add_assign_words(self.0.as_mut_slice(), negated.0.as_slice()); + self + } } #[cfg(test)] @@ -107,7 +117,7 @@ impl std::cmp::PartialOrd for StaticUnsignedBigInt { impl std::ops::AddAssign for StaticUnsignedBigInt { fn add_assign(&mut self, rhs: Self) { - super::algorithms::add_assign_words(self.0.as_mut_slice(), rhs.0.as_slice()); + super::algorithms::wrapping_add_assign_words(self.0.as_mut_slice(), rhs.0.as_slice()); } } diff --git a/tfhe/src/integer/server_key/radix/tests.rs b/tfhe/src/integer/server_key/radix/tests.rs index 88cd620b2d..9f702bd4f4 100644 --- a/tfhe/src/integer/server_key/radix/tests.rs +++ b/tfhe/src/integer/server_key/radix/tests.rs @@ -264,13 +264,13 @@ fn integer_smart_add_128_bits(param: ClassicPBSParameters) { // add the two ciphertexts let mut ct_res = sks.smart_add(&mut ctxt_0, &mut ctxt_1); - let mut clear_result = clear_0 + clear_1; + let mut clear_result = clear_0.wrapping_add(clear_1); // println!("clear_0 = {}, clear_1 = {}", clear_0, clear_1); //add multiple times to raise the degree for _ in 0..2 { ct_res = sks.smart_add(&mut ct_res, &mut ctxt_0); - clear_result += clear_0; + clear_result = clear_result.wrapping_add(clear_0); let dec_res: u128 = cks.decrypt_radix(&ct_res); // println!("clear = {}, dec_res = {}", clear, dec_res); @@ -629,7 +629,7 @@ fn integer_unchecked_scalar_decomposition_overflow(param: ClassicPBSParameters) let ct_res = sks.unchecked_scalar_add(&ct_0, scalar); let dec_res = cks.decrypt_radix(&ct_res); - assert_eq!((clear_0 + scalar as u128), dec_res); + assert_eq!(clear_0.wrapping_add(scalar as u128), dec_res); // Check subtraction // ----------------- @@ -640,7 +640,7 @@ fn integer_unchecked_scalar_decomposition_overflow(param: ClassicPBSParameters) let ct_res = sks.unchecked_scalar_sub(&ct_0, scalar); let dec_res = cks.decrypt_radix(&ct_res); - assert_eq!((clear_0 - scalar as u128), dec_res); + assert_eq!(clear_0.wrapping_sub(scalar as u128), dec_res); } #[test] @@ -666,7 +666,7 @@ fn integer_smart_scalar_mul_decomposition_overflow() { let ct_res = sks.smart_scalar_mul(&mut ct_0, scalar); let dec_res = cks.decrypt_radix(&ct_res); - assert_eq!((clear_0 * scalar as u128), dec_res); + assert_eq!(clear_0.wrapping_mul(scalar as u128), dec_res); } fn integer_default_overflowing_sub

(param: P) @@ -696,6 +696,9 @@ fn integer_create_trivial_min_max(param: impl Into) { // If num_bits_in_one_block is not a multiple of bit_size, then // the actual number of bits is not the same as bit size (we end up with more) let actual_num_bits = num_blocks * num_bits_in_one_block; + if actual_num_bits >= i128::BITS { + break; + } // Unsigned { diff --git a/tfhe/src/integer/server_key/radix_parallel/add.rs b/tfhe/src/integer/server_key/radix_parallel/add.rs index d544a80343..e2d87d20ec 100644 --- a/tfhe/src/integer/server_key/radix_parallel/add.rs +++ b/tfhe/src/integer/server_key/radix_parallel/add.rs @@ -1110,7 +1110,7 @@ impl ServerKey { } else { // u64::MAX is -1 in two's complement // We apply the modulus including the padding bit - u64::MAX % (1 << (block_modulus + 1)) + u64::MAX % (block_modulus * 2) } })] }; diff --git a/tfhe/src/integer/server_key/radix_parallel/comparison.rs b/tfhe/src/integer/server_key/radix_parallel/comparison.rs index 965f473e0a..762aa6e465 100644 --- a/tfhe/src/integer/server_key/radix_parallel/comparison.rs +++ b/tfhe/src/integer/server_key/radix_parallel/comparison.rs @@ -544,7 +544,7 @@ impl ServerKey { } else { // u64::MAX is -1 in tow's complement // We apply the modulus including the padding bit - u64::MAX % (1 << (block_modulus + 1)) + u64::MAX % (block_modulus * 2) } })] }; diff --git a/tfhe/src/integer/server_key/radix_parallel/scalar_comparison.rs b/tfhe/src/integer/server_key/radix_parallel/scalar_comparison.rs index 6c1a684dcc..1a374a5de7 100644 --- a/tfhe/src/integer/server_key/radix_parallel/scalar_comparison.rs +++ b/tfhe/src/integer/server_key/radix_parallel/scalar_comparison.rs @@ -881,8 +881,8 @@ impl ServerKey { let block_states = { for i in 1..grouping_size { let state_fn = |block| { - let r = (u64::MAX * u64::from(block != 0)) % (packed_modulus * 2); - r << (i - 1) + let r = u64::MAX * u64::from(block != 0); + (r << (i - 1)) % (packed_modulus * 2) }; first_grouping_luts.push(self.key.generate_lookup_table(state_fn)); } @@ -890,8 +890,8 @@ impl ServerKey { let other_block_state_luts = (0..grouping_size) .map(|i| { let state_fn = |block| { - let r = (u64::MAX * u64::from(block != 0)) % (packed_modulus * 2); - r << i + let r = u64::MAX * u64::from(block != 0); + (r << i) % (packed_modulus * 2) }; self.key.generate_lookup_table(state_fn) }) diff --git a/tfhe/src/integer/server_key/radix_parallel/scalar_div_mod.rs b/tfhe/src/integer/server_key/radix_parallel/scalar_div_mod.rs index 4d1d3809d1..f3668c0010 100644 --- a/tfhe/src/integer/server_key/radix_parallel/scalar_div_mod.rs +++ b/tfhe/src/integer/server_key/radix_parallel/scalar_div_mod.rs @@ -42,6 +42,8 @@ pub trait MiniUnsignedInteger: fn ilog2(self) -> u32; fn is_power_of_two(self) -> bool; + + fn wrapping_sub(self, other: Self) -> Self; } impl MiniUnsignedInteger for T @@ -59,6 +61,10 @@ where fn is_power_of_two(self) -> bool { ::is_power_of_two(self) } + + fn wrapping_sub(self, other: Self) -> Self { + ::wrapping_sub(self, other) + } } impl MiniUnsignedInteger for StaticUnsignedBigInt { @@ -73,6 +79,10 @@ impl MiniUnsignedInteger for StaticUnsignedBigInt { fn is_power_of_two(self) -> bool { self.is_power_of_two() } + + fn wrapping_sub(self, other: Self) -> Self { + self.wrapping_sub(other) + } } pub trait Reciprocable: MiniUnsignedInteger { @@ -496,8 +506,9 @@ impl ServerKey { // The subtraction may overflow. // We then cast the result to a signed type. // Overall, this will work fine due to two's complement representation - let cst = chosen_multiplier.multiplier - - (::DoublePrecision::ONE << numerator_bits); + let cst = chosen_multiplier.multiplier.wrapping_sub( + ::DoublePrecision::ONE << numerator_bits, + ); let cst = T::DoublePrecision::cast_from(cst); // MULSH(m - 2^N, n) diff --git a/tfhe/src/integer/server_key/radix_parallel/tests_cases_unsigned.rs b/tfhe/src/integer/server_key/radix_parallel/tests_cases_unsigned.rs index a9fe430edf..98af540d2f 100644 --- a/tfhe/src/integer/server_key/radix_parallel/tests_cases_unsigned.rs +++ b/tfhe/src/integer/server_key/radix_parallel/tests_cases_unsigned.rs @@ -1215,12 +1215,12 @@ where let mut ctxt_0 = cks.encrypt(clear_0); let mut ct_res = executor.execute((&mut ctxt_0, clear_1)); - clear = (clear_0 - clear_1) % modulus; + clear = clear_0.wrapping_sub(clear_1) % modulus; // Sub multiple times to raise the degree for _ in 0..nb_tests_smaller { ct_res = executor.execute((&mut ct_res, clear_1)); - clear = (clear - clear_1) % modulus; + clear = clear.wrapping_sub(clear_1) % modulus; let dec_res: u64 = cks.decrypt(&ct_res); diff --git a/tfhe/src/shortint/server_key/tests/parametrized_test.rs b/tfhe/src/shortint/server_key/tests/parametrized_test.rs index d30d8c8f02..3d366817e0 100644 --- a/tfhe/src/shortint/server_key/tests/parametrized_test.rs +++ b/tfhe/src/shortint/server_key/tests/parametrized_test.rs @@ -1150,7 +1150,7 @@ where let dec_res = cks.decrypt(&ct_res); - assert_eq!((clear - scalar) % message_modulus, dec_res as u8); + assert_eq!(clear.wrapping_sub(scalar) % message_modulus, dec_res as u8); } } @@ -1174,11 +1174,11 @@ where let mut ct_res = sks.smart_scalar_sub(&mut ctxt_0, clear_1); - let mut clear = (clear_0 - clear_1) % modulus; + let mut clear = clear_0.wrapping_sub(clear_1) % modulus; for _ in 0..NB_SUB_TEST_SMART { ct_res = sks.smart_scalar_sub(&mut ct_res, clear_1); - clear = (clear - clear_1) % modulus; + clear = clear.wrapping_sub(clear_1) % modulus; let dec_res = cks.decrypt(&ct_res); @@ -1427,7 +1427,7 @@ where let dec = cks.decrypt(&ct_tmp); - let clear_result = (clear1 - clear2) % modulus; + let clear_result = clear1.wrapping_sub(clear2) % modulus; assert_eq!(clear_result, dec % modulus); } } @@ -1452,10 +1452,10 @@ where let mut ct_res = sks.smart_sub(&mut ct1, &mut ct2); - let mut clear_res = (clear1 - clear2) % modulus; + let mut clear_res = clear1.wrapping_sub(clear2) % modulus; for _ in 0..NB_SUB_TEST_SMART { ct_res = sks.smart_sub(&mut ct_res, &mut ct2); - clear_res = (clear_res - clear2) % modulus; + clear_res = clear_res.wrapping_sub(clear2) % modulus; } let dec_res = cks.decrypt(&ct_res); @@ -1625,7 +1625,7 @@ where let dec_res = cks.decrypt(&res); - let clear_mux = (msg_true - msg_false) * control_bit + msg_false; + let clear_mux = (msg_true.wrapping_sub(msg_false) * control_bit).wrapping_add(msg_false); println!("(msg_true - msg_false) * control_bit + msg_false = {clear_mux}, res = {dec_res}"); assert_eq!(clear_mux, dec_res); }