diff --git a/tfhe/src/core_crypto/commons/math/decomposition/decomposer.rs b/tfhe/src/core_crypto/commons/math/decomposition/decomposer.rs index 82ff9bf277..05baf35518 100644 --- a/tfhe/src/core_crypto/commons/math/decomposition/decomposer.rs +++ b/tfhe/src/core_crypto/commons/math/decomposition/decomposer.rs @@ -46,6 +46,25 @@ pub fn native_closest_representable( res << shift } +/// With +/// +/// B = 2^bit_count +/// val < B +/// random € [0, 1] +/// +/// returns 1 if the following if condition is true otherwise 0 +/// +/// (val > B / 2) || ((val == B / 2) && (random == 1)) +#[inline(always)] +fn balanced_rounding_condition_bit_trick( + val: Scalar, + bit_count: usize, + random: Scalar, +) -> Scalar { + let shifted_random = random << (bit_count - 1); + ((val.wrapping_sub(Scalar::ONE) | shifted_random) & val) >> (bit_count - 1) +} + impl SignedDecomposer where Scalar: UnsignedInteger, @@ -125,6 +144,36 @@ where native_closest_representable(input, self.level_count, self.base_log) } + #[inline(always)] + pub fn init_decomposer_state(&self, input: Scalar) -> Scalar { + // The closest number representable by the decomposition can be computed by performing + // the rounding at the appropriate bit. + + // We compute the number of least significant bits which can not be represented by the + // decomposition + // Example with level_count = 3, base_log = 4 and BITS == 64 -> 52 + let rep_bit_count = self.level_count * self.base_log; + let non_rep_bit_count: usize = Scalar::BITS - rep_bit_count; + // Move the representable bits + 1 to the LSB, with our example : + // |-----| 64 - (64 - 12 - 1) == 13 bits + // 0....0XX...XX + let mut res = input >> (non_rep_bit_count - 1); + // Fetch the first bit value as we need it for a balanced rounding + let rounding_bit = res & Scalar::ONE; + // Add one to do the rounding by adding the half interval + res += Scalar::ONE; + // Discard the LSB which was the one deciding in which direction we round + res >>= 1; + // Keep the low base_log * level bits + let mod_mask = Scalar::MAX >> (Scalar::BITS - rep_bit_count); + res &= mod_mask; + // Control bit about whether we should balance the state + // This is equivalent to res > 2^(base_log * l) || (res == 2^(base_log * l) && random == 1) + let need_balance = balanced_rounding_condition_bit_trick(res, rep_bit_count, rounding_bit); + // Balance depending on the control bit + res.wrapping_sub(need_balance << rep_bit_count) + } + /// Generate an iterator over the terms of the decomposition of the input. /// /// # Warning @@ -161,7 +210,7 @@ where // Note that there would be no sense of making the decomposition on an input which was // not rounded to the closest representable first. We then perform it before decomposing. SignedDecompositionIter::new( - self.closest_representable(input), + self.init_decomposer_state(input), DecompositionBaseLog(self.base_log), DecompositionLevelCount(self.level_count), ) @@ -517,3 +566,41 @@ where } } } + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_balanced_rounding_condition_as_bit_trick() { + for rep_bit_count in 1..13 { + println!("{rep_bit_count}"); + let b = 1u64 << rep_bit_count; + let b_over_2 = b / 2; + + for val in 0..b { + for random in [0, 1] { + let test_val = (val > b_over_2) || ((val == b_over_2) && (random == 1)); + let bit_trick = + balanced_rounding_condition_bit_trick(val, rep_bit_count, random); + let bit_trick_as_bool = if bit_trick == 1 { + true + } else if bit_trick == 0 { + false + } else { + panic!("Bit trick result was not a bit."); + }; + + assert_eq!( + test_val, bit_trick_as_bool, + "val ={val}\n\ + val_b ={val:064b}\n\ + random ={random}\n\ + expected: {test_val}\n\ + got : {bit_trick_as_bool}" + ); + } + } + } + } +} diff --git a/tfhe/src/core_crypto/commons/math/decomposition/iter.rs b/tfhe/src/core_crypto/commons/math/decomposition/iter.rs index 3d91888976..865d51e794 100644 --- a/tfhe/src/core_crypto/commons/math/decomposition/iter.rs +++ b/tfhe/src/core_crypto/commons/math/decomposition/iter.rs @@ -43,7 +43,7 @@ where Self { base_log: base_log.0, level_count: level.0, - state: input >> (T::BITS - base_log.0 * level.0), + state: input, current_level: level.0, mod_b_mask: (T::ONE << base_log.0) - T::ONE, fresh: true, @@ -118,6 +118,23 @@ where } } +/// With +/// +/// B = 2^base_log +/// res < B +/// +/// returns 1 if the following condition is true otherwise 0 +/// +/// (res > B / 2) || ((res == B / 2) && ((state % B) >= B / 2)); +#[inline(always)] +fn decomposition_bit_trick( + res: Scalar, + state: Scalar, + base_log: usize, +) -> Scalar { + ((res.wrapping_sub(Scalar::ONE) | state) & res) >> (base_log - 1) +} + #[inline] pub(crate) fn decompose_one_level( base_log: usize, @@ -126,8 +143,7 @@ pub(crate) fn decompose_one_level( ) -> S { let res = *state & mod_b_mask; *state >>= base_log; - let mut carry = (res.wrapping_sub(S::ONE) | *state) & res; - carry >>= base_log - 1; + let carry = decomposition_bit_trick(res, *state, base_log); *state += carry; res.wrapping_sub(carry << base_log) } @@ -298,7 +314,7 @@ pub struct TensorSignedDecompositionLendingIterNonNative<'buffers> { impl<'buffers> TensorSignedDecompositionLendingIterNonNative<'buffers> { #[inline] - pub fn new( + pub(crate) fn new( decomposer: &SignedDecomposerNonNative, input: &[u64], modulus: u64, @@ -400,3 +416,45 @@ impl<'buffers> TensorSignedDecompositionLendingIterNonNative<'buffers> { (glwe_level, glwe_decomp_term, substack2) } } + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_decomp_bit_trick() { + for rep_bit_count in 1..13 { + println!("{rep_bit_count}"); + let b = 1u64 << rep_bit_count; + let b_over_2 = b / 2; + + for val in 0..b { + // Have a chance to sample all values in 0..b at least once, here we expect on + // average about 10 occurrence for each value in the range + for _ in 0..10 * b { + let state: u64 = rand::random(); + let test_val = + (val > b_over_2) || ((val == b_over_2) && ((state % b) >= b_over_2)); + let bit_trick = decomposition_bit_trick(val, state, rep_bit_count); + let bit_trick_as_bool = if bit_trick == 1 { + true + } else if bit_trick == 0 { + false + } else { + panic!("Bit trick result was not a bit."); + }; + + assert_eq!( + test_val, bit_trick_as_bool, + "\nval ={val}\n\ + val_b ={val:064b}\n\ + state ={state}\n\ + state_b={state:064b}\n\ + expected: {test_val}\n\ + got : {bit_trick_as_bool}" + ); + } + } + } + } +} diff --git a/tfhe/src/core_crypto/commons/math/decomposition/tests.rs b/tfhe/src/core_crypto/commons/math/decomposition/tests.rs index 1f09dcc558..da5f1c48cf 100644 --- a/tfhe/src/core_crypto/commons/math/decomposition/tests.rs +++ b/tfhe/src/core_crypto/commons/math/decomposition/tests.rs @@ -200,3 +200,33 @@ fn test_decompose_recompose_non_native_solinas_u64() { fn test_decompose_recompose_non_native_edge_mod_round_up_u64() { test_decompose_recompose_non_native::(CiphertextModulus::try_new((1 << 48) + 1).unwrap()); } + +#[test] +fn test_single_level_decompose_balanced() { + let decomposer = SignedDecomposer::new(DecompositionBaseLog(12), DecompositionLevelCount(1)); + + assert_eq!( + decomposer.level_count().0, + 1, + "This test is only valid if the decomposition level count is 1" + ); + use rand::prelude::*; + let mut rng = rand::thread_rng(); + let mut mean = 0f64; + // Still runs fast, about 1 billion runs which is exactly representable in float + let runs = 1usize << 30; + for _ in 0..runs { + let val: u64 = rng.gen(); + let decomp = decomposer.decompose(val).next().unwrap(); + let value: i64 = decomp.value() as i64; + mean += value as f64; + } + mean /= runs as f64; + + // To print with --nocapture to check in the terminal + println!("mean={mean}"); + + // This bound is not very tight or good, but as an unbalanced decomposition has a mean of about + // 0.5 this will do + assert!(mean.abs() < 0.2); +} diff --git a/tfhe/src/core_crypto/experimental/algorithms/glwe_fast_keyswitch.rs b/tfhe/src/core_crypto/experimental/algorithms/glwe_fast_keyswitch.rs index 10d9ea52a9..05229ae7a1 100644 --- a/tfhe/src/core_crypto/experimental/algorithms/glwe_fast_keyswitch.rs +++ b/tfhe/src/core_crypto/experimental/algorithms/glwe_fast_keyswitch.rs @@ -207,7 +207,7 @@ pub fn glwe_fast_keyswitch( let (mut decomposition, mut substack1) = TensorSignedDecompositionLendingIter::new( glwe.as_ref() .iter() - .map(|s| decomposer.closest_representable(*s)), + .map(|s| decomposer.init_decomposer_state(*s)), DecompositionBaseLog(decomposer.base_log), DecompositionLevelCount(decomposer.level_count), substack0.rb_mut(), diff --git a/tfhe/src/core_crypto/fft_impl/fft128/crypto/ggsw.rs b/tfhe/src/core_crypto/fft_impl/fft128/crypto/ggsw.rs index 1e8dc852ea..c221bdedee 100644 --- a/tfhe/src/core_crypto/fft_impl/fft128/crypto/ggsw.rs +++ b/tfhe/src/core_crypto/fft_impl/fft128/crypto/ggsw.rs @@ -418,7 +418,7 @@ pub fn add_external_product_assign( let (mut decomposition, mut substack1) = TensorSignedDecompositionLendingIter::new( glwe.as_ref() .iter() - .map(|s| decomposer.closest_representable(*s)), + .map(|s| decomposer.init_decomposer_state(*s)), DecompositionBaseLog(decomposer.base_log), DecompositionLevelCount(decomposer.level_count), substack0.rb_mut(), diff --git a/tfhe/src/core_crypto/fft_impl/fft64/crypto/ggsw.rs b/tfhe/src/core_crypto/fft_impl/fft64/crypto/ggsw.rs index a23bbd6661..6b8ec2961f 100644 --- a/tfhe/src/core_crypto/fft_impl/fft64/crypto/ggsw.rs +++ b/tfhe/src/core_crypto/fft_impl/fft64/crypto/ggsw.rs @@ -517,7 +517,7 @@ pub fn add_external_product_assign( let (mut decomposition, mut substack1) = TensorSignedDecompositionLendingIter::new( glwe.as_ref() .iter() - .map(|s| decomposer.closest_representable(*s)), + .map(|s| decomposer.init_decomposer_state(*s)), DecompositionBaseLog(decomposer.base_log), DecompositionLevelCount(decomposer.level_count), substack0.rb_mut(), diff --git a/tfhe/src/core_crypto/fft_impl/fft64/math/decomposition.rs b/tfhe/src/core_crypto/fft_impl/fft64/math/decomposition.rs index efc5212792..5eabd8f344 100644 --- a/tfhe/src/core_crypto/fft_impl/fft64/math/decomposition.rs +++ b/tfhe/src/core_crypto/fft_impl/fft64/math/decomposition.rs @@ -31,9 +31,7 @@ impl<'buffers, Scalar: UnsignedInteger> TensorSignedDecompositionLendingIter<'bu level: DecompositionLevelCount, stack: PodStack<'buffers>, ) -> (Self, PodStack<'buffers>) { - let shift = Scalar::BITS - base_log.0 * level.0; - let (states, stack) = - stack.collect_aligned(aligned_vec::CACHELINE_ALIGN, input.map(|i| i >> shift)); + let (states, stack) = stack.collect_aligned(aligned_vec::CACHELINE_ALIGN, input); ( TensorSignedDecompositionLendingIter { base_log: base_log.0,