diff --git a/tfhe/src/core_crypto/commons/math/decomposition/decomposer.rs b/tfhe/src/core_crypto/commons/math/decomposition/decomposer.rs index 82ff9bf277..ddeb50c968 100644 --- a/tfhe/src/core_crypto/commons/math/decomposition/decomposer.rs +++ b/tfhe/src/core_crypto/commons/math/decomposition/decomposer.rs @@ -125,6 +125,37 @@ where native_closest_representable(input, self.level_count, self.base_log) } + #[inline(always)] + pub fn init_decomposer_state(&self, input: Scalar) -> Scalar { + // The closest number representable by the decomposition can be computed by performing + // the rounding at the appropriate bit. + + // We compute the number of least significant bits which can not be represented by the + // decomposition + // Example with level_count = 3, base_log = 4 and BITS == 64 -> 52 + let rep_bit_count = self.level_count * self.base_log; + let non_rep_bit_count: usize = Scalar::BITS - rep_bit_count; + // Move the representable bits + 1 to the LSB, with our example : + // |-----| 64 - (64 - 12 - 1) == 13 bits + // 0....0XX...XX + let mut res = input >> (non_rep_bit_count - 1); + // Fetch the first bit value as we need it for a balanced rounding + let rounding_bit = (res & Scalar::ONE) << (rep_bit_count - 1); + // Add one to do the rounding by adding the half interval + res += Scalar::ONE; + // Discard the LSB which was the one deciding in which direction we round + res >>= 1; + // Keep the low base_log * level bits + let mod_mask = Scalar::MAX >> (Scalar::BITS - rep_bit_count); + res &= mod_mask; + // Control bit about whether we should balance the state + // This is equivalent to res >= 2^(base_log * l) || (res == 2^(base_log * l) && random == 1) + let need_balance = + (res.wrapping_sub(Scalar::ONE) | (res & rounding_bit)) >> (rep_bit_count - 1); + // Balance depending on the control bit + res.wrapping_sub(need_balance << rep_bit_count) + } + /// Generate an iterator over the terms of the decomposition of the input. /// /// # Warning @@ -161,7 +192,7 @@ where // Note that there would be no sense of making the decomposition on an input which was // not rounded to the closest representable first. We then perform it before decomposing. SignedDecompositionIter::new( - self.closest_representable(input), + self.init_decomposer_state(input), DecompositionBaseLog(self.base_log), DecompositionLevelCount(self.level_count), ) diff --git a/tfhe/src/core_crypto/commons/math/decomposition/iter.rs b/tfhe/src/core_crypto/commons/math/decomposition/iter.rs index 3d91888976..3d8a3d0211 100644 --- a/tfhe/src/core_crypto/commons/math/decomposition/iter.rs +++ b/tfhe/src/core_crypto/commons/math/decomposition/iter.rs @@ -43,7 +43,7 @@ where Self { base_log: base_log.0, level_count: level.0, - state: input >> (T::BITS - base_log.0 * level.0), + state: input, current_level: level.0, mod_b_mask: (T::ONE << base_log.0) - T::ONE, fresh: true, @@ -298,7 +298,7 @@ pub struct TensorSignedDecompositionLendingIterNonNative<'buffers> { impl<'buffers> TensorSignedDecompositionLendingIterNonNative<'buffers> { #[inline] - pub fn new( + pub(crate) fn new( decomposer: &SignedDecomposerNonNative, input: &[u64], modulus: u64, diff --git a/tfhe/src/core_crypto/commons/math/decomposition/tests.rs b/tfhe/src/core_crypto/commons/math/decomposition/tests.rs index 1f09dcc558..e38d24b935 100644 --- a/tfhe/src/core_crypto/commons/math/decomposition/tests.rs +++ b/tfhe/src/core_crypto/commons/math/decomposition/tests.rs @@ -200,3 +200,33 @@ fn test_decompose_recompose_non_native_solinas_u64() { fn test_decompose_recompose_non_native_edge_mod_round_up_u64() { test_decompose_recompose_non_native::(CiphertextModulus::try_new((1 << 48) + 1).unwrap()); } + +#[test] +fn test_single_level_decompose_balanced() { + let decomposer = SignedDecomposer::new(DecompositionBaseLog(12), DecompositionLevelCount(1)); + + assert_eq!( + decomposer.level_count().0, + 1, + "This test is only valid if the decomposition level count is 1" + ); + use rand::prelude::*; + let mut rng = rand::thread_rng(); + let mut mean = 0f64; + // Still runs fast, about 1 billion runs which is exactly representable in float + let runs = 1usize << 30; + for _ in 0..runs { + let val: u64 = rng.gen(); + let decomp = decomposer.decompose(val).next().unwrap(); + let value: i64 = decomp.value() as i64; + mean += value as f64; + } + mean /= runs as f64; + + // To print with --nocapture to check in the terminal + println!("mean={mean}"); + + // This bound is not very tight or good, but as an unbalanced decomposition has a mean of about + // 0.5 this will do + assert!(mean.abs() < 0.1); +} diff --git a/tfhe/src/core_crypto/experimental/algorithms/glwe_fast_keyswitch.rs b/tfhe/src/core_crypto/experimental/algorithms/glwe_fast_keyswitch.rs index 10d9ea52a9..05229ae7a1 100644 --- a/tfhe/src/core_crypto/experimental/algorithms/glwe_fast_keyswitch.rs +++ b/tfhe/src/core_crypto/experimental/algorithms/glwe_fast_keyswitch.rs @@ -207,7 +207,7 @@ pub fn glwe_fast_keyswitch( let (mut decomposition, mut substack1) = TensorSignedDecompositionLendingIter::new( glwe.as_ref() .iter() - .map(|s| decomposer.closest_representable(*s)), + .map(|s| decomposer.init_decomposer_state(*s)), DecompositionBaseLog(decomposer.base_log), DecompositionLevelCount(decomposer.level_count), substack0.rb_mut(), diff --git a/tfhe/src/core_crypto/fft_impl/fft128/crypto/ggsw.rs b/tfhe/src/core_crypto/fft_impl/fft128/crypto/ggsw.rs index 1e8dc852ea..c221bdedee 100644 --- a/tfhe/src/core_crypto/fft_impl/fft128/crypto/ggsw.rs +++ b/tfhe/src/core_crypto/fft_impl/fft128/crypto/ggsw.rs @@ -418,7 +418,7 @@ pub fn add_external_product_assign( let (mut decomposition, mut substack1) = TensorSignedDecompositionLendingIter::new( glwe.as_ref() .iter() - .map(|s| decomposer.closest_representable(*s)), + .map(|s| decomposer.init_decomposer_state(*s)), DecompositionBaseLog(decomposer.base_log), DecompositionLevelCount(decomposer.level_count), substack0.rb_mut(), diff --git a/tfhe/src/core_crypto/fft_impl/fft64/crypto/ggsw.rs b/tfhe/src/core_crypto/fft_impl/fft64/crypto/ggsw.rs index a23bbd6661..6b8ec2961f 100644 --- a/tfhe/src/core_crypto/fft_impl/fft64/crypto/ggsw.rs +++ b/tfhe/src/core_crypto/fft_impl/fft64/crypto/ggsw.rs @@ -517,7 +517,7 @@ pub fn add_external_product_assign( let (mut decomposition, mut substack1) = TensorSignedDecompositionLendingIter::new( glwe.as_ref() .iter() - .map(|s| decomposer.closest_representable(*s)), + .map(|s| decomposer.init_decomposer_state(*s)), DecompositionBaseLog(decomposer.base_log), DecompositionLevelCount(decomposer.level_count), substack0.rb_mut(), diff --git a/tfhe/src/core_crypto/fft_impl/fft64/math/decomposition.rs b/tfhe/src/core_crypto/fft_impl/fft64/math/decomposition.rs index efc5212792..5eabd8f344 100644 --- a/tfhe/src/core_crypto/fft_impl/fft64/math/decomposition.rs +++ b/tfhe/src/core_crypto/fft_impl/fft64/math/decomposition.rs @@ -31,9 +31,7 @@ impl<'buffers, Scalar: UnsignedInteger> TensorSignedDecompositionLendingIter<'bu level: DecompositionLevelCount, stack: PodStack<'buffers>, ) -> (Self, PodStack<'buffers>) { - let shift = Scalar::BITS - base_log.0 * level.0; - let (states, stack) = - stack.collect_aligned(aligned_vec::CACHELINE_ALIGN, input.map(|i| i >> shift)); + let (states, stack) = stack.collect_aligned(aligned_vec::CACHELINE_ALIGN, input); ( TensorSignedDecompositionLendingIter { base_log: base_log.0,