diff --git a/tfhe/src/integer/gpu/server_key/radix/add.rs b/tfhe/src/integer/gpu/server_key/radix/add.rs index 826f29acc1..b96585852e 100644 --- a/tfhe/src/integer/gpu/server_key/radix/add.rs +++ b/tfhe/src/integer/gpu/server_key/radix/add.rs @@ -11,9 +11,14 @@ use crate::integer::gpu::{ unchecked_signed_overflowing_add_or_sub_radix_kb_assign_async, unchecked_sum_ciphertexts_integer_radix_kb_assign_async, PBSType, }; -use crate::integer::server_key::radix_parallel::sub::SignedOperation; use crate::shortint::ciphertext::NoiseLevel; +#[derive(Copy, Clone, PartialEq, Eq)] +pub(crate) enum SignedOperation { + Addition, + Subtraction, +} + impl CudaServerKey { /// Computes homomorphically an addition between two ciphertexts encrypting integer values. /// diff --git a/tfhe/src/integer/gpu/server_key/radix/sub.rs b/tfhe/src/integer/gpu/server_key/radix/sub.rs index 2a30cc7944..46a1232dc1 100644 --- a/tfhe/src/integer/gpu/server_key/radix/sub.rs +++ b/tfhe/src/integer/gpu/server_key/radix/sub.rs @@ -1,3 +1,4 @@ +use super::add::SignedOperation; use crate::core_crypto::gpu::lwe_ciphertext_list::CudaLweCiphertextList; use crate::core_crypto::gpu::CudaStreams; use crate::core_crypto::prelude::{CiphertextModulus, LweBskGroupingFactor, LweCiphertextCount}; @@ -11,7 +12,6 @@ use crate::integer::gpu::server_key::{CudaBootstrappingKey, CudaServerKey}; use crate::integer::gpu::{ unchecked_unsigned_overflowing_sub_integer_radix_kb_assign_async, PBSType, }; -use crate::integer::server_key::radix_parallel::sub::SignedOperation; use crate::shortint::ciphertext::NoiseLevel; impl CudaServerKey { diff --git a/tfhe/src/integer/server_key/radix/add.rs b/tfhe/src/integer/server_key/radix/add.rs index 329df0377c..4e6d322365 100644 --- a/tfhe/src/integer/server_key/radix/add.rs +++ b/tfhe/src/integer/server_key/radix/add.rs @@ -1,5 +1,5 @@ use crate::integer::ciphertext::IntegerRadixCiphertext; -use crate::integer::server_key::radix_parallel::sub::SignedOperation; +use crate::integer::server_key::radix_parallel::OutputFlag; use crate::integer::server_key::CheckError; use crate::integer::{BooleanBlock, ServerKey, SignedRadixCiphertext}; use crate::shortint::ciphertext::{Degree, MaxDegree, NoiseLevel}; @@ -267,6 +267,16 @@ impl ServerKey { lhs: &SignedRadixCiphertext, rhs: &SignedRadixCiphertext, ) -> (SignedRadixCiphertext, BooleanBlock) { - self.unchecked_signed_overflowing_add_or_sub(lhs, rhs, SignedOperation::Addition) + let mut result = lhs.clone(); + let overflowed = self + .advanced_add_assign_with_carry_sequential_parallelized( + &mut result.blocks, + &rhs.blocks, + None, + OutputFlag::from_signedness(true), + ) + .expect("overflow flag was requested"); + + (result, overflowed) } } diff --git a/tfhe/src/integer/server_key/radix/scalar_sub.rs b/tfhe/src/integer/server_key/radix/scalar_sub.rs index c844941ecf..0023bca31c 100644 --- a/tfhe/src/integer/server_key/radix/scalar_sub.rs +++ b/tfhe/src/integer/server_key/radix/scalar_sub.rs @@ -65,7 +65,7 @@ impl ServerKey { // - `None` if scalar is zero // - `Some` if scalar is non-zero // - fn create_negated_block_decomposer( + pub(crate) fn create_negated_block_decomposer( &self, scalar: Scalar, ) -> Option> diff --git a/tfhe/src/integer/server_key/radix/sub.rs b/tfhe/src/integer/server_key/radix/sub.rs index 0235bb3a9d..0efc1805e2 100644 --- a/tfhe/src/integer/server_key/radix/sub.rs +++ b/tfhe/src/integer/server_key/radix/sub.rs @@ -1,9 +1,8 @@ use crate::integer::ciphertext::IntegerRadixCiphertext; -use crate::integer::server_key::radix_parallel::sub::SignedOperation; +use crate::integer::server_key::radix_parallel::OutputFlag; use crate::integer::server_key::CheckError; use crate::integer::{BooleanBlock, RadixCiphertext, ServerKey, SignedRadixCiphertext}; use crate::shortint::ciphertext::{Degree, MaxDegree, NoiseLevel}; -use crate::shortint::Ciphertext; impl ServerKey { /// Computes homomorphically a subtraction between two ciphertexts encrypting integer values. @@ -420,124 +419,23 @@ impl ServerKey { (result, overflowed) } - pub(crate) fn unchecked_signed_overflowing_add_or_sub( - &self, - lhs: &SignedRadixCiphertext, - rhs: &SignedRadixCiphertext, - signed_operation: SignedOperation, - ) -> (SignedRadixCiphertext, BooleanBlock) { - let mut result = lhs.clone(); - - let num_blocks = result.blocks.len(); - if num_blocks == 0 { - return (result, self.create_trivial_boolean_block(false)); - } - - fn block_add_assign_returning_carry( - sks: &ServerKey, - lhs: &mut Ciphertext, - rhs: &Ciphertext, - ) -> Ciphertext { - sks.key.unchecked_add_assign(lhs, rhs); - let (carry, message) = rayon::join( - || sks.key.carry_extract(lhs), - || sks.key.message_extract(lhs), - ); - - *lhs = message; - - carry - } - - // 2_2, 3_3, 4_4 - // If we have at least 2 bits and at least as much carries - if self.key.message_modulus.0 >= 4 && self.key.carry_modulus.0 >= self.key.message_modulus.0 - { - if signed_operation == SignedOperation::Subtraction { - self.unchecked_sub_assign(&mut result, rhs); - } else { - self.unchecked_add_assign(&mut result, rhs); - } - - let mut input_carry = self.key.create_trivial(0); - - // For the first block do the first step of overflow computation in parallel - let (_, last_block_inner_propagation) = rayon::join( - || { - input_carry = - block_add_assign_returning_carry(self, &mut result.blocks[0], &input_carry); - }, - || { - self.generate_last_block_inner_propagation( - &lhs.blocks[num_blocks - 1], - &rhs.blocks[num_blocks - 1], - signed_operation, - ) - }, - ); - - for block in result.blocks[1..num_blocks - 1].iter_mut() { - input_carry = block_add_assign_returning_carry(self, block, &input_carry); - } - - // Treat the last block separately to handle last step of overflow behavior - let output_carry = block_add_assign_returning_carry( - self, - &mut result.blocks[num_blocks - 1], - &input_carry, - ); - let overflowed = self.resolve_signed_overflow( - last_block_inner_propagation, - &BooleanBlock::new_unchecked(input_carry), - &BooleanBlock::new_unchecked(output_carry), - ); - - return (result, overflowed); - } - - // 1_X parameters - // - // Same idea as other algorithms, however since we have 1 bit per block - // we do not have to resolve any inner propagation but it adds one more - // sequential PBS - if self.key.message_modulus.0 == 2 { - if signed_operation == SignedOperation::Subtraction { - self.unchecked_sub_assign(&mut result, rhs); - } else { - self.unchecked_add_assign(&mut result, rhs); - } - - let mut input_carry = self.key.create_trivial(0); - for block in result.blocks[..num_blocks - 1].iter_mut() { - input_carry = block_add_assign_returning_carry(self, block, &input_carry); - } - - let output_carry = block_add_assign_returning_carry( - self, - &mut result.blocks[num_blocks - 1], - &input_carry, - ); - - // Encode the rule - // "Overflow occurred if the carry into the last bit is different than the carry out - // of the last bit" - let overflowed = self.key.not_equal(&output_carry, &input_carry); - return (result, BooleanBlock::new_unchecked(overflowed)); - } - - panic!( - "Invalid combo of message modulus ({}) and carry modulus ({}) \n\ - This function requires the message modulus >= 2 and carry modulus >= message_modulus \n\ - I.e. PARAM_MESSAGE_X_CARRY_Y where X >= 1 and Y >= X.", - self.key.message_modulus.0, self.key.carry_modulus.0 - ); - } pub fn unchecked_signed_overflowing_sub( &self, lhs: &SignedRadixCiphertext, rhs: &SignedRadixCiphertext, ) -> (SignedRadixCiphertext, BooleanBlock) { - self.unchecked_signed_overflowing_add_or_sub(lhs, rhs, SignedOperation::Subtraction) + let flipped_rhs = self.bitnot(rhs); + let carry = self.create_trivial_boolean_block(true); + let mut result = lhs.clone(); + let overflowed = self + .advanced_add_assign_with_carry_sequential_parallelized( + &mut result.blocks, + &flipped_rhs.blocks, + Some(&carry), + OutputFlag::from_signedness(true), + ) + .expect("overflow flat was requested"); + (result, overflowed) } pub fn signed_overflowing_sub( diff --git a/tfhe/src/integer/server_key/radix_parallel/add.rs b/tfhe/src/integer/server_key/radix_parallel/add.rs index bf76103fc5..fd750489ac 100644 --- a/tfhe/src/integer/server_key/radix_parallel/add.rs +++ b/tfhe/src/integer/server_key/radix_parallel/add.rs @@ -1,49 +1,46 @@ use crate::core_crypto::commons::numeric::UnsignedInteger; use crate::integer::ciphertext::IntegerRadixCiphertext; -use crate::integer::server_key::radix_parallel::sub::SignedOperation; use crate::integer::{BooleanBlock, RadixCiphertext, ServerKey, SignedRadixCiphertext}; use crate::shortint::ciphertext::Degree; use crate::shortint::Ciphertext; use rayon::prelude::*; -#[repr(u64)] -#[derive(PartialEq, Eq)] -pub(crate) enum OutputCarry { - /// The block does not generate nor propagate a carry - None = 0, - /// The block generates a carry - Generated = 1, - /// The block will propagate a carry if it ever - /// receives one - Propagated = 2, +/// Possible output flag that the advanced_add_assign_with_carry family of +/// functions can compute. +#[derive(Copy, Clone, PartialEq, Eq, Debug)] +pub(crate) enum OutputFlag { + /// Request no flag at all + None, + /// The overflow flag is the flag that tells whether the input carry bit onto the last bit + /// is different than the output bit. + /// + /// This is useful to know if a signed addition overflowed (in 2's complement) + Overflow, + /// The carry flag is simply the carry bit that the output from the last pair of blocks + /// in an addition. + /// + /// This is useful to know if an unsigned addition overflowed. + Carry, } -/// Function to create the LUT used in parallel prefix sum -/// to compute carry propagation -/// -/// If msb propagates it take the value of lsb, -/// this means: -/// - if lsb propagates, msb will propagate (but we don't know yet if there will actually be a carry -/// to propagate), -/// - if lsb generates a carry, as msb propagates it, lsb will generate a carry. Note that this lsb -/// generates might be due to x propagating ('resolved' by an earlier iteration of the loop) -/// - if lsb does not output a carry, msb will have nothing to propagate -/// -/// Otherwise, msb either does not generate, or it does generate, -/// but it means it won't propagate -fn prefix_sum_carry_propagation(msb: u64, lsb: u64) -> u64 { - if msb == OutputCarry::Propagated as u64 { - lsb - } else { - msb +impl OutputFlag { + /// Returns which flag shall be computed in order to get the flag + /// telling the overflow status + pub(crate) const fn from_signedness(is_signed: bool) -> Self { + if is_signed { + Self::Overflow + } else { + Self::Carry + } } } -fn should_hillis_steele_propagation_be_faster(num_blocks: usize, num_threads: usize) -> bool { - // Measures have shown that using a parallelized algorithm degrades - // the latency of a PBS, so we take that into account. - // (This factor is a bit pessimistic). - const PARALLEL_LATENCY_PENALTY: usize = 2; +fn should_parallel_propagation_be_faster( + full_modulus: usize, + num_blocks: usize, + num_threads: usize, +) -> bool { + const PARALLEL_LATENCY_PENALTY: usize = 1; // However that penalty only kicks in when certain level of // parallelism is used let penalty_threshold = num_threads / 2; @@ -59,14 +56,37 @@ fn should_hillis_steele_propagation_be_faster(num_blocks: usize, num_threads: us }; // Estimate the latency of the parallelized algorithm - let mut parallel_expected_latency = 2 * compute_latency_of_one_layer(num_blocks, num_threads); - let max_depth = num_blocks.ceil_ilog2(); - let mut space = 1; - for _ in 0..max_depth { - let num_block_at_iter = num_blocks - space; - let iter_latency = compute_latency_of_one_layer(num_block_at_iter, num_threads); - parallel_expected_latency += iter_latency; - space *= 2; + // One pre-processing layer, one layer to compute what happens in each grouping, + // one final post processing layer + let mut parallel_expected_latency = 3 * compute_latency_of_one_layer(num_blocks, num_threads); + + let grouping_size = full_modulus.ilog2(); + let num_groups = num_blocks.div_ceil(grouping_size as usize); + + let num_carry_to_resolve = num_groups.saturating_sub(1); + + let sequential_depth = (num_carry_to_resolve.saturating_sub(1) as u32) / (grouping_size - 1); + let hillis_steel_depth = if num_carry_to_resolve == 0 { + 0 + } else { + num_carry_to_resolve.ceil_ilog2() + }; + + let parallel_algo_uses_sequential_to_resolve_grouping_carries = + sequential_depth <= hillis_steel_depth; + + if parallel_algo_uses_sequential_to_resolve_grouping_carries { + parallel_expected_latency += sequential_depth as usize + * compute_latency_of_one_layer(grouping_size as usize, num_threads); + } else { + let max_depth = num_blocks.ceil_ilog2(); + let mut space = 1; + for _ in 0..max_depth { + let num_block_at_iter = num_blocks - space; + let iter_latency = compute_latency_of_one_layer(num_block_at_iter, num_threads); + parallel_expected_latency += iter_latency; + space *= 2; + } } // the other algorithm has num_blocks latency @@ -231,14 +251,10 @@ impl ServerKey { } }; - if self.is_eligible_for_parallel_single_carry_propagation(lhs) { - let _carry = self.unchecked_add_assign_parallelized_low_latency(lhs, rhs); - } else { - self.unchecked_add_assign(lhs, rhs); - self.full_propagate_parallelized(lhs); - } + self.add_assign_with_carry_parallelized(lhs, rhs, None); } - /// Computes the addition of two unsigned ciphertexts and returns the overflow flag + + /// Computes the addition of two ciphertexts and returns the overflow flag /// /// # Example /// @@ -265,23 +281,25 @@ impl ServerKey { /// assert_eq!(dec_result, expected_result); /// assert_eq!(dec_overflowed, expected_overflow); /// ``` - pub fn unsigned_overflowing_add_parallelized( - &self, - ct_left: &RadixCiphertext, - ct_right: &RadixCiphertext, - ) -> (RadixCiphertext, BooleanBlock) { + pub fn overflowing_add_parallelized(&self, ct_left: &T, ct_right: &T) -> (T, BooleanBlock) + where + T: IntegerRadixCiphertext, + { let mut ct_res = ct_left.clone(); - let overflowed = self.unsigned_overflowing_add_assign_parallelized(&mut ct_res, ct_right); + let overflowed = self.overflowing_add_assign_parallelized(&mut ct_res, ct_right); (ct_res, overflowed) } - pub fn unsigned_overflowing_add_assign_parallelized( + pub fn overflowing_add_assign_parallelized( &self, - ct_left: &mut RadixCiphertext, - ct_right: &RadixCiphertext, - ) -> BooleanBlock { - let mut tmp_rhs: RadixCiphertext; - if ct_left.blocks.is_empty() || ct_right.blocks.is_empty() { + ct_left: &mut T, + ct_right: &T, + ) -> BooleanBlock + where + T: IntegerRadixCiphertext, + { + let mut tmp_rhs: T; + if ct_left.blocks().is_empty() || ct_right.blocks().is_empty() { return self.create_trivial_boolean_block(false); } @@ -309,30 +327,50 @@ impl ServerKey { } }; - self.unchecked_add_assign_parallelized(lhs, rhs); - self.unsigned_overflowing_propagate_addition_carry(lhs) + self.overflowing_add_assign_with_carry(lhs, rhs, None) } - /// This function takes a ciphertext resulting from an addition of 2 clean ciphertexts + /// Computes the addition of two unsigned ciphertexts and returns the overflow flag + /// + /// # Example + /// + /// ```rust + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS; + /// + /// // Generate the client key and the server key: + /// let num_blocks = 4; + /// let (cks, sks) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2_KS_PBS, num_blocks); + /// + /// let msg1 = u8::MAX; + /// let msg2 = 1; + /// + /// let ct1 = cks.encrypt(msg1); + /// let ct2 = cks.encrypt(msg2); + /// + /// let (ct_res, overflowed) = sks.unsigned_overflowing_add_parallelized(&ct1, &ct2); /// - /// It propagates the carries in-place, making the ciphertext clean and returns - /// the boolean indicating overflow - pub(in crate::integer) fn unsigned_overflowing_propagate_addition_carry( + /// // Decrypt: + /// let dec_result: u8 = cks.decrypt(&ct_res); + /// let dec_overflowed = cks.decrypt_bool(&overflowed); + /// let (expected_result, expected_overflow) = msg1.overflowing_add(msg2); + /// assert_eq!(dec_result, expected_result); + /// assert_eq!(dec_overflowed, expected_overflow); + /// ``` + pub fn unsigned_overflowing_add_parallelized( + &self, + ct_left: &RadixCiphertext, + ct_right: &RadixCiphertext, + ) -> (RadixCiphertext, BooleanBlock) { + self.overflowing_add_parallelized(ct_left, ct_right) + } + + pub fn unsigned_overflowing_add_assign_parallelized( &self, - ct: &mut RadixCiphertext, + ct_left: &mut RadixCiphertext, + ct_right: &RadixCiphertext, ) -> BooleanBlock { - if self.is_eligible_for_parallel_single_carry_propagation(ct) { - let carry = self.propagate_single_carry_parallelized_low_latency(&mut ct.blocks); - BooleanBlock::new_unchecked(carry) - } else { - let len = ct.blocks.len(); - for i in 0..len - 1 { - let _ = self.propagate_parallelized(ct, i); - } - let mut carry = self.propagate_parallelized(ct, len - 1); - carry.degree = Degree::new(1); - BooleanBlock::new_unchecked(carry) - } + self.overflowing_add_assign_parallelized(ct_left, ct_right) } pub fn signed_overflowing_add_parallelized( @@ -340,36 +378,7 @@ impl ServerKey { ct_left: &SignedRadixCiphertext, ct_right: &SignedRadixCiphertext, ) -> (SignedRadixCiphertext, BooleanBlock) { - let mut tmp_lhs: SignedRadixCiphertext; - let mut tmp_rhs: SignedRadixCiphertext; - - let (lhs, rhs) = match ( - ct_left.block_carries_are_empty(), - ct_right.block_carries_are_empty(), - ) { - (true, true) => (ct_left, ct_right), - (true, false) => { - tmp_rhs = ct_right.clone(); - self.full_propagate_parallelized(&mut tmp_rhs); - (ct_left, &tmp_rhs) - } - (false, true) => { - tmp_lhs = ct_left.clone(); - self.full_propagate_parallelized(&mut tmp_lhs); - (&tmp_lhs, ct_right) - } - (false, false) => { - tmp_lhs = ct_left.clone(); - tmp_rhs = ct_right.clone(); - rayon::join( - || self.full_propagate_parallelized(&mut tmp_lhs), - || self.full_propagate_parallelized(&mut tmp_rhs), - ); - (&tmp_lhs, &tmp_rhs) - } - }; - - self.unchecked_signed_overflowing_add_parallelized(lhs, rhs) + self.overflowing_add_parallelized(ct_left, ct_right) } pub fn unchecked_signed_overflowing_add_parallelized( @@ -386,188 +395,996 @@ impl ServerKey { ); assert!(!ct_left.blocks.is_empty(), "inputs cannot be empty"); - if self.is_eligible_for_parallel_single_carry_propagation(ct_left) { - self.unchecked_signed_overflowing_add_or_sub_parallelized_impl( - ct_left, - ct_right, - SignedOperation::Addition, - ) - } else { - self.unchecked_signed_overflowing_add_or_sub( - ct_left, - ct_right, - SignedOperation::Addition, - ) - } + let mut result = ct_left.clone(); + let overflowed = self.overflowing_add_assign_with_carry(&mut result, ct_right, None); + (result, overflowed) } - pub fn add_parallelized_work_efficient(&self, ct_left: &T, ct_right: &T) -> T - where - T: IntegerRadixCiphertext, - { - let mut ct_res = ct_left.clone(); - self.add_assign_parallelized_work_efficient(&mut ct_res, ct_right); - ct_res + pub(crate) fn is_eligible_for_parallel_single_carry_propagation( + &self, + num_blocks: usize, + ) -> bool { + // having 4-bits is a hard requirement + // as the parallel implementation uses a bivariate BPS where individual values need + // 2 bits + let total_modulus = self.key.message_modulus.0 * self.key.carry_modulus.0; + let has_enough_bits_per_block = total_modulus >= (1 << 4); + if !has_enough_bits_per_block { + return false; + } + + should_parallel_propagation_be_faster( + self.message_modulus().0 * self.carry_modulus().0, + num_blocks, + rayon::current_num_threads(), + ) } - pub fn add_assign_parallelized_work_efficient(&self, ct_left: &mut T, ct_right: &T) - where + /// Does lhs += (rhs + carry) + pub fn add_assign_with_carry_parallelized( + &self, + lhs: &mut T, + rhs: &T, + input_carry: Option<&BooleanBlock>, + ) where T: IntegerRadixCiphertext, { - let mut tmp_rhs: T; + if !lhs.block_carries_are_empty() { + self.full_propagate_parallelized(lhs); + } - let (lhs, rhs) = match ( - ct_left.block_carries_are_empty(), - ct_right.block_carries_are_empty(), - ) { - (true, true) => (ct_left, ct_right), - (true, false) => { - tmp_rhs = ct_right.clone(); - self.full_propagate_parallelized(&mut tmp_rhs); - (ct_left, &tmp_rhs) - } - (false, true) => { - self.full_propagate_parallelized(ct_left); - (ct_left, ct_right) - } - (false, false) => { - tmp_rhs = ct_right.clone(); - rayon::join( - || self.full_propagate_parallelized(ct_left), - || self.full_propagate_parallelized(&mut tmp_rhs), - ); - (ct_left, &tmp_rhs) - } + let mut cloned_rhs; + + let rhs = if rhs.block_carries_are_empty() { + rhs + } else { + cloned_rhs = rhs.clone(); + self.full_propagate_parallelized(&mut cloned_rhs); + &cloned_rhs }; - self.unchecked_add_assign_parallelized_work_efficient(lhs, rhs); + self.advanced_add_assign_with_carry_parallelized( + lhs.blocks_mut(), + rhs.blocks(), + input_carry, + OutputFlag::None, + ); } - pub(crate) fn is_eligible_for_parallel_single_carry_propagation(&self, ct: &T) -> bool + /// Does lhs += (rhs + carry) + /// + /// Returns a boolean block that encrypts `true` if overflow happened + pub fn overflowing_add_assign_with_carry( + &self, + lhs: &mut T, + rhs: &T, + input_carry: Option<&BooleanBlock>, + ) -> BooleanBlock where T: IntegerRadixCiphertext, { - // having 4-bits is a hard requirement - // as the parallel implementation uses a bivariate BPS where individual values need - // 2 bits - let total_modulus = self.key.message_modulus.0 * self.key.carry_modulus.0; - let has_enough_bits_per_block = total_modulus >= (1 << 4); - if !has_enough_bits_per_block { - return false; - } + self.advanced_add_assign_with_carry_parallelized( + lhs.blocks_mut(), + rhs.blocks(), + input_carry, + OutputFlag::from_signedness(T::IS_SIGNED), + ) + .expect("internal error, overflow computation was not returned as was requested") + } - should_hillis_steele_propagation_be_faster(ct.blocks().len(), rayon::current_num_threads()) + pub(crate) fn propagate_single_carry_parallelized(&self, radix: &mut [Ciphertext]) { + self.advanced_add_assign_with_carry_at_least_4_bits(radix, &[], None, OutputFlag::None); } - /// This add_assign two numbers - /// - /// It uses the Hillis and Steele algorithm to do - /// prefix sum / cumulative sum in parallel. + /// Computes the result of `lhs += rhs + input_carry` /// - /// It it not "work efficient" as in, it adds a lot - /// of work compared to the single threaded approach, - /// however it is highly parallelized and so is the fastest - /// assuming enough threads are available. + /// This will selects what seems to be best algorithm to propagate carries + /// (fully parallel vs sequential) by looking at the number of blocks and + /// number of threads. /// - /// At most num_block - 1 threads are used + /// - `lhs` and `rhs` must have the same `len()`, empty is allowed + /// - `blocks of lhs` and `rhs` must all be without carry + /// - blocks must have at least one bit of message and one bit of carry /// - /// Returns the output carry that can be used to check for unsigned addition - /// overflow. + /// Returns `Some(...)` if requested_flag != ComputationFlags::None + pub(crate) fn advanced_add_assign_with_carry_parallelized( + &self, + lhs: &mut [Ciphertext], + rhs: &[Ciphertext], + input_carry: Option<&BooleanBlock>, + requested_flag: OutputFlag, + ) -> Option { + if self.is_eligible_for_parallel_single_carry_propagation(lhs.len()) { + self.advanced_add_assign_with_carry_at_least_4_bits( + lhs, + rhs, + input_carry, + requested_flag, + ) + } else { + self.advanced_add_assign_with_carry_sequential_parallelized( + lhs, + rhs, + input_carry, + requested_flag, + ) + } + } + + /// Computes the result of `lhs += rhs + input_carry` /// - /// # Requirements + /// This uses the sequential algorithm to propagate the carries /// - /// - The parameters have 4 bits in total - /// - Adding rhs to lhs must not consume more than one carry + /// - `lhs` and `rhs` must have the same `len()`, empty is allowed + /// - `blocks of lhs` and `rhs` must all be without carry + /// - blocks must have at least one bit of message and one bit of carry /// - /// # Output + /// Returns `Some(...)` if requested_flag != ComputationFlags::None + pub(crate) fn advanced_add_assign_with_carry_sequential_parallelized( + &self, + lhs: &mut [Ciphertext], + rhs: &[Ciphertext], + input_carry: Option<&BooleanBlock>, + requested_flag: OutputFlag, + ) -> Option { + assert_eq!( + lhs.len(), + rhs.len(), + "Both operands must have the same number of blocks" + ); + + if lhs.is_empty() { + return if requested_flag == OutputFlag::None { + None + } else { + Some(self.create_trivial_boolean_block(false)) + }; + } + + let carry = + input_carry.map_or_else(|| self.create_trivial_boolean_block(false), Clone::clone); + + // 2_2, 3_3, 4_4 + // If we have at least 2 bits and at least as much carries + // + // The num blocks == 1 + requested_flag == OverflowFlag will actually result in one more + // PBS of latency than num_blocks == 1 && requested_flag != OverflowFlag + // + // It happens because the computation of the overflow flag requires 2 steps, + // and we insert these two steps in parallel to normal carry propagation. + // The first step is done when processing the first block, + // the second step is done when processing the last block. + // So if the number of block is smaller than 2 then, + // the overflow computation adds additional layer of PBS. + if self.key.message_modulus.0 >= 4 && self.key.carry_modulus.0 >= self.key.message_modulus.0 + { + self.advanced_add_assign_sequential_at_least_4_bits( + requested_flag, + lhs, + rhs, + carry, + input_carry, + ) + } else if self.key.message_modulus.0 == 2 + && self.key.carry_modulus.0 >= self.key.message_modulus.0 + { + self.advanced_add_assign_sequential_at_least_2_bits(lhs, rhs, carry, requested_flag) + } else { + panic!( + "Invalid combo of message modulus ({}) and carry modulus ({}) \n\ + This function requires the message modulus >= 2 and carry modulus >= message_modulus \n\ + I.e. PARAM_MESSAGE_X_CARRY_Y where X >= 1 and Y >= X.", + self.key.message_modulus.0, self.key.carry_modulus.0 + ); + } + } + + /// Computes lhs += (rhs + carry) using the sequential propagation of carries /// - /// - lhs will have its carries empty - pub(crate) fn unchecked_add_assign_parallelized_low_latency( + /// parameters of blocks must have 4 bits, parameters in the form X_Y where X >= 2 && Y >= X + fn advanced_add_assign_sequential_at_least_4_bits( &self, - lhs: &mut T, - rhs: &T, - ) -> Ciphertext - where - T: IntegerRadixCiphertext, - { - let degree_after_add_does_not_go_beyond_first_carry = lhs - .blocks() - .iter() - .zip(rhs.blocks().iter()) - .all(|(bl, br)| { - let degree_after_add = bl.degree.get() + br.degree.get(); - degree_after_add < (self.key.message_modulus.0 * 2) + requested_flag: OutputFlag, + lhs: &mut [Ciphertext], + rhs: &[Ciphertext], + carry: BooleanBlock, + input_carry: Option<&BooleanBlock>, + ) -> Option { + let mut carry = carry.0; + + let mut overflow_flag = if requested_flag == OutputFlag::Overflow { + let mut block = self + .key + .unchecked_scalar_mul(lhs.last().as_ref().unwrap(), self.message_modulus().0 as u8); + self.key + .unchecked_add_assign(&mut block, rhs.last().as_ref().unwrap()); + Some(block) + } else { + None + }; + + // Handle the first block + self.key.unchecked_add_assign(&mut lhs[0], &rhs[0]); + self.key.unchecked_add_assign(&mut lhs[0], &carry); + + // To be able to use carry_extract_assign in it + carry.clone_from(&lhs[0]); + rayon::scope(|s| { + s.spawn(|_| { + self.key.message_extract_assign(&mut lhs[0]); }); - assert!(degree_after_add_does_not_go_beyond_first_carry); - self.unchecked_add_assign_parallelized(lhs, rhs); - self.propagate_single_carry_parallelized_low_latency(lhs.blocks_mut()) + s.spawn(|_| { + self.key.carry_extract_assign(&mut carry); + }); + + if requested_flag == OutputFlag::Overflow { + s.spawn(|_| { + // Computing the overflow flag requires an extra step for the first block + + let overflow_flag = overflow_flag.as_mut().unwrap(); + let num_bits_in_message = self.message_modulus().0.ilog2() as u64; + let lut = self.key.generate_lookup_table(|lhs_rhs| { + let lhs = lhs_rhs / self.message_modulus().0 as u64; + let rhs = lhs_rhs % self.message_modulus().0 as u64; + overflow_flag_preparation_lut(lhs, rhs, num_bits_in_message) + }); + self.key.apply_lookup_table_assign(overflow_flag, &lut); + }); + } + }); + + let num_blocks = lhs.len(); + + // We did the first block before, the last block is done after this if, + // so we need 3 blocks at least to enter this + if num_blocks >= 3 { + for (lhs_b, rhs_b) in lhs[1..num_blocks - 1] + .iter_mut() + .zip(rhs[1..num_blocks - 1].iter()) + { + self.key.unchecked_add_assign(lhs_b, rhs_b); + self.key.unchecked_add_assign(lhs_b, &carry); + + carry.clone_from(lhs_b); + rayon::join( + || self.key.message_extract_assign(lhs_b), + || self.key.carry_extract_assign(&mut carry), + ); + } + } + + if num_blocks >= 2 { + // Handle the last block + self.key + .unchecked_add_assign(&mut lhs[num_blocks - 1], &rhs[num_blocks - 1]); + self.key + .unchecked_add_assign(&mut lhs[num_blocks - 1], &carry); + } + + if let Some(block) = overflow_flag.as_mut() { + if num_blocks == 1 && input_carry.is_some() { + self.key + .unchecked_add_assign(block, input_carry.map(|b| &b.0).unwrap()); + } else { + self.key.unchecked_add_assign(block, &carry); + } + } + + // To be able to use carry_extract_assign in it + carry.clone_from(&lhs[num_blocks - 1]); + + // Note that here when num_blocks == 1 && requested_flag != Overflow nothing + // will actually be spawned. + rayon::scope(|s| { + if num_blocks >= 2 { + // These would already have been done when the first block was processed + s.spawn(|_| { + self.key.message_extract_assign(&mut lhs[num_blocks - 1]); + }); + + s.spawn(|_| { + self.key.carry_extract_assign(&mut carry); + }); + } + + if requested_flag == OutputFlag::Overflow { + s.spawn(|_| { + let overflow_flag_block = overflow_flag.as_mut().unwrap(); + // Computing the overflow flag requires and extra step for the first block + let overflow_flag_lut = self.key.generate_lookup_table(|block| { + let input_carry = block & 1; + let does_overflow_if_carry_is_1 = (block >> 3) & 1; + let does_overflow_if_carry_is_0 = (block >> 2) & 1; + if input_carry == 1 { + does_overflow_if_carry_is_1 + } else { + does_overflow_if_carry_is_0 + } + }); + + self.key + .apply_lookup_table_assign(overflow_flag_block, &overflow_flag_lut); + }); + } + }); + + match requested_flag { + OutputFlag::None => None, + OutputFlag::Overflow => { + assert!( + overflow_flag.is_some(), + "internal error, overflow_flag should exist" + ); + overflow_flag.map(BooleanBlock::new_unchecked) + } + OutputFlag::Carry => { + carry.degree = Degree::new(1); + Some(BooleanBlock::new_unchecked(carry)) + } + } + } + + /// Computes lhs += (rhs + carry) using the sequential propagation of carries + /// + /// parameters of blocks must have 2 bits, parameters in the form X_Y where X >= 1 && Y >= X + // so 1_X parameters + // + // Same idea as other algorithms, however since we have 1 bit per block + // we do not have to resolve any inner propagation but it adds one more + // sequential PBS when we are interested in the OverflowFlag + fn advanced_add_assign_sequential_at_least_2_bits( + &self, + lhs: &mut [Ciphertext], + rhs: &[Ciphertext], + carry: BooleanBlock, + requested_flag: OutputFlag, + ) -> Option { + let mut carry = carry.0; + + fn block_add_assign_returning_carry( + sks: &ServerKey, + lhs: &mut Ciphertext, + rhs: &Ciphertext, + carry: &Ciphertext, + ) -> Ciphertext { + sks.key.unchecked_add_assign(lhs, rhs); + sks.key.unchecked_add_assign(lhs, carry); + let (carry, message) = rayon::join( + || sks.key.carry_extract(lhs), + || sks.key.message_extract(lhs), + ); + + *lhs = message; + + carry + } + let num_blocks = lhs.len(); + for (lhs_b, rhs_b) in lhs[..num_blocks - 1] + .iter_mut() + .zip(rhs[..num_blocks - 1].iter()) + { + carry = block_add_assign_returning_carry(self, lhs_b, rhs_b, &carry); + } + + let mut output_carry = block_add_assign_returning_carry( + self, + &mut lhs[num_blocks - 1], + &rhs[num_blocks - 1], + &carry, + ); + + match requested_flag { + OutputFlag::None => None, + OutputFlag::Overflow => { + let overflowed = self.key.not_equal(&output_carry, &carry); + Some(BooleanBlock::new_unchecked(overflowed)) + } + OutputFlag::Carry => { + output_carry.degree = Degree::new(1); + Some(BooleanBlock::new_unchecked(output_carry)) + } + } } - /// This function takes an input slice of shortint ciphertext (aka blocks) - /// for which at most one bit of carry is consumed in each block, and - /// it does the carry propagation in place. + /// Does lhs += (rhs + carry) /// - /// It returns the output carry of the last block + /// acts like the ADC assemby op, except, the flags have to be explicitely requested + /// as they incur additional PBSes /// - /// Used in (among other) 'default' addition: - /// - first unchecked_add - /// - at this point at most on bit of carry is taken - /// - use this function to propagate them in parallel - pub(crate) fn propagate_single_carry_parallelized_low_latency( + /// - Parameters must have at least 2 bits of message, 2 bits of carry + /// - blocks of lhs and rhs must be clean (no carries) + /// - lhs and rhs must have the same length + pub(crate) fn advanced_add_assign_with_carry_at_least_4_bits( &self, - blocks: &mut [Ciphertext], - ) -> Ciphertext { - let generates_or_propagates = self.generate_init_carry_array(blocks); - let (input_carries, output_carry) = - self.compute_carry_propagation_parallelized_low_latency(generates_or_propagates); + lhs: &mut [Ciphertext], + rhs: &[Ciphertext], + input_carry: Option<&BooleanBlock>, + requested_flag: OutputFlag, + ) -> Option { + // Empty rhs is a specially allowed 'weird' case to + // act like a 'propagate single carry' function. + // This is not made explicit in the docs as we have a + // `propagate_single_carry_parallelized` function which wraps this special case + if rhs.is_empty() { + // Techinically, CarryFlag is computable, but OverflowFlag is not + assert_eq!( + requested_flag, + OutputFlag::None, + "Cannot compute flags when called in propagation mode" + ); + } else { + assert_eq!( + lhs.len(), + rhs.len(), + "Both operands must have the same number of blocks" + ); + } - blocks - .par_iter_mut() - .zip(input_carries.par_iter()) - .for_each(|(block, input_carry)| { - self.key.unchecked_add_assign(block, input_carry); - self.key.message_extract_assign(block); + if lhs.is_empty() { + // Then both are empty + if requested_flag == OutputFlag::None { + return None; + } + return Some(self.create_trivial_boolean_block(false)); + } + + let saved_last_blocks = if requested_flag == OutputFlag::Overflow { + Some((lhs.last().cloned().unwrap(), rhs.last().cloned().unwrap())) + } else { + None + }; + + // Perform the block additions + for (lhs_b, rhs_b) in lhs.iter_mut().zip(rhs.iter()) { + self.key.unchecked_add_assign(lhs_b, rhs_b); + } + if let Some(carry) = input_carry { + self.key.unchecked_add_assign(&mut lhs[0], &carry.0); + } + + let blocks = lhs; + let num_blocks = blocks.len(); + + let message_modulus = self.message_modulus().0 as u64; + let num_bits_in_message = message_modulus.ilog2() as u64; + + let block_modulus = self.message_modulus().0 * self.carry_modulus().0; + let num_bits_in_block = block_modulus.ilog2(); + + // Just in case we compare with max noise level, but it should always be num_bits_in_blocks + // with the parameters we provide + let grouping_size = (num_bits_in_block as usize).min(self.key.max_noise_level.get()); + + let num_groupings = num_blocks.div_ceil(grouping_size); + + let num_carry_to_resolve = num_groupings - 1; + + let sequential_depth = + (num_carry_to_resolve.saturating_sub(1) as u32) / (grouping_size as u32 - 1); + let hillis_steel_depth = if num_carry_to_resolve == 0 { + 0 + } else { + num_carry_to_resolve.ceil_ilog2() + }; + + let use_sequential_algorithm_to_resolved_grouping_carries = + sequential_depth <= hillis_steel_depth; + + let mut output_flag = None; + + // First step + let (shifted_blocks, block_states) = match requested_flag { + OutputFlag::None => { + let (shifted_blocks, mut block_states) = + self.compute_shifted_blocks_and_block_states(blocks); + let _ = block_states.pop().unwrap(); + (shifted_blocks, block_states) + } + OutputFlag::Overflow => { + let (block, (shifted_blocks, block_states)) = rayon::join( + || { + // When used on the last block of `lhs` and `rhs`, this will create a + // block that encodes the 2 values needed to later know if overflow did + // happen depending on the input carry of the last block. + let lut = self.key.generate_lookup_table_bivariate(|lhs, rhs| { + overflow_flag_preparation_lut(lhs, rhs, num_bits_in_message) + }); + let (last_lhs_block, last_rhs_block) = saved_last_blocks.as_ref().unwrap(); + self.key.unchecked_apply_lookup_table_bivariate( + last_lhs_block, + last_rhs_block, + &lut, + ) + }, + || { + let (shifted_blocks, mut block_states) = + self.compute_shifted_blocks_and_block_states(blocks); + let _ = block_states.pop().unwrap(); + (shifted_blocks, block_states) + }, + ); + + output_flag = Some(block); + (shifted_blocks, block_states) + } + OutputFlag::Carry => { + let (shifted_blocks, mut block_states) = + self.compute_shifted_blocks_and_block_states(blocks); + let last_block_state = block_states.pop().unwrap(); + output_flag = Some(last_block_state); + (shifted_blocks, block_states) + } + }; + + // Second step + let (mut prepared_blocks, groupings_pgns) = { + // This stores, the LUTs that given a cum sum block in the first grouping + // tells if a carry is generated or not + let first_grouping_inner_propagation_luts = (0..grouping_size - 1) + .map(|index| { + self.key.generate_lookup_table(|propa_cum_sum_block| { + let carry = (propa_cum_sum_block >> index) & 1; + if carry != 0 { + 2 // Generates + } else { + 0 // Nothing + } + }) + }) + .collect::>(); + + // This stores, the LUTs that given a cum sum in non first grouping + // tells if a carry is generated or propagated or neither of these + let other_groupings_inner_propagation_luts = (0..grouping_size) + .map(|index| { + self.key.generate_lookup_table(|propa_cum_sum_block| { + let mask = (2 << index) - 1; + if propa_cum_sum_block >= (2 << index) { + 2 // Generates + } else if (propa_cum_sum_block & mask) == mask { + 1 // Propagate + } else { + 0 + } + }) + }) + .collect::>(); + + // This stores the LUT that outputs the propagation result of the first grouping + let first_grouping_outer_propagation_lut = self.key.generate_lookup_table(|block| { + // Check if the last bit of the block is set + (block >> (num_bits_in_block - 1)) & 1 }); - output_carry + + // This stores the LUTs that output the propagation result of the other groupings + let grouping_chunk_pgn_luts = if use_sequential_algorithm_to_resolved_grouping_carries { + // When using the sequential algorithm for the propagation of one grouping to the + // other we need to shift the PGN state to the correct position, so we later, when + // using them only lwe_add is needed and so noise management is easy + // + // Also, these LUTs are 'negacylic', they are made to exploit the padding bit + // resulting blocks from these LUTs must be added the constant `1 << index`. + (0..grouping_size - 1) + .map(|i| { + self.key.generate_lookup_table(|block| { + // All bits set to 1 (e.g. 0b1111), means propagate + if block == (block_modulus - 1) as u64 { + 0 + } else { + // u64::MAX is -1 in tow's complement + // We apply the modulus including the padding bit + (u64::MAX << i) % (1 << (num_bits_in_block + 1)) + } + }) + }) + .collect::>() + } else { + // This LUT is for when we are using Hillis-Steele prefix-scan to propagate carries + // between groupings. When using this propagation, the encoding of the states + // are a bit different. + // + // Also, these LUTs are 'negacylic', they are made to exploit the padding bit + // resulting blocks from these LUTs must be added the constant `1`. + vec![self.key.generate_lookup_table(|block| { + if block == (block_modulus - 1) as u64 { + // All bits set to 1 (e.g. 0b1111), means propagate + 2 + } else { + // u64::MAX is -1 in tow's complement + // We apply the modulus including the padding bit + u64::MAX % (1 << (block_modulus + 1)) + } + })] + }; + + let mut propagation_cum_sums = Vec::with_capacity(num_blocks); + block_states.chunks(grouping_size).for_each(|grouping| { + propagation_cum_sums.push(grouping[0].clone()); + for other in &grouping[1..] { + let mut result = other.clone(); + self.key + .unchecked_add_assign(&mut result, propagation_cum_sums.last().unwrap()); + + propagation_cum_sums.push(result); + } + }); + + // Compute the cum sum arrays, + // each grouping is independent from other groupings + // but we store everything flattened (Vec<_>) instead of nested (Vec>) + propagation_cum_sums + .par_iter_mut() + .enumerate() + .for_each(|(i, cum_sum_block)| { + let grouping_index = i / grouping_size; + let is_in_first_grouping = grouping_index == 0; + let index_in_grouping = i % grouping_size; + + let lut = if is_in_first_grouping { + if index_in_grouping == grouping_size - 1 { + &first_grouping_outer_propagation_lut + } else { + &first_grouping_inner_propagation_luts[index_in_grouping] + } + } else if index_in_grouping == grouping_size - 1 { + if use_sequential_algorithm_to_resolved_grouping_carries { + &grouping_chunk_pgn_luts[(grouping_index - 1) % (grouping_size - 1)] + } else { + &grouping_chunk_pgn_luts[0] + } + } else { + &other_groupings_inner_propagation_luts[index_in_grouping] + }; + + self.key.apply_lookup_table_assign(cum_sum_block, lut); + + let may_have_its_padding_bit_set = + !is_in_first_grouping && index_in_grouping == grouping_size - 1; + if may_have_its_padding_bit_set { + if use_sequential_algorithm_to_resolved_grouping_carries { + self.key.unchecked_scalar_add_assign( + cum_sum_block, + 1 << ((grouping_index - 1) % (grouping_size - 1)), + ); + } else { + self.key.unchecked_scalar_add_assign(cum_sum_block, 1); + } + cum_sum_block.degree = Degree::new(message_modulus as usize - 1); + } + }); + + let num_groupings = num_blocks / grouping_size; + let mut groupings_pgns = Vec::with_capacity(num_groupings); + let mut propagation_simulators = Vec::with_capacity(num_blocks); + + // First block does not get a carry from + propagation_simulators.push(self.key.create_trivial(0)); + for block in propagation_cum_sums.drain(..) { + if propagation_simulators.len() % grouping_size == 0 { + groupings_pgns.push(block); + // The first block in each grouping has its simulator set to 0 + // because it always receives any input borrow that may be generated from + // previous grouping + propagation_simulators.push(self.key.create_trivial(1)); + } else { + propagation_simulators.push(block); + } + } + + let mut prepared_blocks = shifted_blocks; + prepared_blocks + .iter_mut() + .zip(propagation_simulators.iter()) + .for_each(|(block, simulator)| { + self.key.unchecked_add_assign(block, simulator); + }); + + match requested_flag { + OutputFlag::None => {} + OutputFlag::Overflow => { + let block = output_flag.as_mut().unwrap(); + self.key + .unchecked_add_assign(block, &propagation_simulators[num_blocks - 1]); + } + OutputFlag::Carry => { + let block = output_flag.as_mut().unwrap(); + self.key + .unchecked_add_assign(block, &propagation_simulators[num_blocks - 1]); + } + } + + (prepared_blocks, groupings_pgns) + }; + + // Third step: resolving carry propagation between the groups + let resolved_carries = if groupings_pgns.is_empty() { + vec![self.key.create_trivial(0)] + } else if use_sequential_algorithm_to_resolved_grouping_carries { + self.resolve_carries_of_groups_sequentially(groupings_pgns, grouping_size) + } else { + self.resolve_carries_of_groups_using_hillis_steele(groupings_pgns) + }; + + // Final step: adding resolved carries and cleaning result + let mut add_carries_and_cleanup = || { + let message_extract_lut = self + .key + .generate_lookup_table(|block| (block >> 1) % message_modulus); + + prepared_blocks + .par_iter_mut() + .enumerate() + .for_each(|(i, block)| { + let grouping_index = i / grouping_size; + let carry = &resolved_carries[grouping_index]; + self.key.unchecked_add_assign(block, carry); + + self.key + .apply_lookup_table_assign(block, &message_extract_lut) + }); + }; + + match requested_flag { + OutputFlag::None => { + add_carries_and_cleanup(); + } + OutputFlag::Overflow => { + let overflow_flag_lut = self.key.generate_lookup_table(|block| { + let input_carry = (block >> 1) & 1; + let does_overflow_if_carry_is_1 = (block >> 3) & 1; + let does_overflow_if_carry_is_0 = (block >> 2) & 1; + if input_carry == 1 { + does_overflow_if_carry_is_1 + } else { + does_overflow_if_carry_is_0 + } + }); + rayon::join( + || { + let block = output_flag.as_mut().unwrap(); + self.key.unchecked_add_assign( + block, + &resolved_carries[resolved_carries.len() - 1], + ); + self.key + .apply_lookup_table_assign(block, &overflow_flag_lut); + }, + add_carries_and_cleanup, + ); + } + OutputFlag::Carry => { + let carry_flag_lut = self.key.generate_lookup_table(|block| (block >> 2) & 1); + + rayon::join( + || { + let block = output_flag.as_mut().unwrap(); + self.key.unchecked_add_assign( + block, + &resolved_carries[resolved_carries.len() - 1], + ); + self.key.apply_lookup_table_assign(block, &carry_flag_lut); + }, + add_carries_and_cleanup, + ); + } + } + + blocks.clone_from_slice(&prepared_blocks); + + match requested_flag { + OutputFlag::None => None, + OutputFlag::Overflow | OutputFlag::Carry => { + output_flag.map(BooleanBlock::new_unchecked) + } + } } - /// Backbone algorithm of parallel carry (only one bit) propagation + /// This resolves the carries using a Hillis-Steele algorithm /// - /// Uses the Hillis and Steele prefix scan + /// Blocks must have a value in + /// - 2 or 1 for generate + /// - 3 for propagate + /// - 0 for no carry /// - /// Requires the blocks to have at least 4 bits - pub(crate) fn compute_carry_propagation_parallelized_low_latency( + /// The returned Vec of blocks encrypting 1 if a carry is generated, 0 if not + fn resolve_carries_of_groups_using_hillis_steele( &self, - generates_or_propagates: Vec, - ) -> (Vec, Ciphertext) { - if generates_or_propagates.is_empty() { - return (vec![], self.key.create_trivial(0)); - } - - let lut_carry_propagation_sum = self - .key - .generate_lookup_table_bivariate(prefix_sum_carry_propagation); - // Type annotations are required, otherwise we get confusing errors - // "implementation of `FnOnce` is not general enough" - let sum_function = |block_carry: &mut Ciphertext, previous_block_carry: &Ciphertext| { + groupings_pgns: Vec, + ) -> Vec { + let lut_carry_propagation_sum = + self.key + .generate_lookup_table_bivariate(|msb: u64, lsb: u64| -> u64 { + if msb == 2 { + 1 // Remap Generate to 1 + } else if msb == 3 { + // MSB propagates + if lsb == 2 { + 1 + } else { + lsb + } // also remap here + } else { + msb + } + }); + let sum_function = |block_borrow: &mut Ciphertext, previous_block_borrow: &Ciphertext| { self.key.unchecked_apply_lookup_table_bivariate_assign( - block_carry, - previous_block_carry, + block_borrow, + previous_block_borrow, &lut_carry_propagation_sum, ); }; + let mut resolved_carries = + self.compute_prefix_sum_hillis_steele(groupings_pgns, sum_function); + resolved_carries.insert(0, self.key.create_trivial(0)); + resolved_carries + } + + /// This resolves the carries using a sequential algorithm + /// where each iteration resolves grouping_size - 1 "PGN" + /// + /// Blocks must have a value in + /// - 2 for generate + /// - 1 for propagate + /// - 0 for no carry + /// + /// This value must be shifted by the position in the block's group. + /// + /// The block of the first group (so groupings_pgns[0]) must have a value in + /// - 1 for generate + /// - 0 for no carry + /// + /// The returned Vec of blocks encrypting 1 if a carry is generated, 0 if not + fn resolve_carries_of_groups_sequentially( + &self, + mut groupings_pgns: Vec, + grouping_size: usize, + ) -> Vec { + let luts = (0..grouping_size - 1) + .map(|index| { + self.key.generate_lookup_table(|propa_cum_sum_block| { + (propa_cum_sum_block >> (index + 1)) & 1 + }) + }) + .collect::>(); + + groupings_pgns.rotate_left(1); + let mut resolved_carries = vec![self.key.create_trivial(0), groupings_pgns.pop().unwrap()]; + + for chunk in groupings_pgns.chunks(grouping_size - 1) { + let mut cum_sums = chunk.to_vec(); + self.key + .unchecked_add_assign(&mut cum_sums[0], resolved_carries.last().unwrap()); + + if chunk.len() > 1 { + let mut accumulator = cum_sums[0].clone(); + for block in cum_sums[1..].iter_mut() { + self.key.unchecked_add_assign(&mut accumulator, block); + block.clone_from(&accumulator); + } + } + + cum_sums + .par_iter_mut() + .zip(luts.par_iter()) + .for_each(|(cum_sum_block, lut)| { + self.key.apply_lookup_table_assign(cum_sum_block, lut); + }); + + // Cum sums now contains the output carries + resolved_carries.append(&mut cum_sums); + } + + resolved_carries + } + + fn compute_shifted_blocks_and_block_states( + &self, + blocks: &[Ciphertext], + ) -> (Vec, Vec) { + let num_blocks = blocks.len(); + + let message_modulus = self.message_modulus().0 as u64; + + let block_modulus = self.message_modulus().0 * self.carry_modulus().0; + let num_bits_in_block = block_modulus.ilog2(); + + let grouping_size = num_bits_in_block as usize; + + let shift_block_fn = |block| (block % message_modulus) << 1; + let mut first_grouping_luts = vec![{ + let first_block_state_fn = |block| { + if block >= message_modulus { + 1 // Generates + } else { + 0 // Nothing + } + }; + self.key + .generate_many_lookup_table(&[&first_block_state_fn, &shift_block_fn]) + }]; + for i in 1..grouping_size { + let state_fn = |block| { + let r = if block >= message_modulus { + 2 // Generates Carry + } else if block == message_modulus - 1 { + 1 // Propagates a carry + } else { + 0 // Does not borrow + }; + + r << (i - 1) + }; + first_grouping_luts.push( + self.key + .generate_many_lookup_table(&[&state_fn, &shift_block_fn]), + ); + } + + let other_block_state_luts = (0..grouping_size) + .map(|i| { + let state_fn = |block| { + let r = if block >= message_modulus { + 2 // Generates Carry + } else if block == message_modulus - 1 { + 1 // Propagates a carry + } else { + 0 // Does not borrow + }; + + r << i + }; + self.key + .generate_many_lookup_table(&[&state_fn, &shift_block_fn]) + }) + .collect::>(); + + let last_block_luts = { + if blocks.len() == 1 { + let first_block_state_fn = |block| { + if block >= message_modulus { + 2 << 1 // Generates + } else { + 0 // Nothing + } + }; + self.key + .generate_many_lookup_table(&[&first_block_state_fn, &shift_block_fn]) + } else { + first_grouping_luts[2].clone() + } + }; + + let tmp = blocks + .par_iter() + .enumerate() + .map(|(index, block)| { + let grouping_index = index / grouping_size; + let is_in_first_grouping = grouping_index == 0; + let index_in_grouping = index % (grouping_size); + let is_last_index = index == blocks.len() - 1; + + let luts = if is_last_index { + &last_block_luts + } else if is_in_first_grouping { + &first_grouping_luts[index_in_grouping] + } else { + &other_block_state_luts[index_in_grouping] + }; + self.key.apply_many_lookup_table(block, luts) + }) + .collect::>(); + + let mut shifted_blocks = Vec::with_capacity(num_blocks); + let mut block_states = Vec::with_capacity(num_blocks); + for mut blocks in tmp { + assert_eq!(blocks.len(), 2); + shifted_blocks.push(blocks.pop().unwrap()); + block_states.push(blocks.pop().unwrap()); + } - let num_blocks = generates_or_propagates.len(); - let mut carries_out = - self.compute_prefix_sum_hillis_steele(generates_or_propagates, sum_function); - let mut last_block_out_carry = self.key.create_trivial(0); - std::mem::swap(&mut carries_out[num_blocks - 1], &mut last_block_out_carry); - last_block_out_carry.degree = Degree::new(1); - // The output carry of block i-1 becomes the input - // carry of block i - carries_out.rotate_right(1); - (carries_out, last_block_out_carry) + (shifted_blocks, block_states) } /// Computes a prefix sum/scan in parallel using Hillis & Steel algorithm @@ -581,8 +1398,8 @@ impl ServerKey { { debug_assert!(self.key.message_modulus.0 * self.key.carry_modulus.0 >= (1 << 4)); - if blocks.is_empty() { - return vec![]; + if blocks.is_empty() || blocks.len() == 1 { + return blocks; } let num_blocks = blocks.len(); @@ -607,172 +1424,51 @@ impl ServerKey { blocks } +} - /// This add_assign two numbers - /// - /// It is after the Blelloch algorithm to do - /// prefix sum / cumulative sum in parallel. - /// - /// It is not "work efficient" as in, it does not adds - /// that much work compared to other parallel algorithm, - /// thus requiring less threads. - /// - /// However it is slower. - /// - /// At most num_block / 2 threads are used - /// - /// # Requirements - /// - /// - The parameters have 4 bits in total - /// - Adding rhs to lhs must not consume more than one carry - /// - /// # Output - /// - /// - lhs will have its carries empty - pub(crate) fn unchecked_add_assign_parallelized_work_efficient(&self, lhs: &mut T, rhs: &T) - where - T: IntegerRadixCiphertext, - { - let degree_after_add_does_not_go_beyond_first_carry = lhs - .blocks() - .iter() - .zip(rhs.blocks().iter()) - .all(|(bl, br)| { - let degree_after_add = bl.degree.get() + br.degree.get(); - degree_after_add < (self.key.message_modulus.0 * 2) - }); - assert!(degree_after_add_does_not_go_beyond_first_carry); - debug_assert!(self.key.message_modulus.0 * self.key.carry_modulus.0 >= (1 << 3)); - - self.unchecked_add_assign_parallelized(lhs, rhs); - let generates_or_propagates = self.generate_init_carry_array(lhs.blocks()); - let carry_out = - self.compute_carry_propagation_parallelized_work_efficient(generates_or_propagates); - - lhs.blocks_mut() - .par_iter_mut() - .zip(carry_out.par_iter()) - .for_each(|(block, carry_in)| { - self.key.unchecked_add_assign(block, carry_in); - self.key.message_extract_assign(block); - }); - } - - pub(crate) fn compute_carry_propagation_parallelized_work_efficient( - &self, - mut carry_out: Vec, - ) -> Vec { - debug_assert!(self.key.message_modulus.0 * self.key.carry_modulus.0 >= (1 << 3)); - - let num_blocks = carry_out.len(); - let num_steps = carry_out.len().ilog2() as usize; - - let lut_carry_propagation_sum = self - .key - .generate_lookup_table_bivariate(prefix_sum_carry_propagation); - - for i in 0..num_steps { - let two_pow_i_plus_1 = 2usize.checked_pow((i + 1) as u32).unwrap(); - let two_pow_i = 2usize.checked_pow(i as u32).unwrap(); - - carry_out - .par_chunks_exact_mut(two_pow_i_plus_1) - .for_each(|carry_out| { - let (last, head) = carry_out.split_last_mut().unwrap(); - let current_block = last; - let previous_block = &head[two_pow_i - 1]; - - self.key.unchecked_apply_lookup_table_bivariate_assign( - current_block, - previous_block, - &lut_carry_propagation_sum, - ); - }); - } - - // Down-Sweep phase - let mut buffer = Vec::with_capacity(num_blocks / 2); - self.key - .create_trivial_assign(&mut carry_out[num_blocks - 1], 0); - for i in (0..num_steps).rev() { - let two_pow_i_plus_1 = 2usize.checked_pow((i + 1) as u32).unwrap(); - let two_pow_i = 2usize.checked_pow(i as u32).unwrap(); - - (0..num_blocks) - .into_par_iter() - .step_by(two_pow_i_plus_1) - .map(|k| { - // Since our carry_propagation LUT ie sum function - // is not commutative we have to reverse operands - self.key.unchecked_apply_lookup_table_bivariate( - &carry_out[k + two_pow_i - 1], - &carry_out[k + two_pow_i_plus_1 - 1], - &lut_carry_propagation_sum, - ) - }) - .collect_into_vec(&mut buffer); - - let mut drainer = buffer.drain(..); - for k in (0..num_blocks).step_by(two_pow_i_plus_1) { - let b = drainer.next().unwrap(); - carry_out.swap(k + two_pow_i - 1, k + two_pow_i_plus_1 - 1); - carry_out[k + two_pow_i_plus_1 - 1] = b; - } - drop(drainer); - assert!(buffer.is_empty()); - } - - // The first step of the Down-Sweep phase sets the - // first block to 0, so no need to re-do it - carry_out - } - - pub(super) fn generate_init_carry_array(&self, sum_blocks: &[Ciphertext]) -> Vec { - let modulus = self.key.message_modulus.0 as u64; - - // This is used for the first pair of blocks - // as this pair can either generate or not, but never propagate - let lut_does_block_generate_carry = self.key.generate_lookup_table(|x| { - if x >= modulus { - OutputCarry::Generated as u64 - } else { - OutputCarry::None as u64 - } - }); - - let lut_does_block_generate_or_propagate = self.key.generate_lookup_table(|x| { - if x >= modulus { - OutputCarry::Generated as u64 - } else if x == (modulus - 1) { - OutputCarry::Propagated as u64 - } else { - OutputCarry::None as u64 - } - }); - - let mut generates_or_propagates = Vec::with_capacity(sum_blocks.len()); - sum_blocks - .par_iter() - .enumerate() - .map(|(i, block)| { - if i == 0 { - // The first block can only output a carry - self.key - .apply_lookup_table(block, &lut_does_block_generate_carry) - } else { - self.key - .apply_lookup_table(block, &lut_does_block_generate_or_propagate) - } - }) - .collect_into_vec(&mut generates_or_propagates); +/// This function is meant to be used to creat the lookup table that prepares +/// the overflow flag. +// Computing the overflow flag is a bit more complex than the carry flag. +// +// The overflow flag is computed by comparing the input carry onto the last bit +// with the output carry of the last bit. +// +// Since we have blocks that encrypts multiple bit, +// we have to compute and encode what the input carry onto the last +// bit is depending on the input carry onto the last block. +// +// So this function creates a lookuptable that when applied to a block +// packing the last blocks (MSB) of 2 number, will resulting in a block +// where: +// +// - at bit index 2 is stored whether overflow happens if the input bloc carry is '2' +// - at bit index 3 is stored whether overflow happens if the input bloc carry is '1' +fn overflow_flag_preparation_lut( + last_lhs_block: u64, + last_rhs_block: u64, + num_bits_in_message: u64, +) -> u64 { + let mask = (1 << (num_bits_in_message - 1)) - 1; + let lhs_except_last_bit = last_lhs_block & mask; + let rhs_except_last_bit = last_rhs_block & mask; + + let overflows_with_given_input_carry = |input_carry| { + let output_carry = + ((last_lhs_block + last_rhs_block + input_carry) >> num_bits_in_message) & 1; + + let input_carry_to_last_bit = ((lhs_except_last_bit + rhs_except_last_bit + input_carry) + >> (num_bits_in_message - 1)) + & 1; + + u64::from(input_carry_to_last_bit != output_carry) + }; - generates_or_propagates - } + (overflows_with_given_input_carry(1) << 3) | (overflows_with_given_input_carry(0) << 2) } #[cfg(test)] mod tests { - use super::should_hillis_steele_propagation_be_faster; + use super::should_parallel_propagation_be_faster; use crate::integer::gen_keys_radix; use crate::shortint::prelude::PARAM_MESSAGE_2_CARRY_2_KS_PBS; @@ -781,61 +1477,121 @@ mod tests { // Parameters and num blocks do not matter here let (_, sks) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2_KS_PBS, 4); - let carry = sks.propagate_single_carry_parallelized_low_latency([].as_mut_slice()); - + sks.propagate_single_carry_parallelized(&mut []); // The most interesting part we test is that the code does not panic - assert!(carry.is_trivial()); - assert_eq!(carry.decrypt_trivial().unwrap(), 0u64); } #[test] - fn test_hillis_steele_choice_128_threads() { - // m6i.metal like number of threads - const NUM_THREADS: usize = 128; - // 16, 32, 64, 128, 256 512 bits - for num_blocks in [8, 16, 32, 64, 128, 256] { - assert!( - should_hillis_steele_propagation_be_faster(num_blocks, NUM_THREADS), - "Expected hillis and steele to be chosen for {num_blocks} blocks and {NUM_THREADS} threads" - ); + fn test_propagation_choice_ci_run_filter() { + struct ExpectedChoices { + num_threads: usize, + bit_sizes: Vec<(usize, bool)>, } - // 8 bits - assert!(!should_hillis_steele_propagation_be_faster(4, NUM_THREADS),); - } - #[test] - fn test_hillis_steele_choice_12_threads() { - const NUM_THREADS: usize = 12; - // 8, 16, 32, 64, 128, 256, 512 bits - for num_blocks in [4, 8, 16, 32, 64, 128, 256] { - assert!( - !should_hillis_steele_propagation_be_faster(num_blocks, NUM_THREADS), - "Expected hillis and steele to *not* be chosen for {num_blocks} blocks and {NUM_THREADS} threads" - ); + // These cases have been tested in real conditions by running benchmarks for + // add_parallelized with `RAYON_NUM_THREADS` + let cases = [ + ExpectedChoices { + num_threads: 2, + bit_sizes: vec![ + (2, false), + (4, false), + (8, false), + (16, false), + (32, false), + (64, false), + (128, false), + (256, false), + (512, false), + ], + }, + ExpectedChoices { + num_threads: 4, + bit_sizes: vec![ + (2, false), + (4, false), + (8, true), + (16, true), + (32, true), + (64, true), + (128, true), + (256, false), + (512, false), + ], + }, + ExpectedChoices { + num_threads: 8, + bit_sizes: vec![ + (2, false), + (4, false), + (8, true), + (16, true), + (32, true), + (64, true), + (128, true), + (256, false), + (512, false), + ], + }, + ExpectedChoices { + num_threads: 12, + bit_sizes: vec![ + (2, false), + (4, false), + (8, true), + (16, true), + (32, true), + (64, true), + (128, true), + (256, true), + (512, true), + ], + }, + ExpectedChoices { + num_threads: 128, + bit_sizes: vec![ + (2, false), + (4, false), + (8, true), + (16, true), + (32, true), + (64, true), + (128, true), + (256, true), + (512, true), + ], + }, + ]; + + const FULL_MODULUS: usize = 32; // This is 2_2 parameters + + fn bool_to_algo_name(parallel_chosen: bool) -> &'static str { + if parallel_chosen { + "parallel" + } else { + "sequential" + } } - } - #[test] - fn test_hillis_steele_choice_8_threads() { - const NUM_THREADS: usize = 8; - // 8, 16, 32, 64, 128, 256, 512 bits - for num_blocks in [4, 8, 16, 32, 64, 128, 256] { - assert!( - !should_hillis_steele_propagation_be_faster(num_blocks, NUM_THREADS), - "Expected hillis and steele to *not* be chosen for {num_blocks} blocks and {NUM_THREADS} threads" - ); - } - } - - #[test] - fn test_hillis_steele_choice_4_threads() { - const NUM_THREADS: usize = 4; - // 8, 16, 32, 64, 128, 256, 512 bits - for num_blocks in [4, 8, 16, 32, 64, 128, 256] { - assert!( - !should_hillis_steele_propagation_be_faster(num_blocks, NUM_THREADS), - "Expected hillis and steele to *not* be chosen for {num_blocks} blocks and {NUM_THREADS} threads" - ); + for case in cases { + for (bit_size, expect_parallel) in case.bit_sizes { + let num_blocks = bit_size / 2; + let chose_parallel = should_parallel_propagation_be_faster( + FULL_MODULUS, + num_blocks, + case.num_threads, + ); + assert_eq!( + chose_parallel, + expect_parallel, + "Wrong propagation algorithm chosen for {bit_size} bits ({num_blocks} blocks) and {} threads\n\ + Expected '{}' but '{}' was chosen\ + ", + case.num_threads, + bool_to_algo_name(expect_parallel), + bool_to_algo_name(chose_parallel) + ); + } } } } diff --git a/tfhe/src/integer/server_key/radix_parallel/mod.rs b/tfhe/src/integer/server_key/radix_parallel/mod.rs index d86694d213..dff8f30ea2 100644 --- a/tfhe/src/integer/server_key/radix_parallel/mod.rs +++ b/tfhe/src/integer/server_key/radix_parallel/mod.rs @@ -33,6 +33,7 @@ mod vector_find; use super::ServerKey; use crate::integer::ciphertext::IntegerRadixCiphertext; +pub(crate) use add::OutputFlag; use rayon::prelude::*; pub use scalar_div_mod::{MiniUnsignedInteger, Reciprocable}; pub use vector_find::MatchValues; @@ -123,33 +124,26 @@ impl ServerKey { ) }; - if self.is_eligible_for_parallel_single_carry_propagation(ctxt) { + if self.is_eligible_for_parallel_single_carry_propagation(ctxt.blocks().len()) { let highest_degree = ctxt.blocks()[start_index..] .iter() .max_by(|block_a, block_b| block_a.degree.get().cmp(&block_b.degree.get())) .map(|block| block.degree.get()) .unwrap(); // We checked for emptiness earlier - if highest_degree <= (self.key.message_modulus.0 - 1) * 2 { - let _ = self.propagate_single_carry_parallelized_low_latency( - &mut ctxt.blocks_mut()[start_index..], - ); - } else { + + if highest_degree >= (self.key.message_modulus.0 - 1) * 2 { // At least one of the blocks has more than one carry, // we need to extract message and carries, then add + propagate let (mut message_blocks, carry_blocks) = extract_message_and_carry_blocks(&ctxt.blocks()[start_index..]); - ctxt.blocks_mut()[start_index..].swap_with_slice(&mut message_blocks); - for (block, carry) in ctxt.blocks_mut()[start_index + 1..] - .iter_mut() - .zip(carry_blocks.iter()) - { - self.key.unchecked_add_assign(block, carry); - } - // We can start propagation one index later as we already did the first block - let _ = self.propagate_single_carry_parallelized_low_latency( - &mut ctxt.blocks_mut()[start_index + 1..], - ); + ctxt.blocks_mut()[start_index] = message_blocks.remove(0); + let mut lhs = T::from(message_blocks); + let rhs = T::from(carry_blocks); + self.add_assign_with_carry_parallelized(&mut lhs, &rhs, None); + ctxt.blocks_mut()[start_index + 1..].clone_from_slice(lhs.blocks()); + } else { + self.propagate_single_carry_parallelized(&mut ctxt.blocks_mut()[start_index..]); } } else { let maybe_highest_degree = ctxt diff --git a/tfhe/src/integer/server_key/radix_parallel/neg.rs b/tfhe/src/integer/server_key/radix_parallel/neg.rs index c8d36807ee..c9dea00261 100644 --- a/tfhe/src/integer/server_key/radix_parallel/neg.rs +++ b/tfhe/src/integer/server_key/radix_parallel/neg.rs @@ -88,14 +88,8 @@ impl ServerKey { &tmp_ctxt }; - if self.is_eligible_for_parallel_single_carry_propagation(ct) { - let mut ct = self.unchecked_neg(ct); - let _carry = self.propagate_single_carry_parallelized_low_latency(ct.blocks_mut()); - ct - } else { - let mut ct = self.unchecked_neg(ct); - self.full_propagate_parallelized(&mut ct); - ct - } + let mut ct = self.unchecked_neg(ct); + self.full_propagate_parallelized(&mut ct); + ct } } diff --git a/tfhe/src/integer/server_key/radix_parallel/scalar_add.rs b/tfhe/src/integer/server_key/radix_parallel/scalar_add.rs index c900994b54..7176c01a0c 100644 --- a/tfhe/src/integer/server_key/radix_parallel/scalar_add.rs +++ b/tfhe/src/integer/server_key/radix_parallel/scalar_add.rs @@ -16,14 +16,28 @@ impl ServerKey { self.full_propagate_parallelized(lhs); } - self.unchecked_scalar_add_assign(lhs, scalar); - let overflowed = self.unsigned_overflowing_propagate_addition_carry(lhs); + let bits_in_message = self.key.message_modulus.0.ilog2(); + let mut scalar_blocks = BlockDecomposer::with_early_stop_at_zero(scalar, bits_in_message) + .iter_as::() + .map(|v| self.key.create_trivial(u64::from(v))) + .collect::>(); - let num_scalar_block = - BlockDecomposer::with_early_stop_at_zero(scalar, self.key.message_modulus.0.ilog2()) - .count(); + let trivially_overflowed = match scalar_blocks.len().cmp(&lhs.blocks.len()) { + std::cmp::Ordering::Less => { + scalar_blocks.resize_with(lhs.blocks.len(), || self.key.create_trivial(0)); + false + } + std::cmp::Ordering::Equal => false, + std::cmp::Ordering::Greater => { + scalar_blocks.truncate(lhs.blocks.len()); + true + } + }; + + let rhs = RadixCiphertext::from(scalar_blocks); + let overflowed = self.overflowing_add_assign_with_carry(lhs, &rhs, None); - if num_scalar_block > lhs.blocks.len() { + if trivially_overflowed { // Scalar has more blocks so addition counts as overflowing BooleanBlock::new_unchecked(self.key.create_trivial(1)) } else { @@ -255,9 +269,9 @@ impl ServerKey { self.full_propagate_parallelized(ct); }; - if self.is_eligible_for_parallel_single_carry_propagation(ct) { + if self.is_eligible_for_parallel_single_carry_propagation(ct.blocks().len()) { self.unchecked_scalar_add_assign(ct, scalar); - let _carry = self.propagate_single_carry_parallelized_low_latency(ct.blocks_mut()); + self.propagate_single_carry_parallelized(ct.blocks_mut()) } else { self.unchecked_scalar_add_assign(ct, scalar); self.full_propagate_parallelized(ct); diff --git a/tfhe/src/integer/server_key/radix_parallel/scalar_sub.rs b/tfhe/src/integer/server_key/radix_parallel/scalar_sub.rs index a0deda93ca..04d8db15e0 100644 --- a/tfhe/src/integer/server_key/radix_parallel/scalar_sub.rs +++ b/tfhe/src/integer/server_key/radix_parallel/scalar_sub.rs @@ -105,13 +105,18 @@ impl ServerKey { self.full_propagate_parallelized(ct); }; - self.unchecked_scalar_sub_assign(ct, scalar); + let Some(decomposer) = self.create_negated_block_decomposer(scalar) else { + // subtraction by zero + return; + }; - if self.is_eligible_for_parallel_single_carry_propagation(ct) { - let _carry = self.propagate_single_carry_parallelized_low_latency(ct.blocks_mut()); - } else { - self.full_propagate_parallelized(ct); - } + let blocks = decomposer + .take(ct.blocks().len()) + .map(|v| self.key.create_trivial(u64::from(v))) + .collect::>(); + let rhs = T::from_blocks(blocks); + + self.add_assign_with_carry_parallelized(ct, &rhs, None); } pub fn unsigned_overflowing_scalar_sub_assign_parallelized( diff --git a/tfhe/src/integer/server_key/radix_parallel/sub.rs b/tfhe/src/integer/server_key/radix_parallel/sub.rs index 83809612d2..63362124b9 100644 --- a/tfhe/src/integer/server_key/radix_parallel/sub.rs +++ b/tfhe/src/integer/server_key/radix_parallel/sub.rs @@ -1,8 +1,5 @@ -use super::add::OutputCarry; use crate::integer::ciphertext::IntegerRadixCiphertext; -use crate::integer::{ - BooleanBlock, IntegerCiphertext, RadixCiphertext, ServerKey, SignedRadixCiphertext, -}; +use crate::integer::{BooleanBlock, RadixCiphertext, ServerKey, SignedRadixCiphertext}; use crate::shortint::ciphertext::Degree; use crate::shortint::Ciphertext; use rayon::prelude::*; @@ -20,13 +17,6 @@ enum BorrowGeneration { Propagated = 2, } -// see [ServerKey::generate_last_block_inner_propagation] -#[derive(Copy, Clone, PartialEq, Eq)] -pub(crate) enum SignedOperation { - Addition, - Subtraction, -} - impl ServerKey { /// Computes homomorphically the subtraction between ct_left and ct_right. /// @@ -235,56 +225,8 @@ impl ServerKey { } }; - if self.is_eligible_for_parallel_single_carry_propagation(lhs) { - let neg = self.unchecked_neg(rhs); - let _carry = self.unchecked_add_assign_parallelized_low_latency(lhs, &neg); - } else { - self.unchecked_sub_assign(lhs, rhs); - self.full_propagate_parallelized(lhs); - } - } - - pub fn sub_parallelized_work_efficient(&self, ctxt_left: &T, ctxt_right: &T) -> T - where - T: IntegerRadixCiphertext, - { - let mut ct_res = ctxt_left.clone(); - self.sub_assign_parallelized_work_efficient(&mut ct_res, ctxt_right); - ct_res - } - - pub fn sub_assign_parallelized_work_efficient(&self, ctxt_left: &mut T, ctxt_right: &T) - where - T: IntegerRadixCiphertext, - { - let mut tmp_rhs; - - let (lhs, rhs) = match ( - ctxt_left.block_carries_are_empty(), - ctxt_right.block_carries_are_empty(), - ) { - (true, true) => (ctxt_left, ctxt_right), - (true, false) => { - tmp_rhs = ctxt_right.clone(); - self.full_propagate_parallelized(&mut tmp_rhs); - (ctxt_left, &tmp_rhs) - } - (false, true) => { - self.full_propagate_parallelized(ctxt_left); - (ctxt_left, ctxt_right) - } - (false, false) => { - tmp_rhs = ctxt_right.clone(); - rayon::join( - || self.full_propagate_parallelized(ctxt_left), - || self.full_propagate_parallelized(&mut tmp_rhs), - ); - (ctxt_left, &tmp_rhs) - } - }; - let neg = self.unchecked_neg(rhs); - self.unchecked_add_assign_parallelized_work_efficient(lhs, &neg); + self.add_assign_with_carry_parallelized(lhs, &neg, None); } /// Computes the subtraction and returns an indicator of overflow @@ -369,7 +311,7 @@ impl ServerKey { ); // Here we have to use manual unchecked_sub on shortint blocks // rather than calling integer's unchecked_sub as we need each subtraction - // to be independent from other blocks. And we don't want to do subtraction by + // to be independent of other blocks. And we don't want to do subtraction by // adding negation let ct = lhs .blocks @@ -396,7 +338,7 @@ impl ServerKey { &self, ct: &mut RadixCiphertext, ) -> BooleanBlock { - if self.is_eligible_for_parallel_single_carry_propagation(ct) { + if self.is_eligible_for_parallel_single_carry_propagation(ct.blocks.len()) { let generates_or_propagates = self.generate_init_borrow_array(ct); let (input_borrows, mut output_borrow) = self.compute_borrow_propagation_parallelized_low_latency(generates_or_propagates); @@ -454,217 +396,6 @@ impl ServerKey { } } - // This is used in signed overflow detection - // see [unchecked_signed_overflowing_sub_parallelized] for more context - // - // This is to share the logic between the fully parallelized and - // semi parallelized algorithms. - // - // - last_lhs_block: last block of the lhs used in signed subtraction - // - last_rhs_block: last block the rhs used in signed subtraction - // - // Returns a block to be used as one of the inputs of [resolve_signed_overflow] - pub(crate) fn generate_last_block_inner_propagation( - &self, - last_lhs_block: &Ciphertext, - last_rhs_block: &Ciphertext, - op: SignedOperation, - ) -> Ciphertext { - let bits_of_message = self.key.message_modulus.0.ilog2(); - let message_bit_mask = (1 << bits_of_message) - 1; - - // This lut will generate a block that contains the information - // of how carry propagation happens in the last block, until the last bit. - let last_block_inner_propagation_lut = - self.key - .generate_lookup_table_bivariate(|lhs_block, rhs_block| { - let rhs_block = if op == SignedOperation::Subtraction { - // subtraction is done by doing addition of negation - // negation(x) = bit_flip(x) + 1 - // We only add the flipped value, the + 1 will be resolved by - // carry propagation computation - let flipped_rhs = !rhs_block; - - // We remove the last bit, its not interesting in this step - (flipped_rhs << 1) & message_bit_mask - } else { - (rhs_block << 1) & message_bit_mask - }; - - let lhs_block = (lhs_block << 1) & message_bit_mask; - - // whole_result contains the result of addition with - // the carry being in the first bit of carry space - // the message space contains the message, but with one 0 - // on the right (lsb) - let whole_result = lhs_block + rhs_block; - let carry = whole_result >> bits_of_message; - let result = (whole_result & message_bit_mask) >> 1; - let propagation_result = if carry == 1 { - // Addition of bits before last one generates a carry - OutputCarry::Generated - } else if result == ((self.key.message_modulus.0 as u64 - 1) >> 1) { - // Addition of bits before last one puts the bits - // in a state that makes it so that an input carry into last block - // gets propagated to last bit. - OutputCarry::Propagated - } else { - OutputCarry::None - }; - - // Shift the propagation result in carry part - // to have less noise growth later - (propagation_result as u64) << bits_of_message - }); - self.key.unchecked_apply_lookup_table_bivariate( - last_lhs_block, - last_rhs_block, - &last_block_inner_propagation_lut, - ) - } - - // - last_block_inner_propagation must be the result of generate_last_block_inner_propagation - // - last_block_input_carry: carry that the last pair of blocks (lhs, rhs) receives as input - // - last_block_output_carry: carry that the last pair of blocks (lhs, rhs) output - // - // Returns whether the subtraction overflowed - // - // See [unchecked_signed_overflowing_sub_parallelized] for more context - pub(crate) fn resolve_signed_overflow( - &self, - mut last_block_inner_propagation: Ciphertext, - last_block_input_carry: &BooleanBlock, - last_block_output_carry: &BooleanBlock, - ) -> BooleanBlock { - let bits_of_message = self.key.message_modulus.0.ilog2(); - - let resolve_overflow_lut = self.key.generate_lookup_table(|x| { - let carry_propagation = x >> bits_of_message; - let output_carry_of_block = (x >> 1) & 1; - let input_carry_of_block = x & 1; - - // Resolve the carry that the last bit actually receives as input - let input_carry_to_last_bit = if carry_propagation == OutputCarry::Propagated as u64 { - input_carry_of_block - } else if carry_propagation == OutputCarry::Generated as u64 { - 1 - } else { - 0 - }; - - u64::from(input_carry_to_last_bit != output_carry_of_block) - }); - - let x = self - .key - .unchecked_scalar_mul(last_block_output_carry.as_ref(), 2); - self.key - .unchecked_add_assign(&mut last_block_inner_propagation, &x); - self.key.unchecked_add_assign( - &mut last_block_inner_propagation, - last_block_input_carry.as_ref(), - ); - let result = self - .key - .apply_lookup_table(&last_block_inner_propagation, &resolve_overflow_lut); - BooleanBlock::new_unchecked(result) - } - - // This is the implementation of overflowing add/sub when we can use parallel carry - // propagation, as only a few things change between the two. - pub(crate) fn unchecked_signed_overflowing_add_or_sub_parallelized_impl( - &self, - lhs: &SignedRadixCiphertext, - rhs: &SignedRadixCiphertext, - signed_operation: SignedOperation, - ) -> (SignedRadixCiphertext, BooleanBlock) { - // This assert is here because this overflow computation requires these preconditions - // which is_eligible_for_parallel_single_carry_propagation, but it could change in the - // future - assert!(self.key.message_modulus.0 >= 4 && self.key.carry_modulus.0 >= 4); - - // In Two's complement arithmetic, overflow occurs when the output carry of the - // last bit is not the same as the input carry of the last bit. - // - // Here we have blocks, and we cannot just compare input and output carries of the last - // block as its not equivalent to checking what happens on the last bit. - // So we have to resolve that carry propagation that happens in the last block. - // - // So the carry propagation is done in 2 steps, first we compute the carry propagation - // in the last block to be able at the second step, to know the actual carry that - // the last bit receives. - // - // These are done in parallel to other stuff, and so no additional 'latency cost' - // should occur. - - let mut result = lhs.clone(); - - // Using parallel algorithms for unchecked_add/sub does not seem to bring - // measurable improvements - if signed_operation == SignedOperation::Subtraction { - self.unchecked_sub_assign(&mut result, rhs); - } else { - self.unchecked_add_assign(&mut result, rhs); - } - - let ((input_carries, output_carry), last_block_inner_propagation) = rayon::join( - || { - let generates_or_propagates = self.generate_init_carry_array(result.blocks()); - self.compute_carry_propagation_parallelized_low_latency(generates_or_propagates) - }, - || { - self.generate_last_block_inner_propagation( - lhs.blocks.last().as_ref().unwrap(), - rhs.blocks.last().as_ref().unwrap(), - signed_operation, - ) - }, - ); - - let (_, overflowed) = rayon::join( - || { - result - .blocks - .par_iter_mut() - .zip(input_carries.par_iter()) - .for_each(|(block, input_carry)| { - self.key.unchecked_add_assign(block, input_carry); - self.key.message_extract_assign(block); - }); - }, - || { - let input_carry = input_carries - .last() - .cloned() - .map(BooleanBlock::new_unchecked) - .unwrap(); - let output_carry = BooleanBlock::new_unchecked(output_carry); - self.resolve_signed_overflow( - last_block_inner_propagation, - &input_carry, - &output_carry, - ) - }, - ); - - (result, overflowed) - } - - // It is in its own function so that it can be tested, as the main entry point - // unchecked_signed_overflowing_sub may select non parallel version if lhs - // does not have enough block. - pub(crate) fn unchecked_signed_overflowing_sub_parallelized_impl( - &self, - lhs: &SignedRadixCiphertext, - rhs: &SignedRadixCiphertext, - ) -> (SignedRadixCiphertext, BooleanBlock) { - self.unchecked_signed_overflowing_add_or_sub_parallelized_impl( - lhs, - rhs, - SignedOperation::Subtraction, - ) - } - pub fn unchecked_signed_overflowing_sub_parallelized( &self, lhs: &SignedRadixCiphertext, @@ -679,11 +410,20 @@ impl ServerKey { rhs.blocks.len() ); - if self.is_eligible_for_parallel_single_carry_propagation(lhs) { - self.unchecked_signed_overflowing_sub_parallelized_impl(lhs, rhs) - } else { - self.unchecked_signed_overflowing_sub(lhs, rhs) - } + // We are using two's complement for signed numbers, + // we do the subtraction by adding the negation of rhs. + // But to be able to get the correct overflow flag, we need to + // comute (result, overflow) = (lhs + bitnot(rhs) + 1) instead of + // (result, overflow) = (lhs + (-rhs). We need the bitnot(rhs) and +1 + // 'separated' + // + // Remainder: in two's complement -rhs = bitnot(rhs) + 1 + let flipped_rhs = self.bitnot(rhs); + let input_carry = self.create_trivial_boolean_block(true); + let mut result = lhs.clone(); + let overflowed = + self.overflowing_add_assign_with_carry(&mut result, &flipped_rhs, Some(&input_carry)); + (result, overflowed) } pub(super) fn generate_init_borrow_array(&self, sum_ct: &RadixCiphertext) -> Vec { diff --git a/tfhe/src/integer/server_key/radix_parallel/tests_cases_unsigned.rs b/tfhe/src/integer/server_key/radix_parallel/tests_cases_unsigned.rs index 3d31198a09..9847cb44aa 100644 --- a/tfhe/src/integer/server_key/radix_parallel/tests_cases_unsigned.rs +++ b/tfhe/src/integer/server_key/radix_parallel/tests_cases_unsigned.rs @@ -1881,6 +1881,9 @@ where clear = (clear_0 + clear_1) % modulus; + let dec_res: u64 = cks.decrypt(&ct_res); + assert_eq!(clear, dec_res); + // Add multiple times to raise the degree for _ in 0..nb_tests_smaller { let tmp = executor.execute((&ct_res, clear_1)); @@ -2346,7 +2349,11 @@ where let ct_res = executor.execute((&ct, scalar)); let dec_res: u128 = cks.decrypt(&ct_res); - assert_eq!(clear.wrapping_mul(scalar as u128), dec_res); + assert_eq!( + clear.wrapping_mul(scalar as u128), + dec_res, + "Invalid result {clear} * {scalar}" + ); } pub(crate) fn default_scalar_bitand_test(param: P, mut executor: T) diff --git a/tfhe/src/integer/server_key/radix_parallel/tests_signed/mod.rs b/tfhe/src/integer/server_key/radix_parallel/tests_signed/mod.rs index 341189592c..074c4971b6 100644 --- a/tfhe/src/integer/server_key/radix_parallel/tests_signed/mod.rs +++ b/tfhe/src/integer/server_key/radix_parallel/tests_signed/mod.rs @@ -813,9 +813,12 @@ fn integer_signed_default_scalar_div_rem(param: impl Into) { // Make the degree non-fresh let offset = random_non_zero_value(&mut rng, modulus); + println!("offset: {offset}"); sks.unchecked_scalar_add_assign(&mut ctxt_0, offset); clear_lhs = signed_add_under_modulus(clear_lhs, offset, modulus); assert!(!ctxt_0.block_carries_are_empty()); + let sanity_decryption: i64 = cks.decrypt_signed_radix(&ctxt_0); + assert_eq!(sanity_decryption, clear_lhs); let (q_res, r_res) = sks.signed_scalar_div_rem_parallelized(&ctxt_0, clear_rhs); let q: i64 = cks.decrypt_signed_radix(&q_res); diff --git a/tfhe/src/integer/server_key/radix_parallel/tests_signed/test_sub.rs b/tfhe/src/integer/server_key/radix_parallel/tests_signed/test_sub.rs index 7de7ac574b..d44e9b1433 100644 --- a/tfhe/src/integer/server_key/radix_parallel/tests_signed/test_sub.rs +++ b/tfhe/src/integer/server_key/radix_parallel/tests_signed/test_sub.rs @@ -20,25 +20,6 @@ use std::sync::Arc; create_parametrized_test!(integer_signed_unchecked_sub); create_parametrized_test!(integer_signed_unchecked_overflowing_sub); -create_parametrized_test!( - integer_signed_unchecked_overflowing_sub_parallelized { - coverage => { - COVERAGE_PARAM_MESSAGE_2_CARRY_2_KS_PBS, - COVERAGE_PARAM_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_2_KS_PBS, - }, - no_coverage => { - // Requires 4 bits, so 1_1 parameters are not supported - // until they get their own version of the algorithm - PARAM_MESSAGE_2_CARRY_2_KS_PBS, - PARAM_MESSAGE_3_CARRY_3_KS_PBS, - PARAM_MESSAGE_4_CARRY_4_KS_PBS, - PARAM_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_2_KS_PBS, - PARAM_MULTI_BIT_MESSAGE_3_CARRY_3_GROUP_2_KS_PBS, - PARAM_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_3_KS_PBS, - PARAM_MULTI_BIT_MESSAGE_3_CARRY_3_GROUP_3_KS_PBS, - } - } -); create_parametrized_test!(integer_signed_default_sub); create_parametrized_test!(integer_signed_default_overflowing_sub); @@ -58,17 +39,6 @@ where signed_unchecked_overflowing_sub_test(param, executor); } -fn integer_signed_unchecked_overflowing_sub_parallelized

(param: P) -where - P: Into, -{ - // Call _impl so we are sure the parallel version is tested - // However this only supports param X_X where X >= 4 - let executor = - CpuFunctionExecutor::new(&ServerKey::unchecked_signed_overflowing_sub_parallelized_impl); - signed_unchecked_overflowing_sub_test(param, executor); -} - fn integer_signed_default_sub

(param: P) where P: Into, diff --git a/tfhe/src/integer/server_key/radix_parallel/tests_unsigned/mod.rs b/tfhe/src/integer/server_key/radix_parallel/tests_unsigned/mod.rs index 3dd65b7855..ed23a5580d 100644 --- a/tfhe/src/integer/server_key/radix_parallel/tests_unsigned/mod.rs +++ b/tfhe/src/integer/server_key/radix_parallel/tests_unsigned/mod.rs @@ -255,18 +255,19 @@ where let cks = cks.as_ref(); let max_degree_acceptable = cks.key.parameters.message_modulus().0 - 1; + let num_blocks = ct.blocks.len(); for (i, block) in ct.blocks.iter().enumerate() { assert_eq!( block.noise_level, NoiseLevel::NOMINAL, - "Block at index {i} has a non nominal noise level: {:?}", + "Block at index {i} / {num_blocks} has a non nominal noise level: {:?}", block.noise_level ); assert!( block.degree.get() <= max_degree_acceptable, - "Block at index {i} has a degree {:?} that exceeds the maximum ({}) for a clean block", + "Block at index {i} / {num_blocks} has a degree {:?} that exceeds the maximum ({}) for a clean block", block.degree, max_degree_acceptable ); diff --git a/tfhe/src/integer/server_key/radix_parallel/tests_unsigned/test_add.rs b/tfhe/src/integer/server_key/radix_parallel/tests_unsigned/test_add.rs index 3f048b7844..f4b39e2494 100644 --- a/tfhe/src/integer/server_key/radix_parallel/tests_unsigned/test_add.rs +++ b/tfhe/src/integer/server_key/radix_parallel/tests_unsigned/test_add.rs @@ -6,6 +6,7 @@ use super::{ }; use crate::integer::keycache::KEY_CACHE; use crate::integer::server_key::radix_parallel::tests_cases_unsigned::FunctionExecutor; +use crate::integer::server_key::radix_parallel::OutputFlag; use crate::integer::tests::create_parametrized_test; use crate::integer::{BooleanBlock, IntegerKeyKind, RadixCiphertext, RadixClientKey, ServerKey}; #[cfg(tarpaulin)] @@ -14,29 +15,29 @@ use crate::shortint::parameters::*; use rand::Rng; use std::sync::Arc; +create_parametrized_test!(integer_unchecked_add); +create_parametrized_test!(integer_unchecked_add_assign); create_parametrized_test!(integer_smart_add); create_parametrized_test!(integer_default_add); create_parametrized_test!(integer_default_overflowing_add); -create_parametrized_test!(integer_unchecked_add); -create_parametrized_test!(integer_unchecked_add_assign); -create_parametrized_test!( - integer_default_add_work_efficient { - coverage => { - COVERAGE_PARAM_MESSAGE_2_CARRY_2_KS_PBS, - COVERAGE_PARAM_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_2_KS_PBS, - }, - no_coverage => { - // This algorithm requires 3 bits - PARAM_MESSAGE_2_CARRY_2_KS_PBS, - PARAM_MESSAGE_3_CARRY_3_KS_PBS, - PARAM_MESSAGE_4_CARRY_4_KS_PBS, - PARAM_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_2_KS_PBS, - PARAM_MULTI_BIT_MESSAGE_3_CARRY_3_GROUP_2_KS_PBS, - PARAM_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_3_KS_PBS, - PARAM_MULTI_BIT_MESSAGE_3_CARRY_3_GROUP_3_KS_PBS, - } +create_parametrized_test!(integer_advanced_add_assign_with_carry_at_least_4_bits { + coverage => { + COVERAGE_PARAM_MESSAGE_2_CARRY_2_KS_PBS, + COVERAGE_PARAM_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_2_KS_PBS + }, + no_coverage => { + PARAM_MESSAGE_2_CARRY_2_KS_PBS, + PARAM_MESSAGE_3_CARRY_3_KS_PBS, + PARAM_MESSAGE_4_CARRY_4_KS_PBS, + PARAM_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_2_KS_PBS, + PARAM_MULTI_BIT_MESSAGE_3_CARRY_3_GROUP_2_KS_PBS, + PARAM_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_3_KS_PBS, + PARAM_MULTI_BIT_MESSAGE_3_CARRY_3_GROUP_3_KS_PBS } -); +}); +create_parametrized_test!(integer_advanced_add_assign_with_carry_sequential); + +const MAX_NB_CTXT: usize = 8; fn integer_unchecked_add

(param: P) where @@ -70,11 +71,41 @@ where default_add_test(param, executor); } -fn integer_default_add_work_efficient

(param: P) +fn integer_advanced_add_assign_with_carry_at_least_4_bits

(param: P) +where + P: Into, +{ + // We explicitly call the 4 bit function to make sure it's being tested, + // no matter the number of blocks / threads available + let func = |sks: &ServerKey, lhs: &RadixCiphertext, rhs: &RadixCiphertext| { + let mut result = lhs.clone(); + sks.advanced_add_assign_with_carry_at_least_4_bits( + &mut result.blocks, + &rhs.blocks, + None, + OutputFlag::None, + ); + result + }; + let executor = CpuFunctionExecutor::new(&func); + default_add_test(param, executor); +} + +fn integer_advanced_add_assign_with_carry_sequential

(param: P) where P: Into, { - let executor = CpuFunctionExecutor::new(&ServerKey::add_parallelized_work_efficient); + let func = |sks: &ServerKey, lhs: &RadixCiphertext, rhs: &RadixCiphertext| { + let mut result = lhs.clone(); + sks.advanced_add_assign_with_carry_sequential_parallelized( + &mut result.blocks, + &rhs.blocks, + None, + OutputFlag::None, + ); + result + }; + let executor = CpuFunctionExecutor::new(&func); default_add_test(param, executor); } @@ -301,18 +332,18 @@ where let mut rng = rand::thread_rng(); - let modulus = unsigned_modulus(cks.parameters().message_modulus(), NB_CTXT as u32); - executor.setup(&cks, sks); let mut clear; - for _ in 0..nb_tests_smaller { + for num_blocks in 1..MAX_NB_CTXT { + let modulus = unsigned_modulus(cks.parameters().message_modulus(), num_blocks as u32); + let clear_0 = rng.gen::() % modulus; let clear_1 = rng.gen::() % modulus; - let ctxt_0 = cks.encrypt(clear_0); - let ctxt_1 = cks.encrypt(clear_1); + let ctxt_0 = cks.as_ref().encrypt_radix(clear_0, num_blocks); + let ctxt_1 = cks.as_ref().encrypt_radix(clear_1, num_blocks); let mut ct_res = executor.execute((&ctxt_0, &ctxt_1)); let tmp_ct = executor.execute((&ctxt_0, &ctxt_1)); @@ -322,16 +353,25 @@ where clear = clear_0.wrapping_add(clear_1) % modulus; let dec_res: u64 = cks.decrypt(&ct_res); - assert_eq!(clear, dec_res); + assert_eq!( + clear, dec_res, + "Invalid result for {clear_0} + {clear_1}, expected: {clear}, got: {dec_res}\n\ + num_blocks={num_blocks}, modulus={modulus}" + ); for _ in 0..nb_tests_smaller { ct_res = executor.execute((&ct_res, &ctxt_0)); panic_if_any_block_is_not_clean(&ct_res, &cks); - clear = (clear + clear_0) % modulus; + let result = (clear + clear_0) % modulus; let dec_res: u64 = cks.decrypt(&ct_res); - assert_eq!(clear, dec_res); + assert_eq!( + result, dec_res, + "Invalid result for {clear} + {clear_0}, expected: {result}, got: {dec_res}\n\ + num_blocks={num_blocks}, modulus={modulus}" + ); + clear = result; } } } diff --git a/tfhe/src/integer/server_key/radix_parallel/tests_unsigned/test_sub.rs b/tfhe/src/integer/server_key/radix_parallel/tests_unsigned/test_sub.rs index 1a0f568a7e..0034d68a6a 100644 --- a/tfhe/src/integer/server_key/radix_parallel/tests_unsigned/test_sub.rs +++ b/tfhe/src/integer/server_key/radix_parallel/tests_unsigned/test_sub.rs @@ -18,24 +18,6 @@ use std::sync::Arc; create_parametrized_test!(integer_unchecked_sub); create_parametrized_test!(integer_smart_sub); create_parametrized_test!(integer_default_sub); -create_parametrized_test!( - integer_default_sub_work_efficient { - coverage => { - COVERAGE_PARAM_MESSAGE_2_CARRY_2_KS_PBS, - COVERAGE_PARAM_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_2_KS_PBS, - }, - no_coverage => { - // This algorithm requires 3 bits - PARAM_MESSAGE_2_CARRY_2_KS_PBS, - PARAM_MESSAGE_3_CARRY_3_KS_PBS, - PARAM_MESSAGE_4_CARRY_4_KS_PBS, - PARAM_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_2_KS_PBS, - PARAM_MULTI_BIT_MESSAGE_3_CARRY_3_GROUP_2_KS_PBS, - PARAM_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_3_KS_PBS, - PARAM_MULTI_BIT_MESSAGE_3_CARRY_3_GROUP_3_KS_PBS, - } - } -); create_parametrized_test!(integer_default_overflowing_sub); fn integer_unchecked_sub

(param: P) @@ -62,14 +44,6 @@ where default_sub_test(param, executor); } -fn integer_default_sub_work_efficient

(param: P) -where - P: Into, -{ - let executor = CpuFunctionExecutor::new(&ServerKey::sub_parallelized_work_efficient); - default_sub_test(param, executor); -} - fn integer_default_overflowing_sub

(param: P) where P: Into,