Skip to content

Commit

Permalink
refactor(tfhe)!: update key level order for better performance
Browse files Browse the repository at this point in the history
- use natural order for decomposition levels in bsk
- contains a copy of tfhe-versionable 0.3.2
- updates zk-pok to 0.3.2
- updates the usage to avoid warnings in tfhe-zk-pok

co-authored-by: Agnes Leroy <agnes.leroy@zama.ai>
  • Loading branch information
IceTDrinker and agnesLeroy committed Nov 6, 2024
1 parent f1a354d commit 35fdcdf
Show file tree
Hide file tree
Showing 77 changed files with 2,334 additions and 1,396 deletions.
6 changes: 2 additions & 4 deletions backends/tfhe-cuda-backend/cuda/src/crypto/keyswitch.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,7 @@ keyswitch(Torus *lwe_array_out, const Torus *__restrict__ lwe_output_indexes,
level_count);
Torus state = a_i >> (sizeof(Torus) * 8 - base_log * level_count);

for (int j = level_count - 1; j >= 0; j--) {
// Levels are stored in reverse order
for (int j = 0; j < level_count; j++) {
auto ksk_block =
get_ith_block(ksk, i, j, lwe_dimension_out, level_count);
Torus decomposed = decompose_one<Torus>(state, mask_mod_b, base_log);
Expand Down Expand Up @@ -208,8 +207,7 @@ __device__ void packing_keyswitch_lwe_ciphertext_into_glwe_ciphertext(

// block of key for current lwe coefficient (cur_input_lwe[i])
auto ksk_block = &fp_ksk[i * ksk_block_size];
for (int j = level_count - 1; j >= 0; j--) {
// Levels are stored in reverse order
for (int j = 0; j < level_count; j++) {
auto ksk_glwe = &ksk_block[j * glwe_size * polynomial_size];
// Iterate through each level and multiply by the ksk piece
auto ksk_glwe_chunk = &ksk_glwe[poly_id * coef_per_block];
Expand Down
12 changes: 6 additions & 6 deletions backends/tfhe-cuda-backend/cuda/src/pbs/bootstrapping_key.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ __device__ const T *get_ith_mask_kth_block(const T *ptr, int i, int k,
uint32_t level_count) {
return &ptr[get_start_ith_ggsw(i, polynomial_size, glwe_dimension,
level_count) +
level * polynomial_size / 2 * (glwe_dimension + 1) *
(glwe_dimension + 1) +
(level_count - level - 1) * polynomial_size / 2 *
(glwe_dimension + 1) * (glwe_dimension + 1) +
k * polynomial_size / 2 * (glwe_dimension + 1)];
}

Expand All @@ -35,8 +35,8 @@ __device__ T *get_ith_mask_kth_block(T *ptr, int i, int k, int level,
int glwe_dimension, uint32_t level_count) {
return &ptr[get_start_ith_ggsw(i, polynomial_size, glwe_dimension,
level_count) +
level * polynomial_size / 2 * (glwe_dimension + 1) *
(glwe_dimension + 1) +
(level_count - level - 1) * polynomial_size / 2 *
(glwe_dimension + 1) * (glwe_dimension + 1) +
k * polynomial_size / 2 * (glwe_dimension + 1)];
}
template <typename T>
Expand All @@ -45,8 +45,8 @@ __device__ T *get_ith_body_kth_block(T *ptr, int i, int k, int level,
int glwe_dimension, uint32_t level_count) {
return &ptr[get_start_ith_ggsw(i, polynomial_size, glwe_dimension,
level_count) +
level * polynomial_size / 2 * (glwe_dimension + 1) *
(glwe_dimension + 1) +
(level_count - level - 1) * polynomial_size / 2 *
(glwe_dimension + 1) * (glwe_dimension + 1) +
k * polynomial_size / 2 * (glwe_dimension + 1) +
glwe_dimension * polynomial_size / 2];
}
Expand Down
4 changes: 2 additions & 2 deletions tfhe-zk-pok/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "tfhe-zk-pok"
version = "0.3.1"
version = "0.3.2"
edition = "2021"
keywords = ["zero", "knowledge", "proof", "vector-commitments"]
homepage = "https://zama.ai/"
Expand All @@ -24,7 +24,7 @@ sha3 = "0.10.8"
serde = { version = "~1.0", features = ["derive"] }
zeroize = "1.7.0"
num-bigint = "0.4.5"
tfhe-versionable = { version = "0.3.0", path = "../utils/tfhe-versionable" }
tfhe-versionable = { version = "0.3.2", path = "../utils/tfhe-versionable" }

[dev-dependencies]
serde_json = "~1.0"
Expand Down
40 changes: 6 additions & 34 deletions tfhe-zk-pok/src/curve_api/bls12_381.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,18 +36,13 @@ fn bigint_to_le_bytes(x: [u64; 6]) -> [u8; 6 * 8] {
mod g1 {
use tfhe_versionable::Versionize;

use crate::backward_compatibility::SerializableG1AffineVersions;
use crate::serialization::{InvalidSerializedAffineError, SerializableG1Affine};

use super::*;

#[derive(Copy, Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Hash, Versionize)]
#[serde(try_from = "SerializableG1Affine", into = "SerializableG1Affine")]
#[versionize(
SerializableG1AffineVersions,
try_from = "SerializableG1Affine",
into = "SerializableG1Affine"
)]
#[versionize(try_from = "SerializableG1Affine", into = "SerializableG1Affine")]
#[repr(transparent)]
pub struct G1Affine {
pub(crate) inner: ark_bls12_381::g1::G1Affine,
Expand Down Expand Up @@ -99,11 +94,7 @@ mod g1 {

#[derive(Copy, Clone, PartialEq, Eq, Serialize, Deserialize, Hash, Versionize)]
#[serde(try_from = "SerializableG1Affine", into = "SerializableG1Affine")]
#[versionize(
SerializableG1AffineVersions,
try_from = "SerializableG1Affine",
into = "SerializableG1Affine"
)]
#[versionize(try_from = "SerializableG1Affine", into = "SerializableG1Affine")]
#[repr(transparent)]
pub struct G1 {
pub(crate) inner: ark_bls12_381::G1Projective,
Expand Down Expand Up @@ -264,18 +255,13 @@ mod g1 {
mod g2 {
use tfhe_versionable::Versionize;

use crate::backward_compatibility::SerializableG2AffineVersions;
use crate::serialization::{InvalidSerializedAffineError, SerializableG2Affine};

use super::*;

#[derive(Copy, Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Hash, Versionize)]
#[serde(try_from = "SerializableG2Affine", into = "SerializableG2Affine")]
#[versionize(
SerializableG2AffineVersions,
try_from = "SerializableG2Affine",
into = "SerializableG2Affine"
)]
#[versionize(try_from = "SerializableG2Affine", into = "SerializableG2Affine")]
#[repr(transparent)]
pub struct G2Affine {
pub(crate) inner: ark_bls12_381::g2::G2Affine,
Expand Down Expand Up @@ -328,11 +314,7 @@ mod g2 {

#[derive(Copy, Clone, PartialEq, Eq, Serialize, Deserialize, Hash, Versionize)]
#[serde(try_from = "SerializableG2Affine", into = "SerializableG2Affine")]
#[versionize(
SerializableG2AffineVersions,
try_from = "SerializableG2Affine",
into = "SerializableG2Affine"
)]
#[versionize(try_from = "SerializableG2Affine", into = "SerializableG2Affine")]
#[repr(transparent)]
pub struct G2 {
pub(crate) inner: ark_bls12_381::G2Projective,
Expand Down Expand Up @@ -539,7 +521,6 @@ mod g2 {
}

mod gt {
use crate::backward_compatibility::SerializableFp12Versions;
use crate::serialization::InvalidArraySizeError;

use super::*;
Expand All @@ -548,11 +529,7 @@ mod gt {

#[derive(Copy, Clone, PartialEq, Eq, Serialize, Deserialize, Versionize, Hash)]
#[serde(try_from = "SerializableFp12", into = "SerializableFp12")]
#[versionize(
SerializableFp12Versions,
try_from = "SerializableFp12",
into = "SerializableFp12"
)]
#[versionize(try_from = "SerializableFp12", into = "SerializableFp12")]
#[repr(transparent)]
pub struct Gt {
inner: ark_ec::pairing::PairingOutput<ark_bls12_381::Bls12_381>,
Expand Down Expand Up @@ -697,7 +674,6 @@ mod gt {
}

mod zp {
use crate::backward_compatibility::SerializableFpVersions;
use crate::serialization::InvalidArraySizeError;

use super::*;
Expand Down Expand Up @@ -741,11 +717,7 @@ mod zp {

#[derive(Copy, Clone, PartialEq, Eq, Serialize, Deserialize, Versionize, Hash, Zeroize)]
#[serde(try_from = "SerializableFp", into = "SerializableFp")]
#[versionize(
SerializableFpVersions,
try_from = "SerializableFp",
into = "SerializableFp"
)]
#[versionize(try_from = "SerializableFp", into = "SerializableFp")]
#[repr(transparent)]
pub struct Zp {
pub(crate) inner: ark_bls12_381::Fr,
Expand Down
41 changes: 6 additions & 35 deletions tfhe-zk-pok/src/curve_api/bls12_446.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,18 +36,13 @@ fn bigint_to_le_bytes(x: [u64; 7]) -> [u8; 7 * 8] {
mod g1 {
use tfhe_versionable::Versionize;

use crate::backward_compatibility::SerializableG1AffineVersions;
use crate::serialization::{InvalidSerializedAffineError, SerializableG1Affine};

use super::*;

#[derive(Copy, Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Hash, Versionize)]
#[serde(try_from = "SerializableG1Affine", into = "SerializableG1Affine")]
#[versionize(
SerializableG1AffineVersions,
try_from = "SerializableG1Affine",
into = "SerializableG1Affine"
)]
#[versionize(try_from = "SerializableG1Affine", into = "SerializableG1Affine")]
#[repr(transparent)]
pub struct G1Affine {
pub(crate) inner: crate::curve_446::g1::G1Affine,
Expand Down Expand Up @@ -101,11 +96,7 @@ mod g1 {

#[derive(Copy, Clone, PartialEq, Eq, Serialize, Deserialize, Hash, Versionize)]
#[serde(try_from = "SerializableG1Affine", into = "SerializableG1Affine")]
#[versionize(
SerializableG1AffineVersions,
try_from = "SerializableG1Affine",
into = "SerializableG1Affine"
)]
#[versionize(try_from = "SerializableG1Affine", into = "SerializableG1Affine")]
#[repr(transparent)]
pub struct G1 {
pub(crate) inner: crate::curve_446::g1::G1Projective,
Expand Down Expand Up @@ -267,19 +258,14 @@ mod g1 {
mod g2 {
use tfhe_versionable::Versionize;

use crate::backward_compatibility::SerializableG2AffineVersions;
use crate::serialization::SerializableG2Affine;

use super::*;
use crate::serialization::InvalidSerializedAffineError;

#[derive(Copy, Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Hash, Versionize)]
#[serde(try_from = "SerializableG2Affine", into = "SerializableG2Affine")]
#[versionize(
SerializableG2AffineVersions,
try_from = "SerializableG2Affine",
into = "SerializableG2Affine"
)]
#[versionize(try_from = "SerializableG2Affine", into = "SerializableG2Affine")]
#[repr(transparent)]
pub struct G2Affine {
pub(crate) inner: crate::curve_446::g2::G2Affine,
Expand Down Expand Up @@ -423,11 +409,7 @@ mod g2 {

#[derive(Copy, Clone, PartialEq, Eq, Serialize, Deserialize, Hash, Versionize)]
#[serde(try_from = "SerializableG2Affine", into = "SerializableG2Affine")]
#[versionize(
SerializableG2AffineVersions,
try_from = "SerializableG2Affine",
into = "SerializableG2Affine"
)]
#[versionize(try_from = "SerializableG2Affine", into = "SerializableG2Affine")]
#[repr(transparent)]
pub struct G2 {
pub(crate) inner: crate::curve_446::g2::G2Projective,
Expand Down Expand Up @@ -633,7 +615,6 @@ mod g2 {
}

mod gt {
use crate::backward_compatibility::SerializableFp12Versions;
use crate::curve_446::{Fq, Fq12, Fq2};
use crate::serialization::InvalidSerializedAffineError;

Expand Down Expand Up @@ -812,11 +793,7 @@ mod gt {

#[derive(Copy, Clone, PartialEq, Eq, Serialize, Deserialize, Versionize, Hash)]
#[serde(try_from = "SerializableFp12", into = "SerializableFp12")]
#[versionize(
SerializableFp12Versions,
try_from = "SerializableFp12",
into = "SerializableFp12"
)]
#[versionize(try_from = "SerializableFp12", into = "SerializableFp12")]
#[repr(transparent)]
pub struct Gt {
pub(crate) inner: ark_ec::pairing::PairingOutput<crate::curve_446::Bls12_446>,
Expand Down Expand Up @@ -959,8 +936,6 @@ mod gt {
}

mod zp {
use crate::backward_compatibility::SerializableFpVersions;

use super::*;
use crate::serialization::InvalidArraySizeError;
use ark_ff::Fp;
Expand Down Expand Up @@ -1003,11 +978,7 @@ mod zp {

#[derive(Copy, Clone, PartialEq, Eq, Serialize, Deserialize, Versionize, Hash, Zeroize)]
#[serde(try_from = "SerializableFp", into = "SerializableFp")]
#[versionize(
SerializableFpVersions,
try_from = "SerializableFp",
into = "SerializableFp"
)]
#[versionize(try_from = "SerializableFp", into = "SerializableFp")]
#[repr(transparent)]
pub struct Zp {
pub inner: crate::curve_446::Fr,
Expand Down
4 changes: 2 additions & 2 deletions tfhe/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,8 @@ sha3 = { version = "0.10", optional = true }
# While we wait for repeat_n in rust standard library
itertools = "0.11.0"
rand_core = { version = "0.6.4", features = ["std"] }
tfhe-zk-pok = { version = "0.3.1", path = "../tfhe-zk-pok", optional = true }
tfhe-versionable = { version = "0.3.0", path = "../utils/tfhe-versionable" }
tfhe-zk-pok = { version = "0.3.2", path = "../tfhe-zk-pok", optional = true }
tfhe-versionable = { version = "0.3.2", path = "../utils/tfhe-versionable" }

# wasm deps
wasm-bindgen = { version = ">=0.2.86,<0.2.94", features = [
Expand Down
2 changes: 1 addition & 1 deletion tfhe/docs/guides/data_versioning.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ You can load serialized data with the `unversionize` function, even in newer ver
[dependencies]
# ...
tfhe = { version = "0.10.0", features = ["integer", "x86_64-unix"] }
tfhe-versionable = "0.2.0"
tfhe-versionable = "0.3.2"
bincode = "1.3.3"
```

Expand Down
22 changes: 13 additions & 9 deletions tfhe/src/core_crypto/algorithms/ggsw_encryption.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,12 +136,13 @@ pub fn encrypt_constant_ggsw_ciphertext<Scalar, NoiseDistribution, KeyCont, Outp
.expect("Failed to split generator into ggsw levels");

let decomp_base_log = output.decomposition_base_log();
let decomp_level_count = output.decomposition_level_count();
let ciphertext_modulus = output.ciphertext_modulus();

for (level_index, (mut level_matrix, mut generator)) in
for (output_index, (mut level_matrix, mut generator)) in
output.iter_mut().zip(gen_iter).enumerate()
{
let decomp_level = DecompositionLevel(level_index + 1);
let decomp_level = DecompositionLevel(decomp_level_count.0 - output_index);
let factor = ggsw_encryption_multiplicative_factor(
ciphertext_modulus,
decomp_level,
Expand Down Expand Up @@ -269,11 +270,12 @@ pub fn par_encrypt_constant_ggsw_ciphertext<Scalar, NoiseDistribution, KeyCont,
.expect("Failed to split generator into ggsw levels");

let decomp_base_log = output.decomposition_base_log();
let decomp_level_count = output.decomposition_level_count();
let ciphertext_modulus = output.ciphertext_modulus();

output.par_iter_mut().zip(gen_iter).enumerate().for_each(
|(level_index, (mut level_matrix, mut generator))| {
let decomp_level = DecompositionLevel(level_index + 1);
|(output_index, (mut level_matrix, mut generator))| {
let decomp_level = DecompositionLevel(decomp_level_count.0 - output_index);
let factor = ggsw_encryption_multiplicative_factor(
ciphertext_modulus,
decomp_level,
Expand Down Expand Up @@ -401,12 +403,13 @@ pub fn encrypt_constant_seeded_ggsw_ciphertext_with_existing_generator<
.expect("Failed to split generator into ggsw levels");

let decomp_base_log = output.decomposition_base_log();
let decomp_level_count = output.decomposition_level_count();
let ciphertext_modulus = output.ciphertext_modulus();

for (level_index, (mut level_matrix, mut loop_generator)) in
for (output_index, (mut level_matrix, mut loop_generator)) in
output.iter_mut().zip(gen_iter).enumerate()
{
let decomp_level = DecompositionLevel(level_index + 1);
let decomp_level = DecompositionLevel(decomp_level_count.0 - output_index);
let factor = ggsw_encryption_multiplicative_factor(
ciphertext_modulus,
decomp_level,
Expand Down Expand Up @@ -581,11 +584,12 @@ pub fn par_encrypt_constant_seeded_ggsw_ciphertext_with_existing_generator<
.expect("Failed to split generator into ggsw levels");

let decomp_base_log = output.decomposition_base_log();
let decomp_level_count = output.decomposition_level_count();
let ciphertext_modulus = output.ciphertext_modulus();

output.par_iter_mut().zip(gen_iter).enumerate().for_each(
|(level_index, (mut level_matrix, mut generator))| {
let decomp_level = DecompositionLevel(level_index + 1);
|(output_index, (mut level_matrix, mut generator))| {
let decomp_level = DecompositionLevel(decomp_level_count.0 - output_index);
let factor = ggsw_encryption_multiplicative_factor(
ciphertext_modulus,
decomp_level,
Expand Down Expand Up @@ -881,7 +885,7 @@ where
glwe_secret_key.glwe_dimension()
);

let level_matrix = ggsw_ciphertext.last().unwrap();
let level_matrix = ggsw_ciphertext.first().unwrap();
let level_matrix_as_glwe_list = level_matrix.as_glwe_list();
let last_row = level_matrix_as_glwe_list.last().unwrap();
let decomp_level = ggsw_ciphertext.decomposition_level_count();
Expand Down
Loading

0 comments on commit 35fdcdf

Please sign in to comment.