diff --git a/tfhe/src/core_crypto/algorithms/glwe_keyswitch.rs b/tfhe/src/core_crypto/algorithms/glwe_keyswitch.rs new file mode 100644 index 0000000000..19508d4e2b --- /dev/null +++ b/tfhe/src/core_crypto/algorithms/glwe_keyswitch.rs @@ -0,0 +1,304 @@ +//! Module containing primitives pertaining to [`GLWE ciphertext +//! keyswitch`](`GlweKeyswitchKey#glwe-keyswitch`). + +use crate::core_crypto::algorithms::polynomial_algorithms::*; +use crate::core_crypto::commons::math::decomposition::{ + SignedDecomposer, SignedDecomposerNonNative, +}; +use crate::core_crypto::commons::numeric::UnsignedInteger; +use crate::core_crypto::commons::traits::*; +use crate::core_crypto::entities::*; + +/// Keyswitch a [`GLWE ciphertext`](`GlweCiphertext`) encrypted under a +/// [`GLWE secret key`](`GlweSecretKey`) to another [`GLWE secret key`](`GlweSecretKey`). +/// +/// # Formal Definition +/// +/// See [`GLWE keyswitch key`](`GlweKeyswitchKey#glwe-keyswitch`). +/// +/// # Example +/// +/// ``` +/// use tfhe::core_crypto::prelude::*; +/// +/// // DISCLAIMER: these toy example parameters are not guaranteed to be secure or yield correct +/// // computations +/// // Define parameters for GlweKeyswitchKey creation +/// let input_glwe_dimension = GlweDimension(2); +/// let poly_size = PolynomialSize(512); +/// let glwe_noise_distribution = Gaussian::from_dispersion_parameter( +/// StandardDev(0.00000000000000000000007069849454709433), +/// 0.0, +/// ); +/// let output_glwe_dimension = GlweDimension(1); +/// let decomp_base_log = DecompositionBaseLog(21); +/// let decomp_level_count = DecompositionLevelCount(2); +/// let ciphertext_modulus = CiphertextModulus::new_native(); +/// let delta = 1 << 59; +/// +/// // Create the PRNG +/// let mut seeder = new_seeder(); +/// let seeder = seeder.as_mut(); +/// let mut encryption_generator = +/// EncryptionRandomGenerator::::new(seeder.seed(), seeder); +/// let mut secret_generator = +/// SecretRandomGenerator::::new(seeder.seed()); +/// +/// // Create the LweSecretKey +/// let input_glwe_secret_key = allocate_and_generate_new_binary_glwe_secret_key( +/// input_glwe_dimension, +/// poly_size, +/// &mut secret_generator, +/// ); +/// let output_glwe_secret_key = allocate_and_generate_new_binary_glwe_secret_key( +/// output_glwe_dimension, +/// poly_size, +/// &mut secret_generator, +/// ); +/// +/// let ksk = allocate_and_generate_new_glwe_keyswitch_key( +/// &input_glwe_secret_key, +/// &output_glwe_secret_key, +/// decomp_base_log, +/// decomp_level_count, +/// glwe_noise_distribution, +/// ciphertext_modulus, +/// &mut encryption_generator, +/// ); +/// +/// // Create the plaintext +/// let msg = 3u64; +/// let plaintext_list = PlaintextList::new(msg * delta, PlaintextCount(poly_size.0)); +/// +/// // Create a new GlweCiphertext +/// let mut input_glwe = GlweCiphertext::new( +/// 0u64, +/// input_glwe_dimension.to_glwe_size(), +/// poly_size, +/// ciphertext_modulus, +/// ); +/// +/// encrypt_glwe_ciphertext( +/// &input_glwe_secret_key, +/// &mut input_glwe, +/// &plaintext_list, +/// glwe_noise_distribution, +/// &mut encryption_generator, +/// ); +/// +/// let mut output_glwe = GlweCiphertext::new( +/// 0u64, +/// output_glwe_secret_key.glwe_dimension().to_glwe_size(), +/// output_glwe_secret_key.polynomial_size(), +/// ciphertext_modulus, +/// ); +/// +/// keyswitch_glwe_ciphertext(&ksk, &input_glwe, &mut output_glwe); +/// +/// // Round and remove encoding +/// // First create a decomposer working on the high 5 bits corresponding to our encoding. +/// let decomposer = SignedDecomposer::new(DecompositionBaseLog(5), DecompositionLevelCount(1)); +/// +/// let mut output_plaintext_list = PlaintextList::new(0u64, plaintext_list.plaintext_count()); +/// +/// decrypt_glwe_ciphertext( +/// &output_glwe_secret_key, +/// &output_glwe, +/// &mut output_plaintext_list, +/// ); +/// +/// // Get the raw vector +/// let mut cleartext_list = output_plaintext_list.into_container(); +/// // Remove the encoding +/// cleartext_list +/// .iter_mut() +/// .for_each(|elt| *elt = decomposer.decode_plaintext(*elt)); +/// // Get the list immutably +/// let cleartext_list = cleartext_list; +/// +/// // Check we recovered the original message for each plaintext we encrypted +/// cleartext_list.iter().for_each(|&elt| assert_eq!(elt, msg)); +/// ``` +pub fn keyswitch_glwe_ciphertext( + glwe_keyswitch_key: &GlweKeyswitchKey, + input_glwe_ciphertext: &GlweCiphertext, + output_glwe_ciphertext: &mut GlweCiphertext, +) where + Scalar: UnsignedInteger, + KSKCont: Container, + InputCont: Container, + OutputCont: ContainerMut, +{ + if glwe_keyswitch_key + .ciphertext_modulus() + .is_compatible_with_native_modulus() + { + keyswitch_glwe_ciphertext_native_mod_compatible( + glwe_keyswitch_key, + input_glwe_ciphertext, + output_glwe_ciphertext, + ) + } else { + keyswitch_glwe_ciphertext_other_mod( + glwe_keyswitch_key, + input_glwe_ciphertext, + output_glwe_ciphertext, + ) + } +} + +pub fn keyswitch_glwe_ciphertext_native_mod_compatible( + glwe_keyswitch_key: &GlweKeyswitchKey, + input_glwe_ciphertext: &GlweCiphertext, + output_glwe_ciphertext: &mut GlweCiphertext, +) where + Scalar: UnsignedInteger, + KSKCont: Container, + InputCont: Container, + OutputCont: ContainerMut, +{ + assert!( + glwe_keyswitch_key.input_key_glwe_dimension() + == input_glwe_ciphertext.glwe_size().to_glwe_dimension(), + "Mismatched input GlweDimension. \ + GlweKeyswitchKey input GlweDimension: {:?}, input GlweCiphertext GlweDimension {:?}.", + glwe_keyswitch_key.input_key_glwe_dimension(), + input_glwe_ciphertext.glwe_size().to_glwe_dimension(), + ); + assert!( + glwe_keyswitch_key.output_key_glwe_dimension() + == output_glwe_ciphertext.glwe_size().to_glwe_dimension(), + "Mismatched output GlweDimension. \ + GlweKeyswitchKey output GlweDimension: {:?}, output GlweCiphertext GlweDimension {:?}.", + glwe_keyswitch_key.output_key_glwe_dimension(), + output_glwe_ciphertext.glwe_size().to_glwe_dimension(), + ); + assert!( + glwe_keyswitch_key.polynomial_size() == input_glwe_ciphertext.polynomial_size(), + "Mismatched input PolynomialSize. \ + GlweKeyswithcKey input PolynomialSize: {:?}, input GlweCiphertext PolynomialSize {:?}.", + glwe_keyswitch_key.polynomial_size(), + input_glwe_ciphertext.polynomial_size(), + ); + assert!( + glwe_keyswitch_key.polynomial_size() == output_glwe_ciphertext.polynomial_size(), + "Mismatched output PolynomialSize. \ + GlweKeyswitchKey output PolynomialSize: {:?}, output GlweCiphertext PolynomialSize {:?}.", + glwe_keyswitch_key.polynomial_size(), + output_glwe_ciphertext.polynomial_size(), + ); + assert!(glwe_keyswitch_key + .ciphertext_modulus() + .is_compatible_with_native_modulus()); + + // Clear the output ciphertext, as it will get updated gradually + output_glwe_ciphertext.as_mut().fill(Scalar::ZERO); + + // Copy the input body to the output ciphertext + polynomial_wrapping_add_assign( + &mut output_glwe_ciphertext.get_mut_body().as_mut_polynomial(), + &input_glwe_ciphertext.get_body().as_polynomial(), + ); + + // We instantiate a decomposer + let decomposer = SignedDecomposer::new( + glwe_keyswitch_key.decomposition_base_log(), + glwe_keyswitch_key.decomposition_level_count(), + ); + + for (keyswitch_key_block, input_mask_element) in glwe_keyswitch_key + .iter() + .zip(input_glwe_ciphertext.get_mask().as_polynomial_list().iter()) + { + let mut decomposition_iter = decomposer.decompose_slice(input_mask_element.as_ref()); + // loop over the number of levels + for level_key_ciphertext in keyswitch_key_block.iter() { + let decomposed = decomposition_iter.next_term().unwrap(); + polynomial_list_wrapping_sub_scalar_mul_assign( + &mut output_glwe_ciphertext.as_mut_polynomial_list(), + &level_key_ciphertext.as_polynomial_list(), + &Polynomial::from_container(decomposed.as_slice()), + ); + } + } +} + +pub fn keyswitch_glwe_ciphertext_other_mod( + glwe_keyswitch_key: &GlweKeyswitchKey, + input_glwe_ciphertext: &GlweCiphertext, + output_glwe_ciphertext: &mut GlweCiphertext, +) where + Scalar: UnsignedInteger, + KSKCont: Container, + InputCont: Container, + OutputCont: ContainerMut, +{ + assert!( + glwe_keyswitch_key.input_key_glwe_dimension() + == input_glwe_ciphertext.glwe_size().to_glwe_dimension(), + "Mismatched input GlweDimension. \ + GlweKeyswitchKey input GlweDimension: {:?}, input GlweCiphertext GlweDimension {:?}.", + glwe_keyswitch_key.input_key_glwe_dimension(), + input_glwe_ciphertext.glwe_size().to_glwe_dimension(), + ); + assert!( + glwe_keyswitch_key.output_key_glwe_dimension() + == output_glwe_ciphertext.glwe_size().to_glwe_dimension(), + "Mismatched output GlweDimension. \ + GlweKeyswitchKey output GlweDimension: {:?}, output GlweCiphertext GlweDimension {:?}.", + glwe_keyswitch_key.output_key_glwe_dimension(), + output_glwe_ciphertext.glwe_size().to_glwe_dimension(), + ); + assert!( + glwe_keyswitch_key.polynomial_size() == input_glwe_ciphertext.polynomial_size(), + "Mismatched input PolynomialSize. \ + GlweKeyswithcKey input PolynomialSize: {:?}, input GlweCiphertext PolynomialSize {:?}.", + glwe_keyswitch_key.polynomial_size(), + input_glwe_ciphertext.polynomial_size(), + ); + assert!( + glwe_keyswitch_key.polynomial_size() == output_glwe_ciphertext.polynomial_size(), + "Mismatched output PolynomialSize. \ + GlweKeyswitchKey output PolynomialSize: {:?}, output GlweCiphertext PolynomialSize {:?}.", + glwe_keyswitch_key.polynomial_size(), + output_glwe_ciphertext.polynomial_size(), + ); + let ciphertext_modulus = glwe_keyswitch_key.ciphertext_modulus(); + assert!(!ciphertext_modulus.is_compatible_with_native_modulus()); + + // Clear the output ciphertext, as it will get updated gradually + output_glwe_ciphertext.as_mut().fill(Scalar::ZERO); + + // Copy the input body to the output ciphertext (no need to use non native addition here) + polynomial_wrapping_add_assign( + &mut output_glwe_ciphertext.get_mut_body().as_mut_polynomial(), + &input_glwe_ciphertext.get_body().as_polynomial(), + ); + + // We instantiate a decomposer + let decomposer = SignedDecomposerNonNative::new( + glwe_keyswitch_key.decomposition_base_log(), + glwe_keyswitch_key.decomposition_level_count(), + ciphertext_modulus, + ); + + let mut scalar_poly = Polynomial::new(Scalar::ZERO, input_glwe_ciphertext.polynomial_size()); + + for (keyswitch_key_block, input_mask_element) in glwe_keyswitch_key + .iter() + .zip(input_glwe_ciphertext.get_mask().as_polynomial_list().iter()) + { + let mut decomposition_iter = decomposer.decompose_slice(input_mask_element.as_ref()); + // loop over the number of levels + for level_key_ciphertext in keyswitch_key_block.iter() { + let decomposed = decomposition_iter.next_term().unwrap(); + decomposed.modular_value(scalar_poly.as_mut()); + polynomial_list_wrapping_sub_scalar_mul_assign_custom_mod( + &mut output_glwe_ciphertext.as_mut_polynomial_list(), + &level_key_ciphertext.as_polynomial_list(), + &scalar_poly, + ciphertext_modulus.get_custom_modulus().cast_into(), + ); + } + } +} diff --git a/tfhe/src/core_crypto/algorithms/glwe_keyswitch_key_generation.rs b/tfhe/src/core_crypto/algorithms/glwe_keyswitch_key_generation.rs new file mode 100644 index 0000000000..96dbeb455b --- /dev/null +++ b/tfhe/src/core_crypto/algorithms/glwe_keyswitch_key_generation.rs @@ -0,0 +1,352 @@ +//! Module containing primitives pertaining to [`GLWE keyswitch key generation`](`GlweKeyswitchKey`) + +use crate::core_crypto::algorithms::slice_algorithms::slice_wrapping_scalar_div_assign; +use crate::core_crypto::algorithms::*; +use crate::core_crypto::commons::generators::EncryptionRandomGenerator; +use crate::core_crypto::commons::math::decomposition::{ + DecompositionLevel, DecompositionTermSlice, DecompositionTermSliceNonNative, +}; +use crate::core_crypto::commons::math::random::{Distribution, Uniform}; +use crate::core_crypto::commons::parameters::*; +use crate::core_crypto::commons::traits::*; +use crate::core_crypto::entities::*; + +/// Fill a [`GLWE keyswitch key`](`GlweKeyswitchKey`) with an actual keyswitching key constructed +/// from an input and an output key [`GLWE secret key`](`GlweSecretKey`). +/// +/// ``` +/// use tfhe::core_crypto::prelude::*; +/// +/// // DISCLAIMER: these toy example parameters are not guaranteed to be secure or yield correct +/// // computations +/// // Define parameters for GlweKeyswitchKey creation +/// let input_glwe_dimension = GlweDimension(2); +/// let polynomial_size = PolynomialSize(1024); +/// let glwe_noise_distribution = +/// Gaussian::from_dispersion_parameter(StandardDev(0.000007069849454709433), 0.0); +/// let output_glwe_dimension = GlweDimension(1); +/// let decomp_base_log = DecompositionBaseLog(3); +/// let decomp_level_count = DecompositionLevelCount(5); +/// let ciphertext_modulus = CiphertextModulus::new_native(); +/// +/// // Create the PRNG +/// let mut seeder = new_seeder(); +/// let seeder = seeder.as_mut(); +/// let mut encryption_generator = +/// EncryptionRandomGenerator::::new(seeder.seed(), seeder); +/// let mut secret_generator = +/// SecretRandomGenerator::::new(seeder.seed()); +/// +/// // Create the GlweSecretKey +/// let input_glwe_secret_key = allocate_and_generate_new_binary_glwe_secret_key( +/// input_glwe_dimension, +/// polynomial_size, +/// &mut secret_generator, +/// ); +/// let output_glwe_secret_key = allocate_and_generate_new_binary_glwe_secret_key( +/// output_glwe_dimension, +/// polynomial_size, +/// &mut secret_generator, +/// ); +/// +/// let mut ksk = GlweKeyswitchKey::new( +/// 0u64, +/// decomp_base_log, +/// decomp_level_count, +/// input_glwe_dimension, +/// output_glwe_dimension, +/// polynomial_size, +/// ciphertext_modulus, +/// ); +/// +/// generate_glwe_keyswitch_key( +/// &input_glwe_secret_key, +/// &output_glwe_secret_key, +/// &mut ksk, +/// glwe_noise_distribution, +/// &mut encryption_generator, +/// ); +/// +/// assert!(!ksk.as_ref().iter().all(|&x| x == 0)); +/// ``` +pub fn generate_glwe_keyswitch_key< + Scalar, + NoiseDistribution, + InputKeyCont, + OutputKeyCont, + KSKeyCont, + Gen, +>( + input_glwe_sk: &GlweSecretKey, + output_glwe_sk: &GlweSecretKey, + glwe_keyswitch_key: &mut GlweKeyswitchKey, + noise_distribution: NoiseDistribution, + generator: &mut EncryptionRandomGenerator, +) where + Scalar: Encryptable, + NoiseDistribution: Distribution, + InputKeyCont: Container, + OutputKeyCont: Container, + KSKeyCont: ContainerMut, + Gen: ByteRandomGenerator, +{ + let ciphertext_modulus = glwe_keyswitch_key.ciphertext_modulus(); + + if ciphertext_modulus.is_compatible_with_native_modulus() { + generate_glwe_keyswitch_key_native_mod_compatible( + input_glwe_sk, + output_glwe_sk, + glwe_keyswitch_key, + noise_distribution, + generator, + ) + } else { + generate_glwe_keyswitch_key_other_mod( + input_glwe_sk, + output_glwe_sk, + glwe_keyswitch_key, + noise_distribution, + generator, + ) + } +} + +pub fn generate_glwe_keyswitch_key_native_mod_compatible< + Scalar, + NoiseDistribution, + InputKeyCont, + OutputKeyCont, + KSKeyCont, + Gen, +>( + input_glwe_sk: &GlweSecretKey, + output_glwe_sk: &GlweSecretKey, + glwe_keyswitch_key: &mut GlweKeyswitchKey, + noise_distribution: NoiseDistribution, + generator: &mut EncryptionRandomGenerator, +) where + Scalar: Encryptable, + NoiseDistribution: Distribution, + InputKeyCont: Container, + OutputKeyCont: Container, + KSKeyCont: ContainerMut, + Gen: ByteRandomGenerator, +{ + assert!( + glwe_keyswitch_key.input_key_glwe_dimension() == input_glwe_sk.glwe_dimension(), + "The destination GlweKeyswitchKey input GlweDimension is not equal \ + to the input GlweSecretKey GlweDimension. Destination: {:?}, input: {:?}", + glwe_keyswitch_key.input_key_glwe_dimension(), + input_glwe_sk.glwe_dimension() + ); + assert!( + glwe_keyswitch_key.output_key_glwe_dimension() == output_glwe_sk.glwe_dimension(), + "The destination GlweKeyswitchKey output GlweDimension is not equal \ + to the output GlweSecretKey GlweDimension. Destination: {:?}, output: {:?}", + glwe_keyswitch_key.output_key_glwe_dimension(), + input_glwe_sk.glwe_dimension() + ); + assert!( + glwe_keyswitch_key.polynomial_size() == input_glwe_sk.polynomial_size(), + "The destination GlweKeyswitchKey input PolynomialSize is not equal \ + to the input GlweSecretKey PolynomialSize. Destination: {:?}, input: {:?}", + glwe_keyswitch_key.polynomial_size(), + input_glwe_sk.polynomial_size(), + ); + assert!( + glwe_keyswitch_key.polynomial_size() == output_glwe_sk.polynomial_size(), + "The destination GlweKeyswitchKey output PolynomialSize is not equal \ + to the output GlweSecretKey PolynomialSize. Destination: {:?}, output: {:?}", + glwe_keyswitch_key.polynomial_size(), + output_glwe_sk.polynomial_size(), + ); + + let decomp_base_log = glwe_keyswitch_key.decomposition_base_log(); + let decomp_level_count = glwe_keyswitch_key.decomposition_level_count(); + let ciphertext_modulus = glwe_keyswitch_key.ciphertext_modulus(); + assert!(ciphertext_modulus.is_compatible_with_native_modulus()); + + // Iterate over the input key elements and the destination glwe_keyswitch_key memory + for (input_key_polynomial, mut keyswitch_key_block) in input_glwe_sk + .as_polynomial_list() + .iter() + .zip(glwe_keyswitch_key.iter_mut()) + { + // The plaintexts used to encrypt a key element will be stored in this buffer + let mut decomposition_polynomials_buffer = PolynomialList::new( + Scalar::ZERO, + input_glwe_sk.polynomial_size(), + PolynomialCount(decomp_level_count.0), + ); + + // We fill the buffer with the powers of the key elmements + for (level, mut message_polynomial) in (1..=decomp_level_count.0) + .rev() + .map(DecompositionLevel) + .zip(decomposition_polynomials_buffer.as_mut_view().iter_mut()) + { + let term = + DecompositionTermSlice::new(level, decomp_base_log, input_key_polynomial.as_ref()); + term.fill_slice_with_recomposition_summand(message_polynomial.as_mut()); + slice_wrapping_scalar_div_assign( + message_polynomial.as_mut(), + ciphertext_modulus.get_power_of_two_scaling_to_native_torus(), + ); + } + + let decomposition_plaintexts_buffer = + PlaintextList::from_container(decomposition_polynomials_buffer.into_container()); + + encrypt_glwe_ciphertext_list( + output_glwe_sk, + &mut keyswitch_key_block, + &decomposition_plaintexts_buffer, + noise_distribution, + generator, + ); + } +} + +pub fn generate_glwe_keyswitch_key_other_mod< + Scalar, + NoiseDistribution, + InputKeyCont, + OutputKeyCont, + KSKeyCont, + Gen, +>( + input_glwe_sk: &GlweSecretKey, + output_glwe_sk: &GlweSecretKey, + glwe_keyswitch_key: &mut GlweKeyswitchKey, + noise_distribution: NoiseDistribution, + generator: &mut EncryptionRandomGenerator, +) where + Scalar: Encryptable, + NoiseDistribution: Distribution, + InputKeyCont: Container, + OutputKeyCont: Container, + KSKeyCont: ContainerMut, + Gen: ByteRandomGenerator, +{ + assert!( + glwe_keyswitch_key.input_key_glwe_dimension() == input_glwe_sk.glwe_dimension(), + "The destination GlweKeyswitchKey input GlweDimension is not equal \ + to the input GlweSecretKey GlweDimension. Destination: {:?}, input: {:?}", + glwe_keyswitch_key.input_key_glwe_dimension(), + input_glwe_sk.glwe_dimension() + ); + assert!( + glwe_keyswitch_key.output_key_glwe_dimension() == output_glwe_sk.glwe_dimension(), + "The destination GlweKeyswitchKey output GlweDimension is not equal \ + to the output GlweSecretKey GlweDimension. Destination: {:?}, output: {:?}", + glwe_keyswitch_key.output_key_glwe_dimension(), + input_glwe_sk.glwe_dimension() + ); + assert!( + glwe_keyswitch_key.polynomial_size() == input_glwe_sk.polynomial_size(), + "The destination GlweKeyswitchKey input PolynomialSize is not equal \ + to the input GlweSecretKey PolynomialSize. Destination: {:?}, input: {:?}", + glwe_keyswitch_key.polynomial_size(), + input_glwe_sk.polynomial_size(), + ); + assert!( + glwe_keyswitch_key.polynomial_size() == output_glwe_sk.polynomial_size(), + "The destination GlweKeyswitchKey output PolynomialSize is not equal \ + to the output GlweSecretKey PolynomialSize. Destination: {:?}, output: {:?}", + glwe_keyswitch_key.polynomial_size(), + output_glwe_sk.polynomial_size(), + ); + + let decomp_base_log = glwe_keyswitch_key.decomposition_base_log(); + let decomp_level_count = glwe_keyswitch_key.decomposition_level_count(); + let ciphertext_modulus = glwe_keyswitch_key.ciphertext_modulus(); + assert!(!ciphertext_modulus.is_compatible_with_native_modulus()); + + // Iterate over the input key elements and the destination glwe_keyswitch_key memory + for (input_key_polynomial, mut keyswitch_key_block) in input_glwe_sk + .as_polynomial_list() + .iter() + .zip(glwe_keyswitch_key.iter_mut()) + { + // The plaintexts used to encrypt a key element will be stored in this buffer + let mut decomposition_polynomials_buffer = PolynomialList::new( + Scalar::ZERO, + input_glwe_sk.polynomial_size(), + PolynomialCount(decomp_level_count.0), + ); + + // We fill the buffer with the powers of the key elmements + for (level, mut message_polynomial) in (1..=decomp_level_count.0) + .rev() + .map(DecompositionLevel) + .zip(decomposition_polynomials_buffer.as_mut_view().iter_mut()) + { + let term = DecompositionTermSliceNonNative::new( + level, + decomp_base_log, + input_key_polynomial.as_ref(), + ciphertext_modulus, + ); + term.to_approximate_recomposition_summand(message_polynomial.as_mut()); + } + + let decomposition_plaintexts_buffer = + PlaintextList::from_container(decomposition_polynomials_buffer.into_container()); + + encrypt_glwe_ciphertext_list( + output_glwe_sk, + &mut keyswitch_key_block, + &decomposition_plaintexts_buffer, + noise_distribution, + generator, + ); + } +} + +/// Allocate a new [`GLWE keyswitch key`](`GlweKeyswitchKey`) and fill it with an actual +/// keyswitching key constructed from an input and an output +/// [`GLWE secret key`](`GlweSecretKey`). +/// +/// See [`keyswitch_glwe_ciphertext`] for usage. +pub fn allocate_and_generate_new_glwe_keyswitch_key< + Scalar, + NoiseDistribution, + InputKeyCont, + OutputKeyCont, + Gen, +>( + input_glwe_sk: &GlweSecretKey, + output_glwe_sk: &GlweSecretKey, + decomp_base_log: DecompositionBaseLog, + decomp_level_count: DecompositionLevelCount, + noise_distribution: NoiseDistribution, + ciphertext_modulus: CiphertextModulus, + generator: &mut EncryptionRandomGenerator, +) -> GlweKeyswitchKeyOwned +where + Scalar: Encryptable, + NoiseDistribution: Distribution, + InputKeyCont: Container, + OutputKeyCont: Container, + Gen: ByteRandomGenerator, +{ + let mut new_glwe_keyswitch_key = GlweKeyswitchKeyOwned::new( + Scalar::ZERO, + decomp_base_log, + decomp_level_count, + input_glwe_sk.glwe_dimension(), + output_glwe_sk.glwe_dimension(), + output_glwe_sk.polynomial_size(), + ciphertext_modulus, + ); + + generate_glwe_keyswitch_key( + input_glwe_sk, + output_glwe_sk, + &mut new_glwe_keyswitch_key, + noise_distribution, + generator, + ); + + new_glwe_keyswitch_key +} diff --git a/tfhe/src/core_crypto/algorithms/mod.rs b/tfhe/src/core_crypto/algorithms/mod.rs index a94edac4ce..80ed9fe82a 100644 --- a/tfhe/src/core_crypto/algorithms/mod.rs +++ b/tfhe/src/core_crypto/algorithms/mod.rs @@ -5,6 +5,8 @@ pub mod ggsw_conversion; pub mod ggsw_encryption; pub mod glwe_encryption; +pub mod glwe_keyswitch; +pub mod glwe_keyswitch_key_generation; pub mod glwe_linear_algebra; pub mod glwe_sample_extraction; pub mod glwe_secret_key_generation; @@ -53,6 +55,8 @@ pub(crate) mod test; pub use ggsw_conversion::*; pub use ggsw_encryption::*; pub use glwe_encryption::*; +pub use glwe_keyswitch::*; +pub use glwe_keyswitch_key_generation::*; pub use glwe_linear_algebra::*; pub use glwe_sample_extraction::*; pub use glwe_secret_key_generation::*; diff --git a/tfhe/src/core_crypto/algorithms/polynomial_algorithms.rs b/tfhe/src/core_crypto/algorithms/polynomial_algorithms.rs index 43e4c39304..5fbb7ba2db 100644 --- a/tfhe/src/core_crypto/algorithms/polynomial_algorithms.rs +++ b/tfhe/src/core_crypto/algorithms/polynomial_algorithms.rs @@ -334,6 +334,30 @@ pub fn polynomial_wrapping_add_mul_assign_custom_mod( + output: &mut Polynomial, + scalar: Scalar, + custom_modulus: Scalar, +) where + Scalar: UnsignedInteger, + PolyCont: ContainerMut, +{ + slice_wrapping_scalar_mul_assign_custom_mod(output.as_mut(), scalar, custom_modulus) +} + /// Divides (mod $(X^{N}+1)$), the output polynomial with a monic monomial of a given degree i.e. /// $X^{degree}$. /// @@ -919,6 +943,63 @@ pub fn polynomial_wrapping_sub_mul_assign_custom_mod( + output_poly_list: &mut PolynomialList, + input_poly_list: &PolynomialList, + scalar_poly: &Polynomial, +) where + Scalar: UnsignedInteger, + OutputCont: ContainerMut, + InputCont: Container, + PolyCont: Container, +{ + assert_eq!( + output_poly_list.polynomial_size(), + input_poly_list.polynomial_size() + ); + assert_eq!( + output_poly_list.polynomial_count(), + input_poly_list.polynomial_count() + ); + for (mut output_poly, input_poly) in output_poly_list.iter_mut().zip(input_poly_list.iter()) { + polynomial_wrapping_sub_mul_assign(&mut output_poly, &input_poly, scalar_poly) + } +} + +pub fn polynomial_list_wrapping_sub_scalar_mul_assign_custom_mod< + Scalar, + InputCont, + OutputCont, + PolyCont, +>( + output_poly_list: &mut PolynomialList, + input_poly_list: &PolynomialList, + scalar_poly: &Polynomial, + custom_modulus: Scalar, +) where + Scalar: UnsignedInteger, + OutputCont: ContainerMut, + InputCont: Container, + PolyCont: Container, +{ + assert_eq!( + output_poly_list.polynomial_size(), + input_poly_list.polynomial_size() + ); + assert_eq!( + output_poly_list.polynomial_count(), + input_poly_list.polynomial_count() + ); + for (mut output_poly, input_poly) in output_poly_list.iter_mut().zip(input_poly_list.iter()) { + polynomial_wrapping_sub_mul_assign_custom_mod( + &mut output_poly, + &input_poly, + scalar_poly, + custom_modulus, + ) + } +} + /// Fill the output polynomial, with the result of the product of two polynomials, reduced modulo /// $(X^{N} + 1)$ with the schoolbook algorithm Complexity: $O(N^{2})$ /// diff --git a/tfhe/src/core_crypto/commons/math/decomposition/decomposer.rs b/tfhe/src/core_crypto/commons/math/decomposition/decomposer.rs index 05baf35518..cb082c4036 100644 --- a/tfhe/src/core_crypto/commons/math/decomposition/decomposer.rs +++ b/tfhe/src/core_crypto/commons/math/decomposition/decomposer.rs @@ -1,6 +1,7 @@ use crate::core_crypto::commons::ciphertext_modulus::CiphertextModulus; use crate::core_crypto::commons::math::decomposition::{ - SignedDecompositionIter, SignedDecompositionNonNativeIter, ValueSign, + SignedDecompositionIter, SignedDecompositionNonNativeIter, SliceSignedDecompositionIter, + SliceSignedDecompositionNonNativeIter, ValueSign, }; use crate::core_crypto::commons::numeric::{CastInto, UnsignedInteger}; use crate::core_crypto::commons::parameters::{DecompositionBaseLog, DecompositionLevelCount}; @@ -174,6 +175,56 @@ where res.wrapping_sub(need_balance << rep_bit_count) } + pub fn init_decomposer_state_slice(&self, input: &[Scalar], output: &mut [Scalar]) { + assert_eq!(input.len(), output.len()); + let rep_bit_count = self.level_count * self.base_log; + let non_rep_bit_count: usize = Scalar::BITS - rep_bit_count; + let mod_mask = Scalar::MAX >> non_rep_bit_count; + input + .iter() + .zip(output.iter_mut()) + .for_each(|(input, output)| { + *output = *input >> (non_rep_bit_count - 1); + let rounding_bit = *output & Scalar::ONE; + *output += Scalar::ONE; + *output >>= 1; + *output &= mod_mask; + let need_balance = + balanced_rounding_condition_bit_trick(*output, rep_bit_count, rounding_bit); + *output = output.wrapping_sub(need_balance << rep_bit_count) + }); + } + + /// Decode a plaintext value using the decoder to compute the closest representable. + pub fn decode_plaintext(&self, input: Scalar) -> Scalar { + let shift = Scalar::BITS - self.level_count * self.base_log; + self.closest_representable(input) >> shift + } + + /// Fills a mutable tensor-like objects with the closest representable values from another + /// tensor-like object. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::math::decomposition::SignedDecomposer; + /// use tfhe::core_crypto::prelude::{DecompositionBaseLog, DecompositionLevelCount}; + /// let decomposer = + /// SignedDecomposer::::new(DecompositionBaseLog(4), DecompositionLevelCount(3)); + /// + /// let input = vec![1_340_987_234_u32; 2]; + /// let mut closest = vec![0u32; 2]; + /// decomposer.fill_slice_with_closest_representable(&mut closest, &input); + /// assert!(closest.iter().all(|&x| x == 1_341_128_704_u32)); + /// ``` + pub fn fill_slice_with_closest_representable(&self, output: &mut [Scalar], input: &[Scalar]) { + assert_eq!(output.len(), input.len()); + output + .iter_mut() + .zip(input.iter()) + .for_each(|(dst, &src)| *dst = self.closest_representable(src)); + } + /// Generate an iterator over the terms of the decomposition of the input. /// /// # Warning @@ -242,6 +293,89 @@ where None } } + + /// Generates an iterator-like object over tensors of terms of the decomposition of the input + /// tensor. + /// + /// # Warning + /// + /// The returned iterator yields the terms $(\tilde{\theta}^{(a)}\_i)\_{a\in\mathbb{N}}$ in + /// order of decreasing $i$. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::math::decomposition::SignedDecomposer; + /// use tfhe::core_crypto::commons::numeric::UnsignedInteger; + /// use tfhe::core_crypto::prelude::{DecompositionBaseLog, DecompositionLevelCount}; + /// let decomposer = + /// SignedDecomposer::::new(DecompositionBaseLog(4), DecompositionLevelCount(3)); + /// let decomposable = vec![1_340_987_234_u32, 1_340_987_234_u32]; + /// let mut decomp = decomposer.decompose_slice(&decomposable); + /// + /// let mut count = 0; + /// while let Some(term) = decomp.next_term() { + /// assert!(1 <= term.level().0); + /// assert!(term.level().0 <= 3); + /// for elmt in term.as_slice().iter() { + /// let signed_term = elmt.into_signed(); + /// let half_basis = 2i32.pow(4) / 2i32; + /// assert!(-half_basis <= signed_term); + /// assert!(signed_term < half_basis); + /// } + /// count += 1; + /// } + /// assert_eq!(count, 3); + /// ``` + pub fn decompose_slice(&self, input: &[Scalar]) -> SliceSignedDecompositionIter { + // Note that there would be no sense of making the decomposition on an input which was + // not rounded to the closest representable first. We then perform it before decomposing. + let mut closest = vec![Scalar::ZERO; input.len()]; + self.init_decomposer_state_slice(input, &mut closest); + SliceSignedDecompositionIter::new( + &closest, + DecompositionBaseLog(self.base_log), + DecompositionLevelCount(self.level_count), + ) + } + + /// Fills the output tensor with the recomposition of another tensor. + /// + /// Returns `Some(())` if the decomposition was fresh, and the output was filled with a + /// recomposition, and `None`, if not. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::math::decomposition::SignedDecomposer; + /// use tfhe::core_crypto::prelude::{DecompositionBaseLog, DecompositionLevelCount}; + /// let decomposer = + /// SignedDecomposer::::new(DecompositionBaseLog(4), DecompositionLevelCount(3)); + /// let decomposable = vec![1_340_987_234_u32; 2]; + /// let mut rounded = vec![0u32; 2]; + /// decomposer.fill_slice_with_closest_representable(&mut rounded, &decomposable); + /// let decomp = decomposer.decompose_slice(&rounded); + /// let mut recomposition = vec![0u32; 2]; + /// decomposer + /// .fill_slice_with_recompose(decomp, &mut recomposition) + /// .unwrap(); + /// assert_eq!(recomposition, rounded); + /// ``` + pub fn fill_slice_with_recompose( + &self, + decomp: SliceSignedDecompositionIter, + output: &mut [Scalar], + ) -> Option<()> { + let mut decomp = decomp; + if decomp.is_fresh() { + while let Some(term) = decomp.next_term() { + term.update_slice_with_recomposition_summand_wrapping_addition(output); + } + Some(()) + } else { + None + } + } } /// A structure which allows to decompose unsigned integers into a set of smaller terms for moduli @@ -437,6 +571,26 @@ where } } + /// Decode a plaintext value using the decoder modulo a custom modulus. + pub fn decode_plaintext(&self, input: Scalar) -> Scalar { + let ciphertext_modulus_as_scalar: Scalar = + self.ciphertext_modulus.get_custom_modulus().cast_into(); + let mut negate_input = false; + let mut ptxt = input; + if input > ciphertext_modulus_as_scalar >> 1 { + negate_input = true; + ptxt = ptxt.wrapping_neg_custom_mod(ciphertext_modulus_as_scalar); + } + let number_of_message_bits = self.base_log().0 * self.level_count().0; + let delta = ciphertext_modulus_as_scalar >> number_of_message_bits; + let half_delta = delta >> 1; + let mut decoded = (ptxt + half_delta) / delta; + if negate_input { + decoded = decoded.wrapping_neg_custom_mod(ciphertext_modulus_as_scalar); + } + decoded + } + #[inline(always)] pub fn init_decomposer_state(&self, input: Scalar) -> (Scalar, ValueSign) { let ciphertext_modulus_as_scalar: Scalar = @@ -468,6 +622,36 @@ where (abs_closest_representable, input_sign) } + pub fn init_decomposer_state_slice( + &self, + input: &[Scalar], + output: &mut [Scalar], + signs: &mut [ValueSign], + ) { + assert_eq!(input.len(), output.len()); + assert_eq!(input.len(), signs.len()); + let ciphertext_modulus_as_scalar: Scalar = + self.ciphertext_modulus.get_custom_modulus().cast_into(); + let shift_to_native = Scalar::BITS - self.ciphertext_modulus_bit_count() as usize; + + input + .iter() + .zip(output.iter_mut()) + .zip(signs.iter_mut()) + .for_each(|((input, output), sign)| { + if *input < ciphertext_modulus_as_scalar.div_ceil(Scalar::TWO) { + (*output, *sign) = (*input, ValueSign::Positive) + } else { + (*output, *sign) = (ciphertext_modulus_as_scalar - *input, ValueSign::Negative) + }; + *output = native_closest_representable( + *output << shift_to_native, + self.level_count, + self.base_log, + ) >> shift_to_native + }); + } + /// Generate an iterator over the terms of the decomposition of the input. /// /// # Warning @@ -565,6 +749,153 @@ where None } } + + /// Fills a mutable tensor-like objects with the closest representable values from another + /// tensor-like object. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::math::decomposition::{SignedDecomposerNonNative, ValueSign}; + /// use tfhe::core_crypto::prelude::{ + /// CiphertextModulus, DecompositionBaseLog, DecompositionLevelCount, + /// }; + /// let decomposer = SignedDecomposerNonNative::new( + /// DecompositionBaseLog(4), + /// DecompositionLevelCount(3), + /// CiphertextModulus::try_new((1 << 48) + 1).unwrap(), + /// ); + /// + /// let input = vec![249280154129830u64; 2]; + /// let mut closest = vec![0u64; 2]; + /// let mut signs = vec![ValueSign::Positive; 2]; + /// decomposer.init_decomposer_state_slice(&input, &mut closest, &mut signs); + /// assert!(closest.iter().all(|&x| x == 32160715112448u64)); + /// decomposer.fill_slice_with_closest_representable(&mut closest, &input); + /// assert!(closest.iter().all(|&x| x == 249314261598209u64)); + /// ``` + pub fn fill_slice_with_closest_representable(&self, output: &mut [Scalar], input: &[Scalar]) { + assert_eq!(output.len(), input.len()); + let mut signs = vec![ValueSign::Positive; input.len()]; + self.init_decomposer_state_slice(input, output, &mut signs); + + let modulus_as_scalar: Scalar = self.ciphertext_modulus.get_custom_modulus().cast_into(); + output + .iter_mut() + .zip(signs.iter()) + .for_each(|(output, sign)| match sign { + ValueSign::Positive => (), + ValueSign::Negative => *output = output.wrapping_neg_custom_mod(modulus_as_scalar), + }); + } + + /// Generates an iterator-like object over tensors of terms of the decomposition of the input + /// tensor. + /// + /// # Warning + /// + /// The returned iterator yields the terms $(\tilde{\theta}^{(a)}\_i)\_{a\in\mathbb{N}}$ in + /// order of decreasing $i$. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::math::decomposition::SignedDecomposerNonNative; + /// use tfhe::core_crypto::commons::numeric::UnsignedInteger; + /// use tfhe::core_crypto::prelude::{ + /// CiphertextModulus, DecompositionBaseLog, DecompositionLevelCount, + /// }; + /// + /// let decomposition_base_log = DecompositionBaseLog(4); + /// let decomposition_level_count = DecompositionLevelCount(3); + /// let ciphertext_modulus = CiphertextModulus::try_new((1 << 64) - (1 << 32) + 1).unwrap(); + /// + /// let decomposer = SignedDecomposerNonNative::new( + /// decomposition_base_log, + /// decomposition_level_count, + /// ciphertext_modulus, + /// ); + /// + /// let basis = 2i64.pow(decomposition_base_log.0.try_into().unwrap()); + /// let half_basis = basis / 2; + /// + /// let decomposable = [9223372032559808513u64, 1u64 << 63]; + /// let mut decomp = decomposer.decompose_slice(&decomposable); + /// + /// let mut count = 0; + /// while let Some(term) = decomp.next_term() { + /// assert!(1 <= term.level().0); + /// assert!(term.level().0 <= 3); + /// for elmt in term.as_slice().iter() { + /// let signed_term = elmt.into_signed(); + /// assert!(-half_basis <= signed_term); + /// assert!(signed_term <= half_basis); + /// } + /// count += 1; + /// } + /// assert_eq!(count, 3); + /// ``` + pub fn decompose_slice( + &self, + input: &[Scalar], + ) -> SliceSignedDecompositionNonNativeIter { + let mut abs_closest_representables = vec![Scalar::ZERO; input.len()]; + let mut signs = vec![ValueSign::Positive; input.len()]; + self.init_decomposer_state_slice(input, &mut abs_closest_representables, &mut signs); + + SliceSignedDecompositionNonNativeIter::new( + &abs_closest_representables, + &signs, + DecompositionBaseLog(self.base_log), + DecompositionLevelCount(self.level_count), + self.ciphertext_modulus, + ) + } + + /// Fills the output tensor with the recomposition of another tensor. + /// + /// Returns `Some(())` if the decomposition was fresh, and the output was filled with a + /// recomposition, and `None`, if not. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::math::decomposition::SignedDecomposerNonNative; + /// use tfhe::core_crypto::prelude::{ + /// CiphertextModulus, DecompositionBaseLog, DecompositionLevelCount, + /// }; + /// + /// let ciphertext_modulus = CiphertextModulus::try_new((1 << 32) - (1 << 16) + 1).unwrap(); + /// let decomposer = SignedDecomposerNonNative::new( + /// DecompositionBaseLog(4), + /// DecompositionLevelCount(3), + /// ciphertext_modulus, + /// ); + /// let decomposable = vec![1_340_987_234_u32; 2]; + /// let mut rounded = vec![0u32; 2]; + /// decomposer.fill_slice_with_closest_representable(&mut rounded, &decomposable); + /// let decomp = decomposer.decompose_slice(&rounded); + /// let mut recomposition = vec![0u32; 2]; + /// decomposer + /// .fill_slice_with_recompose(decomp, &mut recomposition) + /// .unwrap(); + /// assert_eq!(recomposition, rounded); + /// ``` + pub fn fill_slice_with_recompose( + &self, + decomp: SliceSignedDecompositionNonNativeIter, + output: &mut [Scalar], + ) -> Option<()> { + let mut decomp = decomp; + if decomp.is_fresh() { + while let Some(term) = decomp.next_term() { + term.update_slice_with_recomposition_summand_wrapping_addition(output); + } + Some(()) + } else { + None + } + } } #[cfg(test)] diff --git a/tfhe/src/core_crypto/commons/math/decomposition/iter.rs b/tfhe/src/core_crypto/commons/math/decomposition/iter.rs index 865d51e794..7d44e08f73 100644 --- a/tfhe/src/core_crypto/commons/math/decomposition/iter.rs +++ b/tfhe/src/core_crypto/commons/math/decomposition/iter.rs @@ -1,6 +1,7 @@ use crate::core_crypto::commons::ciphertext_modulus::CiphertextModulus; use crate::core_crypto::commons::math::decomposition::{ - DecompositionLevel, DecompositionTerm, DecompositionTermNonNative, SignedDecomposerNonNative, + DecompositionLevel, DecompositionTerm, DecompositionTermNonNative, DecompositionTermSlice, + DecompositionTermSliceNonNative, SignedDecomposerNonNative, }; use crate::core_crypto::commons::numeric::UnsignedInteger; use crate::core_crypto::commons::parameters::{DecompositionBaseLog, DecompositionLevelCount}; @@ -148,6 +149,149 @@ pub(crate) fn decompose_one_level( res.wrapping_sub(carry << base_log) } +/// An iterator-like object that yields the terms of the signed decomposition of a tensor of values. +/// +/// # Note +/// +/// On each call to [`SliceSignedDecompositionIter::next_term`], this structure yields a new +/// [`DecompositionTermSlice`], backed by a `Vec` owned by the structure. This vec is mutated at +/// each call of the `next_term` method, and as such the term must be dropped before `next_term` is +/// called again. +/// +/// Such a pattern can not be implemented with iterators yet (without GATs), which is why this +/// iterator must be explicitly called. +/// +/// # Warning +/// +/// This iterator yields the decomposition in reverse order. That means that the highest level +/// will be yielded first. +pub struct SliceSignedDecompositionIter +where + T: UnsignedInteger, +{ + // The base log of the decomposition + base_log: usize, + // The number of levels of the decomposition + level_count: usize, + // The current level + current_level: usize, + // A mask which allows to compute the mod B of a value. For B=2^4, this guy is of the form: + // ...0001111 + mod_b_mask: T, + // The internal states of each decomposition + states: Vec, + // In order to avoid allocating a new Vec every time we yield a decomposition term, we store + // a Vec inside the structure and yield slices pointing to it. + outputs: Vec, + // A flag which stores whether the iterator is a fresh one (for the recompose method). + fresh: bool, +} + +impl SliceSignedDecompositionIter +where + T: UnsignedInteger, +{ + // Creates a new tensor decomposition iterator. + pub(crate) fn new( + input: &[T], + base_log: DecompositionBaseLog, + level: DecompositionLevelCount, + ) -> Self { + let len = input.len(); + Self { + base_log: base_log.0, + level_count: level.0, + current_level: level.0, + mod_b_mask: (T::ONE << base_log.0) - T::ONE, + outputs: vec![T::ZERO; len], + states: input.to_vec(), + fresh: true, + } + } + + pub(crate) fn is_fresh(&self) -> bool { + self.fresh + } + + /// Returns the logarithm in base two of the base of this decomposition. + /// + /// If the decomposition uses a base $B=2^b$, this returns $b$. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::math::decomposition::SignedDecomposer; + /// use tfhe::core_crypto::prelude::{DecompositionBaseLog, DecompositionLevelCount}; + /// let decomposer = + /// SignedDecomposer::::new(DecompositionBaseLog(4), DecompositionLevelCount(3)); + /// let decomposable = vec![1_340_987_234_u32; 2]; + /// let decomp = decomposer.decompose_slice(&decomposable); + /// assert_eq!(decomp.base_log(), DecompositionBaseLog(4)); + /// ``` + pub fn base_log(&self) -> DecompositionBaseLog { + DecompositionBaseLog(self.base_log) + } + + /// Returns the number of levels of this decomposition. + /// + /// If the decomposition uses $l$ levels, this returns $l$. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::math::decomposition::SignedDecomposer; + /// use tfhe::core_crypto::prelude::{DecompositionBaseLog, DecompositionLevelCount}; + /// let decomposer = + /// SignedDecomposer::::new(DecompositionBaseLog(4), DecompositionLevelCount(3)); + /// let decomposable = vec![1_340_987_234_u32; 2]; + /// let decomp = decomposer.decompose_slice(&decomposable); + /// assert_eq!(decomp.level_count(), DecompositionLevelCount(3)); + /// ``` + pub fn level_count(&self) -> DecompositionLevelCount { + DecompositionLevelCount(self.level_count) + } + + /// Yield the next term of the decomposition, if any. + /// + /// # Note + /// + /// Because this function returns a borrowed tensor, owned by the iterator, the term must be + /// dropped before `next_term` is called again. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::math::decomposition::{DecompositionLevel, SignedDecomposer}; + /// use tfhe::core_crypto::prelude::{DecompositionBaseLog, DecompositionLevelCount}; + /// let decomposer = + /// SignedDecomposer::::new(DecompositionBaseLog(4), DecompositionLevelCount(3)); + /// let decomposable = vec![1_340_987_234_u32; 2]; + /// let mut decomp = decomposer.decompose_slice(&decomposable); + /// let term = decomp.next_term().unwrap(); + /// assert_eq!(term.level(), DecompositionLevel(3)); + /// assert_eq!(term.as_slice()[0], 4294967295); + /// ``` + pub fn next_term(&mut self) -> Option> { + // The iterator is not fresh anymore. + self.fresh = false; + // We check if the decomposition is over + if self.current_level == 0 { + return None; + } + // We iterate over the elements of the outputs and decompose + for (output_i, state_i) in self.outputs.iter_mut().zip(self.states.iter_mut()) { + *output_i = decompose_one_level(self.base_log, state_i, self.mod_b_mask); + } + self.current_level -= 1; + // We return the term tensor. + Some(DecompositionTermSlice::new( + DecompositionLevel(self.current_level + 1), + DecompositionBaseLog(self.base_log), + &self.outputs, + )) + } +} + /// An iterator that yields the terms of the signed decomposition of an integer. /// /// # Warning @@ -293,6 +437,191 @@ where } } +/// An iterator-like object that yields the terms of the signed decomposition of a tensor of values. +/// +/// # Note +/// +/// On each call to [`SliceSignedDecompositionNonNativeIter::next_term`], this structure yields a +/// new +/// [`DecompositionTermSlice`], backed by a `Vec` owned by the structure. This vec is mutated at +/// each call of the `next_term` method, and as such the term must be dropped before `next_term` is +/// called again. +/// +/// Such a pattern can not be implemented with iterators yet (without GATs), which is why this +/// iterator must be explicitly called. +/// +/// # Warning +/// +/// This iterator yields the decomposition in reverse order. That means that the highest level +/// will be yielded first. +pub struct SliceSignedDecompositionNonNativeIter +where + T: UnsignedInteger, +{ + // The base log of the decomposition + base_log: usize, + // The number of levels of the decomposition + level_count: usize, + // The current level + current_level: usize, + // A mask which allows to compute the mod B of a value. For B=2^4, this guy is of the form: + // ...0001111 + mod_b_mask: T, + // Ciphertext modulus + ciphertext_modulus: CiphertextModulus, + // The internal states of each decomposition + states: Vec, + // In order to avoid allocating a new Vec every time we yield a decomposition term, we store + // a Vec inside the structure and yield slices pointing to it. + outputs: Vec, + // A flag which stores whether the iterator is a fresh one (for the recompose method). + fresh: bool, + // The signs of the input values, for the algorithm we use, returned values require adaption + // depending on the sign of the input + signs: Vec, +} + +impl SliceSignedDecompositionNonNativeIter +where + T: UnsignedInteger, +{ + // Creates a new tensor decomposition iterator. + pub(crate) fn new( + input: &[T], + input_signs: &[ValueSign], + base_log: DecompositionBaseLog, + level: DecompositionLevelCount, + ciphertext_modulus: CiphertextModulus, + ) -> Self { + Self { + base_log: base_log.0, + level_count: level.0, + current_level: level.0, + mod_b_mask: (T::ONE << base_log.0) - T::ONE, + ciphertext_modulus, + outputs: vec![T::ZERO; input.len()], + states: input + .iter() + .map(|i| { + *i >> (ciphertext_modulus.get_custom_modulus().ceil_ilog2() as usize + - base_log.0 * level.0) + }) + .collect(), + fresh: true, + signs: input_signs.to_vec(), + } + } + + pub(crate) fn is_fresh(&self) -> bool { + self.fresh + } + + /// Returns the logarithm in base two of the base of this decomposition. + /// + /// If the decomposition uses a base $B=2^b$, this returns $b$. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::math::decomposition::SignedDecomposerNonNative; + /// use tfhe::core_crypto::prelude::{ + /// CiphertextModulus, DecompositionBaseLog, DecompositionLevelCount, + /// }; + /// let decomposer = SignedDecomposerNonNative::::new( + /// DecompositionBaseLog(4), + /// DecompositionLevelCount(3), + /// CiphertextModulus::try_new((1 << 32) - (1 << 16) + 1).unwrap(), + /// ); + /// let decomposable = vec![1_340_987_234_u32; 2]; + /// let decomp = decomposer.decompose_slice(&decomposable); + /// assert_eq!(decomp.base_log(), DecompositionBaseLog(4)); + /// ``` + pub fn base_log(&self) -> DecompositionBaseLog { + DecompositionBaseLog(self.base_log) + } + + /// Returns the number of levels of this decomposition. + /// + /// If the decomposition uses $l$ levels, this returns $l$. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::math::decomposition::SignedDecomposerNonNative; + /// use tfhe::core_crypto::prelude::{ + /// CiphertextModulus, DecompositionBaseLog, DecompositionLevelCount, + /// }; + /// let decomposer = SignedDecomposerNonNative::::new( + /// DecompositionBaseLog(4), + /// DecompositionLevelCount(3), + /// CiphertextModulus::try_new((1 << 32) - (1 << 16) + 1).unwrap(), + /// ); + /// let decomposable = vec![1_340_987_234_u32; 2]; + /// let decomp = decomposer.decompose_slice(&decomposable); + /// assert_eq!(decomp.level_count(), DecompositionLevelCount(3)); + /// ``` + pub fn level_count(&self) -> DecompositionLevelCount { + DecompositionLevelCount(self.level_count) + } + + /// Yield the next term of the decomposition, if any. + /// + /// # Note + /// + /// Because this function returns a borrowed tensor, owned by the iterator, the term must be + /// dropped before `next_term` is called again. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::math::decomposition::{ + /// DecompositionLevel, SignedDecomposerNonNative, + /// }; + /// use tfhe::core_crypto::prelude::{ + /// CiphertextModulus, DecompositionBaseLog, DecompositionLevelCount, + /// }; + /// let decomposer = SignedDecomposerNonNative::::new( + /// DecompositionBaseLog(4), + /// DecompositionLevelCount(3), + /// CiphertextModulus::try_new((1 << 32) - (1 << 16) + 1).unwrap(), + /// ); + /// let decomposable = vec![1_340_987_234_u32; 2]; + /// let mut decomp = decomposer.decompose_slice(&decomposable); + /// let term = decomp.next_term().unwrap(); + /// assert_eq!(term.level(), DecompositionLevel(3)); + /// assert_eq!(term.as_slice()[0], u32::MAX); + /// ``` + pub fn next_term(&mut self) -> Option> { + // The iterator is not fresh anymore. + self.fresh = false; + // We check if the decomposition is over + if self.current_level == 0 { + return None; + } + // We iterate over the elements of the outputs and decompose + for ((output_i, state_i), sign_i) in self + .outputs + .iter_mut() + .zip(self.states.iter_mut()) + .zip(self.signs.iter()) + { + *output_i = decompose_one_level(self.base_log, state_i, self.mod_b_mask); + *output_i = match sign_i { + ValueSign::Positive => *output_i, + ValueSign::Negative => output_i.wrapping_neg(), + }; + } + self.current_level -= 1; + // We return the term tensor. + Some(DecompositionTermSliceNonNative::new( + DecompositionLevel(self.current_level + 1), + DecompositionBaseLog(self.base_log), + &self.outputs, + self.ciphertext_modulus, + )) + } +} + /// Specialized high performance implementation of a non native decomposer over a collection of /// elements, used notably in the PBS. pub struct TensorSignedDecompositionLendingIterNonNative<'buffers> { diff --git a/tfhe/src/core_crypto/commons/math/decomposition/term.rs b/tfhe/src/core_crypto/commons/math/decomposition/term.rs index b5146859f0..e7e5443e5f 100644 --- a/tfhe/src/core_crypto/commons/math/decomposition/term.rs +++ b/tfhe/src/core_crypto/commons/math/decomposition/term.rs @@ -223,3 +223,284 @@ where DecompositionLevel(self.level) } } + +/// A tensor whose elements are the terms of the decomposition of another tensor. +/// +/// If we decompose each elements of a set of values $(\theta^{(a)})\_{a\in\mathbb{N}}$ as a set of +/// sums $(\sum\_{i=1}^l\tilde{\theta}^{(a)}\_i\frac{q}{B^i})\_{a\in\mathbb{N}}$, this represents a +/// set of $(\tilde{\theta}^{(a)}\_i)\_{a\in\mathbb{N}}$. +#[derive(Debug, PartialEq, Eq, Clone)] +pub struct DecompositionTermSlice<'a, T> +where + T: UnsignedInteger, +{ + level: usize, + base_log: usize, + slice: &'a [T], +} + +impl<'a, T> DecompositionTermSlice<'a, T> +where + T: UnsignedInteger, +{ + // Creates a new tensor decomposition term. + pub(crate) fn new( + level: DecompositionLevel, + base_log: DecompositionBaseLog, + slice: &'a [T], + ) -> Self { + Self { + level: level.0, + base_log: base_log.0, + slice, + } + } + + /// Fills the output tensor with the terms turned to summands. + /// + /// If our term tensor represents a set of $(\tilde{\theta}^{(a)}\_i)\_{a\in\mathbb{N}}$ of the + /// decomposition, this method fills the output tensor with a set of + /// $(\tilde{\theta}^{(a)}\_i\frac{q}{B^i})\_{a\in\mathbb{N}}$. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::math::decomposition::SignedDecomposer; + /// use tfhe::core_crypto::prelude::{DecompositionBaseLog, DecompositionLevelCount}; + /// let decomposer = + /// SignedDecomposer::::new(DecompositionBaseLog(4), DecompositionLevelCount(3)); + /// let input = vec![2u32.pow(19); 2]; + /// let mut decomp = decomposer.decompose_slice(&input); + /// let term = decomp.next_term().unwrap(); + /// let mut output = vec![0u32; 2]; + /// term.fill_slice_with_recomposition_summand(&mut output); + /// assert!(output.iter().all(|&x| x == 1048576)); + /// ``` + pub fn fill_slice_with_recomposition_summand(&self, output: &mut [T]) { + assert_eq!(self.slice.len(), output.len()); + output + .iter_mut() + .zip(self.slice.iter()) + .for_each(|(dst, &value)| { + let shift: usize = ::BITS - self.base_log * self.level; + *dst = value << shift + }); + } + + pub(crate) fn update_slice_with_recomposition_summand_wrapping_addition( + &self, + output: &mut [T], + ) { + assert_eq!(self.slice.len(), output.len()); + let shift: usize = ::BITS - self.base_log * self.level; + output + .iter_mut() + .zip(self.slice.iter()) + .for_each(|(out, &value)| { + *out = (*out).wrapping_add(value << shift); + }); + } + + /// Returns a tensor with the values of term. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::math::decomposition::SignedDecomposer; + /// use tfhe::core_crypto::prelude::{DecompositionBaseLog, DecompositionLevelCount}; + /// let decomposer = + /// SignedDecomposer::::new(DecompositionBaseLog(4), DecompositionLevelCount(3)); + /// let input = vec![2u32.pow(19); 2]; + /// let mut decomp = decomposer.decompose_slice(&input); + /// let term = decomp.next_term().unwrap(); + /// assert_eq!(term.as_slice()[0], 1); + /// ``` + pub fn as_slice(&self) -> &'a [T] { + self.slice + } + + /// Returns the level of this decomposition term tensor. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::math::decomposition::{DecompositionLevel, SignedDecomposer}; + /// use tfhe::core_crypto::prelude::{DecompositionBaseLog, DecompositionLevelCount}; + /// let decomposer = + /// SignedDecomposer::::new(DecompositionBaseLog(4), DecompositionLevelCount(3)); + /// let input = vec![2u32.pow(19); 2]; + /// let mut decomp = decomposer.decompose_slice(&input); + /// let term = decomp.next_term().unwrap(); + /// assert_eq!(term.level(), DecompositionLevel(3)); + /// ``` + pub fn level(&self) -> DecompositionLevel { + DecompositionLevel(self.level) + } +} + +/// A tensor whose elements are the terms of the decomposition of another tensor. +/// +/// If we decompose each elements of a set of values $(\theta^{(a)})\_{a\in\mathbb{N}}$ as a set of +/// sums $(\sum\_{i=1}^l\tilde{\theta}^{(a)}\_i\frac{q}{B^i})\_{a\in\mathbb{N}}$, this represents a +/// set of $(\tilde{\theta}^{(a)}\_i)\_{a\in\mathbb{N}}$. +#[derive(Debug, PartialEq, Eq, Clone)] +pub struct DecompositionTermSliceNonNative<'a, T> +where + T: UnsignedInteger, +{ + level: usize, + base_log: usize, + slice: &'a [T], + ciphertext_modulus: CiphertextModulus, +} + +impl<'a, T> DecompositionTermSliceNonNative<'a, T> +where + T: UnsignedInteger, +{ + // Creates a new tensor decomposition term. + pub(crate) fn new( + level: DecompositionLevel, + base_log: DecompositionBaseLog, + slice: &'a [T], + ciphertext_modulus: CiphertextModulus, + ) -> Self { + Self { + level: level.0, + base_log: base_log.0, + slice, + ciphertext_modulus, + } + } + + /// Fills the output tensor with the terms turned to summands. + /// + /// If our term tensor represents a set of $(\tilde{\theta}^{(a)}\_i)\_{a\in\mathbb{N}}$ of the + /// decomposition, this method fills the output tensor with a set of + /// $(\tilde{\theta}^{(a)}\_i\frac{q}{B^i})\_{a\in\mathbb{N}}$. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::math::decomposition::SignedDecomposerNonNative; + /// use tfhe::core_crypto::prelude::{ + /// CiphertextModulus, DecompositionBaseLog, DecompositionLevelCount, + /// }; + /// let decomposer = SignedDecomposerNonNative::::new( + /// DecompositionBaseLog(4), + /// DecompositionLevelCount(3), + /// CiphertextModulus::try_new((1 << 32) - 1).unwrap(), + /// ); + /// let input = vec![2u32.pow(19); 2]; + /// let mut decomp = decomposer.decompose_slice(&input); + /// let term = decomp.next_term().unwrap(); + /// let mut output = vec![0; 2]; + /// term.to_approximate_recomposition_summand(&mut output); + /// assert!(output.iter().all(|&x| x == 1048576)); + /// ``` + pub fn to_approximate_recomposition_summand(&self, output: &mut [T]) { + assert_eq!(self.slice.len(), output.len()); + let modulus_as_t = T::cast_from(self.ciphertext_modulus.get_custom_modulus()); + let ciphertext_modulus_bit_count: usize = modulus_as_t.ceil_ilog2().try_into().unwrap(); + let shift: usize = ciphertext_modulus_bit_count - self.base_log * self.level; + + output + .iter_mut() + .zip(self.slice.iter()) + .for_each(|(dst, &value)| { + if value.into_signed() >= T::Signed::ZERO { + *dst = value << shift + } else { + *dst = modulus_as_t.wrapping_add(value << shift) + } + }); + } + + /// Compute the value of the term modulo the modulus given when building the + /// [`DecompositionTermSliceNonNative`] + pub fn modular_value(&self, output: &mut [T]) { + assert_eq!(self.slice.len(), output.len()); + let modulus_as_t = T::cast_from(self.ciphertext_modulus.get_custom_modulus()); + self.slice + .iter() + .zip(output.iter_mut()) + .for_each(|(&value, output)| { + if value.into_signed() >= T::Signed::ZERO { + *output = value + } else { + *output = modulus_as_t.wrapping_add(value) + } + }); + } + + pub(crate) fn update_slice_with_recomposition_summand_wrapping_addition( + &self, + output: &mut [T], + ) { + assert_eq!(self.slice.len(), output.len()); + let modulus_as_t = T::cast_from(self.ciphertext_modulus.get_custom_modulus()); + let ciphertext_modulus_bit_count: usize = modulus_as_t.ceil_ilog2().try_into().unwrap(); + let shift: usize = ciphertext_modulus_bit_count - self.base_log * self.level; + output + .iter_mut() + .zip(self.slice.iter()) + .for_each(|(out, &value)| { + if value.into_signed() >= T::Signed::ZERO { + *out = (*out).wrapping_add_custom_mod(value << shift, modulus_as_t) + } else { + *out = (*out).wrapping_add_custom_mod( + modulus_as_t.wrapping_add(value << shift), + modulus_as_t, + ) + } + }); + } + + /// Returns a tensor with the values of term. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::math::decomposition::SignedDecomposerNonNative; + /// use tfhe::core_crypto::prelude::{ + /// CiphertextModulus, DecompositionBaseLog, DecompositionLevelCount, + /// }; + /// let decomposer = SignedDecomposerNonNative::::new( + /// DecompositionBaseLog(4), + /// DecompositionLevelCount(3), + /// CiphertextModulus::try_new((1 << 32) - 1).unwrap(), + /// ); + /// let input = vec![2u32.pow(19); 2]; + /// let mut decomp = decomposer.decompose_slice(&input); + /// let term = decomp.next_term().unwrap(); + /// assert_eq!(term.as_slice()[0], 1); + /// ``` + pub fn as_slice(&self) -> &'a [T] { + self.slice + } + + /// Returns the level of this decomposition term tensor. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::math::decomposition::{ + /// DecompositionLevel, SignedDecomposerNonNative, + /// }; + /// use tfhe::core_crypto::prelude::{ + /// CiphertextModulus, DecompositionBaseLog, DecompositionLevelCount, + /// }; + /// let decomposer = SignedDecomposerNonNative::::new( + /// DecompositionBaseLog(4), + /// DecompositionLevelCount(3), + /// CiphertextModulus::try_new((1 << 32) - 1).unwrap(), + /// ); + /// let input = vec![2u32.pow(19); 2]; + /// let mut decomp = decomposer.decompose_slice(&input); + /// let term = decomp.next_term().unwrap(); + /// assert_eq!(term.level(), DecompositionLevel(3)); + /// ``` + pub fn level(&self) -> DecompositionLevel { + DecompositionLevel(self.level) + } +} diff --git a/tfhe/src/core_crypto/entities/glwe_keyswitch_key.rs b/tfhe/src/core_crypto/entities/glwe_keyswitch_key.rs new file mode 100644 index 0000000000..cc80d9e9a6 --- /dev/null +++ b/tfhe/src/core_crypto/entities/glwe_keyswitch_key.rs @@ -0,0 +1,505 @@ +//! Module containing the definition of the [`GlweKeyswitchKey`]. + +use crate::conformance::ParameterSetConformant; +use crate::core_crypto::commons::parameters::*; +use crate::core_crypto::commons::traits::*; +use crate::core_crypto::entities::*; + +/// A [`GLWE keyswitch key`](`GlweKeyswitchKey`). +/// +/// # Formal Definition +/// +/// ## Key Switching Key +/// +/// A key switching key is a vector of GLev ciphertexts (described on the bottom of +/// [`this page`](`crate::core_crypto::entities::GgswCiphertext#Glev-ciphertext`)). +/// It encrypts the coefficient of +/// the [`GLWE secret key`](`crate::core_crypto::entities::GlweSecretKey`) +/// $\vec{S}\_{\mathsf{in}}$ under the +/// [`GLWE secret key`](`crate::core_crypto::entities::GlweSecretKey`) +/// $\vec{S}\_{\mathsf{out}}$. +/// +/// $$\mathsf{KSK}\_{\vec{S}\_{\mathsf{in}}\rightarrow \vec{S}\_{\mathsf{out}}} = \left( +/// \overline{\mathsf{CT}\_0}, \cdots , \overline{\mathsf{CT}\_{k\_{\mathsf{in}}-1}}\right) +/// \subseteq R\_q^{(k\_{\mathsf{out}}+1)\cdot k\_{\mathsf{in}}\cdot \ell}$$ +/// +/// where $\vec{S}\_{\mathsf{in}} = \left( S\_0 , \cdots , S\_{\mathsf{in}-1} \right)$ and for all +/// $0\le i +where + C::Element: UnsignedInteger, +{ + data: C, + decomp_base_log: DecompositionBaseLog, + decomp_level_count: DecompositionLevelCount, + output_glwe_size: GlweSize, + poly_size: PolynomialSize, + ciphertext_modulus: CiphertextModulus, +} + +impl> AsRef<[T]> for GlweKeyswitchKey { + fn as_ref(&self) -> &[T] { + self.data.as_ref() + } +} + +impl> AsMut<[T]> for GlweKeyswitchKey { + fn as_mut(&mut self) -> &mut [T] { + self.data.as_mut() + } +} + +/// Return the number of elements in an encryption of an input [`GlweSecretKey`] element for a +/// [`GlweKeyswitchKey`] given a [`DecompositionLevelCount`] and output [`GlweSize`] and +/// [`PolynomialSize`]. +pub fn glwe_keyswitch_key_input_key_element_encrypted_size( + decomp_level_count: DecompositionLevelCount, + output_glwe_size: GlweSize, + poly_size: PolynomialSize, +) -> usize { + // One ciphertext per level encrypted under the output key + decomp_level_count.0 * output_glwe_size.0 * poly_size.0 +} + +impl> GlweKeyswitchKey { + /// Create an [`GlweKeyswitchKey`] from an existing container. + /// + /// # Note + /// + /// This function only wraps a container in the appropriate type. If you want to generate an + /// [`GlweKeyswitchKey`] you need to call + /// [`crate::core_crypto::algorithms::generate_glwe_keyswitch_key`] using this key as output. + /// + /// This docstring exhibits [`GlweKeyswitchKey`] primitives usage. + /// + /// ``` + /// use tfhe::core_crypto::prelude::*; + /// + /// // DISCLAIMER: these toy example parameters are not guaranteed to be secure or yield correct + /// // computations + /// // Define parameters for LweKeyswitchKey creation + /// let input_glwe_dimension = GlweDimension(1); + /// let output_glwe_dimension = GlweDimension(2); + /// let poly_size = PolynomialSize(1024); + /// let decomp_base_log = DecompositionBaseLog(4); + /// let decomp_level_count = DecompositionLevelCount(5); + /// let ciphertext_modulus = CiphertextModulus::new_native(); + /// + /// // Create a new LweKeyswitchKey + /// let glwe_ksk = GlweKeyswitchKey::new( + /// 0u64, + /// decomp_base_log, + /// decomp_level_count, + /// input_glwe_dimension, + /// output_glwe_dimension, + /// poly_size, + /// ciphertext_modulus, + /// ); + /// + /// assert_eq!(glwe_ksk.decomposition_base_log(), decomp_base_log); + /// assert_eq!(glwe_ksk.decomposition_level_count(), decomp_level_count); + /// assert_eq!(glwe_ksk.input_key_glwe_dimension(), input_glwe_dimension); + /// assert_eq!(glwe_ksk.output_key_glwe_dimension(), output_glwe_dimension); + /// assert_eq!(glwe_ksk.polynomial_size(), poly_size); + /// assert_eq!( + /// glwe_ksk.output_glwe_size(), + /// output_glwe_dimension.to_glwe_size() + /// ); + /// assert_eq!(glwe_ksk.ciphertext_modulus(), ciphertext_modulus); + /// + /// // Demonstrate how to recover the allocated container + /// let underlying_container: Vec = glwe_ksk.into_container(); + /// + /// // Recreate a keyswitch key using from_container + /// let glwe_ksk = GlweKeyswitchKey::from_container( + /// underlying_container, + /// decomp_base_log, + /// decomp_level_count, + /// output_glwe_dimension.to_glwe_size(), + /// poly_size, + /// ciphertext_modulus, + /// ); + /// + /// assert_eq!(glwe_ksk.decomposition_base_log(), decomp_base_log); + /// assert_eq!(glwe_ksk.decomposition_level_count(), decomp_level_count); + /// assert_eq!(glwe_ksk.input_key_glwe_dimension(), input_glwe_dimension); + /// assert_eq!(glwe_ksk.output_key_glwe_dimension(), output_glwe_dimension); + /// assert_eq!( + /// glwe_ksk.output_glwe_size(), + /// output_glwe_dimension.to_glwe_size() + /// ); + /// assert_eq!(glwe_ksk.ciphertext_modulus(), ciphertext_modulus); + /// ``` + pub fn from_container( + container: C, + decomp_base_log: DecompositionBaseLog, + decomp_level_count: DecompositionLevelCount, + output_glwe_size: GlweSize, + poly_size: PolynomialSize, + ciphertext_modulus: CiphertextModulus, + ) -> Self { + assert!( + container.container_len() > 0, + "Got an empty container to create a GlweKeyswitchKey" + ); + assert!( + container.container_len() % (decomp_level_count.0 * output_glwe_size.0 * poly_size.0) + == 0, + "The provided container length is not valid. \ + It needs to be dividable by decomp_level_count * output_glwe_size * output_poly_size: {}. \ + Got container length: {} and decomp_level_count: {decomp_level_count:?}, \ + output_glwe_size: {output_glwe_size:?}, poly_size: {poly_size:?}.", + decomp_level_count.0 * output_glwe_size.0 * poly_size.0, + container.container_len() + ); + + Self { + data: container, + decomp_base_log, + decomp_level_count, + output_glwe_size, + poly_size, + ciphertext_modulus, + } + } + + /// Return the [`DecompositionBaseLog`] of the [`LweKeyswitchKey`]. + /// + /// See [`LweKeyswitchKey::from_container`] for usage. + pub fn decomposition_base_log(&self) -> DecompositionBaseLog { + self.decomp_base_log + } + + /// Return the [`DecompositionLevelCount`] of the [`LweKeyswitchKey`]. + /// + /// See [`LweKeyswitchKey::from_container`] for usage. + pub fn decomposition_level_count(&self) -> DecompositionLevelCount { + self.decomp_level_count + } + + /// Return the input [`GlweDimension`] of the [`GlweKeyswitchKey`]. + /// + /// See [`GlweKeyswitchKey::from_container`] for usage. + pub fn input_key_glwe_dimension(&self) -> GlweDimension { + GlweDimension(self.data.container_len() / self.input_key_element_encrypted_size()) + } + + /// Return the input [`PolynomialSize`] of the [`GlweKeyswitchKey`]. + /// + /// See [`GlweKeyswitchKey::from_container`] for usage. + pub fn polynomial_size(&self) -> PolynomialSize { + self.poly_size + } + + /// Return the output [`GlweDimension`] of the [`GlweKeyswitchKey`]. + /// + /// See [`GlweKeyswitchKey::from_container`] for usage. + pub fn output_key_glwe_dimension(&self) -> GlweDimension { + self.output_glwe_size.to_glwe_dimension() + } + + /// Return the output [`GlweSize`] of the [`GlweKeyswitchKey`]. + /// + /// See [`GlweKeyswitchKey::from_container`] for usage. + pub fn output_glwe_size(&self) -> GlweSize { + self.output_glwe_size + } + + /// Return the number of elements in an encryption of an input [`GlweSecretKey`] element of the + /// current [`GlweKeyswitchKey`]. + pub fn input_key_element_encrypted_size(&self) -> usize { + glwe_keyswitch_key_input_key_element_encrypted_size( + self.decomp_level_count, + self.output_glwe_size, + self.poly_size, + ) + } + + /// Return a view of the [`GlweKeyswitchKey`]. This is useful if an algorithm takes a view by + /// value. + pub fn as_view(&self) -> GlweKeyswitchKey<&'_ [Scalar]> { + GlweKeyswitchKey::from_container( + self.as_ref(), + self.decomp_base_log, + self.decomp_level_count, + self.output_glwe_size, + self.poly_size, + self.ciphertext_modulus, + ) + } + + /// Consume the entity and return its underlying container. + /// + /// See [`GlweKeyswitchKey::from_container`] for usage. + pub fn into_container(self) -> C { + self.data + } + + pub fn as_glwe_ciphertext_list(&self) -> GlweCiphertextListView<'_, Scalar> { + GlweCiphertextListView::from_container( + self.as_ref(), + self.output_glwe_size(), + self.polynomial_size(), + self.ciphertext_modulus(), + ) + } + + /// Return the [`CiphertextModulus`] of the [`GlweKeyswitchKey`]. + /// + /// See [`GlweKeyswitchKey::from_container`] for usage. + pub fn ciphertext_modulus(&self) -> CiphertextModulus { + self.ciphertext_modulus + } +} + +impl> GlweKeyswitchKey { + /// Mutable variant of [`GlweKeyswitchKey::as_view`]. + pub fn as_mut_view(&mut self) -> GlweKeyswitchKey<&'_ mut [Scalar]> { + let decomp_base_log = self.decomp_base_log; + let decomp_level_count = self.decomp_level_count; + let output_glwe_size = self.output_glwe_size; + let poly_size = self.poly_size; + let ciphertext_modulus = self.ciphertext_modulus; + GlweKeyswitchKey::from_container( + self.as_mut(), + decomp_base_log, + decomp_level_count, + output_glwe_size, + poly_size, + ciphertext_modulus, + ) + } + + pub fn as_mut_glwe_ciphertext_list(&mut self) -> GlweCiphertextListMutView<'_, Scalar> { + let output_glwe_size = self.output_glwe_size(); + let poly_size = self.polynomial_size(); + let ciphertext_modulus = self.ciphertext_modulus(); + GlweCiphertextListMutView::from_container( + self.as_mut(), + output_glwe_size, + poly_size, + ciphertext_modulus, + ) + } +} + +/// A [`GlweKeyswitchKey`] owning the memory for its own storage. +pub type GlweKeyswitchKeyOwned = GlweKeyswitchKey>; +/// A [`GlweKeyswitchKey`] immutably borrowing memory for its own storage. +pub type GlweKeyswitchKeyView<'data, Scalar> = GlweKeyswitchKey<&'data [Scalar]>; +/// A [`GlweKeyswitchKey`] mutably borrowing memory for its own storage. +pub type GlweKeyswitchKeyMutView<'data, Scalar> = GlweKeyswitchKey<&'data mut [Scalar]>; + +impl GlweKeyswitchKeyOwned { + /// Allocate memory and create a new owned [`GlweKeyswitchKey`]. + /// + /// # Note + /// + /// This function allocates a vector of the appropriate size and wraps it in the appropriate + /// type. If you want to generate an [`GlweKeyswitchKey`] you need to call + /// [`crate::core_crypto::algorithms::generate_glwe_keyswitch_key`] using this key as output. + /// + /// See [`GlweKeyswitchKey::from_container`] for usage. + pub fn new( + fill_with: Scalar, + decomp_base_log: DecompositionBaseLog, + decomp_level_count: DecompositionLevelCount, + input_key_glwe_dimension: GlweDimension, + output_key_glwe_dimension: GlweDimension, + poly_size: PolynomialSize, + ciphertext_modulus: CiphertextModulus, + ) -> Self { + Self::from_container( + vec![ + fill_with; + input_key_glwe_dimension.0 + * glwe_keyswitch_key_input_key_element_encrypted_size( + decomp_level_count, + output_key_glwe_dimension.to_glwe_size(), + poly_size, + ) + ], + decomp_base_log, + decomp_level_count, + output_key_glwe_dimension.to_glwe_size(), + poly_size, + ciphertext_modulus, + ) + } +} + +#[derive(Clone, Copy)] +pub struct GlweKeyswitchKeyCreationMetadata { + pub decomp_base_log: DecompositionBaseLog, + pub decomp_level_count: DecompositionLevelCount, + pub output_glwe_size: GlweSize, + pub polynomial_size: PolynomialSize, + pub ciphertext_modulus: CiphertextModulus, +} + +impl> CreateFrom + for GlweKeyswitchKey +{ + type Metadata = GlweKeyswitchKeyCreationMetadata; + + #[inline] + fn create_from(from: C, meta: Self::Metadata) -> Self { + let GlweKeyswitchKeyCreationMetadata { + decomp_base_log, + decomp_level_count, + output_glwe_size, + polynomial_size, + ciphertext_modulus, + } = meta; + Self::from_container( + from, + decomp_base_log, + decomp_level_count, + output_glwe_size, + polynomial_size, + ciphertext_modulus, + ) + } +} + +impl> ContiguousEntityContainer + for GlweKeyswitchKey +{ + type Element = C::Element; + + type EntityViewMetadata = GlweCiphertextListCreationMetadata; + + type EntityView<'this> + = GlweCiphertextListView<'this, Self::Element> + where + Self: 'this; + + type SelfViewMetadata = GlweKeyswitchKeyCreationMetadata; + + type SelfView<'this> + = GlweKeyswitchKeyView<'this, Self::Element> + where + Self: 'this; + + fn get_entity_view_creation_metadata(&self) -> Self::EntityViewMetadata { + GlweCiphertextListCreationMetadata { + glwe_size: self.output_glwe_size(), + polynomial_size: self.polynomial_size(), + ciphertext_modulus: self.ciphertext_modulus(), + } + } + + fn get_entity_view_pod_size(&self) -> usize { + self.input_key_element_encrypted_size() + } + + fn get_self_view_creation_metadata(&self) -> Self::SelfViewMetadata { + GlweKeyswitchKeyCreationMetadata { + decomp_base_log: self.decomposition_base_log(), + decomp_level_count: self.decomposition_level_count(), + output_glwe_size: self.output_glwe_size(), + polynomial_size: self.polynomial_size(), + ciphertext_modulus: self.ciphertext_modulus(), + } + } +} + +impl> ContiguousEntityContainerMut + for GlweKeyswitchKey +{ + type EntityMutView<'this> + = GlweCiphertextListMutView<'this, Self::Element> + where + Self: 'this; + + type SelfMutView<'this> + = GlweKeyswitchKeyMutView<'this, Self::Element> + where + Self: 'this; +} + +pub struct GlweKeyswitchKeyConformanceParams { + pub decomp_base_log: DecompositionBaseLog, + pub decomp_level_count: DecompositionLevelCount, + pub output_glwe_size: GlweSize, + pub input_glwe_dimension: GlweDimension, + pub polynomial_size: PolynomialSize, + pub ciphertext_modulus: CiphertextModulus, +} + +impl> ParameterSetConformant for GlweKeyswitchKey { + type ParameterSet = GlweKeyswitchKeyConformanceParams; + + fn is_conformant(&self, parameter_set: &Self::ParameterSet) -> bool { + let Self { + data, + decomp_base_log, + decomp_level_count, + output_glwe_size, + poly_size, + ciphertext_modulus, + } = self; + + *ciphertext_modulus == parameter_set.ciphertext_modulus + && data.container_len() + == parameter_set.input_glwe_dimension.0 + * glwe_keyswitch_key_input_key_element_encrypted_size( + parameter_set.decomp_level_count, + parameter_set.output_glwe_size, + parameter_set.polynomial_size, + ) + && *decomp_base_log == parameter_set.decomp_base_log + && *decomp_level_count == parameter_set.decomp_level_count + && *output_glwe_size == parameter_set.output_glwe_size + && *poly_size == parameter_set.polynomial_size + } +} diff --git a/tfhe/src/core_crypto/entities/mod.rs b/tfhe/src/core_crypto/entities/mod.rs index f951c7b1c9..a0a5ec329a 100644 --- a/tfhe/src/core_crypto/entities/mod.rs +++ b/tfhe/src/core_crypto/entities/mod.rs @@ -11,6 +11,7 @@ pub mod ggsw_ciphertext; pub mod ggsw_ciphertext_list; pub mod glwe_ciphertext; pub mod glwe_ciphertext_list; +pub mod glwe_keyswitch_key; pub mod glwe_secret_key; pub mod gsw_ciphertext; pub mod lwe_bootstrap_key; @@ -68,6 +69,7 @@ pub use ggsw_ciphertext::*; pub use ggsw_ciphertext_list::*; pub use glwe_ciphertext::*; pub use glwe_ciphertext_list::*; +pub use glwe_keyswitch_key::*; pub use glwe_secret_key::*; pub use gsw_ciphertext::*; pub use lwe_bootstrap_key::*;