Skip to content

Commit

Permalink
fix(core): fix decomposition over 1 level to be balanced
Browse files Browse the repository at this point in the history
  • Loading branch information
IceTDrinker committed Oct 24, 2024
1 parent 7c29594 commit d38d46f
Show file tree
Hide file tree
Showing 7 changed files with 184 additions and 11 deletions.
89 changes: 88 additions & 1 deletion tfhe/src/core_crypto/commons/math/decomposition/decomposer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,25 @@ pub fn native_closest_representable<Scalar: UnsignedInteger>(
res << shift
}

/// With
///
/// B = 2^bit_count
/// val < B
/// random € [0, 1]
///
/// returns 1 if the following if condition is true otherwise 0
///
/// (val > B / 2) || ((val == B / 2) && (random == 1))
#[inline(always)]
fn balanced_rounding_condition_bit_trick<Scalar: UnsignedInteger>(
val: Scalar,
bit_count: usize,
random: Scalar,
) -> Scalar {
let shifted_random = random << (bit_count - 1);
((val.wrapping_sub(Scalar::ONE) | shifted_random) & val) >> (bit_count - 1)
}

impl<Scalar> SignedDecomposer<Scalar>
where
Scalar: UnsignedInteger,
Expand Down Expand Up @@ -125,6 +144,36 @@ where
native_closest_representable(input, self.level_count, self.base_log)
}

#[inline(always)]
pub fn init_decomposer_state(&self, input: Scalar) -> Scalar {
// The closest number representable by the decomposition can be computed by performing
// the rounding at the appropriate bit.

// We compute the number of least significant bits which can not be represented by the
// decomposition
// Example with level_count = 3, base_log = 4 and BITS == 64 -> 52
let rep_bit_count = self.level_count * self.base_log;
let non_rep_bit_count: usize = Scalar::BITS - rep_bit_count;
// Move the representable bits + 1 to the LSB, with our example :
// |-----| 64 - (64 - 12 - 1) == 13 bits
// 0....0XX...XX
let mut res = input >> (non_rep_bit_count - 1);
// Fetch the first bit value as we need it for a balanced rounding
let rounding_bit = res & Scalar::ONE;
// Add one to do the rounding by adding the half interval
res += Scalar::ONE;
// Discard the LSB which was the one deciding in which direction we round
res >>= 1;
// Keep the low base_log * level bits
let mod_mask = Scalar::MAX >> (Scalar::BITS - rep_bit_count);
res &= mod_mask;
// Control bit about whether we should balance the state
// This is equivalent to res > 2^(base_log * l) || (res == 2^(base_log * l) && random == 1)
let need_balance = balanced_rounding_condition_bit_trick(res, rep_bit_count, rounding_bit);
// Balance depending on the control bit
res.wrapping_sub(need_balance << rep_bit_count)
}

/// Generate an iterator over the terms of the decomposition of the input.
///
/// # Warning
Expand Down Expand Up @@ -161,7 +210,7 @@ where
// Note that there would be no sense of making the decomposition on an input which was
// not rounded to the closest representable first. We then perform it before decomposing.
SignedDecompositionIter::new(
self.closest_representable(input),
self.init_decomposer_state(input),
DecompositionBaseLog(self.base_log),
DecompositionLevelCount(self.level_count),
)
Expand Down Expand Up @@ -517,3 +566,41 @@ where
}
}
}

#[cfg(test)]
mod test {
use super::*;

#[test]
fn test_balanced_rounding_condition_as_bit_trick() {
for rep_bit_count in 1..13 {
println!("{rep_bit_count}");
let b = 1u64 << rep_bit_count;
let b_over_2 = b / 2;

for val in 0..b {
for random in [0, 1] {
let test_val = (val > b_over_2) || ((val == b_over_2) && (random == 1));
let bit_trick =
balanced_rounding_condition_bit_trick(val, rep_bit_count, random);
let bit_trick_as_bool = if bit_trick == 1 {
true
} else if bit_trick == 0 {
false
} else {
panic!("Bit trick result was not a bit.");
};

assert_eq!(
test_val, bit_trick_as_bool,
"val ={val}\n\
val_b ={val:064b}\n\
random ={random}\n\
expected: {test_val}\n\
got : {bit_trick_as_bool}"
);
}
}
}
}
}
66 changes: 62 additions & 4 deletions tfhe/src/core_crypto/commons/math/decomposition/iter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ where
Self {
base_log: base_log.0,
level_count: level.0,
state: input >> (T::BITS - base_log.0 * level.0),
state: input,
current_level: level.0,
mod_b_mask: (T::ONE << base_log.0) - T::ONE,
fresh: true,
Expand Down Expand Up @@ -118,6 +118,23 @@ where
}
}

/// With
///
/// B = 2^base_log
/// res < B
///
/// returns 1 if the following condition is true otherwise 0
///
/// (res > B / 2) || ((res == B / 2) && ((state % B) >= B / 2));
#[inline(always)]
fn decomposition_bit_trick<Scalar: UnsignedInteger>(
res: Scalar,
state: Scalar,
base_log: usize,
) -> Scalar {
((res.wrapping_sub(Scalar::ONE) | state) & res) >> (base_log - 1)
}

#[inline]
pub(crate) fn decompose_one_level<S: UnsignedInteger>(
base_log: usize,
Expand All @@ -126,8 +143,7 @@ pub(crate) fn decompose_one_level<S: UnsignedInteger>(
) -> S {
let res = *state & mod_b_mask;
*state >>= base_log;
let mut carry = (res.wrapping_sub(S::ONE) | *state) & res;
carry >>= base_log - 1;
let carry = decomposition_bit_trick(res, *state, base_log);
*state += carry;
res.wrapping_sub(carry << base_log)
}
Expand Down Expand Up @@ -298,7 +314,7 @@ pub struct TensorSignedDecompositionLendingIterNonNative<'buffers> {

impl<'buffers> TensorSignedDecompositionLendingIterNonNative<'buffers> {
#[inline]
pub fn new(
pub(crate) fn new(
decomposer: &SignedDecomposerNonNative<u64>,
input: &[u64],
modulus: u64,
Expand Down Expand Up @@ -400,3 +416,45 @@ impl<'buffers> TensorSignedDecompositionLendingIterNonNative<'buffers> {
(glwe_level, glwe_decomp_term, substack2)
}
}

#[cfg(test)]
mod test {
use super::*;

#[test]
fn test_decomp_bit_trick() {
for rep_bit_count in 1..13 {
println!("{rep_bit_count}");
let b = 1u64 << rep_bit_count;
let b_over_2 = b / 2;

for val in 0..b {
// Have a chance to sample all values in 0..b at least once, here we expect on
// average about 10 occurrence for each value in the range
for _ in 0..10 * b {
let state: u64 = rand::random();
let test_val =
(val > b_over_2) || ((val == b_over_2) && ((state % b) >= b_over_2));
let bit_trick = decomposition_bit_trick(val, state, rep_bit_count);
let bit_trick_as_bool = if bit_trick == 1 {
true
} else if bit_trick == 0 {
false
} else {
panic!("Bit trick result was not a bit.");
};

assert_eq!(
test_val, bit_trick_as_bool,
"\nval ={val}\n\
val_b ={val:064b}\n\
state ={state}\n\
state_b={state:064b}\n\
expected: {test_val}\n\
got : {bit_trick_as_bool}"
);
}
}
}
}
}
30 changes: 30 additions & 0 deletions tfhe/src/core_crypto/commons/math/decomposition/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -200,3 +200,33 @@ fn test_decompose_recompose_non_native_solinas_u64() {
fn test_decompose_recompose_non_native_edge_mod_round_up_u64() {
test_decompose_recompose_non_native::<u64>(CiphertextModulus::try_new((1 << 48) + 1).unwrap());
}

#[test]
fn test_single_level_decompose_balanced() {
let decomposer = SignedDecomposer::new(DecompositionBaseLog(12), DecompositionLevelCount(1));

assert_eq!(
decomposer.level_count().0,
1,
"This test is only valid if the decomposition level count is 1"
);
use rand::prelude::*;
let mut rng = rand::thread_rng();
let mut mean = 0f64;
// Still runs fast, about 1 billion runs which is exactly representable in float
let runs = 1usize << 30;
for _ in 0..runs {
let val: u64 = rng.gen();
let decomp = decomposer.decompose(val).next().unwrap();
let value: i64 = decomp.value() as i64;
mean += value as f64;
}
mean /= runs as f64;

// To print with --nocapture to check in the terminal
println!("mean={mean}");

// This bound is not very tight or good, but as an unbalanced decomposition has a mean of about
// 0.5 this will do
assert!(mean.abs() < 0.2);
}
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ pub fn glwe_fast_keyswitch<Scalar, OutputGlweCont, InputGlweCont, GgswCont>(
let (mut decomposition, mut substack1) = TensorSignedDecompositionLendingIter::new(
glwe.as_ref()
.iter()
.map(|s| decomposer.closest_representable(*s)),
.map(|s| decomposer.init_decomposer_state(*s)),
DecompositionBaseLog(decomposer.base_log),
DecompositionLevelCount(decomposer.level_count),
substack0.rb_mut(),
Expand Down
2 changes: 1 addition & 1 deletion tfhe/src/core_crypto/fft_impl/fft128/crypto/ggsw.rs
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,7 @@ pub fn add_external_product_assign<Scalar, ContOut, ContGgsw, ContGlwe>(
let (mut decomposition, mut substack1) = TensorSignedDecompositionLendingIter::new(
glwe.as_ref()
.iter()
.map(|s| decomposer.closest_representable(*s)),
.map(|s| decomposer.init_decomposer_state(*s)),
DecompositionBaseLog(decomposer.base_log),
DecompositionLevelCount(decomposer.level_count),
substack0.rb_mut(),
Expand Down
2 changes: 1 addition & 1 deletion tfhe/src/core_crypto/fft_impl/fft64/crypto/ggsw.rs
Original file line number Diff line number Diff line change
Expand Up @@ -517,7 +517,7 @@ pub fn add_external_product_assign<Scalar>(
let (mut decomposition, mut substack1) = TensorSignedDecompositionLendingIter::new(
glwe.as_ref()
.iter()
.map(|s| decomposer.closest_representable(*s)),
.map(|s| decomposer.init_decomposer_state(*s)),
DecompositionBaseLog(decomposer.base_log),
DecompositionLevelCount(decomposer.level_count),
substack0.rb_mut(),
Expand Down
4 changes: 1 addition & 3 deletions tfhe/src/core_crypto/fft_impl/fft64/math/decomposition.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,7 @@ impl<'buffers, Scalar: UnsignedInteger> TensorSignedDecompositionLendingIter<'bu
level: DecompositionLevelCount,
stack: PodStack<'buffers>,
) -> (Self, PodStack<'buffers>) {
let shift = Scalar::BITS - base_log.0 * level.0;
let (states, stack) =
stack.collect_aligned(aligned_vec::CACHELINE_ALIGN, input.map(|i| i >> shift));
let (states, stack) = stack.collect_aligned(aligned_vec::CACHELINE_ALIGN, input);
(
TensorSignedDecompositionLendingIter {
base_log: base_log.0,
Expand Down

0 comments on commit d38d46f

Please sign in to comment.