Skip to content

Commit

Permalink
refactor(shortint): refactor the shortint keyswitching code
Browse files Browse the repository at this point in the history
- this manages better all the cases we encouter, we force a refresh PBS in
all cases for now which is less optimal in certain cases but allows to be
safe in cases where keyswitches might be chained
  • Loading branch information
IceTDrinker committed Jul 25, 2024
1 parent 2004333 commit 2588947
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 57 deletions.
138 changes: 81 additions & 57 deletions tfhe/src/shortint/key_switching_key/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -500,36 +500,17 @@ impl<'keys> KeySwitchingKeyView<'keys> {
.dest_server_key
.unchecked_create_trivial_with_lwe_size(0, output_lwe_size);

// TODO: We are outside the standard AP, if we chain keyswitches, we will refresh, which is
// safer for now. We can likely add an additional flag in shortint to indicate if we
// want to refresh or not, for now refresh anyways.
keyswitched.set_noise_level(NoiseLevel::UNKNOWN);

let cast_rshift = self.key_switching_key_material.cast_rshift;

match cast_rshift.cmp(&0) {
// Same bit size: only key switch
Ordering::Equal => {
keyswitch_lwe_ciphertext(
self.key_switching_key_material.key_switching_key,
&input_ct.ct,
&mut keyswitched.ct,
);
keyswitched.degree = input_ct.degree;
// We don't really know where we stand in terms of noise here
keyswitched.set_noise_level(NoiseLevel::UNKNOWN);
}
// Cast to bigger bit length: keyswitch, then right shift
Ordering::Greater => {
keyswitch_lwe_ciphertext(
self.key_switching_key_material.key_switching_key,
&input_ct.ct,
&mut keyswitched.ct,
);
// First pre process
let tmp_preprocessed: Ciphertext;

let acc = self
.dest_server_key
.generate_lookup_table(|n| n >> cast_rshift);
self.dest_server_key
.apply_lookup_table_assign(&mut keyswitched, &acc);
// degree updated by the apply lookup table
keyswitched.set_noise_level(NoiseLevel::NOMINAL);
}
let pre_processed = match cast_rshift.cmp(&0) {
// Cast to smaller bit length: left shift, then keyswitch
Ordering::Less => {
let src_server_key = self.src_server_key.as_ref().expect(
Expand All @@ -541,38 +522,32 @@ impl<'keys> KeySwitchingKeyView<'keys> {
(n << -cast_rshift)
% (input_ct.carry_modulus.0 * input_ct.message_modulus.0) as u64
});
let shifted_cipher = src_server_key.apply_lookup_table(input_ct, &acc);

keyswitch_lwe_ciphertext(
self.key_switching_key_material.key_switching_key,
&shifted_cipher.ct,
&mut keyswitched.ct,
);
// The degree is high in the source plaintext modulus, but smaller in the arriving
// one.
//
// src 4 bits:
// 0 | XX | 11
// shifted:
// 0 | 11 | 00 -> Applied lut will have max degree 1100 = 12
// dst 2 bits :
// 0 | 11 -> 11 = 3
keyswitched.degree = Degree::new(shifted_cipher.degree.get() >> -cast_rshift);
// We don't really know where we stand in terms of noise here
keyswitched.set_noise_level(NoiseLevel::UNKNOWN);
tmp_preprocessed = src_server_key.apply_lookup_table(input_ct, &acc);
&tmp_preprocessed
}
}
// No pre-processing
Ordering::Equal | Ordering::Greater => input_ct,
};

// The keyswitch
keyswitch_lwe_ciphertext(
self.key_switching_key_material.key_switching_key,
&pre_processed.ct,
&mut keyswitched.ct,
);
keyswitched.degree = pre_processed.degree;

let ret = {
// Manage the destination key adjustment
let mut res = {
let destination_pbs_order: PBSOrder =
self.key_switching_key_material.destination_key.into();
if destination_pbs_order == self.dest_server_key.pbs_order {
keyswitched
} else {
let wrong_key_ct = keyswitched;
let mut output = self.dest_server_key.create_trivial(0);
output.degree = wrong_key_ct.degree;
output.set_noise_level(wrong_key_ct.noise_level());
let mut correct_key_ct = self.dest_server_key.create_trivial(0);
correct_key_ct.degree = wrong_key_ct.degree;
correct_key_ct.set_noise_level(wrong_key_ct.noise_level());

// We are arriving under the wrong key for the dest_server_key
match self.key_switching_key_material.destination_key {
Expand All @@ -581,9 +556,8 @@ impl<'keys> KeySwitchingKeyView<'keys> {
keyswitch_lwe_ciphertext(
&self.dest_server_key.key_switching_key,
&wrong_key_ct.ct,
&mut output.ct,
&mut correct_key_ct.ct,
);
// TODO refresh ?
}
// Small to Big == PBS
EncryptionKeyChoice::Small => {
Expand All @@ -593,20 +567,70 @@ impl<'keys> KeySwitchingKeyView<'keys> {
apply_programmable_bootstrap(
&self.dest_server_key.bootstrapping_key,
&wrong_key_ct.ct,
&mut output.ct,
&mut correct_key_ct.ct,
&acc.acc,
buffers,
);
});
output.set_noise_level(NoiseLevel::NOMINAL);
// Degree does not need to be updated as we apply an Identity LUT and we
// apply only the bootstrap directly on the underlying ciphertext, we have
// to update the noise however.
correct_key_ct.set_noise_level(NoiseLevel::NOMINAL);
}
}

output
correct_key_ct
}
};

ret
let degree_after_keyswitch = res.degree;
match cast_rshift.cmp(&0) {
// Same bit size: only key switch
Ordering::Equal => {
// Refresh if we haven't applied a PBS yet
if res.noise_level() == NoiseLevel::UNKNOWN {
let acc = self.dest_server_key.generate_lookup_table(|x| x);
self.dest_server_key
.apply_lookup_table_assign(&mut res, &acc);
// We apply an Identity LUT so we know a tighter bound than the worst case LUT
// value
res.degree = degree_after_keyswitch;
}
}
// Cast to bigger bit length: keyswitch, then right shift
Ordering::Greater => {
let acc = self
.dest_server_key
.generate_lookup_table(|n| n >> cast_rshift);
self.dest_server_key
.apply_lookup_table_assign(&mut res, &acc);
// degree and noise are updated by the apply lookup table
}
// Cast to smaller bit length: left shift, then keyswitch
Ordering::Less => {
// The degree is high in the source plaintext modulus, but smaller in the arriving
// one.
//
// src 4 bits:
// 0 | XX | 11
// shifted:
// 0 | 11 | 00 -> Applied lut will have max degree 1100 = 12
// dst 2 bits :
// 0 | 11 -> 11 = 3
let new_degree = Degree::new(degree_after_keyswitch.get() >> -cast_rshift);
// Refresh if we haven't applied a PBS yet
if res.noise_level() == NoiseLevel::UNKNOWN {
let acc = self.dest_server_key.generate_lookup_table(|x| x);
self.dest_server_key
.apply_lookup_table_assign(&mut res, &acc);
}
// Apply the degree correction, even if we bootstrapped as the Identity LUT would
// not change this correction
res.degree = new_degree;
}
}

res
}
}

Expand Down
5 changes: 5 additions & 0 deletions tfhe/src/shortint/key_switching_key/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ fn gen_multi_keys_test_truncate_ci_run_filter() {
// Message 0 Carry 0
let cipher = ck1.unchecked_encrypt(0);
let output_of_cast = ksk.cast(&cipher);
assert_eq!(output_of_cast.degree.get(), 3);
let clear = ck2.decrypt(&output_of_cast);
assert_eq!(clear, 0);
let ct_carry = sk2.carry_extract(&output_of_cast);
Expand All @@ -190,6 +191,7 @@ fn gen_multi_keys_test_truncate_ci_run_filter() {
// Message 1 Carry 0
let cipher = ck1.unchecked_encrypt(1);
let output_of_cast = ksk.cast(&cipher);
assert_eq!(output_of_cast.degree.get(), 3);
let clear = ck2.decrypt(&output_of_cast);
assert_eq!(clear, 1);
let ct_carry = sk2.carry_extract(&output_of_cast);
Expand All @@ -199,6 +201,7 @@ fn gen_multi_keys_test_truncate_ci_run_filter() {
// Message 0 Carry 1
let cipher = ck1.unchecked_encrypt(2);
let output_of_cast = ksk.cast(&cipher);
assert_eq!(output_of_cast.degree.get(), 3);
let clear = ck2.decrypt(&output_of_cast);
assert_eq!(clear, 0);
let ct_carry = sk2.carry_extract(&output_of_cast);
Expand All @@ -208,6 +211,7 @@ fn gen_multi_keys_test_truncate_ci_run_filter() {
// Message 1 Carry 1
let cipher = ck1.unchecked_encrypt(3);
let output_of_cast = ksk.cast(&cipher);
assert_eq!(output_of_cast.degree.get(), 3);
let clear = ck2.decrypt(&output_of_cast);
assert_eq!(clear, 1);
let ct_carry = sk2.carry_extract(&output_of_cast);
Expand All @@ -222,6 +226,7 @@ fn gen_multi_keys_test_truncate_ci_run_filter() {
assert_eq!((clear, carry), (0, 3));

let output_of_cast = ksk.cast(&cipher);
assert_eq!(output_of_cast.degree.get(), 3);
let clear = ck2.decrypt(&output_of_cast);
assert_eq!(clear, 0);
let ct_carry = sk2.carry_extract(&output_of_cast);
Expand Down

0 comments on commit 2588947

Please sign in to comment.