Skip to content

Commit

Permalink
fix(core): fix decomposition over 1 level to be balanced
Browse files Browse the repository at this point in the history
  • Loading branch information
IceTDrinker committed Oct 22, 2024
1 parent 7c29594 commit 2168a44
Show file tree
Hide file tree
Showing 7 changed files with 68 additions and 9 deletions.
33 changes: 32 additions & 1 deletion tfhe/src/core_crypto/commons/math/decomposition/decomposer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,37 @@ where
native_closest_representable(input, self.level_count, self.base_log)
}

#[inline(always)]
pub fn init_decomposer_state(&self, input: Scalar) -> Scalar {
// The closest number representable by the decomposition can be computed by performing
// the rounding at the appropriate bit.

// We compute the number of least significant bits which can not be represented by the
// decomposition
// Example with level_count = 3, base_log = 4 and BITS == 64 -> 52
let rep_bit_count = self.level_count * self.base_log;
let non_rep_bit_count: usize = Scalar::BITS - rep_bit_count;
// Move the representable bits + 1 to the LSB, with our example :
// |-----| 64 - (64 - 12 - 1) == 13 bits
// 0....0XX...XX
let mut res = input >> (non_rep_bit_count - 1);
// Fetch the first bit value as we need it for a balanced rounding
let rounding_bit = (res & Scalar::ONE) << (rep_bit_count - 1);
// Add one to do the rounding by adding the half interval
res += Scalar::ONE;
// Discard the LSB which was the one deciding in which direction we round
res >>= 1;
// Keep the low base_log * level bits
let mod_mask = Scalar::MAX >> (Scalar::BITS - rep_bit_count);
res &= mod_mask;
// Control bit about whether we should balance the state
// This is equivalent to res >= 2^(base_log * l) || (res == 2^(base_log * l) && random == 1)
let need_balance =
(res.wrapping_sub(Scalar::ONE) | (res & rounding_bit)) >> (rep_bit_count - 1);
// Balance depending on the control bit
res.wrapping_sub(need_balance << rep_bit_count)
}

/// Generate an iterator over the terms of the decomposition of the input.
///
/// # Warning
Expand Down Expand Up @@ -161,7 +192,7 @@ where
// Note that there would be no sense of making the decomposition on an input which was
// not rounded to the closest representable first. We then perform it before decomposing.
SignedDecompositionIter::new(
self.closest_representable(input),
self.init_decomposer_state(input),
DecompositionBaseLog(self.base_log),
DecompositionLevelCount(self.level_count),
)
Expand Down
4 changes: 2 additions & 2 deletions tfhe/src/core_crypto/commons/math/decomposition/iter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ where
Self {
base_log: base_log.0,
level_count: level.0,
state: input >> (T::BITS - base_log.0 * level.0),
state: input,
current_level: level.0,
mod_b_mask: (T::ONE << base_log.0) - T::ONE,
fresh: true,
Expand Down Expand Up @@ -298,7 +298,7 @@ pub struct TensorSignedDecompositionLendingIterNonNative<'buffers> {

impl<'buffers> TensorSignedDecompositionLendingIterNonNative<'buffers> {
#[inline]
pub fn new(
pub(crate) fn new(
decomposer: &SignedDecomposerNonNative<u64>,
input: &[u64],
modulus: u64,
Expand Down
30 changes: 30 additions & 0 deletions tfhe/src/core_crypto/commons/math/decomposition/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -200,3 +200,33 @@ fn test_decompose_recompose_non_native_solinas_u64() {
fn test_decompose_recompose_non_native_edge_mod_round_up_u64() {
test_decompose_recompose_non_native::<u64>(CiphertextModulus::try_new((1 << 48) + 1).unwrap());
}

#[test]
fn test_single_level_decompose_balanced() {
let decomposer = SignedDecomposer::new(DecompositionBaseLog(12), DecompositionLevelCount(1));

assert_eq!(
decomposer.level_count().0,
1,
"This test is only valid if the decomposition level count is 1"
);
use rand::prelude::*;
let mut rng = rand::thread_rng();
let mut mean = 0f64;
// Still runs fast, about 1 billion runs which is exactly representable in float
let runs = 1usize << 30;
for _ in 0..runs {
let val: u64 = rng.gen();
let decomp = decomposer.decompose(val).next().unwrap();
let value: i64 = decomp.value() as i64;
mean += value as f64;
}
mean /= runs as f64;

// To print with --nocapture to check in the terminal
println!("mean={mean}");

// This bound is not very tight or good, but as an unbalanced decomposition has a mean of about
// 0.5 this will do
assert!(mean.abs() < 0.1);
}
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ pub fn glwe_fast_keyswitch<Scalar, OutputGlweCont, InputGlweCont, GgswCont>(
let (mut decomposition, mut substack1) = TensorSignedDecompositionLendingIter::new(
glwe.as_ref()
.iter()
.map(|s| decomposer.closest_representable(*s)),
.map(|s| decomposer.init_decomposer_state(*s)),
DecompositionBaseLog(decomposer.base_log),
DecompositionLevelCount(decomposer.level_count),
substack0.rb_mut(),
Expand Down
2 changes: 1 addition & 1 deletion tfhe/src/core_crypto/fft_impl/fft128/crypto/ggsw.rs
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,7 @@ pub fn add_external_product_assign<Scalar, ContOut, ContGgsw, ContGlwe>(
let (mut decomposition, mut substack1) = TensorSignedDecompositionLendingIter::new(
glwe.as_ref()
.iter()
.map(|s| decomposer.closest_representable(*s)),
.map(|s| decomposer.init_decomposer_state(*s)),
DecompositionBaseLog(decomposer.base_log),
DecompositionLevelCount(decomposer.level_count),
substack0.rb_mut(),
Expand Down
2 changes: 1 addition & 1 deletion tfhe/src/core_crypto/fft_impl/fft64/crypto/ggsw.rs
Original file line number Diff line number Diff line change
Expand Up @@ -517,7 +517,7 @@ pub fn add_external_product_assign<Scalar>(
let (mut decomposition, mut substack1) = TensorSignedDecompositionLendingIter::new(
glwe.as_ref()
.iter()
.map(|s| decomposer.closest_representable(*s)),
.map(|s| decomposer.init_decomposer_state(*s)),
DecompositionBaseLog(decomposer.base_log),
DecompositionLevelCount(decomposer.level_count),
substack0.rb_mut(),
Expand Down
4 changes: 1 addition & 3 deletions tfhe/src/core_crypto/fft_impl/fft64/math/decomposition.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,7 @@ impl<'buffers, Scalar: UnsignedInteger> TensorSignedDecompositionLendingIter<'bu
level: DecompositionLevelCount,
stack: PodStack<'buffers>,
) -> (Self, PodStack<'buffers>) {
let shift = Scalar::BITS - base_log.0 * level.0;
let (states, stack) =
stack.collect_aligned(aligned_vec::CACHELINE_ALIGN, input.map(|i| i >> shift));
let (states, stack) = stack.collect_aligned(aligned_vec::CACHELINE_ALIGN, input);
(
TensorSignedDecompositionLendingIter {
base_log: base_log.0,
Expand Down

0 comments on commit 2168a44

Please sign in to comment.