From 72ad76b5e7918ac3abe4ce3706972e96edcb6e14 Mon Sep 17 00:00:00 2001 From: tmontaigu Date: Thu, 12 Sep 2024 12:52:08 +0200 Subject: [PATCH] fix(integer): do sum by safe chunk sizes Parameters are made with with assumptions on the number of leveled add/sub/scalar_mul operations are made, so that the noise level before doing a PBS has a correct level and everything is safe, secure and correct. So the lib implementation has to uphold these assumptions in order to keep the error probability failure correct. In the comparisons, at some point we had a vector of ciphertexts with a degree == 1, so we greedily summed them (e.g with 2_2 params we summed them by chunks of 15), while it is correct with regards to the carry and message space it is however less correct with regards to the noise level. Noise wise, doing this huge sum is correct as long as the noise of each ciphertext is independent from the others in the same chunk. While it may generally be the case we are in, its not guaranteed, and since we do not track that information we have to take the safer approach of assuming the worst case: all noise are dependent. So to fix the issue we compute the correct size of sum chunk by also taking into account the max noise level. --- .../tfhe-cuda-backend/cuda/include/integer.h | 2 +- .../cuda/src/integer/comparison.cuh | 4 +-- tfhe/src/integer/server_key/mod.rs | 18 +++++++++- .../integer/server_key/radix/comparison.rs | 25 ++++++------- .../server_key/radix_parallel/comparison.rs | 34 +++--------------- .../radix_parallel/scalar_comparison.rs | 36 ++++++++----------- 6 files changed, 49 insertions(+), 70 deletions(-) diff --git a/backends/tfhe-cuda-backend/cuda/include/integer.h b/backends/tfhe-cuda-backend/cuda/include/integer.h index 3e35cc8e73..872ec7810b 100644 --- a/backends/tfhe-cuda-backend/cuda/include/integer.h +++ b/backends/tfhe-cuda-backend/cuda/include/integer.h @@ -2095,7 +2095,7 @@ template struct int_are_all_block_true_buffer { if (allocate_gpu_memory) { Torus total_modulus = params.message_modulus * params.carry_modulus; - uint32_t max_value = total_modulus - 1; + uint32_t max_value = (total_modulus - 1) / (params.message_modulus - 1); int max_chunks = (num_radix_blocks + max_value - 1) / max_value; tmp_block_accumulated = (Torus *)cuda_malloc_async( diff --git a/backends/tfhe-cuda-backend/cuda/src/integer/comparison.cuh b/backends/tfhe-cuda-backend/cuda/src/integer/comparison.cuh index ed2a3bbef5..3b288f2283 100644 --- a/backends/tfhe-cuda-backend/cuda/src/integer/comparison.cuh +++ b/backends/tfhe-cuda-backend/cuda/src/integer/comparison.cuh @@ -74,7 +74,7 @@ __host__ void are_all_comparisons_block_true( auto tmp_out = are_all_block_true_buffer->tmp_out; uint32_t total_modulus = message_modulus * carry_modulus; - uint32_t max_value = total_modulus - 1; + uint32_t max_value = (total_modulus - 1) / (message_modulus - 1); cuda_memcpy_async_gpu_to_gpu(tmp_out, lwe_array_in, num_radix_blocks * (big_lwe_dimension + 1) * @@ -173,7 +173,7 @@ __host__ void is_at_least_one_comparisons_block_true( auto buffer = mem_ptr->eq_buffer->are_all_block_true_buffer; uint32_t total_modulus = message_modulus * carry_modulus; - uint32_t max_value = total_modulus - 1; + uint32_t max_value = (total_modulus - 1) / (message_modulus - 1); cuda_memcpy_async_gpu_to_gpu(mem_ptr->tmp_lwe_array_out, lwe_array_in, num_radix_blocks * (big_lwe_dimension + 1) * diff --git a/tfhe/src/integer/server_key/mod.rs b/tfhe/src/integer/server_key/mod.rs index cdc3fd4a36..dd382d9e33 100644 --- a/tfhe/src/integer/server_key/mod.rs +++ b/tfhe/src/integer/server_key/mod.rs @@ -9,7 +9,7 @@ pub(crate) mod radix; pub(crate) mod radix_parallel; use crate::integer::client_key::ClientKey; -use crate::shortint::ciphertext::MaxDegree; +use crate::shortint::ciphertext::{Degree, MaxDegree}; use serde::{Deserialize, Serialize}; use tfhe_versionable::Versionize; @@ -227,6 +227,22 @@ impl ServerKey { num_bits_to_represent_output_value.div_ceil(num_bits_in_message as usize) } + + /// Returns how many ciphertext can be summed at once + /// + /// The number of ciphertext that can be added together depends on the degree + /// (in order not to go beyond the carry space and keep results correct) but also + /// on the noise level (in order to have the correct error probability and so correctness and + /// security) + /// + /// - `degree` is expected degree of all elements to be summed + pub(crate) fn max_sum_size(&self, degree: Degree) -> usize { + let max_degree = + MaxDegree::from_msg_carry_modulus(self.message_modulus(), self.carry_modulus()); + let max_sum_to_full_carry = max_degree.get() / degree.get(); + + max_sum_to_full_carry.min(self.key.max_noise_level.get()) + } } impl AsRef for ServerKey { diff --git a/tfhe/src/integer/server_key/radix/comparison.rs b/tfhe/src/integer/server_key/radix/comparison.rs index bf570567d1..a6e82b4e6c 100644 --- a/tfhe/src/integer/server_key/radix/comparison.rs +++ b/tfhe/src/integer/server_key/radix/comparison.rs @@ -2,6 +2,7 @@ use super::ServerKey; use crate::integer::ciphertext::boolean_value::BooleanBlock; use crate::integer::ciphertext::IntegerRadixCiphertext; use crate::integer::server_key::comparator::Comparator; +use crate::shortint::ciphertext::Degree; impl ServerKey { /// Compares for equality 2 ciphertexts @@ -53,30 +54,27 @@ impl ServerKey { .unchecked_apply_lookup_table_bivariate_assign(lhs_block, rhs_block, &lut); }); - let message_modulus = self.key.message_modulus.0; - let carry_modulus = self.key.carry_modulus.0; - let total_modulus = message_modulus * carry_modulus; - let max_value = total_modulus - 1; + let max_sum_size = self.max_sum_size(Degree::new(1)); let is_max_value = self .key - .generate_lookup_table(|x| u64::from((x & max_value as u64) == max_value as u64)); + .generate_lookup_table(|x| u64::from(x == max_sum_size as u64)); while block_comparisons.len() > 1 { block_comparisons = block_comparisons - .chunks(max_value) + .chunks(max_sum_size) .map(|blocks| { let mut sum = blocks[0].clone(); for other_block in &blocks[1..] { self.key.unchecked_add_assign(&mut sum, other_block); } - if blocks.len() == max_value { + if blocks.len() == max_sum_size { self.key.apply_lookup_table(&sum, &is_max_value) } else { - let is_equal_to_num_blocks = self.key.generate_lookup_table(|x| { - u64::from((x & max_value as u64) == blocks.len() as u64) - }); + let is_equal_to_num_blocks = self + .key + .generate_lookup_table(|x| u64::from(x == blocks.len() as u64)); self.key.apply_lookup_table(&sum, &is_equal_to_num_blocks) } }) @@ -112,15 +110,12 @@ impl ServerKey { .unchecked_apply_lookup_table_bivariate_assign(lhs_block, rhs_block, &lut); }); - let message_modulus = self.key.message_modulus.0; - let carry_modulus = self.key.carry_modulus.0; - let total_modulus = message_modulus * carry_modulus; - let max_value = total_modulus - 1; + let max_sum_size = self.max_sum_size(Degree::new(1)); let is_non_zero = self.key.generate_lookup_table(|x| u64::from(x != 0)); while block_comparisons.len() > 1 { block_comparisons = block_comparisons - .chunks(max_value) + .chunks(max_sum_size) .map(|blocks| { let mut sum = blocks[0].clone(); for other_block in &blocks[1..] { diff --git a/tfhe/src/integer/server_key/radix_parallel/comparison.rs b/tfhe/src/integer/server_key/radix_parallel/comparison.rs index 2f45fcd592..b23ce60562 100644 --- a/tfhe/src/integer/server_key/radix_parallel/comparison.rs +++ b/tfhe/src/integer/server_key/radix_parallel/comparison.rs @@ -51,7 +51,7 @@ impl ServerKey { { // Even though the corresponding function // may already exist in self.key - // we generate our own lut to do less allocations + // we generate our own lut to do fewer allocations // one for all the threads as opposed to one per thread let lut = self .key @@ -76,7 +76,7 @@ impl ServerKey { { // Even though the corresponding function // may already exist in self.key - // we generate our own lut to do less allocations + // we generate our own lut to do fewer allocations // one for all the threads as opposed to one per thread let lut = self .key @@ -90,34 +90,8 @@ impl ServerKey { .unchecked_apply_lookup_table_bivariate_assign(lhs_block, rhs_block, &lut); }); - let message_modulus = self.key.message_modulus.0; - let carry_modulus = self.key.carry_modulus.0; - let total_modulus = message_modulus * carry_modulus; - let max_value = total_modulus - 1; - - let mut block_comparisons_2 = Vec::with_capacity(block_comparisons.len() / 2); - let is_non_zero = self.key.generate_lookup_table(|x| u64::from(x != 0)); - - while block_comparisons.len() > 1 { - block_comparisons - .par_chunks(max_value) - .map(|blocks| { - let mut sum = blocks[0].clone(); - for other_block in &blocks[1..] { - self.key.unchecked_add_assign(&mut sum, other_block); - } - self.key.apply_lookup_table(&sum, &is_non_zero) - }) - .collect_into_vec(&mut block_comparisons_2); - std::mem::swap(&mut block_comparisons_2, &mut block_comparisons); - } - - BooleanBlock::new_unchecked( - block_comparisons - .into_iter() - .next() - .unwrap_or_else(|| self.key.create_trivial(0)), - ) + let result = self.is_at_least_one_comparisons_block_true(block_comparisons); + BooleanBlock::new_unchecked(result) } /// This implements all comparisons (<, <=, >, >=) for both signed and unsigned diff --git a/tfhe/src/integer/server_key/radix_parallel/scalar_comparison.rs b/tfhe/src/integer/server_key/radix_parallel/scalar_comparison.rs index 34cd2ff05a..4d712f53f2 100644 --- a/tfhe/src/integer/server_key/radix_parallel/scalar_comparison.rs +++ b/tfhe/src/integer/server_key/radix_parallel/scalar_comparison.rs @@ -3,6 +3,7 @@ use crate::integer::block_decomposition::{BlockDecomposer, DecomposableInto}; use crate::integer::ciphertext::boolean_value::BooleanBlock; use crate::integer::ciphertext::IntegerRadixCiphertext; use crate::integer::server_key::comparator::{Comparator, ZeroComparisonType}; +use crate::shortint::ciphertext::Degree; use crate::shortint::server_key::LookupTableOwned; use crate::shortint::Ciphertext; use rayon::prelude::*; @@ -160,27 +161,23 @@ impl ServerKey { return self.key.create_trivial(1); } - let message_modulus = self.key.message_modulus.0; - let carry_modulus = self.key.carry_modulus.0; - let total_modulus = message_modulus * carry_modulus; - let max_value = total_modulus - 1; - + let max_sum_size = self.max_sum_size(Degree::new(1)); let is_max_value = self .key - .generate_lookup_table(|x| u64::from(x == max_value as u64)); + .generate_lookup_table(|x| u64::from(x == max_sum_size as u64)); while block_comparisons.len() > 1 { // Since all blocks encrypt either 0 or 1, we can sum max_value of them // as in the worst case we will be adding `max_value` ones block_comparisons = block_comparisons - .par_chunks(max_value) + .par_chunks(max_sum_size) .map(|blocks| { let mut sum = blocks[0].clone(); for other_block in &blocks[1..] { self.key.unchecked_add_assign(&mut sum, other_block); } - if blocks.len() == max_value { + if blocks.len() == max_sum_size { self.key.apply_lookup_table(&sum, &is_max_value) } else { let is_equal_to_num_blocks = self @@ -213,25 +210,22 @@ impl ServerKey { return self.key.create_trivial(1); } - let message_modulus = self.key.message_modulus.0; - let carry_modulus = self.key.carry_modulus.0; - let total_modulus = message_modulus * carry_modulus; - let max_value = total_modulus - 1; - let is_not_zero = self.key.generate_lookup_table(|x| u64::from(x != 0)); + let mut block_comparisons_2 = Vec::with_capacity(block_comparisons.len() / 2); + let max_sum_size = self.max_sum_size(Degree::new(1)); while block_comparisons.len() > 1 { - block_comparisons = block_comparisons - .par_chunks(max_value) + block_comparisons + .par_chunks(max_sum_size) .map(|blocks| { let mut sum = blocks[0].clone(); for other_block in &blocks[1..] { self.key.unchecked_add_assign(&mut sum, other_block); } - self.key.apply_lookup_table(&sum, &is_not_zero) }) - .collect::>(); + .collect_into_vec(&mut block_comparisons_2); + std::mem::swap(&mut block_comparisons_2, &mut block_comparisons); } block_comparisons @@ -423,10 +417,10 @@ impl ServerKey { let message_modulus = self.key.message_modulus.0; let carry_modulus = self.key.carry_modulus.0; let total_modulus = message_modulus * carry_modulus; - let max_value = total_modulus - 1; + let max_sum_size = self.max_sum_size(Degree::new(1)); assert!(carry_modulus >= message_modulus); - u8::try_from(max_value).unwrap(); + u8::try_from(max_sum_size).unwrap(); let num_blocks = lhs.blocks().len(); let num_blocks_halved = (num_blocks / 2) + (num_blocks % 2); @@ -516,10 +510,10 @@ impl ServerKey { let message_modulus = self.key.message_modulus.0; let carry_modulus = self.key.carry_modulus.0; let total_modulus = message_modulus * carry_modulus; - let max_value = total_modulus - 1; + let max_sum_size = self.max_sum_size(Degree::new(1)); assert!(carry_modulus >= message_modulus); - u8::try_from(max_value).unwrap(); + u8::try_from(max_sum_size).unwrap(); let num_blocks = lhs.blocks().len(); let num_blocks_halved = (num_blocks / 2) + (num_blocks % 2);