Skip to content

Commit

Permalink
chore(data)!: breaking data changes for future compatibility
Browse files Browse the repository at this point in the history
- invert the LweKeyswitchKey level order and propagate change
- remove dependency on unsupported wopbs keys for the HL keys
  • Loading branch information
IceTDrinker committed Oct 16, 2024
1 parent c429db5 commit c017006
Show file tree
Hide file tree
Showing 25 changed files with 360 additions and 145 deletions.
6 changes: 4 additions & 2 deletions backends/tfhe-cuda-backend/cuda/src/crypto/keyswitch.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,8 @@ keyswitch(Torus *lwe_array_out, const Torus *__restrict__ lwe_output_indexes,
level_count);
Torus state = a_i >> (sizeof(Torus) * 8 - base_log * level_count);

for (int j = 0; j < level_count; j++) {
for (int j = level_count - 1; j >= 0; j--) {
// Levels are stored in reverse order
auto ksk_block =
get_ith_block(ksk, i, j, lwe_dimension_out, level_count);
Torus decomposed = decompose_one<Torus>(state, mask_mod_b, base_log);
Expand Down Expand Up @@ -208,7 +209,8 @@ __device__ void packing_keyswitch_lwe_ciphertext_into_glwe_ciphertext(

// block of key for current lwe coefficient (cur_input_lwe[i])
auto ksk_block = &fp_ksk[i * ksk_block_size];
for (int j = 0; j < level_count; j++) {
for (int j = level_count - 1; j >= 0; j--) {
// Levels are stored in reverse order
auto ksk_glwe = &ksk_block[j * glwe_size * polynomial_size];
// Iterate through each level and multiply by the ksk piece
auto ksk_glwe_chunk = &ksk_glwe[poly_id * coef_per_block];
Expand Down
13 changes: 8 additions & 5 deletions tfhe/src/core_crypto/algorithms/lwe_keyswitch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,8 @@ pub fn keyswitch_lwe_ciphertext_native_mod_compatible<Scalar, KSKCont, InputCont
{
let decomposition_iter = decomposer.decompose(input_mask_element);
// Loop over the levels
for (level_key_ciphertext, decomposed) in keyswitch_key_block.iter().zip(decomposition_iter)
for (level_key_ciphertext, decomposed) in
keyswitch_key_block.iter().rev().zip(decomposition_iter)
{
slice_wrapping_sub_scalar_mul_assign(
output_lwe_ciphertext.as_mut(),
Expand Down Expand Up @@ -304,7 +305,8 @@ pub fn keyswitch_lwe_ciphertext_other_mod<Scalar, KSKCont, InputCont, OutputCont
{
let decomposition_iter = decomposer.decompose(input_mask_element);
// Loop over the levels
for (level_key_ciphertext, decomposed) in keyswitch_key_block.iter().zip(decomposition_iter)
for (level_key_ciphertext, decomposed) in
keyswitch_key_block.iter().rev().zip(decomposition_iter)
{
slice_wrapping_sub_scalar_mul_assign_custom_modulus(
output_lwe_ciphertext.as_mut(),
Expand Down Expand Up @@ -436,7 +438,8 @@ pub fn keyswitch_lwe_ciphertext_with_scalar_change<
{
let decomposition_iter = input_decomposer.decompose(input_mask_element);
// Loop over the levels
for (level_key_ciphertext, decomposed) in keyswitch_key_block.iter().zip(decomposition_iter)
for (level_key_ciphertext, decomposed) in
keyswitch_key_block.iter().rev().zip(decomposition_iter)
{
slice_wrapping_sub_scalar_mul_assign(
output_lwe_ciphertext.as_mut(),
Expand Down Expand Up @@ -799,7 +802,7 @@ pub fn par_keyswitch_lwe_ciphertext_with_thread_count_native_mod_compatible<
let decomposition_iter = decomposer.decompose(input_mask_element);
// Loop over the levels
for (level_key_ciphertext, decomposed) in
keyswitch_key_block.iter().zip(decomposition_iter)
keyswitch_key_block.iter().rev().zip(decomposition_iter)
{
slice_wrapping_sub_scalar_mul_assign(
buffer.as_mut(),
Expand Down Expand Up @@ -946,7 +949,7 @@ pub fn par_keyswitch_lwe_ciphertext_with_thread_count_other_mod<
let decomposition_iter = decomposer.decompose(input_mask_element);
// Loop over the levels
for (level_key_ciphertext, decomposed) in
keyswitch_key_block.iter().zip(decomposition_iter)
keyswitch_key_block.iter().rev().zip(decomposition_iter)
{
slice_wrapping_sub_scalar_mul_assign_custom_modulus(
buffer.as_mut(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,6 @@ pub fn generate_lwe_keyswitch_key_native_mod_compatible<
{
// We fill the buffer with the powers of the key elements
for (level, message) in (1..=decomp_level_count.0)
.rev()
.map(DecompositionLevel)
.zip(decomposition_plaintexts_buffer.iter_mut())
{
Expand Down Expand Up @@ -234,7 +233,6 @@ pub fn generate_lwe_keyswitch_key_other_mod<
{
// We fill the buffer with the powers of the key elements
for (level, message) in (1..=decomp_level_count.0)
.rev()
.map(DecompositionLevel)
.zip(decomposition_plaintexts_buffer.iter_mut())
{
Expand Down Expand Up @@ -416,7 +414,6 @@ pub fn generate_seeded_lwe_keyswitch_key<
{
// We fill the buffer with the powers of the key elmements
for (level, message) in (1..=decomp_level_count.0)
.rev()
.map(DecompositionLevel)
.zip(decomposition_plaintexts_buffer.iter_mut())
{
Expand Down
2 changes: 1 addition & 1 deletion tfhe/src/core_crypto/algorithms/lwe_packing_keyswitch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ pub fn keyswitch_lwe_ciphertext_into_glwe_ciphertext<Scalar, KeyCont, InputCont,
// Loop over the number of levels:
// We compute the multiplication of a ciphertext from the private functional
// keyswitching key with a piece of the decomposition and subtract it to the buffer
for (level_key_cipher, decomposed) in keyswitch_key_block.iter().zip(decomp) {
for (level_key_cipher, decomposed) in keyswitch_key_block.iter().rev().zip(decomp) {
slice_wrapping_sub_scalar_mul_assign(
output_glwe_ciphertext.as_mut(),
level_key_cipher.as_ref(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,6 @@ pub fn generate_lwe_packing_keyswitch_key<
{
// We fill the buffer with the powers of the key elements
for (level, mut messages) in (1..=decomp_level_count.0)
.rev()
.map(DecompositionLevel)
.zip(decomposition_plaintexts_buffer.chunks_exact_mut(polynomial_size.0))
{
Expand Down Expand Up @@ -330,7 +329,6 @@ pub fn generate_seeded_lwe_packing_keyswitch_key<
{
// We fill the buffer with the powers of the key elements
for (level, mut messages) in (1..=decomp_level_count.0)
.rev()
.map(DecompositionLevel)
.zip(decomposition_plaintexts_buffer.chunks_exact_mut(polynomial_size.0))
{
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,57 @@
use tfhe_versionable::VersionsDispatch;
use tfhe_versionable::{Upgrade, Version, VersionsDispatch};

use crate::core_crypto::prelude::{Container, LweKeyswitchKey, UnsignedInteger};
use crate::core_crypto::prelude::{
CiphertextModulus, Container, ContainerMut, ContiguousEntityContainerMut, DecompositionBaseLog,
DecompositionLevelCount, LweKeyswitchKey, LweSize, UnsignedInteger,
};

#[derive(Version)]
pub struct LweKeyswitchKeyV0<C: Container>
where
C::Element: UnsignedInteger,
{
data: C,
decomp_base_log: DecompositionBaseLog,
decomp_level_count: DecompositionLevelCount,
output_lwe_size: LweSize,
ciphertext_modulus: CiphertextModulus<C::Element>,
}

impl<Scalar: UnsignedInteger, C: ContainerMut<Element = Scalar>> Upgrade<LweKeyswitchKey<C>>
for LweKeyswitchKeyV0<C>
{
type Error = std::convert::Infallible;

fn upgrade(self) -> Result<LweKeyswitchKey<C>, Self::Error> {
let Self {
data,
decomp_base_log,
decomp_level_count,
output_lwe_size,
ciphertext_modulus,
} = self;
let mut new_ksk = LweKeyswitchKey::from_container(
data,
decomp_base_log,
decomp_level_count,
output_lwe_size,
ciphertext_modulus,
);

// Invert levels
for mut ksk_block in new_ksk.iter_mut() {
ksk_block.reverse();
}

Ok(new_ksk)
}
}

#[derive(VersionsDispatch)]
pub enum LweKeyswitchKeyVersions<C: Container>
where
C::Element: UnsignedInteger,
{
V0(LweKeyswitchKey<C>),
V0(LweKeyswitchKeyV0<C>),
V1(LweKeyswitchKey<C>),
}
Original file line number Diff line number Diff line change
@@ -1,11 +1,60 @@
use tfhe_versionable::VersionsDispatch;
use tfhe_versionable::{Upgrade, Version, VersionsDispatch};

use crate::core_crypto::prelude::{Container, LwePackingKeyswitchKey, UnsignedInteger};
use crate::core_crypto::prelude::{
CiphertextModulus, Container, ContainerMut, ContiguousEntityContainerMut, DecompositionBaseLog,
DecompositionLevelCount, GlweSize, LwePackingKeyswitchKey, PolynomialSize, UnsignedInteger,
};

#[derive(Version)]
pub struct LwePackingKeyswitchKeyV0<C: Container>
where
C::Element: UnsignedInteger,
{
data: C,
decomp_base_log: DecompositionBaseLog,
decomp_level_count: DecompositionLevelCount,
output_glwe_size: GlweSize,
output_polynomial_size: PolynomialSize,
ciphertext_modulus: CiphertextModulus<C::Element>,
}

impl<Scalar: UnsignedInteger, C: ContainerMut<Element = Scalar>> Upgrade<LwePackingKeyswitchKey<C>>
for LwePackingKeyswitchKeyV0<C>
{
type Error = std::convert::Infallible;

fn upgrade(self) -> Result<LwePackingKeyswitchKey<C>, Self::Error> {
let Self {
data,
decomp_base_log,
decomp_level_count,
output_glwe_size,
output_polynomial_size,
ciphertext_modulus,
} = self;
let mut new_pksk = LwePackingKeyswitchKey::from_container(
data,
decomp_base_log,
decomp_level_count,
output_glwe_size,
output_polynomial_size,
ciphertext_modulus,
);

// Invert levels
for mut pksk_block in new_pksk.iter_mut() {
pksk_block.reverse();
}

Ok(new_pksk)
}
}

#[derive(VersionsDispatch)]
pub enum LwePackingKeyswitchKeyVersions<C: Container>
where
C::Element: UnsignedInteger,
{
V0(LwePackingKeyswitchKey<C>),
V0(LwePackingKeyswitchKeyV0<C>),
V1(LwePackingKeyswitchKey<C>),
}
Original file line number Diff line number Diff line change
@@ -1,11 +1,29 @@
use tfhe_versionable::VersionsDispatch;
use tfhe_versionable::{Upgrade, Version, VersionsDispatch};

use crate::core_crypto::prelude::{Container, SeededLweKeyswitchKey, UnsignedInteger};

#[derive(Version)]
pub struct UnsupportedSeededLweKeyswitchKeyV0;

impl<Scalar: UnsignedInteger, C: Container<Element = Scalar>> Upgrade<SeededLweKeyswitchKey<C>>
for UnsupportedSeededLweKeyswitchKeyV0
{
type Error = crate::Error;

fn upgrade(self) -> Result<SeededLweKeyswitchKey<C>, Self::Error> {
Err(crate::Error::new(
"Unable to load SeededLweKeyswitchKey, \
this format is UnsupportedSeededLweKeyswitchKeyV0 by this TFHE-rs version."
.to_string(),
))
}
}

#[derive(VersionsDispatch)]
pub enum SeededLweKeyswitchKeyVersions<C: Container>
where
C::Element: UnsignedInteger,
{
V0(SeededLweKeyswitchKey<C>),
V0(UnsupportedSeededLweKeyswitchKeyV0),
V1(SeededLweKeyswitchKey<C>),
}
Original file line number Diff line number Diff line change
@@ -1,11 +1,29 @@
use tfhe_versionable::VersionsDispatch;
use tfhe_versionable::{Upgrade, Version, VersionsDispatch};

use crate::core_crypto::prelude::{Container, SeededLwePackingKeyswitchKey, UnsignedInteger};

#[derive(Version)]
pub struct UnsupportedSeededLwePackingKeyswitchKeyV0;

impl<Scalar: UnsignedInteger, C: Container<Element = Scalar>>
Upgrade<SeededLwePackingKeyswitchKey<C>> for UnsupportedSeededLwePackingKeyswitchKeyV0
{
type Error = crate::Error;

fn upgrade(self) -> Result<SeededLwePackingKeyswitchKey<C>, Self::Error> {
Err(crate::Error::new(
"Unable to load SeededLwePackingKeyswitchKey, \
this format is unsupported by this TFHE-rs version."
.to_string(),
))
}
}

#[derive(VersionsDispatch)]
pub enum SeededLwePackingKeyswitchKeyVersions<C: Container>
where
C::Element: UnsignedInteger,
{
V0(SeededLwePackingKeyswitchKey<C>),
V0(UnsupportedSeededLwePackingKeyswitchKeyV0),
V1(SeededLwePackingKeyswitchKey<C>),
}
11 changes: 11 additions & 0 deletions tfhe/src/core_crypto/commons/traits/contiguous_entity_container.rs
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,17 @@ pub trait ContiguousEntityContainerMut: ContiguousEntityContainer + AsMut<[Self:
.map(|(elt, meta)| Self::SelfMutView::<'_>::create_from(elt, meta))
}

fn reverse(&mut self) {
let entity_view_pod_size = self.get_entity_view_pod_size();
let container = self.as_mut();

container.reverse();

for entity_slot in self.as_mut().chunks_exact_mut(entity_view_pod_size) {
entity_slot.reverse();
}
}

fn par_iter_mut<'this>(
&'this mut self,
) -> ParallelChunksExactWrappingLendingIteratorMut<
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,8 @@ pub fn shrinking_keyswitch_lwe_ciphertext<Scalar, KSKCont, InputCont, OutputCont
{
let decomposition_iter = decomposer.decompose(input_mask_element);
// Loop over the levels
for (level_key_ciphertext, decomposed) in keyswitch_key_block.iter().zip(decomposition_iter)
for (level_key_ciphertext, decomposed) in
keyswitch_key_block.iter().rev().zip(decomposition_iter)
{
slice_wrapping_sub_scalar_mul_assign(
output_lwe_ciphertext.as_mut(),
Expand Down
Loading

0 comments on commit c017006

Please sign in to comment.