diff --git a/tfhe/src/shortint/key_switching_key/mod.rs b/tfhe/src/shortint/key_switching_key/mod.rs index d66494cb23..256472d490 100644 --- a/tfhe/src/shortint/key_switching_key/mod.rs +++ b/tfhe/src/shortint/key_switching_key/mod.rs @@ -500,36 +500,17 @@ impl<'keys> KeySwitchingKeyView<'keys> { .dest_server_key .unchecked_create_trivial_with_lwe_size(0, output_lwe_size); + // TODO: We are outside the standard AP, if we chain keyswitches, we will refresh, which is + // safer for now. We can likely add an additional flag in shortint to indicate if we + // want to refresh or not, for now refresh anyways. + keyswitched.set_noise_level(NoiseLevel::UNKNOWN); + let cast_rshift = self.key_switching_key_material.cast_rshift; - match cast_rshift.cmp(&0) { - // Same bit size: only key switch - Ordering::Equal => { - keyswitch_lwe_ciphertext( - self.key_switching_key_material.key_switching_key, - &input_ct.ct, - &mut keyswitched.ct, - ); - keyswitched.degree = input_ct.degree; - // We don't really know where we stand in terms of noise here - keyswitched.set_noise_level(NoiseLevel::UNKNOWN); - } - // Cast to bigger bit length: keyswitch, then right shift - Ordering::Greater => { - keyswitch_lwe_ciphertext( - self.key_switching_key_material.key_switching_key, - &input_ct.ct, - &mut keyswitched.ct, - ); + // First pre process + let tmp_preprocessed: Ciphertext; - let acc = self - .dest_server_key - .generate_lookup_table(|n| n >> cast_rshift); - self.dest_server_key - .apply_lookup_table_assign(&mut keyswitched, &acc); - // degree updated by the apply lookup table - keyswitched.set_noise_level(NoiseLevel::NOMINAL); - } + let pre_processed = match cast_rshift.cmp(&0) { // Cast to smaller bit length: left shift, then keyswitch Ordering::Less => { let src_server_key = self.src_server_key.as_ref().expect( @@ -541,38 +522,32 @@ impl<'keys> KeySwitchingKeyView<'keys> { (n << -cast_rshift) % (input_ct.carry_modulus.0 * input_ct.message_modulus.0) as u64 }); - let shifted_cipher = src_server_key.apply_lookup_table(input_ct, &acc); - - keyswitch_lwe_ciphertext( - self.key_switching_key_material.key_switching_key, - &shifted_cipher.ct, - &mut keyswitched.ct, - ); - // The degree is high in the source plaintext modulus, but smaller in the arriving - // one. - // - // src 4 bits: - // 0 | XX | 11 - // shifted: - // 0 | 11 | 00 -> Applied lut will have max degree 1100 = 12 - // dst 2 bits : - // 0 | 11 -> 11 = 3 - keyswitched.degree = Degree::new(shifted_cipher.degree.get() >> -cast_rshift); - // We don't really know where we stand in terms of noise here - keyswitched.set_noise_level(NoiseLevel::UNKNOWN); + tmp_preprocessed = src_server_key.apply_lookup_table(input_ct, &acc); + &tmp_preprocessed } - } + // No pre-processing + Ordering::Equal | Ordering::Greater => input_ct, + }; + + // The keyswitch + keyswitch_lwe_ciphertext( + self.key_switching_key_material.key_switching_key, + &pre_processed.ct, + &mut keyswitched.ct, + ); + keyswitched.degree = pre_processed.degree; - let ret = { + // Manage the destination key adjustment + let mut res = { let destination_pbs_order: PBSOrder = self.key_switching_key_material.destination_key.into(); if destination_pbs_order == self.dest_server_key.pbs_order { keyswitched } else { let wrong_key_ct = keyswitched; - let mut output = self.dest_server_key.create_trivial(0); - output.degree = wrong_key_ct.degree; - output.set_noise_level(wrong_key_ct.noise_level()); + let mut correct_key_ct = self.dest_server_key.create_trivial(0); + correct_key_ct.degree = wrong_key_ct.degree; + correct_key_ct.set_noise_level(wrong_key_ct.noise_level()); // We are arriving under the wrong key for the dest_server_key match self.key_switching_key_material.destination_key { @@ -581,9 +556,8 @@ impl<'keys> KeySwitchingKeyView<'keys> { keyswitch_lwe_ciphertext( &self.dest_server_key.key_switching_key, &wrong_key_ct.ct, - &mut output.ct, + &mut correct_key_ct.ct, ); - // TODO refresh ? } // Small to Big == PBS EncryptionKeyChoice::Small => { @@ -593,20 +567,70 @@ impl<'keys> KeySwitchingKeyView<'keys> { apply_programmable_bootstrap( &self.dest_server_key.bootstrapping_key, &wrong_key_ct.ct, - &mut output.ct, + &mut correct_key_ct.ct, &acc.acc, buffers, ); }); - output.set_noise_level(NoiseLevel::NOMINAL); + // Degree does not need to be updated as we apply an Identity LUT and we + // apply only the bootstrap directly on the underlying ciphertext, we have + // to update the noise however. + correct_key_ct.set_noise_level(NoiseLevel::NOMINAL); } } - output + correct_key_ct } }; - ret + let degree_after_keyswitch = res.degree; + match cast_rshift.cmp(&0) { + // Same bit size: only key switch + Ordering::Equal => { + // Refresh if we haven't applied a PBS yet + if res.noise_level() == NoiseLevel::UNKNOWN { + let acc = self.dest_server_key.generate_lookup_table(|x| x); + self.dest_server_key + .apply_lookup_table_assign(&mut res, &acc); + // We apply an Identity LUT so we know a tighter bound than the worst case LUT + // value + res.degree = degree_after_keyswitch; + } + } + // Cast to bigger bit length: keyswitch, then right shift + Ordering::Greater => { + let acc = self + .dest_server_key + .generate_lookup_table(|n| n >> cast_rshift); + self.dest_server_key + .apply_lookup_table_assign(&mut res, &acc); + // degree and noise are updated by the apply lookup table + } + // Cast to smaller bit length: left shift, then keyswitch + Ordering::Less => { + // The degree is high in the source plaintext modulus, but smaller in the arriving + // one. + // + // src 4 bits: + // 0 | XX | 11 + // shifted: + // 0 | 11 | 00 -> Applied lut will have max degree 1100 = 12 + // dst 2 bits : + // 0 | 11 -> 11 = 3 + let new_degree = Degree::new(degree_after_keyswitch.get() >> -cast_rshift); + // Refresh if we haven't applied a PBS yet + if res.noise_level() == NoiseLevel::UNKNOWN { + let acc = self.dest_server_key.generate_lookup_table(|x| x); + self.dest_server_key + .apply_lookup_table_assign(&mut res, &acc); + } + // Apply the degree correction, even if we bootstrapped as the Identity LUT would + // not change this correction + res.degree = new_degree; + } + } + + res } } diff --git a/tfhe/src/shortint/key_switching_key/test.rs b/tfhe/src/shortint/key_switching_key/test.rs index 8a32f86868..3da0028690 100644 --- a/tfhe/src/shortint/key_switching_key/test.rs +++ b/tfhe/src/shortint/key_switching_key/test.rs @@ -181,6 +181,7 @@ fn gen_multi_keys_test_truncate_ci_run_filter() { // Message 0 Carry 0 let cipher = ck1.unchecked_encrypt(0); let output_of_cast = ksk.cast(&cipher); + assert_eq!(output_of_cast.degree.get(), 3); let clear = ck2.decrypt(&output_of_cast); assert_eq!(clear, 0); let ct_carry = sk2.carry_extract(&output_of_cast); @@ -190,6 +191,7 @@ fn gen_multi_keys_test_truncate_ci_run_filter() { // Message 1 Carry 0 let cipher = ck1.unchecked_encrypt(1); let output_of_cast = ksk.cast(&cipher); + assert_eq!(output_of_cast.degree.get(), 3); let clear = ck2.decrypt(&output_of_cast); assert_eq!(clear, 1); let ct_carry = sk2.carry_extract(&output_of_cast); @@ -199,6 +201,7 @@ fn gen_multi_keys_test_truncate_ci_run_filter() { // Message 0 Carry 1 let cipher = ck1.unchecked_encrypt(2); let output_of_cast = ksk.cast(&cipher); + assert_eq!(output_of_cast.degree.get(), 3); let clear = ck2.decrypt(&output_of_cast); assert_eq!(clear, 0); let ct_carry = sk2.carry_extract(&output_of_cast); @@ -208,6 +211,7 @@ fn gen_multi_keys_test_truncate_ci_run_filter() { // Message 1 Carry 1 let cipher = ck1.unchecked_encrypt(3); let output_of_cast = ksk.cast(&cipher); + assert_eq!(output_of_cast.degree.get(), 3); let clear = ck2.decrypt(&output_of_cast); assert_eq!(clear, 1); let ct_carry = sk2.carry_extract(&output_of_cast); @@ -222,6 +226,7 @@ fn gen_multi_keys_test_truncate_ci_run_filter() { assert_eq!((clear, carry), (0, 3)); let output_of_cast = ksk.cast(&cipher); + assert_eq!(output_of_cast.degree.get(), 3); let clear = ck2.decrypt(&output_of_cast); assert_eq!(clear, 0); let ct_carry = sk2.carry_extract(&output_of_cast);