Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore(bench): measure key sizes and zk proof sizes #1379

Merged
merged 3 commits into from
Jul 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions tfhe/benches/integer/zk_pke.rs
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,9 @@ fn pke_zk_verify(c: &mut Criterion, results_file: &Path) {
public_params
.serialize_with_mode(&mut crs_data, Compress::No)
.unwrap();

println!("CRS size: {}", crs_data.len());

let test_name = format!("zk::crs_sizes::{param_name}_{bits}_bits_packed");

write_result(&mut file, &test_name, crs_data.len());
Expand Down Expand Up @@ -178,6 +181,24 @@ fn pke_zk_verify(c: &mut Criterion, results_file: &Path) {
.build_with_proof_packed(public_params, compute_load)
.unwrap();

let proof_serialized = bincode::serialize(&ct1).unwrap();

println!("proof size: {}", proof_serialized.len());

let test_name =
format!("zk::proof_sizes::{param_name}_{bits}_bits_packed_{zk_load}");

write_result(&mut file, &test_name, proof_serialized.len());
write_to_json::<u64, _>(
&test_name,
shortint_params,
param_name,
"pke_zk_proof",
&OperatorType::Atomic,
0,
vec![],
);

bench_group.bench_function(&bench_id_verify, |b| {
b.iter(|| {
let _ret = ct1.verify(public_params, &pk);
Expand Down
23 changes: 21 additions & 2 deletions tfhe/benches/utilities.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,8 @@ use tfhe::core_crypto::prelude::*;
#[cfg(feature = "boolean")]
pub mod boolean_utils {
use super::*;
#[cfg(feature = "boolean")]
use tfhe::boolean::parameters::BooleanParameters;

#[cfg(feature = "boolean")]
impl From<BooleanParameters> for CryptoParametersRecord<u32> {
fn from(params: BooleanParameters) -> Self {
CryptoParametersRecord {
Expand Down Expand Up @@ -38,6 +36,8 @@ pub mod shortint_utils {
use super::*;
use itertools::iproduct;
use std::vec::IntoIter;
use tfhe::shortint::parameters::compact_public_key_only::CompactPublicKeyEncryptionParameters;
use tfhe::shortint::parameters::list_compression::CompressionParameters;
#[cfg(feature = "gpu")]
use tfhe::shortint::parameters::PARAM_GPU_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_3_KS_PBS;
#[cfg(not(feature = "gpu"))]
Expand Down Expand Up @@ -128,6 +128,25 @@ pub mod shortint_utils {
}
}
}

impl From<CompactPublicKeyEncryptionParameters> for CryptoParametersRecord<u64> {
fn from(params: CompactPublicKeyEncryptionParameters) -> Self {
CryptoParametersRecord {
message_modulus: Some(params.message_modulus.0),
carry_modulus: Some(params.carry_modulus.0),
ciphertext_modulus: Some(params.ciphertext_modulus),
..Default::default()
}
}
}

impl From<CompressionParameters> for CryptoParametersRecord<u64> {
fn from(_params: CompressionParameters) -> Self {
CryptoParametersRecord {
..Default::default()
}
}
}
}

#[allow(unused_imports)]
Expand Down
160 changes: 156 additions & 4 deletions tfhe/examples/utilities/shortint_key_sizes.rs
Original file line number Diff line number Diff line change
@@ -1,22 +1,29 @@
#[path = "../../benches/utilities.rs"]
mod utilities;

use crate::utilities::{write_to_json, OperatorType};
use crate::utilities::{write_to_json, CryptoParametersRecord, OperatorType};
use std::fs::{File, OpenOptions};
use std::io::Write;
use std::path::Path;
use tfhe::keycache::NamedParam;
use tfhe::shortint::keycache::KEY_CACHE;
use tfhe::shortint::parameters::compact_public_key_only::PARAM_PKE_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64;
use tfhe::shortint::parameters::key_switching::PARAM_KEYSWITCH_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64;
use tfhe::shortint::parameters::list_compression::COMP_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64;
use tfhe::shortint::parameters::{
PARAM_MESSAGE_1_CARRY_1_KS_PBS, PARAM_MESSAGE_2_CARRY_2_KS_PBS, PARAM_MESSAGE_3_CARRY_3_KS_PBS,
PARAM_MESSAGE_1_CARRY_1_KS_PBS, PARAM_MESSAGE_2_CARRY_2_KS_PBS,
PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64, PARAM_MESSAGE_3_CARRY_3_KS_PBS,
PARAM_MESSAGE_4_CARRY_4_KS_PBS, PARAM_MULTI_BIT_MESSAGE_1_CARRY_1_GROUP_2_KS_PBS,
PARAM_MULTI_BIT_MESSAGE_1_CARRY_1_GROUP_3_KS_PBS,
PARAM_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_2_KS_PBS,
PARAM_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_3_KS_PBS,
PARAM_MULTI_BIT_MESSAGE_3_CARRY_3_GROUP_2_KS_PBS,
PARAM_MULTI_BIT_MESSAGE_3_CARRY_3_GROUP_3_KS_PBS,
};
use tfhe::shortint::{CompressedServerKey, PBSParameters};
use tfhe::shortint::{
ClassicPBSParameters, ClientKey, CompactPrivateKey, CompressedCompactPublicKey,
CompressedKeySwitchingKey, CompressedServerKey, PBSParameters,
};

fn write_result(file: &mut File, name: &str, value: usize) {
let line = format!("{name},{value}\n");
Expand Down Expand Up @@ -128,6 +135,150 @@ fn client_server_key_sizes(results_file: &Path) {
}
}

fn measure_serialized_size<T: serde::Serialize, P: Into<CryptoParametersRecord<u64>> + Clone>(
to_serialize: &T,
param: P,
param_name: &str,
test_name_suffix: &str,
display_name: &str,
file: &mut File,
) {
let serialized = bincode::serialize(to_serialize).unwrap();
let size = serialized.len();
let test_name = format!("shortint_key_sizes_{}_{}", param_name, test_name_suffix);
write_result(file, &test_name, size);
write_to_json::<u64, _>(
&test_name,
param.clone(),
param_name,
display_name,
&OperatorType::Atomic,
0,
vec![],
);

println!(
"{} {} -> size: {} bytes",
test_name_suffix, param_name, size,
);
}

fn tuniform_key_set_sizes(results_file: &Path) {
File::create(results_file).expect("create results file failed");
let mut file = OpenOptions::new()
.append(true)
.open(results_file)
.expect("cannot open results file");

println!("Measuring shortint key sizes:");

let param_fhe = PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64;
let param_fhe_name = param_fhe.name();
let cks = ClientKey::new(param_fhe);
let compressed_sks = CompressedServerKey::new(&cks);
let sks = compressed_sks.decompress();

measure_serialized_size(
&sks.key_switching_key,
<ClassicPBSParameters as Into<PBSParameters>>::into(param_fhe),
&param_fhe_name,
"ksk",
"KSK",
&mut file,
);
measure_serialized_size(
&compressed_sks.key_switching_key,
<ClassicPBSParameters as Into<PBSParameters>>::into(param_fhe),
&param_fhe_name,
"ksk_compressed",
"KSK",
&mut file,
);

measure_serialized_size(
&sks.bootstrapping_key,
<ClassicPBSParameters as Into<PBSParameters>>::into(param_fhe),
&param_fhe_name,
"bsk",
"BSK",
&mut file,
);
measure_serialized_size(
&compressed_sks.bootstrapping_key,
<ClassicPBSParameters as Into<PBSParameters>>::into(param_fhe),
&param_fhe_name,
"bsk_compressed",
"BSK",
&mut file,
);

let param_pke = PARAM_PKE_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64;
let param_pke_name = stringify!(PARAM_PKE_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64);
let compact_private_key = CompactPrivateKey::new(param_pke);
let compressed_pk = CompressedCompactPublicKey::new(&compact_private_key);
let pk = compressed_pk.decompress();

measure_serialized_size(&pk, param_pke, param_pke_name, "cpk", "CPK", &mut file);
measure_serialized_size(
&compressed_pk,
param_pke,
param_pke_name,
"cpk_compressed",
"CPK",
&mut file,
);

let param_compression = COMP_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64;
let param_compression_name = stringify!(COMP_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64);

let private_compression_key = cks.new_compression_private_key(param_compression);
let (compression_key, decompression_key) =
cks.new_compression_decompression_keys(&private_compression_key);

measure_serialized_size(
&compression_key,
param_compression,
param_compression_name,
"compression_key",
"CompressionKey",
&mut file,
);
measure_serialized_size(
&decompression_key,
param_compression,
param_compression_name,
"decompression_key",
"CompressionKey",
&mut file,
);

let param_casting = PARAM_KEYSWITCH_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64;
let param_casting_name = stringify!(PARAM_KEYSWITCH_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64);
let compressed_casting_key = CompressedKeySwitchingKey::new(
(&compact_private_key, None),
(&cks, &compressed_sks),
param_casting,
);
let casting_key = compressed_casting_key.decompress();

measure_serialized_size(
&casting_key.into_raw_parts().0,
param_casting,
param_casting_name,
"casting_key",
"CastKey",
&mut file,
);
measure_serialized_size(
&compressed_casting_key.into_raw_parts().0,
param_casting,
param_casting_name,
"casting_key_compressed",
"CastKey",
&mut file,
);
}

fn main() {
let work_dir = std::env::current_dir().unwrap();
println!("work_dir: {}", std::env::current_dir().unwrap().display());
Expand All @@ -137,5 +288,6 @@ fn main() {
std::env::set_current_dir(new_work_dir).unwrap();

let results_file = Path::new("shortint_key_sizes.csv");
client_server_key_sizes(results_file)
client_server_key_sizes(results_file);
tuniform_key_set_sizes(results_file);
}
104 changes: 104 additions & 0 deletions tfhe/src/shortint/key_switching_key/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -733,4 +733,108 @@ impl CompressedKeySwitchingKey {
.map(CompressedServerKey::decompress),
}
}

/// Deconstruct a [`CompressedKeySwitchingKey`] into its constituents.
pub fn into_raw_parts(
self,
) -> (
CompressedKeySwitchingKeyMaterial,
CompressedServerKey,
Option<CompressedServerKey>,
) {
let Self {
key_switching_key_material,
dest_server_key,
src_server_key,
} = self;

(key_switching_key_material, dest_server_key, src_server_key)
}

/// Construct a [`CompressedKeySwitchingKey`] from its constituents.
///
/// # Panics
///
/// Panics if the provided raw parts are not compatible with each other, i.e.:
///
/// if the provided source [`CompressedServerKey`] ciphertext
/// [`LweDimension`](`crate::core_crypto::commons::parameters::LweDimension`) does not match the
/// input [`LweDimension`](`crate::core_crypto::commons::parameters::LweDimension`) of the
/// [`SeededLweKeyswitchKeyOwned`] in the provided [`CompressedKeySwitchingKeyMaterial`] or if
/// the provided destination [`CompressedServerKey`] ciphertext
/// [`LweDimension`](`crate::core_crypto::commons::parameters::LweDimension`) does not match
/// the output [`LweDimension`](`crate::core_crypto::commons::parameters::LweDimension`) of
/// the [`SeededLweKeyswitchKeyOwned`] in the provided [`CompressedKeySwitchingKeyMaterial`].
pub fn from_raw_parts(
key_switching_key_material: CompressedKeySwitchingKeyMaterial,
dest_server_key: CompressedServerKey,
src_server_key: Option<CompressedServerKey>,
) -> Self {
match src_server_key {
Some(ref src_server_key) => {
let src_lwe_dimension = src_server_key.ciphertext_lwe_dimension();

assert_eq!(
src_lwe_dimension,
key_switching_key_material
.key_switching_key
.input_key_lwe_dimension(),
"Mismatch between the source CompressedServerKey ciphertext LweDimension ({:?}) \
and the SeededLweKeyswitchKey input LweDimension ({:?})",
src_lwe_dimension,
key_switching_key_material
.key_switching_key
.input_key_lwe_dimension(),
);

assert_eq!(
src_server_key.ciphertext_modulus, dest_server_key.ciphertext_modulus,
"Mismatch between the source CompressedServerKey CiphertextModulus ({:?}) \
and the destination CompressedServerKey CiphertextModulus ({:?})",
src_server_key.ciphertext_modulus, dest_server_key.ciphertext_modulus,
);
}
None => assert!(
key_switching_key_material.cast_rshift >= 0,
"Trying to build a shortint::CompressedKeySwitchingKey with a negative cast_rshift \
without providing a source CompressedServerKey, this is not supported"
),
}

let dst_lwe_dimension = match key_switching_key_material.destination_key {
EncryptionKeyChoice::Big => dest_server_key.bootstrapping_key.output_lwe_dimension(),
EncryptionKeyChoice::Small => dest_server_key.bootstrapping_key.input_lwe_dimension(),
};

assert_eq!(
dst_lwe_dimension,
key_switching_key_material
.key_switching_key
.output_key_lwe_dimension(),
"Mismatch between the destination CompressedServerKey ciphertext LweDimension ({:?}) \
and the SeededLweKeyswitchKey output LweDimension ({:?})",
dst_lwe_dimension,
key_switching_key_material
.key_switching_key
.output_key_lwe_dimension(),
);
assert_eq!(
key_switching_key_material
.key_switching_key
.ciphertext_modulus(),
dest_server_key.ciphertext_modulus,
"Mismatch between the SeededLweKeyswitchKey CiphertextModulus ({:?}) \
and the destination CompressedServerKey CiphertextModulus ({:?})",
key_switching_key_material
.key_switching_key
.ciphertext_modulus(),
dest_server_key.ciphertext_modulus,
);

Self {
key_switching_key_material,
dest_server_key,
src_server_key,
}
}
}
7 changes: 7 additions & 0 deletions tfhe/src/shortint/server_key/compressed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -395,4 +395,11 @@ impl CompressedServerKey {
engine.new_compressed_server_key_with_max_degree(cks, max_degree)
})
}

pub fn ciphertext_lwe_dimension(&self) -> LweDimension {
match self.pbs_order {
PBSOrder::KeyswitchBootstrap => self.key_switching_key.input_key_lwe_dimension(),
PBSOrder::BootstrapKeyswitch => self.key_switching_key.output_key_lwe_dimension(),
}
}
}
Loading