From e9d3e21b9317741f4893db96cc3a9fc0c13a52f7 Mon Sep 17 00:00:00 2001 From: Nicolas Sarlin Date: Fri, 27 Sep 2024 15:58:31 +0200 Subject: [PATCH] chore(all)!: use a builder pattern for safe serialization API BREAKING CHANGES: - `safe_serialize` and `safe_deserialize` are replaced by `SerializationConfig::serialize_into` and `DeserializationConfig::deserialize_from`. - C API: the `XXX_safe_serialize_versioned` is deprecated, `XXX_safe_serialize` is now versioned by default - JS API: the `safe_serialize` method now versionize the data before serialization. This is *NOT* a serialization breaking change for data serialized in previous versions with `safe_serialize_versioned`. --- tfhe/c_api_tests/test_high_level_integers.c | 51 +- tfhe/docs/fundamentals/serialization.md | 36 +- tfhe/src/c_api/high_level_api/booleans.rs | 8 +- tfhe/src/c_api/high_level_api/integers.rs | 14 +- tfhe/src/c_api/high_level_api/keys.rs | 13 + tfhe/src/c_api/high_level_api/utils.rs | 130 ++-- tfhe/src/c_api/high_level_api/zk.rs | 64 +- tfhe/src/high_level_api/booleans/base.rs | 1 + tfhe/src/high_level_api/booleans/tests.rs | 25 +- .../high_level_api/integers/signed/base.rs | 1 + .../high_level_api/integers/signed/tests.rs | 21 +- .../high_level_api/integers/unsigned/base.rs | 1 + .../integers/unsigned/tests/cpu.rs | 34 +- tfhe/src/high_level_api/mod.rs | 32 +- tfhe/src/integer/parameters/mod.rs | 2 + .../js_high_level_api/integers.rs | 42 +- .../js_on_wasm_api/js_high_level_api/keys.rs | 156 +++++ tfhe/src/lib.rs | 2 +- tfhe/src/safe_deserialization.rs | 513 -------------- tfhe/src/safe_serialization.rs | 662 ++++++++++++++++++ 20 files changed, 1041 insertions(+), 767 deletions(-) delete mode 100644 tfhe/src/safe_deserialization.rs create mode 100644 tfhe/src/safe_serialization.rs diff --git a/tfhe/c_api_tests/test_high_level_integers.c b/tfhe/c_api_tests/test_high_level_integers.c index 80183b2a85..17c176548f 100644 --- a/tfhe/c_api_tests/test_high_level_integers.c +++ b/tfhe/c_api_tests/test_high_level_integers.c @@ -399,54 +399,7 @@ int uint8_safe_serialization(const ClientKey *client_key, const ServerKey *serve deser_view.pointer = value_buffer.pointer; deser_view.length = value_buffer.length; ok = fhe_uint8_safe_deserialize_conformant(deser_view, max_serialization_size, server_key, - &deserialized_lhs); - assert(ok == 0); - - uint8_t clear; - ok = fhe_uint8_decrypt(deserialized_lhs, deserialized_client_key, &clear); - assert(ok == 0); - - assert(clear == lhs_clear); - - if (value_buffer.pointer != NULL) { - destroy_dynamic_buffer(&value_buffer); - } - fhe_uint8_destroy(lhs); - fhe_uint8_destroy(deserialized_lhs); - return ok; -} - -int uint8_safe_serialization_versioned(const ClientKey *client_key, const ServerKey *server_key) { - int ok; - FheUint8 *lhs = NULL; - FheUint8 *deserialized_lhs = NULL; - DynamicBuffer value_buffer = {.pointer = NULL, .length = 0, .destructor = NULL}; - DynamicBuffer cks_buffer = {.pointer = NULL, .length = 0, .destructor = NULL}; - DynamicBufferView deser_view = {.pointer = NULL, .length = 0}; - ClientKey *deserialized_client_key = NULL; - - const uint64_t max_serialization_size = UINT64_C(1) << UINT64_C(20); - - uint8_t lhs_clear = 123; - - ok = client_key_serialize(client_key, &cks_buffer); - assert(ok == 0); - - deser_view.pointer = cks_buffer.pointer; - deser_view.length = cks_buffer.length; - ok = client_key_deserialize(deser_view, &deserialized_client_key); - assert(ok == 0); - - ok = fhe_uint8_try_encrypt_with_client_key_u8(lhs_clear, client_key, &lhs); - assert(ok == 0); - - ok = fhe_uint8_safe_serialize_versioned(lhs, &value_buffer, max_serialization_size); - assert(ok == 0); - - deser_view.pointer = value_buffer.pointer; - deser_view.length = value_buffer.length; - ok = fhe_uint8_safe_deserialize_conformant_versioned(deser_view, max_serialization_size, - server_key, &deserialized_lhs); + &deserialized_lhs); assert(ok == 0); uint8_t clear; @@ -657,8 +610,6 @@ int main(void) { assert(ok == 0); ok = uint8_safe_serialization(client_key, server_key); assert(ok == 0); - ok = uint8_safe_serialization_versioned(client_key, server_key); - assert(ok == 0); ok = uint8_compressed(client_key); assert(ok == 0); diff --git a/tfhe/docs/fundamentals/serialization.md b/tfhe/docs/fundamentals/serialization.md index 6c9f8d0563..4e2ab0e94f 100644 --- a/tfhe/docs/fundamentals/serialization.md +++ b/tfhe/docs/fundamentals/serialization.md @@ -78,7 +78,7 @@ When dealing with sensitive types, it's important to implement safe serializatio The safe deserialization must take the output of a safe-serialization as input. During the process, the following validation occurs: * **Type match**: deserializing `type A` from a serialized `type B` raises an error indicating "On deserialization, expected type A, got type B". -* **Version compatibility**: deserializing `type A` of a newer version (for example, version 0.2) from a serialized `type A` of an older version (for example, version 0.1) raises an error indicating "On deserialization, expected serialization version 0.2, got version 0.1". +* **Version compatibility**: data serialized in previous versions of **TFHE-rs** are automatically upgraded to the latest version using the [data versioning](../guides/data\_versioning.md) feature. * **Parameter compatibility**: deserializing an object of `type A` with one set of crypto parameters from an object of `type A` with another set of crypto parameters raises an error indicating "Deserialized object of type A not conformant with given parameter set" * If both parameter sets have the same LWE dimension for ciphertexts, a ciphertext from param 1 may not fail this deserialization check with param 2. * This check can't distinguish ciphertexts/server keys from independent client keys with the same parameters. @@ -97,7 +97,7 @@ Here is an example: use tfhe::conformance::ParameterSetConformant; use tfhe::prelude::*; -use tfhe::safe_deserialization::{safe_deserialize_conformant, safe_serialize}; +use tfhe::safe_serialization::{SerializationConfig, DeserializationConfig}; use tfhe::shortint::parameters::{PARAM_MESSAGE_2_CARRY_2_KS_PBS, PARAM_MESSAGE_2_CARRY_2_PBS_KS}; use tfhe::conformance::ListSizeConstraint; use tfhe::{ @@ -127,19 +127,15 @@ fn main() { let mut buffer = vec![]; - safe_serialize(&ct, &mut buffer, 1 << 40).unwrap(); + SerializationConfig::new(1 << 20).serialize_into(&ct, &mut buffer).unwrap(); - assert!(safe_deserialize_conformant::( - buffer.as_slice(), - 1 << 20, - &conformance_params_2 - ).is_err()); - - let ct2 = safe_deserialize_conformant::( - buffer.as_slice(), - 1 << 20, - &conformance_params_1 - ).unwrap(); + assert!(DeserializationConfig::new(1 << 20) + .deserialize_from::(buffer.as_slice(), &conformance_params_2) + .is_err()); + + let ct2 = DeserializationConfig::new(1 << 20) + .deserialize_from::(buffer.as_slice(), &conformance_params_1) + .unwrap(); let dec: u8 = ct2.decrypt(&client_key); assert_eq!(msg, dec); @@ -152,18 +148,14 @@ fn main() { let compact_list = builder.build(); let mut buffer = vec![]; - safe_serialize(&compact_list, &mut buffer, 1 << 40).unwrap(); + SerializationConfig::new(1 << 20).serialize_into(&compact_list, &mut buffer).unwrap(); let conformance_params = CompactCiphertextListConformanceParams { shortint_params: params_1.to_shortint_conformance_param(), num_elements_constraint: ListSizeConstraint::exact_size(2), }; - assert!(safe_deserialize_conformant::( - buffer.as_slice(), - 1 << 20, - &conformance_params - ).is_ok()); + DeserializationConfig::new(1 << 20) + .deserialize_from::(buffer.as_slice(), &conformance_params) + .unwrap(); } ``` - -You can combine this serialization/deserialization feature with the [data versioning](../guides/data\_versioning.md) feature by using the `safe_serialize_versioned` and `safe_deserialize_conformant_versioned` functions. diff --git a/tfhe/src/c_api/high_level_api/booleans.rs b/tfhe/src/c_api/high_level_api/booleans.rs index 865007baa4..2241067d6b 100644 --- a/tfhe/src/c_api/high_level_api/booleans.rs +++ b/tfhe/src/c_api/high_level_api/booleans.rs @@ -8,9 +8,7 @@ impl_destroy_on_type!(FheBool); impl_clone_on_type!(FheBool); impl_serialize_deserialize_on_type!(FheBool); impl_safe_serialize_on_type!(FheBool); -impl_safe_serialize_versioned_on_type!(FheBool); -impl_safe_deserialize_conformant_integer!(FheBool, FheBoolConformanceParams); -impl_safe_deserialize_conformant_versioned_integer!(FheBool, FheBoolConformanceParams); +impl_safe_deserialize_conformant_on_type!(FheBool, FheBoolConformanceParams); impl_binary_fn_on_type!(FheBool => bitand, bitor, bitxor); impl_binary_assign_fn_on_type!(FheBool => bitand_assign, bitor_assign, bitxor_assign); @@ -48,9 +46,7 @@ impl_destroy_on_type!(CompressedFheBool); impl_clone_on_type!(CompressedFheBool); impl_serialize_deserialize_on_type!(CompressedFheBool); impl_safe_serialize_on_type!(CompressedFheBool); -impl_safe_serialize_versioned_on_type!(CompressedFheBool); -impl_safe_deserialize_conformant_integer!(CompressedFheBool, FheBoolConformanceParams); -impl_safe_deserialize_conformant_versioned_integer!(CompressedFheBool, FheBoolConformanceParams); +impl_safe_deserialize_conformant_on_type!(CompressedFheBool, FheBoolConformanceParams); impl_try_encrypt_with_client_key_on_type!(CompressedFheBool{crate::high_level_api::CompressedFheBool}, bool); #[no_mangle] diff --git a/tfhe/src/c_api/high_level_api/integers.rs b/tfhe/src/c_api/high_level_api/integers.rs index a6d73ce144..9555f0ffaa 100644 --- a/tfhe/src/c_api/high_level_api/integers.rs +++ b/tfhe/src/c_api/high_level_api/integers.rs @@ -306,14 +306,8 @@ macro_rules! create_integer_wrapper_type { impl_safe_serialize_on_type!($name); - impl_safe_serialize_versioned_on_type!($name); - - ::paste::paste! { - impl_safe_deserialize_conformant_integer!($name, [<$name ConformanceParams>]); - } - ::paste::paste! { - impl_safe_deserialize_conformant_versioned_integer!($name, [<$name ConformanceParams>]); + impl_safe_deserialize_conformant_on_type!($name, [<$name ConformanceParams>]); } define_all_cast_into_for_integer_type!($name); @@ -332,11 +326,7 @@ macro_rules! create_integer_wrapper_type { impl_safe_serialize_on_type!([]); - impl_safe_serialize_versioned_on_type!([]); - - impl_safe_deserialize_conformant_integer!([], [<$name ConformanceParams>]); - - impl_safe_deserialize_conformant_versioned_integer!([], [<$name ConformanceParams>]); + impl_safe_deserialize_conformant_on_type!([], [<$name ConformanceParams>]); #[no_mangle] pub unsafe extern "C" fn []( diff --git a/tfhe/src/c_api/high_level_api/keys.rs b/tfhe/src/c_api/high_level_api/keys.rs index 35b5e8e62c..7405bfd48e 100644 --- a/tfhe/src/c_api/high_level_api/keys.rs +++ b/tfhe/src/c_api/high_level_api/keys.rs @@ -34,6 +34,19 @@ impl_serialize_deserialize_on_type!(CompressedCompactPublicKey); impl_serialize_deserialize_on_type!(ServerKey); impl_serialize_deserialize_on_type!(CompressedServerKey); +impl_safe_serialize_on_type!(ClientKey); +impl_safe_serialize_on_type!(PublicKey); +impl_safe_serialize_on_type!(CompactPublicKey); +impl_safe_serialize_on_type!(CompressedCompactPublicKey); +impl_safe_serialize_on_type!(ServerKey); +impl_safe_serialize_on_type!(CompressedServerKey); + +impl_safe_deserialize_on_type!(ClientKey); +impl_safe_deserialize_on_type!(PublicKey); +impl_safe_deserialize_on_type!(CompactPublicKey); +impl_safe_deserialize_on_type!(CompressedCompactPublicKey); +impl_safe_deserialize_on_type!(CompressedServerKey); + #[no_mangle] pub unsafe extern "C" fn generate_keys( config: *mut super::config::Config, diff --git a/tfhe/src/c_api/high_level_api/utils.rs b/tfhe/src/c_api/high_level_api/utils.rs index fb178aa988..c21393bac0 100644 --- a/tfhe/src/c_api/high_level_api/utils.rs +++ b/tfhe/src/c_api/high_level_api/utils.rs @@ -272,6 +272,13 @@ macro_rules! impl_safe_serialize_on_type { ($wrapper_type:ty) => { ::paste::paste! { #[no_mangle] + /// Serializes safely. + /// + /// This function adds versioning information to the serialized buffer, meaning that it will keep compatibility with future + /// versions of TFHE-rs. + /// + /// - `serialized_size_limit`: size limit (in number of byte) of the serialized object + /// (to avoid out of memory attacks) pub unsafe extern "C" fn [<$wrapper_type:snake _safe_serialize>]( sself: *const $wrapper_type, result: *mut crate::c_api::buffer::DynamicBuffer, @@ -284,7 +291,7 @@ macro_rules! impl_safe_serialize_on_type { let sself = crate::c_api::utils::get_ref_checked(sself).unwrap(); - crate::high_level_api::safe_serialize(&sself.0, &mut buffer, serialized_size_limit) + crate::safe_serialization::SerializationConfig::new(serialized_size_limit).serialize_into(&sself.0, &mut buffer) .unwrap(); *result = buffer.into(); @@ -294,42 +301,16 @@ macro_rules! impl_safe_serialize_on_type { } } -macro_rules! impl_safe_serialize_versioned_on_type { - ($wrapper_type:ty) => { - ::paste::paste! { - #[no_mangle] - pub unsafe extern "C" fn [<$wrapper_type:snake _safe_serialize_versioned>]( - sself: *const $wrapper_type, - result: *mut crate::c_api::buffer::DynamicBuffer, - serialized_size_limit: u64, - ) -> ::std::os::raw::c_int { - crate::c_api::utils::catch_panic(|| { - crate::c_api::utils::check_ptr_is_non_null_and_aligned(result).unwrap(); - - let mut buffer = vec![]; - - let sself = crate::c_api::utils::get_ref_checked(sself).unwrap(); - - crate::high_level_api::safe_serialize_versioned(&sself.0, &mut buffer, serialized_size_limit) - .unwrap(); - - *result = buffer.into(); - }) - } - } - }; -} - -pub(crate) use {impl_safe_serialize_on_type, impl_safe_serialize_versioned_on_type}; +pub(crate) use impl_safe_serialize_on_type; -macro_rules! impl_safe_deserialize_conformant_integer { +macro_rules! impl_safe_deserialize_conformant_on_type { ($wrapper_type:ty, $conformance_param_type:ty) => { ::paste::paste! { #[no_mangle] /// Deserializes safely, and checks that the resulting ciphertext /// is in compliance with the shape of ciphertext that the `server_key` expects. /// - /// This function can only deserialize, types which have been serialized + /// This function can only deserialize types which have been serialized /// by a `safe_serialize` function. /// /// - `serialized_size_limit`: size limit (in number of byte) of the serialized object @@ -344,91 +325,78 @@ macro_rules! impl_safe_deserialize_conformant_integer { server_key: *const crate::c_api::high_level_api::keys::ServerKey, result: *mut *mut $wrapper_type, ) -> ::std::os::raw::c_int { - ::paste::paste! { - crate::c_api::utils::catch_panic(|| { - crate::c_api::utils::check_ptr_is_non_null_and_aligned(result).unwrap(); + crate::c_api::utils::catch_panic(|| { + crate::c_api::utils::check_ptr_is_non_null_and_aligned(result).unwrap(); - let sk = crate::c_api::utils::get_ref_checked(server_key).unwrap(); + let sk = crate::c_api::utils::get_ref_checked(server_key).unwrap(); - let buffer_view: &[u8] = buffer_view.as_slice(); + let buffer_view: &[u8] = buffer_view.as_slice(); - // First fill the result with a null ptr so that if we fail and the return code is not - // checked, then any access to the result pointer will segfault (mimics malloc on failure) - *result = std::ptr::null_mut(); + // First fill the result with a null ptr so that if we fail and the return code is not + // checked, then any access to the result pointer will segfault (mimics malloc on failure) + *result = std::ptr::null_mut(); - let params = $crate::high_level_api::$conformance_param_type::from(&sk.0); - let inner = $crate::safe_deserialization::safe_deserialize_conformant( - buffer_view, - serialized_size_limit, - ¶ms, - ) - .unwrap(); + let params = $crate::high_level_api::$conformance_param_type::from(&sk.0); + let inner = $crate::safe_serialization::DeserializationConfig::new( + serialized_size_limit) + .deserialize_from(buffer_view, ¶ms) + .unwrap(); - let heap_allocated_object = Box::new($wrapper_type(inner)); + let heap_allocated_object = Box::new($wrapper_type(inner)); - *result = Box::into_raw(heap_allocated_object); - }) - } + *result = Box::into_raw(heap_allocated_object); + }) } + } }; } -macro_rules! impl_safe_deserialize_conformant_versioned_integer { - ($wrapper_type:ty, $conformance_param_type:ty) => { +pub(crate) use impl_safe_deserialize_conformant_on_type; + +macro_rules! impl_safe_deserialize_on_type { + ($wrapper_type:ty) => { ::paste::paste! { #[no_mangle] - /// Deserializes safely, and checks that the resulting ciphertext - /// is in compliance with the shape of ciphertext that the `server_key` expects. + /// Deserializes safely. /// - /// This function can only deserialize, types which have been serialized and versioned - /// by a `safe_serialize_versioned` function. + /// This function can only deserialize types which have been serialized + /// by a `safe_serialize` function. /// /// - `serialized_size_limit`: size limit (in number of byte) of the serialized object /// (to avoid out of memory attacks) - /// - `server_key`: ServerKey used in the conformance check /// - `result`: pointer where resulting deserialized object needs to be stored. /// * cannot be NULL /// * (*result) will point the deserialized object on success, else NULL - pub unsafe extern "C" fn [<$wrapper_type:snake _safe_deserialize_conformant_versioned>]( + pub unsafe extern "C" fn [<$wrapper_type:snake _safe_deserialize>]( buffer_view: crate::c_api::buffer::DynamicBufferView, serialized_size_limit: u64, - server_key: *const crate::c_api::high_level_api::keys::ServerKey, result: *mut *mut $wrapper_type, ) -> ::std::os::raw::c_int { - ::paste::paste! { - crate::c_api::utils::catch_panic(|| { - crate::c_api::utils::check_ptr_is_non_null_and_aligned(result).unwrap(); - - let sk = crate::c_api::utils::get_ref_checked(server_key).unwrap(); + crate::c_api::utils::catch_panic(|| { + crate::c_api::utils::check_ptr_is_non_null_and_aligned(result).unwrap(); - let buffer_view: &[u8] = buffer_view.as_slice(); + let buffer_view: &[u8] = buffer_view.as_slice(); - // First fill the result with a null ptr so that if we fail and the return code is not - // checked, then any access to the result pointer will segfault (mimics malloc on failure) - *result = std::ptr::null_mut(); + // First fill the result with a null ptr so that if we fail and the return code is not + // checked, then any access to the result pointer will segfault (mimics malloc on failure) + *result = std::ptr::null_mut(); - let params = $crate::high_level_api::$conformance_param_type::from(&sk.0); - let inner = $crate::safe_deserialization::safe_deserialize_conformant_versioned( - buffer_view, - serialized_size_limit, - ¶ms, - ) - .unwrap(); + let inner = $crate::safe_serialization::DeserializationConfig::new(serialized_size_limit) + .disable_conformance() + .deserialize_from(buffer_view) + .unwrap(); - let heap_allocated_object = Box::new($wrapper_type(inner)); + let heap_allocated_object = Box::new($wrapper_type(inner)); - *result = Box::into_raw(heap_allocated_object); - }) - } + *result = Box::into_raw(heap_allocated_object); + }) } } }; } -pub(crate) use { - impl_safe_deserialize_conformant_integer, impl_safe_deserialize_conformant_versioned_integer, -}; +pub(crate) use impl_safe_deserialize_on_type; macro_rules! impl_binary_fn_on_type { // More general binary fn case, diff --git a/tfhe/src/c_api/high_level_api/zk.rs b/tfhe/src/c_api/high_level_api/zk.rs index 3bccb128f3..52d8f01b0e 100644 --- a/tfhe/src/c_api/high_level_api/zk.rs +++ b/tfhe/src/c_api/high_level_api/zk.rs @@ -26,10 +26,7 @@ impl_destroy_on_type!(CompactPkePublicParams); /// Serializes the public params /// /// If compress is true, the data will be compressed (less serialized bytes), however, this makes -/// the serialization process slower. -/// -/// Also, the value to `compress` should match the value given to `is_compressed` -/// when deserializing. +/// the deserialization process slower. #[no_mangle] pub unsafe extern "C" fn compact_pke_public_params_serialize( sself: *const CompactPkePublicParams, @@ -72,6 +69,65 @@ pub unsafe extern "C" fn compact_pke_public_params_deserialize( }) } +/// Serializes the public params +/// +/// If compress is true, the data will be compressed (less serialized bytes), however, this makes +/// the deserialization process slower. +#[no_mangle] +pub unsafe extern "C" fn compact_pke_public_params_safe_serialize( + sself: *const CompactPkePublicParams, + compress: bool, + serialized_size_limit: u64, + result: *mut crate::c_api::buffer::DynamicBuffer, +) -> ::std::os::raw::c_int { + crate::c_api::utils::catch_panic(|| { + crate::c_api::utils::check_ptr_is_non_null_and_aligned(result).unwrap(); + + let wrapper = crate::c_api::utils::get_ref_checked(sself).unwrap(); + + let mut buffer = Vec::new(); + if compress { + crate::safe_serialization::SerializationConfig::new(serialized_size_limit) + .serialize_into(&wrapper.0.compress(), &mut buffer) + .unwrap(); + } else { + crate::safe_serialization::SerializationConfig::new(serialized_size_limit) + .serialize_into(&wrapper.0, &mut buffer) + .unwrap(); + }; + + *result = buffer.into(); + }) +} + +/// Deserializes the public params +/// +/// If the data comes from compressed public params, then `is_compressed` must be true. +#[no_mangle] +pub unsafe extern "C" fn compact_pke_public_params_safe_deserialize( + buffer_view: crate::c_api::buffer::DynamicBufferView, + serialized_size_limit: u64, + result: *mut *mut CompactPkePublicParams, +) -> ::std::os::raw::c_int { + crate::c_api::utils::catch_panic(|| { + crate::c_api::utils::check_ptr_is_non_null_and_aligned(result).unwrap(); + + *result = std::ptr::null_mut(); + + let buffer_view: &[u8] = buffer_view.as_slice(); + + let deserialized = + crate::safe_serialization::DeserializationConfig::new(serialized_size_limit) + .disable_conformance() + .deserialize_from(buffer_view) + .unwrap(); + + let heap_allocated_object = Box::new(CompactPkePublicParams(deserialized)); + + *result = Box::into_raw(heap_allocated_object); + }) +} + pub struct CompactPkeCrs(pub(crate) crate::core_crypto::entities::CompactPkeCrs); impl_destroy_on_type!(CompactPkeCrs); diff --git a/tfhe/src/high_level_api/booleans/base.rs b/tfhe/src/high_level_api/booleans/base.rs index e22058ba14..63f6e5cd47 100644 --- a/tfhe/src/high_level_api/booleans/base.rs +++ b/tfhe/src/high_level_api/booleans/base.rs @@ -57,6 +57,7 @@ impl Named for FheBool { const NAME: &'static str = "high_level_api::FheBool"; } +#[derive(Copy, Clone)] pub struct FheBoolConformanceParams(pub(crate) CiphertextConformanceParams); impl

From

for FheBoolConformanceParams diff --git a/tfhe/src/high_level_api/booleans/tests.rs b/tfhe/src/high_level_api/booleans/tests.rs index ddb05ddb3c..e6127628b3 100644 --- a/tfhe/src/high_level_api/booleans/tests.rs +++ b/tfhe/src/high_level_api/booleans/tests.rs @@ -318,7 +318,7 @@ fn compressed_bool_test_case(setup_fn: impl FnOnce() -> (ClientKey, Device)) { mod cpu { use super::*; - use crate::safe_deserialization::safe_deserialize_conformant; + use crate::safe_serialization::{DeserializationConfig, SerializationConfig}; use crate::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS; use crate::FheBoolConformanceParams; use rand::random; @@ -682,12 +682,14 @@ mod cpu { let clear_a = random::(); let a = FheBool::encrypt(clear_a, &client_key); let mut serialized = vec![]; - assert!(crate::safe_serialize(&a, &mut serialized, 1 << 20).is_ok()); + SerializationConfig::new(1 << 20) + .serialize_into(&a, &mut serialized) + .unwrap(); let params = FheBoolConformanceParams::from(&server_key); - let deserialized_a = - safe_deserialize_conformant::(serialized.as_slice(), 1 << 20, ¶ms) - .unwrap(); + let deserialized_a = DeserializationConfig::new(1 << 20) + .deserialize_from::(serialized.as_slice(), ¶ms) + .unwrap(); let decrypted: bool = deserialized_a.decrypt(&client_key); assert_eq!(decrypted, clear_a); @@ -703,15 +705,14 @@ mod cpu { let clear_a = random::(); let a = CompressedFheBool::encrypt(clear_a, &client_key); let mut serialized = vec![]; - assert!(crate::safe_serialize(&a, &mut serialized, 1 << 20).is_ok()); + SerializationConfig::new(1 << 20) + .serialize_into(&a, &mut serialized) + .unwrap(); let params = FheBoolConformanceParams::from(&server_key); - let deserialized_a = safe_deserialize_conformant::( - serialized.as_slice(), - 1 << 20, - ¶ms, - ) - .unwrap(); + let deserialized_a = DeserializationConfig::new(1 << 20) + .deserialize_from::(serialized.as_slice(), ¶ms) + .unwrap(); assert!(deserialized_a.is_conformant(&FheBoolConformanceParams::from(block_params))); diff --git a/tfhe/src/high_level_api/integers/signed/base.rs b/tfhe/src/high_level_api/integers/signed/base.rs index 0f0af0e23a..cc2e4987b9 100644 --- a/tfhe/src/high_level_api/integers/signed/base.rs +++ b/tfhe/src/high_level_api/integers/signed/base.rs @@ -42,6 +42,7 @@ pub struct FheInt { pub(crate) tag: Tag, } +#[derive(Copy, Clone)] pub struct FheIntConformanceParams { pub(crate) params: RadixCiphertextConformanceParams, pub(crate) id: PhantomData, diff --git a/tfhe/src/high_level_api/integers/signed/tests.rs b/tfhe/src/high_level_api/integers/signed/tests.rs index 3cc0d539ea..7466533137 100644 --- a/tfhe/src/high_level_api/integers/signed/tests.rs +++ b/tfhe/src/high_level_api/integers/signed/tests.rs @@ -1,6 +1,6 @@ use crate::integer::I256; use crate::prelude::*; -use crate::safe_deserialization::safe_deserialize_conformant; +use crate::safe_serialization::{DeserializationConfig, SerializationConfig}; use crate::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS; use crate::{ generate_keys, set_server_key, ClientKey, CompactCiphertextList, CompactPublicKey, @@ -648,11 +648,14 @@ fn test_safe_deserialize_conformant_fhe_int32() { let clear_a = random::(); let a = FheInt32::encrypt(clear_a, &client_key); let mut serialized = vec![]; - assert!(crate::safe_serialize(&a, &mut serialized, 1 << 20).is_ok()); + SerializationConfig::new(1 << 20) + .serialize_into(&a, &mut serialized) + .unwrap(); let params = FheInt32ConformanceParams::from(&server_key); - let deserialized_a = - safe_deserialize_conformant::(serialized.as_slice(), 1 << 20, ¶ms).unwrap(); + let deserialized_a = DeserializationConfig::new(1 << 20) + .deserialize_from::(serialized.as_slice(), ¶ms) + .unwrap(); let decrypted: i32 = deserialized_a.decrypt(&client_key); assert_eq!(decrypted, clear_a); @@ -670,12 +673,14 @@ fn test_safe_deserialize_conformant_compressed_fhe_int32() { let clear_a = random::(); let a = CompressedFheInt32::encrypt(clear_a, &client_key); let mut serialized = vec![]; - assert!(crate::safe_serialize(&a, &mut serialized, 1 << 20).is_ok()); + SerializationConfig::new(1 << 20) + .serialize_into(&a, &mut serialized) + .unwrap(); let params = FheInt32ConformanceParams::from(&server_key); - let deserialized_a = - safe_deserialize_conformant::(serialized.as_slice(), 1 << 20, ¶ms) - .unwrap(); + let deserialized_a = DeserializationConfig::new(1 << 20) + .deserialize_from::(serialized.as_slice(), ¶ms) + .unwrap(); let params = FheInt32ConformanceParams::from(block_params); assert!(deserialized_a.is_conformant(¶ms)); diff --git a/tfhe/src/high_level_api/integers/unsigned/base.rs b/tfhe/src/high_level_api/integers/unsigned/base.rs index 9f1edd66f6..a0dda24f4f 100644 --- a/tfhe/src/high_level_api/integers/unsigned/base.rs +++ b/tfhe/src/high_level_api/integers/unsigned/base.rs @@ -82,6 +82,7 @@ pub struct FheUint { pub(crate) tag: Tag, } +#[derive(Copy, Clone)] pub struct FheUintConformanceParams { pub(crate) params: RadixCiphertextConformanceParams, pub(crate) id: PhantomData, diff --git a/tfhe/src/high_level_api/integers/unsigned/tests/cpu.rs b/tfhe/src/high_level_api/integers/unsigned/tests/cpu.rs index 0828678991..e1127b3bc7 100644 --- a/tfhe/src/high_level_api/integers/unsigned/tests/cpu.rs +++ b/tfhe/src/high_level_api/integers/unsigned/tests/cpu.rs @@ -2,7 +2,7 @@ use crate::conformance::ListSizeConstraint; use crate::high_level_api::prelude::*; use crate::high_level_api::{generate_keys, set_server_key, ConfigBuilder, FheUint8}; use crate::integer::U256; -use crate::safe_deserialization::safe_deserialize_conformant; +use crate::safe_serialization::{DeserializationConfig, SerializationConfig}; use crate::shortint::parameters::classic::compact_pk::*; use crate::shortint::parameters::compact_public_key_only::p_fail_2_minus_64::ks_pbs::PARAM_PKE_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64; use crate::shortint::parameters::key_switching::p_fail_2_minus_64::ks_pbs::PARAM_KEYSWITCH_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64; @@ -419,11 +419,14 @@ fn test_safe_deserialize_conformant_fhe_uint32() { let clear_a = random::(); let a = FheUint32::encrypt(clear_a, &client_key); let mut serialized = vec![]; - assert!(crate::safe_serialize(&a, &mut serialized, 1 << 20).is_ok()); + SerializationConfig::new(1 << 20) + .serialize_into(&a, &mut serialized) + .unwrap(); let params = FheUint32ConformanceParams::from(&server_key); - let deserialized_a = - safe_deserialize_conformant::(serialized.as_slice(), 1 << 20, ¶ms).unwrap(); + let deserialized_a = DeserializationConfig::new(1 << 20) + .deserialize_from::(serialized.as_slice(), ¶ms) + .unwrap(); let decrypted: u32 = deserialized_a.decrypt(&client_key); assert_eq!(decrypted, clear_a); @@ -440,12 +443,14 @@ fn test_safe_deserialize_conformant_compressed_fhe_uint32() { let clear_a = random::(); let a = CompressedFheUint32::encrypt(clear_a, &client_key); let mut serialized = vec![]; - assert!(crate::safe_serialize(&a, &mut serialized, 1 << 20).is_ok()); + SerializationConfig::new(1 << 20) + .serialize_into(&a, &mut serialized) + .unwrap(); let params = FheUint32ConformanceParams::from(&server_key); - let deserialized_a = - safe_deserialize_conformant::(serialized.as_slice(), 1 << 20, ¶ms) - .unwrap(); + let deserialized_a = DeserializationConfig::new(1 << 20) + .deserialize_from::(serialized.as_slice(), ¶ms) + .unwrap(); assert!(deserialized_a.is_conformant(&FheUint32ConformanceParams::from(block_params))); @@ -466,18 +471,17 @@ fn test_safe_deserialize_conformant_compact_fhe_uint32() { .extend(clears.iter().copied()) .build(); let mut serialized = vec![]; - assert!(crate::safe_serialize(&a, &mut serialized, 1 << 20).is_ok()); + SerializationConfig::new(1 << 20) + .serialize_into(&a, &mut serialized) + .unwrap(); let params = CompactCiphertextListConformanceParams { shortint_params: block_params.to_shortint_conformance_param(), num_elements_constraint: ListSizeConstraint::exact_size(clears.len()), }; - let deserialized_a = safe_deserialize_conformant::( - serialized.as_slice(), - 1 << 20, - ¶ms, - ) - .unwrap(); + let deserialized_a = DeserializationConfig::new(1 << 20) + .deserialize_from::(serialized.as_slice(), ¶ms) + .unwrap(); let expander = deserialized_a.expand().unwrap(); for (i, clear) in clears.into_iter().enumerate() { diff --git a/tfhe/src/high_level_api/mod.rs b/tfhe/src/high_level_api/mod.rs index c4178d51a2..b8df4310e1 100644 --- a/tfhe/src/high_level_api/mod.rs +++ b/tfhe/src/high_level_api/mod.rs @@ -51,13 +51,13 @@ expand_pub_use_fhe_type!( ); pub use crate::integer::parameters::CompactCiphertextListConformanceParams; +pub use crate::safe_serialization::{DeserializationConfig, SerializationConfig}; #[cfg(feature = "zk-pok")] pub use compact_list::ProvenCompactCiphertextList; pub use compact_list::{ CompactCiphertextList, CompactCiphertextListBuilder, CompactCiphertextListExpander, }; pub use compressed_ciphertext_list::{CompressedCiphertextList, CompressedCiphertextListBuilder}; -pub use safe_serialize::{safe_serialize, safe_serialize_versioned}; pub use tag::Tag; @@ -123,33 +123,3 @@ pub enum FheTypes { Int160, Int256, } - -pub mod safe_serialize { - use crate::named::Named; - use serde::Serialize; - use tfhe_versionable::Versionize; - - pub fn safe_serialize( - a: &T, - writer: impl std::io::Write, - serialized_size_limit: u64, - ) -> Result<(), String> - where - T: Named + Serialize, - { - crate::safe_deserialization::safe_serialize(a, writer, serialized_size_limit) - .map_err(|err| err.to_string()) - } - - pub fn safe_serialize_versioned( - a: &T, - writer: impl std::io::Write, - serialized_size_limit: u64, - ) -> Result<(), String> - where - T: Named + Versionize, - { - crate::safe_deserialization::safe_serialize_versioned(a, writer, serialized_size_limit) - .map_err(|err| err.to_string()) - } -} diff --git a/tfhe/src/integer/parameters/mod.rs b/tfhe/src/integer/parameters/mod.rs index 0fd52e7944..787523f1ac 100644 --- a/tfhe/src/integer/parameters/mod.rs +++ b/tfhe/src/integer/parameters/mod.rs @@ -176,6 +176,7 @@ pub const PARAM_MESSAGE_1_CARRY_1_KS_PBS_32_BITS: WopbsParameters = WopbsParamet encryption_key_choice: EncryptionKeyChoice::Big, }; +#[derive(Copy, Clone)] pub struct RadixCiphertextConformanceParams { pub shortint_params: CiphertextConformanceParams, pub num_blocks_per_integer: usize, @@ -210,6 +211,7 @@ impl RadixCiphertextConformanceParams { /// Structure to store the expected properties of a ciphertext list /// Can be used on a server to check if client inputs are well formed /// before running a computation on them +#[derive(Copy, Clone)] pub struct CompactCiphertextListConformanceParams { pub shortint_params: CiphertextConformanceParams, pub num_elements_constraint: ListSizeConstraint, diff --git a/tfhe/src/js_on_wasm_api/js_high_level_api/integers.rs b/tfhe/src/js_on_wasm_api/js_high_level_api/integers.rs index 6168e9e736..5a58c9a2a7 100644 --- a/tfhe/src/js_on_wasm_api/js_high_level_api/integers.rs +++ b/tfhe/src/js_on_wasm_api/js_high_level_api/integers.rs @@ -194,7 +194,8 @@ macro_rules! create_wrapper_type_non_native_type ( #[wasm_bindgen] pub fn safe_serialize(&self, serialized_size_limit: u64) -> Result, JsError> { let mut buffer = vec![]; - catch_panic_result(|| crate::safe_deserialization::safe_serialize(&self.0, &mut buffer, serialized_size_limit) + catch_panic_result(|| crate::safe_serialization::SerializationConfig::new(serialized_size_limit) + .serialize_into(&self.0, &mut buffer) .map_err(into_js_error))?; Ok(buffer) @@ -203,7 +204,9 @@ macro_rules! create_wrapper_type_non_native_type ( #[wasm_bindgen] pub fn safe_deserialize(buffer: &[u8], serialized_size_limit: u64) -> Result<$type_name, JsError> { catch_panic_result(|| { - crate::safe_deserialization::safe_deserialize(buffer, serialized_size_limit) + crate::safe_serialization::DeserializationConfig::new(serialized_size_limit) + .disable_conformance() + .deserialize_from(buffer) .map($type_name) .map_err(into_js_error) }) @@ -255,7 +258,8 @@ macro_rules! create_wrapper_type_non_native_type ( #[wasm_bindgen] pub fn safe_serialize(&self, serialized_size_limit: u64) -> Result, JsError> { let mut buffer = vec![]; - catch_panic_result(|| crate::safe_deserialization::safe_serialize(&self.0, &mut buffer, serialized_size_limit) + catch_panic_result(|| crate::safe_serialization::SerializationConfig::new(serialized_size_limit) + .serialize_into(&self.0, &mut buffer) .map_err(into_js_error))?; Ok(buffer) @@ -264,7 +268,9 @@ macro_rules! create_wrapper_type_non_native_type ( #[wasm_bindgen] pub fn safe_deserialize(buffer: &[u8], serialized_size_limit: u64) -> Result<$compressed_type_name, JsError> { catch_panic_result(|| { - crate::safe_deserialization::safe_deserialize(buffer, serialized_size_limit) + crate::safe_serialization::DeserializationConfig::new(serialized_size_limit) + .disable_conformance() + .deserialize_from(buffer) .map($compressed_type_name) .map_err(into_js_error) }) @@ -432,7 +438,8 @@ macro_rules! create_wrapper_type_that_has_native_type ( #[wasm_bindgen] pub fn safe_serialize(&self, serialized_size_limit: u64) -> Result, JsError> { let mut buffer = vec![]; - catch_panic_result(|| crate::safe_deserialization::safe_serialize(&self.0, &mut buffer, serialized_size_limit) + catch_panic_result(|| crate::safe_serialization::SerializationConfig::new(serialized_size_limit) + .serialize_into(&self.0, &mut buffer) .map_err(into_js_error))?; Ok(buffer) @@ -441,7 +448,9 @@ macro_rules! create_wrapper_type_that_has_native_type ( #[wasm_bindgen] pub fn safe_deserialize(buffer: &[u8], serialized_size_limit: u64) -> Result<$type_name, JsError> { catch_panic_result(|| { - crate::safe_deserialization::safe_deserialize(buffer, serialized_size_limit) + crate::safe_serialization::DeserializationConfig::new(serialized_size_limit) + .disable_conformance() + .deserialize_from(buffer) .map(Self) .map_err(into_js_error) }) @@ -490,7 +499,8 @@ macro_rules! create_wrapper_type_that_has_native_type ( #[wasm_bindgen] pub fn safe_serialize(&self, serialized_size_limit: u64) -> Result, JsError> { let mut buffer = vec![]; - catch_panic_result(|| crate::safe_deserialization::safe_serialize(&self.0, &mut buffer, serialized_size_limit) + catch_panic_result(|| crate::safe_serialization::SerializationConfig::new(serialized_size_limit) + .serialize_into(&self.0, &mut buffer) .map_err(into_js_error))?; Ok(buffer) @@ -499,7 +509,9 @@ macro_rules! create_wrapper_type_that_has_native_type ( #[wasm_bindgen] pub fn safe_deserialize(buffer: &[u8], serialized_size_limit: u64) -> Result<$compressed_type_name, JsError> { catch_panic_result(|| { - crate::safe_deserialization::safe_deserialize(buffer, serialized_size_limit) + crate::safe_serialization::DeserializationConfig::new(serialized_size_limit) + .disable_conformance() + .deserialize_from(buffer) .map($compressed_type_name) .map_err(into_js_error) }) @@ -728,7 +740,8 @@ impl CompactCiphertextList { pub fn safe_serialize(&self, serialized_size_limit: u64) -> Result, JsError> { let mut buffer = vec![]; catch_panic_result(|| { - crate::safe_deserialization::safe_serialize(&self.0, &mut buffer, serialized_size_limit) + crate::safe_serialization::SerializationConfig::new(serialized_size_limit) + .serialize_into(&self.0, &mut buffer) .map_err(into_js_error) })?; @@ -741,7 +754,9 @@ impl CompactCiphertextList { serialized_size_limit: u64, ) -> Result { catch_panic_result(|| { - crate::safe_deserialization::safe_deserialize(buffer, serialized_size_limit) + crate::safe_serialization::DeserializationConfig::new(serialized_size_limit) + .disable_conformance() + .deserialize_from(buffer) .map(CompactCiphertextList) .map_err(into_js_error) }) @@ -821,7 +836,8 @@ impl ProvenCompactCiphertextList { pub fn safe_serialize(&self, serialized_size_limit: u64) -> Result, JsError> { let mut buffer = vec![]; catch_panic_result(|| { - crate::safe_deserialization::safe_serialize(&self.0, &mut buffer, serialized_size_limit) + crate::safe_serialization::SerializationConfig::new(serialized_size_limit) + .serialize_into(&self.0, &mut buffer) .map_err(into_js_error) })?; @@ -834,7 +850,9 @@ impl ProvenCompactCiphertextList { serialized_size_limit: u64, ) -> Result { catch_panic_result(|| { - crate::safe_deserialization::safe_deserialize(buffer, serialized_size_limit) + crate::safe_serialization::DeserializationConfig::new(serialized_size_limit) + .disable_conformance() + .deserialize_from(buffer) .map(ProvenCompactCiphertextList) .map_err(into_js_error) }) diff --git a/tfhe/src/js_on_wasm_api/js_high_level_api/keys.rs b/tfhe/src/js_on_wasm_api/js_high_level_api/keys.rs index 80e90ef03f..4c1f661fde 100644 --- a/tfhe/src/js_on_wasm_api/js_high_level_api/keys.rs +++ b/tfhe/src/js_on_wasm_api/js_high_level_api/keys.rs @@ -44,6 +44,32 @@ impl TfheClientKey { .map_err(into_js_error) }) } + + #[wasm_bindgen] + pub fn safe_serialize(&self, serialized_size_limit: u64) -> Result, JsError> { + let mut buffer = vec![]; + catch_panic_result(|| { + crate::safe_serialization::SerializationConfig::new(serialized_size_limit) + .serialize_into(&self.0, &mut buffer) + .map_err(into_js_error) + })?; + + Ok(buffer) + } + + #[wasm_bindgen] + pub fn safe_deserialize( + buffer: &[u8], + serialized_size_limit: u64, + ) -> Result { + catch_panic_result(|| { + crate::safe_serialization::DeserializationConfig::new(serialized_size_limit) + .disable_conformance() + .deserialize_from(buffer) + .map(Self) + .map_err(into_js_error) + }) + } } // Wasm cannot generate a normal server key, only a compressed one @@ -70,6 +96,32 @@ impl TfheCompressedServerKey { .map_err(into_js_error) }) } + + #[wasm_bindgen] + pub fn safe_serialize(&self, serialized_size_limit: u64) -> Result, JsError> { + let mut buffer = vec![]; + catch_panic_result(|| { + crate::safe_serialization::SerializationConfig::new(serialized_size_limit) + .serialize_into(&self.0, &mut buffer) + .map_err(into_js_error) + })?; + + Ok(buffer) + } + + #[wasm_bindgen] + pub fn safe_deserialize( + buffer: &[u8], + serialized_size_limit: u64, + ) -> Result { + catch_panic_result(|| { + crate::safe_serialization::DeserializationConfig::new(serialized_size_limit) + .disable_conformance() + .deserialize_from(buffer) + .map(Self) + .map_err(into_js_error) + }) + } } #[wasm_bindgen] @@ -123,6 +175,32 @@ impl TfhePublicKey { .map_err(into_js_error) }) } + + #[wasm_bindgen] + pub fn safe_serialize(&self, serialized_size_limit: u64) -> Result, JsError> { + let mut buffer = vec![]; + catch_panic_result(|| { + crate::safe_serialization::SerializationConfig::new(serialized_size_limit) + .serialize_into(&self.0, &mut buffer) + .map_err(into_js_error) + })?; + + Ok(buffer) + } + + #[wasm_bindgen] + pub fn safe_deserialize( + buffer: &[u8], + serialized_size_limit: u64, + ) -> Result { + catch_panic_result(|| { + crate::safe_serialization::DeserializationConfig::new(serialized_size_limit) + .disable_conformance() + .deserialize_from(buffer) + .map(Self) + .map_err(into_js_error) + }) + } } #[wasm_bindgen] @@ -153,6 +231,32 @@ impl TfheCompressedPublicKey { .map_err(into_js_error) }) } + + #[wasm_bindgen] + pub fn safe_serialize(&self, serialized_size_limit: u64) -> Result, JsError> { + let mut buffer = vec![]; + catch_panic_result(|| { + crate::safe_serialization::SerializationConfig::new(serialized_size_limit) + .serialize_into(&self.0, &mut buffer) + .map_err(into_js_error) + })?; + + Ok(buffer) + } + + #[wasm_bindgen] + pub fn safe_deserialize( + buffer: &[u8], + serialized_size_limit: u64, + ) -> Result { + catch_panic_result(|| { + crate::safe_serialization::DeserializationConfig::new(serialized_size_limit) + .disable_conformance() + .deserialize_from(buffer) + .map(Self) + .map_err(into_js_error) + }) + } } #[wasm_bindgen] @@ -178,6 +282,32 @@ impl TfheCompactPublicKey { .map_err(into_js_error) }) } + + #[wasm_bindgen] + pub fn safe_serialize(&self, serialized_size_limit: u64) -> Result, JsError> { + let mut buffer = vec![]; + catch_panic_result(|| { + crate::safe_serialization::SerializationConfig::new(serialized_size_limit) + .serialize_into(&self.0, &mut buffer) + .map_err(into_js_error) + })?; + + Ok(buffer) + } + + #[wasm_bindgen] + pub fn safe_deserialize( + buffer: &[u8], + serialized_size_limit: u64, + ) -> Result { + catch_panic_result(|| { + crate::safe_serialization::DeserializationConfig::new(serialized_size_limit) + .disable_conformance() + .deserialize_from(buffer) + .map(Self) + .map_err(into_js_error) + }) + } } #[wasm_bindgen] @@ -208,4 +338,30 @@ impl TfheCompressedCompactPublicKey { pub fn decompress(&self) -> Result { catch_panic(|| TfheCompactPublicKey(self.0.decompress())) } + + #[wasm_bindgen] + pub fn safe_serialize(&self, serialized_size_limit: u64) -> Result, JsError> { + let mut buffer = vec![]; + catch_panic_result(|| { + crate::safe_serialization::SerializationConfig::new(serialized_size_limit) + .serialize_into(&self.0, &mut buffer) + .map_err(into_js_error) + })?; + + Ok(buffer) + } + + #[wasm_bindgen] + pub fn safe_deserialize( + buffer: &[u8], + serialized_size_limit: u64, + ) -> Result { + catch_panic_result(|| { + crate::safe_serialization::DeserializationConfig::new(serialized_size_limit) + .disable_conformance() + .deserialize_from(buffer) + .map(Self) + .map_err(into_js_error) + }) + } } diff --git a/tfhe/src/lib.rs b/tfhe/src/lib.rs index 2d346e7d28..ee4cbfd08a 100644 --- a/tfhe/src/lib.rs +++ b/tfhe/src/lib.rs @@ -129,7 +129,7 @@ pub use high_level_api::*; /// cbindgen:ignore pub mod keycache; -pub mod safe_deserialization; +pub mod safe_serialization; pub mod conformance; diff --git a/tfhe/src/safe_deserialization.rs b/tfhe/src/safe_deserialization.rs deleted file mode 100644 index ad7ceb21e4..0000000000 --- a/tfhe/src/safe_deserialization.rs +++ /dev/null @@ -1,513 +0,0 @@ -use std::borrow::Cow; - -use crate::conformance::ParameterSetConformant; -use crate::named::Named; -use bincode::Options; -use serde::de::DeserializeOwned; -use serde::{Deserialize, Serialize}; -use tfhe_versionable::{Unversionize, Versionize}; - -// The `SERIALIZATION_VERSION` is serialized along objects serialized with `safe_serialize`. -// This `SERIALIZATION_VERSION` should be changed on each release where any object serialization -// details changes (this can happen when adding a field or reorderging fields of a struct). -// When a object is deserialized using `safe_deserialize`, the deserialized version is checked -// to be equal to SERIALIZATION_VERSION. -// This helps prevent users from inadvertently deserializaing an object serialized in another -// release. -// When this happens, it also gives a clear version mismatch error rather than a generic -// deserialization error or worse, a garbage object. -const SERIALIZATION_VERSION: &str = "0.4"; - -/// Tells if this serialized object is versioned or not -#[derive(Serialize, Deserialize, PartialEq, Eq)] -// This type should not be versioned because it is part of a wrapper of versioned messages. -#[cfg_attr(tfhe_lints, allow(tfhe_lints::serialize_without_versionize))] -enum SerializationMode { - /// Serialize with type versioning for backward compatibility - Versioned, - /// Directly serialize the type as it is provided - Direct, -} - -/// Header with global metadata about the serialized object. -#[derive(Serialize, Deserialize)] -// This type should not be versioned because it is part of a wrapper of versioned messages. -#[cfg_attr(tfhe_lints, allow(tfhe_lints::serialize_without_versionize))] -struct SerializationHeader { - mode: SerializationMode, - version: Cow<'static, str>, - name: Cow<'static, str>, -} - -impl SerializationHeader { - /// Creates a new header for a versioned message - fn new_versioned() -> Self { - Self { - mode: SerializationMode::Versioned, - version: Cow::Borrowed(VERSIONING_VERSION), - name: Cow::Borrowed(T::NAME), - } - } - - /// Checks the validity of a versioned message - fn check_versioned(&self) -> Result<(), String> { - if self.mode != SerializationMode::Versioned { - return Err( - "On deserialization, expected versioned type but got unversioned one".to_string(), - ); - } - - // Since there is only one "VERSIONING_VERSION", a message with a different value than the - // expected one is clearly invalid, so we return an error. In the future, we want to - // be able to upgrade it to the new versioning scheme. - if self.version != VERSIONING_VERSION { - return Err(format!( - "On deserialization, expected versioning scheme version {VERSIONING_VERSION}, \ - got version {}", - self.version - )); - } - - if self.name != T::NAME { - return Err(format!( - "On deserialization, expected type {}, got type {}", - T::NAME, - self.name - )); - } - - Ok(()) - } -} - -// This is the version of the versioning scheme used to add backward compatibibility on tfhe-rs -// types. Similar to SERIALIZATION_VERSION, this number should be increased when the versioning -// scheme is upgraded. -const VERSIONING_VERSION: &str = "0.1"; - -// `VERSION_LENGTH_LIMIT` is the maximum `SERIALIZATION_VERSION` size which `safe_deserialization` -// is going to try to read (it returns an error if it's too big). -// It helps prevent an attacker passing a very long `SERIALIZATION_VERSION` to exhaust memory. -const VERSION_LENGTH_LIMIT: u64 = 100; - -const TYPE_NAME_LENGTH_LIMIT: u64 = 1000; - -const HEADER_LENGTH_LIMIT: u64 = 1000; - -/// Serializes an object into a [writer](std::io::Write). -/// The result contains a version of the serialization and the name of the -/// serialized type to provide checks on deserialization with [safe_deserialize]. -/// `serialized_size_limit` is the size limit (in number of byte) of the serialized object -/// (excluding version and name serialization). -pub fn safe_serialize( - object: &T, - mut writer: impl std::io::Write, - serialized_size_limit: u64, -) -> bincode::Result<()> { - let options = bincode::DefaultOptions::new() - .with_fixint_encoding() - .with_limit(0); - - options - .with_limit(VERSION_LENGTH_LIMIT) - .serialize_into::<_, String>(&mut writer, &SERIALIZATION_VERSION.to_owned())?; - - options - .with_limit(TYPE_NAME_LENGTH_LIMIT) - .serialize_into::<_, String>(&mut writer, &T::NAME.to_owned())?; - - options - .with_limit(serialized_size_limit) - .serialize_into(&mut writer, object)?; - - Ok(()) -} - -/// Serializes an object into a [writer](std::io::Write) like [`safe_serialize`] does, -/// but adds versioning information before. -pub fn safe_serialize_versioned( - object: &T, - mut writer: impl std::io::Write, - serialized_size_limit: u64, -) -> bincode::Result<()> { - let options = bincode::DefaultOptions::new() - .with_fixint_encoding() - .with_limit(0); - - let header = SerializationHeader::new_versioned::(); - options - .with_limit(HEADER_LENGTH_LIMIT) - .serialize_into(&mut writer, &header)?; - - options - .with_limit(serialized_size_limit) - .serialize_into(&mut writer, &object.versionize())?; - - Ok(()) -} - -/// Deserializes an object serialized by `safe_serialize` from a [reader](std::io::Read). -/// Checks that the serialization version and the name of the -/// deserialized type are correct. -/// `serialized_size_limit` is the size limit (in number of byte) of the serialized object -/// (excluding version and name serialization). -pub fn safe_deserialize( - mut reader: impl std::io::Read, - serialized_size_limit: u64, -) -> Result { - let options = bincode::DefaultOptions::new() - .with_fixint_encoding() - .with_limit(0); - - let deserialized_version: String = options - .with_limit(VERSION_LENGTH_LIMIT) - .deserialize_from::<_, String>(&mut reader) - .map_err(|err| err.to_string())?; - - if deserialized_version != SERIALIZATION_VERSION { - return Err(format!( - "On deserialization, expected serialization version {SERIALIZATION_VERSION}, got version {deserialized_version}" - )); - } - - let deserialized_type: String = options - .with_limit(TYPE_NAME_LENGTH_LIMIT) - .deserialize_from::<_, String>(&mut reader) - .map_err(|err| err.to_string())?; - - if deserialized_type != T::NAME { - return Err(format!( - "On deserialization, expected type {}, got type {}", - T::NAME, - deserialized_type - )); - } - - options - .with_limit(serialized_size_limit) - .deserialize_from(&mut reader) - .map_err(|err| err.to_string()) -} - -/// Deserializes an object with [safe_deserialize] and checks than it is conformant with the given -/// parameter set -pub fn safe_deserialize_conformant( - reader: impl std::io::Read, - serialized_size_limit: u64, - parameter_set: &T::ParameterSet, -) -> Result { - let deser: T = safe_deserialize(reader, serialized_size_limit)?; - - if !deser.is_conformant(parameter_set) { - return Err(format!( - "Deserialized object of type {} not conformant with given parameter set", - T::NAME - )); - } - - Ok(deser) -} - -/// Deserializes an object serialized by `safe_serialize_versioned` from a [reader](std::io::Read). -/// Checks that the serialization version and the name of the -/// deserialized type are correct. -/// `serialized_size_limit` is the size limit (in number of byte) of the serialized object -/// (excluding version and name serialization). -pub fn safe_deserialize_versioned( - mut reader: impl std::io::Read, - serialized_size_limit: u64, -) -> Result { - let options = bincode::DefaultOptions::new() - .with_fixint_encoding() - .with_limit(0); - - let deserialized_header: SerializationHeader = options - .with_limit(HEADER_LENGTH_LIMIT) - .deserialize_from(&mut reader) - .map_err(|err| err.to_string())?; - - deserialized_header.check_versioned::()?; - - options - .with_limit(serialized_size_limit) - .deserialize_from(&mut reader) - .map_err(|err| err.to_string()) - .and_then(|val| T::unversionize(val).map_err(|err| err.to_string())) -} - -/// Deserializes an object with [safe_deserialize] and checks than it is conformant with the given -/// parameter set -pub fn safe_deserialize_conformant_versioned( - reader: impl std::io::Read, - serialized_size_limit: u64, - parameter_set: &T::ParameterSet, -) -> Result { - let deser: T = safe_deserialize_versioned(reader, serialized_size_limit)?; - - if !deser.is_conformant(parameter_set) { - return Err(format!( - "Deserialized object of type {} not conformant with given parameter set", - T::NAME - )); - } - - Ok(deser) -} - -#[cfg(all(test, feature = "shortint"))] -mod test_shortint { - use crate::safe_deserialization::{ - safe_deserialize_conformant, safe_deserialize_conformant_versioned, safe_serialize, - safe_serialize_versioned, - }; - use crate::shortint::parameters::{ - PARAM_MESSAGE_2_CARRY_2_KS_PBS, PARAM_MESSAGE_3_CARRY_3_KS_PBS, - }; - use crate::shortint::{gen_keys, Ciphertext}; - - #[test] - fn safe_deserialization_ct() { - let (ck, _sk) = gen_keys(PARAM_MESSAGE_2_CARRY_2_KS_PBS); - - let msg = 2_u64; - - let ct = ck.encrypt(msg); - - let mut buffer = vec![]; - - safe_serialize(&ct, &mut buffer, 1 << 40).unwrap(); - - assert!(safe_deserialize_conformant::( - buffer.as_slice(), - 1 << 20, - &PARAM_MESSAGE_3_CARRY_3_KS_PBS.to_shortint_conformance_param(), - ) - .is_err()); - - let ct2 = safe_deserialize_conformant( - buffer.as_slice(), - 1 << 20, - &PARAM_MESSAGE_2_CARRY_2_KS_PBS.to_shortint_conformance_param(), - ) - .unwrap(); - - let dec = ck.decrypt(&ct2); - assert_eq!(msg, dec); - } - - #[test] - fn safe_deserialization_ct_versioned() { - let (ck, _sk) = gen_keys(PARAM_MESSAGE_2_CARRY_2_KS_PBS); - - let msg = 2_u64; - - let ct = ck.encrypt(msg); - - let mut buffer = vec![]; - - safe_serialize_versioned(&ct, &mut buffer, 1 << 40).unwrap(); - - assert!(safe_deserialize_conformant_versioned::( - buffer.as_slice(), - 1 << 20, - &PARAM_MESSAGE_3_CARRY_3_KS_PBS.to_shortint_conformance_param(), - ) - .is_err()); - - let ct2 = safe_deserialize_conformant_versioned( - buffer.as_slice(), - 1 << 20, - &PARAM_MESSAGE_2_CARRY_2_KS_PBS.to_shortint_conformance_param(), - ) - .unwrap(); - - let dec = ck.decrypt(&ct2); - assert_eq!(msg, dec); - } -} - -#[cfg(all(test, feature = "integer"))] -mod test_integer { - use crate::conformance::ListSizeConstraint; - use crate::high_level_api::{generate_keys, ConfigBuilder}; - use crate::prelude::*; - use crate::safe_deserialization::{ - safe_deserialize_conformant, safe_deserialize_conformant_versioned, safe_serialize, - safe_serialize_versioned, - }; - use crate::shortint::parameters::{ - PARAM_MESSAGE_2_CARRY_2_KS_PBS, PARAM_MESSAGE_3_CARRY_3_KS_PBS, - }; - use crate::{ - set_server_key, CompactCiphertextList, CompactCiphertextListConformanceParams, - CompactPublicKey, FheUint8, - }; - - #[test] - fn safe_deserialization_ct_list() { - let (client_key, sks) = generate_keys(ConfigBuilder::default().build()); - set_server_key(sks); - - let public_key = CompactPublicKey::new(&client_key); - - let msg = [27u8, 10, 3]; - - let ct_list = CompactCiphertextList::builder(&public_key) - .push(27u8) - .push(10u8) - .push(3u8) - .build(); - - let mut buffer = vec![]; - - safe_serialize(&ct_list, &mut buffer, 1 << 40).unwrap(); - - let to_param_set = |list_size_constraint| CompactCiphertextListConformanceParams { - shortint_params: PARAM_MESSAGE_2_CARRY_2_KS_PBS.to_shortint_conformance_param(), - num_elements_constraint: list_size_constraint, - }; - - for param_set in [ - CompactCiphertextListConformanceParams { - shortint_params: PARAM_MESSAGE_3_CARRY_3_KS_PBS.to_shortint_conformance_param(), - num_elements_constraint: ListSizeConstraint::exact_size(3), - }, - to_param_set(ListSizeConstraint::exact_size(2)), - to_param_set(ListSizeConstraint::exact_size(4)), - to_param_set(ListSizeConstraint::try_size_in_range(1, 2).unwrap()), - to_param_set(ListSizeConstraint::try_size_in_range(4, 5).unwrap()), - ] { - assert!(safe_deserialize_conformant::( - buffer.as_slice(), - 1 << 20, - ¶m_set - ) - .is_err()); - } - - for len_constraint in [ - ListSizeConstraint::exact_size(3), - ListSizeConstraint::try_size_in_range(2, 3).unwrap(), - ListSizeConstraint::try_size_in_range(3, 4).unwrap(), - ListSizeConstraint::try_size_in_range(2, 4).unwrap(), - ] { - let params = CompactCiphertextListConformanceParams { - shortint_params: PARAM_MESSAGE_2_CARRY_2_KS_PBS.to_shortint_conformance_param(), - num_elements_constraint: len_constraint, - }; - assert!(safe_deserialize_conformant::( - buffer.as_slice(), - 1 << 20, - ¶ms, - ) - .is_ok()); - } - - let params = CompactCiphertextListConformanceParams { - shortint_params: PARAM_MESSAGE_2_CARRY_2_KS_PBS.to_shortint_conformance_param(), - num_elements_constraint: ListSizeConstraint::exact_size(3), - }; - let ct2 = safe_deserialize_conformant::( - buffer.as_slice(), - 1 << 20, - ¶ms, - ) - .unwrap(); - - let mut cts = Vec::with_capacity(3); - let expander = ct2.expand().unwrap(); - for i in 0..3 { - cts.push(expander.get::(i).unwrap().unwrap()); - } - - let dec: Vec = cts.iter().map(|a| a.decrypt(&client_key)).collect(); - - assert_eq!(&msg[..], &dec); - } - - #[test] - fn safe_deserialization_ct_list_versioned() { - let (client_key, sks) = generate_keys(ConfigBuilder::default().build()); - set_server_key(sks); - - let public_key = CompactPublicKey::new(&client_key); - - let msg = [27u8, 10, 3]; - - let ct_list = CompactCiphertextList::builder(&public_key) - .push(27u8) - .push(10u8) - .push(3u8) - .build(); - - let mut buffer = vec![]; - - safe_serialize_versioned(&ct_list, &mut buffer, 1 << 40).unwrap(); - - let to_param_set = |list_size_constraint| CompactCiphertextListConformanceParams { - shortint_params: PARAM_MESSAGE_2_CARRY_2_KS_PBS.to_shortint_conformance_param(), - num_elements_constraint: list_size_constraint, - }; - - for param_set in [ - CompactCiphertextListConformanceParams { - shortint_params: PARAM_MESSAGE_3_CARRY_3_KS_PBS.to_shortint_conformance_param(), - num_elements_constraint: ListSizeConstraint::exact_size(3), - }, - to_param_set(ListSizeConstraint::exact_size(2)), - to_param_set(ListSizeConstraint::exact_size(4)), - to_param_set(ListSizeConstraint::try_size_in_range(1, 2).unwrap()), - to_param_set(ListSizeConstraint::try_size_in_range(4, 5).unwrap()), - ] { - assert!( - safe_deserialize_conformant_versioned::( - buffer.as_slice(), - 1 << 20, - ¶m_set - ) - .is_err() - ); - } - - for len_constraint in [ - ListSizeConstraint::exact_size(3), - ListSizeConstraint::try_size_in_range(2, 3).unwrap(), - ListSizeConstraint::try_size_in_range(3, 4).unwrap(), - ListSizeConstraint::try_size_in_range(2, 4).unwrap(), - ] { - let params = CompactCiphertextListConformanceParams { - shortint_params: PARAM_MESSAGE_2_CARRY_2_KS_PBS.to_shortint_conformance_param(), - num_elements_constraint: len_constraint, - }; - assert!( - safe_deserialize_conformant_versioned::( - buffer.as_slice(), - 1 << 20, - ¶ms, - ) - .is_ok() - ); - } - - let params = CompactCiphertextListConformanceParams { - shortint_params: PARAM_MESSAGE_2_CARRY_2_KS_PBS.to_shortint_conformance_param(), - num_elements_constraint: ListSizeConstraint::exact_size(3), - }; - let ct2 = safe_deserialize_conformant_versioned::( - buffer.as_slice(), - 1 << 20, - ¶ms, - ) - .unwrap(); - - let mut cts = Vec::with_capacity(3); - let expander = ct2.expand().unwrap(); - for i in 0..3 { - cts.push(expander.get::(i).unwrap().unwrap()); - } - - let dec: Vec = cts.iter().map(|a| a.decrypt(&client_key)).collect(); - - assert_eq!(&msg[..], &dec); - } -} diff --git a/tfhe/src/safe_serialization.rs b/tfhe/src/safe_serialization.rs new file mode 100644 index 0000000000..80891ed555 --- /dev/null +++ b/tfhe/src/safe_serialization.rs @@ -0,0 +1,662 @@ +//! Serialization utilities with some safety checks + +use std::borrow::Cow; +use std::fmt::Display; + +use crate::conformance::ParameterSetConformant; +use crate::named::Named; +use bincode::Options; +use serde::de::DeserializeOwned; +use serde::{Deserialize, Serialize}; +use tfhe_versionable::{Unversionize, Versionize}; + +/// This is the global version of the serialization scheme that is used. This should be updated when +/// the SerializationHeader is updated. +const SERIALIZATION_VERSION: &str = "0.5"; + +/// This is the version of the versioning scheme used to add backward compatibibility on tfhe-rs +/// types. Similar to SERIALIZATION_VERSION, this number should be increased when the versioning +/// scheme is upgraded. +const VERSIONING_VERSION: &str = "0.1"; + +/// This is the current version of this crate. This is used to be able to reject unversioned data +/// if they come from a previous version. +const CRATE_VERSION: &str = concat!( + env!("CARGO_PKG_VERSION_MAJOR"), + ".", + env!("CARGO_PKG_VERSION_MINOR") +); + +/// Tells if this serialized object is versioned or not +#[derive(Serialize, Deserialize, Clone, PartialEq, Eq)] +// This type should not be versioned because it is part of a wrapper of versioned messages. +#[cfg_attr(tfhe_lints, allow(tfhe_lints::serialize_without_versionize))] +enum SerializationVersioningMode { + /// Serialize with type versioning for backward compatibility + Versioned { + /// Version of the versioning scheme in use + versioning_version: Cow<'static, str>, + }, + /// Serialize the type without versioning information + Unversioned { + /// Version of tfhe-rs where this data was generated + crate_version: Cow<'static, str>, + }, +} + +impl Display for SerializationVersioningMode { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Versioned { .. } => write!(f, "versioned"), + Self::Unversioned { .. } => write!(f, "unversioned"), + } + } +} + +impl SerializationVersioningMode { + fn versioned() -> Self { + Self::Versioned { + versioning_version: Cow::Borrowed(VERSIONING_VERSION), + } + } + + fn unversioned() -> Self { + Self::Unversioned { + crate_version: Cow::Borrowed(CRATE_VERSION), + } + } +} + +/// `HEADER_LENGTH_LIMIT` is the maximum `SerializationHeader` size which +/// `DeserializationConfig::deserialize_from` is going to try to read (it returns an error if +/// it's too big). +/// It helps prevent an attacker passing a very long header to exhaust memory. +const HEADER_LENGTH_LIMIT: u64 = 1000; + +/// Header with global metadata about the serialized object. This help checking that we are not +/// deserializing data that we can't handle. +#[derive(Serialize, Deserialize)] +// This type should not be versioned because it is part of a wrapper of versioned messages. +#[cfg_attr(tfhe_lints, allow(tfhe_lints::serialize_without_versionize))] +struct SerializationHeader { + header_version: Cow<'static, str>, + versioning_mode: SerializationVersioningMode, + name: Cow<'static, str>, +} + +impl SerializationHeader { + /// Creates a new header for a versioned message + fn new_versioned() -> Self { + Self { + header_version: Cow::Borrowed(SERIALIZATION_VERSION), + versioning_mode: SerializationVersioningMode::versioned(), + name: Cow::Borrowed(T::NAME), + } + } + + /// Creates a new header for an unversioned message + fn new_unversioned() -> Self { + Self { + header_version: Cow::Borrowed(SERIALIZATION_VERSION), + versioning_mode: SerializationVersioningMode::unversioned(), + name: Cow::Borrowed(T::NAME), + } + } + + /// Checks the validity of the header + fn validate(&self) -> Result<(), String> { + match &self.versioning_mode { + SerializationVersioningMode::Versioned { versioning_version } => { + // For the moment there is only one versioning scheme, so another value is + // a hard error. But maybe if we upgrade it we will be able to automatically convert + // it. + if versioning_version != VERSIONING_VERSION { + return Err(format!( + "On deserialization, expected versioning scheme version {VERSIONING_VERSION}, \ +got version {versioning_version}" + )); + } + } + SerializationVersioningMode::Unversioned { crate_version } => { + if crate_version != CRATE_VERSION { + return Err(format!( + "This {} has been saved from TFHE-rs v{crate_version}, without versioning informations. \ +Please use the versioned serialization mode for backward compatibility.", + self.name + )); + } + } + } + + if self.name != T::NAME { + return Err(format!( + "On deserialization, expected type {}, got type {}", + T::NAME, + self.name + )); + } + + Ok(()) + } +} + +/// A configuration used to Serialize *TFHE-rs* objects. This configuration decides +/// if the object will be versioned and holds the max byte size of the written data. +#[derive(Clone)] +pub struct SerializationConfig { + versioned: SerializationVersioningMode, + serialized_size_limit: u64, +} + +impl SerializationConfig { + /// Creates a new serialization config. The default configuration will serialize the object + /// with versioning information for backward compatibility. + /// `serialized_size_limit` is the size limit (in number of byte) of the serialized object + /// (excluding the header). + pub fn new(serialized_size_limit: u64) -> Self { + Self { + versioned: SerializationVersioningMode::versioned(), + serialized_size_limit, + } + } + + /// Creates a new serialization config without any size check. + pub fn new_with_unlimited_size() -> Self { + Self { + versioned: SerializationVersioningMode::versioned(), + serialized_size_limit: 0, + } + } + + /// Disables the size limit for serialized objects + pub fn disable_size_limit(self) -> Self { + Self { + serialized_size_limit: 0, + ..self + } + } + + /// Disable the versioning of serialized objects + pub fn disable_versioning(self) -> Self { + Self { + versioned: SerializationVersioningMode::unversioned(), + ..self + } + } + + /// Create a serialization header based on the current config + fn create_header(&self) -> SerializationHeader { + match self.versioned { + SerializationVersioningMode::Versioned { .. } => { + SerializationHeader::new_versioned::() + } + SerializationVersioningMode::Unversioned { .. } => { + SerializationHeader::new_unversioned::() + } + } + } + + /// Returns the max length of the serialized header + fn header_length_limit(&self) -> u64 { + if self.serialized_size_limit == 0 { + 0 + } else { + HEADER_LENGTH_LIMIT + } + } + + /// Serializes an object into a [writer](std::io::Write), based on the current config. + /// The written bytes can be deserialized using [`DeserializationConfig::deserialize_from`]. + pub fn serialize_into( + self, + object: &T, + mut writer: impl std::io::Write, + ) -> bincode::Result<()> { + let options = bincode::DefaultOptions::new() + .with_fixint_encoding() + .with_limit(0); + + let header = self.create_header::(); + options + .with_limit(self.header_length_limit()) + .serialize_into(&mut writer, &header)?; + + match self.versioned { + SerializationVersioningMode::Versioned { .. } => options + .with_limit(self.serialized_size_limit) + .serialize_into(&mut writer, &object.versionize())?, + SerializationVersioningMode::Unversioned { .. } => options + .with_limit(self.serialized_size_limit) + .serialize_into(&mut writer, &object)?, + }; + + Ok(()) + } +} + +/// A configuration used to Serialize *TFHE-rs* objects. This configuration decides +/// the various sanity checks that will be performed during deserialization. +#[derive(Copy, Clone)] +pub struct DeserializationConfig { + serialized_size_limit: u64, + validate_header: bool, +} + +/// A configuration used to Serialize *TFHE-rs* objects. This is similar to +/// [`DeserializationConfig`] but it will not require conformance parameters. +/// +/// This type should be created with [`DeserializationConfig::disable_conformance`] +#[derive(Copy, Clone)] +pub struct NonConformantDeserializationConfig { + serialized_size_limit: u64, + validate_header: bool, +} + +impl NonConformantDeserializationConfig { + /// Deserializes an object serialized by [`SerializationConfig::serialize_into`] from a + /// [reader](std::io::Read). Performs various sanity checks based on the deserialization config, + /// but skips conformance checks. + pub fn deserialize_from( + self, + mut reader: impl std::io::Read, + ) -> Result { + if self.serialized_size_limit != 0 && self.serialized_size_limit <= HEADER_LENGTH_LIMIT { + return Err(format!( + "The provided size limit is too small, provide a limit of at least \ +{HEADER_LENGTH_LIMIT} bytes" + )); + } + + let options = bincode::DefaultOptions::new() + .with_fixint_encoding() + .with_limit(0); + + let deserialized_header: SerializationHeader = options + .with_limit(self.header_length_limit()) + .deserialize_from(&mut reader) + .map_err(|err| err.to_string())?; + + if self.validate_header { + deserialized_header.validate::()?; + } + + match deserialized_header.versioning_mode { + SerializationVersioningMode::Versioned { .. } => { + let deser_versioned = options + .with_limit(self.serialized_size_limit - self.header_length_limit()) + .deserialize_from(&mut reader) + .map_err(|err| err.to_string())?; + + T::unversionize(deser_versioned).map_err(|e| e.to_string()) + } + SerializationVersioningMode::Unversioned { .. } => options + .with_limit(self.serialized_size_limit - self.header_length_limit()) + .deserialize_from(&mut reader) + .map_err(|err| err.to_string()), + } + } + + /// Enables the conformance check on an existing config. + pub fn enable_conformance(self) -> DeserializationConfig { + DeserializationConfig { + serialized_size_limit: self.serialized_size_limit, + validate_header: self.validate_header, + } + } + + fn header_length_limit(&self) -> u64 { + if self.serialized_size_limit == 0 { + 0 + } else { + HEADER_LENGTH_LIMIT + } + } +} + +impl DeserializationConfig { + /// Creates a new deserialization config. + /// + /// By default, it will check that the serialization version and the name of the + /// deserialized type are correct. + /// `serialized_size_limit` is the size limit (in number of byte) of the serialized object + /// (excluding version and name serialization). + /// + /// It will also check that the object is conformant with the parameter set given in + /// `conformance_params`. Finally, it will check the compatibility of the loaded data with + /// the current *TFHE-rs* version. + pub fn new(serialized_size_limit: u64) -> Self { + Self { + serialized_size_limit, + validate_header: true, + } + } + + /// Creates a new config without any size limit for the deserialized objects. + pub fn new_with_unlimited_size() -> Self { + Self { + serialized_size_limit: 0, + validate_header: true, + } + } + + /// Disables the size limit for the serialized objects. + pub fn disable_size_limit(self) -> Self { + Self { + serialized_size_limit: 0, + ..self + } + } + + /// Disables the header validation on the object. This header validations + /// checks that the serialized object is the one that is supposed to be loaded + /// and is compatible with this version of *TFHE-rs*. + pub fn disable_header_validation(self) -> Self { + Self { + validate_header: false, + ..self + } + } + + /// Disables the conformance check on an existing config. + pub fn disable_conformance(self) -> NonConformantDeserializationConfig { + NonConformantDeserializationConfig { + serialized_size_limit: self.serialized_size_limit, + validate_header: self.validate_header, + } + } + + /// Deserializes an object serialized by [`SerializationConfig::serialize_into`] from a + /// [reader](std::io::Read). Performs various sanity checks based on the deserialization config. + pub fn deserialize_from( + self, + reader: impl std::io::Read, + parameter_set: &T::ParameterSet, + ) -> Result { + let deser: T = self.disable_conformance().deserialize_from(reader)?; + if !deser.is_conformant(parameter_set) { + return Err(format!( + "Deserialized object of type {} not conformant with given parameter set", + T::NAME + )); + } + + Ok(deser) + } +} + +/// Serialize an object with the default configuration (with size limit and versioning). +/// This is an alias for `SerializationConfig::new(serialized_size_limit).serialize_into` +pub fn safe_serialize( + object: &T, + writer: impl std::io::Write, + serialized_size_limit: u64, +) -> bincode::Result<()> { + SerializationConfig::new(serialized_size_limit).serialize_into(object, writer) +} + +/// Serialize an object with the default configuration (with size limit, header check and +/// versioning). This is an alias for +/// `DeserializationConfig::new(serialized_size_limit).disable_conformance().deserialize_from` +pub fn safe_deserialize( + reader: impl std::io::Read, + serialized_size_limit: u64, +) -> Result { + DeserializationConfig::new(serialized_size_limit) + .disable_conformance() + .deserialize_from(reader) +} + +/// Serialize an object with the default configuration and conformance checks (with size limit, +/// header check and versioning). This is an alias for +/// `DeserializationConfig::new(serialized_size_limit).deserialize_from` +pub fn safe_deserialize_conformant< + T: DeserializeOwned + Unversionize + Named + ParameterSetConformant, +>( + reader: impl std::io::Read, + serialized_size_limit: u64, + parameter_set: &T::ParameterSet, +) -> Result { + DeserializationConfig::new(serialized_size_limit).deserialize_from(reader, parameter_set) +} + +#[cfg(all(test, feature = "shortint"))] +mod test_shortint { + use crate::safe_serialization::{DeserializationConfig, SerializationConfig}; + use crate::shortint::parameters::{ + PARAM_MESSAGE_2_CARRY_2_KS_PBS, PARAM_MESSAGE_3_CARRY_3_KS_PBS, + }; + use crate::shortint::{gen_keys, Ciphertext}; + + #[test] + fn safe_deserialization_ct() { + let (ck, _sk) = gen_keys(PARAM_MESSAGE_2_CARRY_2_KS_PBS); + + let msg = 2_u64; + + let ct = ck.encrypt(msg); + + let mut buffer = vec![]; + + SerializationConfig::new(1 << 20) + .disable_versioning() + .serialize_into(&ct, &mut buffer) + .unwrap(); + + assert!(DeserializationConfig::new(1 << 20) + .deserialize_from::( + buffer.as_slice(), + &PARAM_MESSAGE_3_CARRY_3_KS_PBS.to_shortint_conformance_param() + ) + .is_err()); + + let ct2 = DeserializationConfig::new(1 << 20) + .deserialize_from::( + buffer.as_slice(), + &PARAM_MESSAGE_2_CARRY_2_KS_PBS.to_shortint_conformance_param(), + ) + .unwrap(); + + let dec = ck.decrypt(&ct2); + assert_eq!(msg, dec); + } + + #[test] + fn safe_deserialization_ct_versioned() { + let (ck, _sk) = gen_keys(PARAM_MESSAGE_2_CARRY_2_KS_PBS); + + let msg = 2_u64; + + let ct = ck.encrypt(msg); + + let mut buffer = vec![]; + + SerializationConfig::new(1 << 20) + .serialize_into(&ct, &mut buffer) + .unwrap(); + + assert!(DeserializationConfig::new(1 << 20,) + .deserialize_from::( + buffer.as_slice(), + &PARAM_MESSAGE_3_CARRY_3_KS_PBS.to_shortint_conformance_param() + ) + .is_err()); + + let ct2 = DeserializationConfig::new(1 << 20) + .deserialize_from::( + buffer.as_slice(), + &PARAM_MESSAGE_2_CARRY_2_KS_PBS.to_shortint_conformance_param(), + ) + .unwrap(); + + let dec = ck.decrypt(&ct2); + assert_eq!(msg, dec); + } +} + +#[cfg(all(test, feature = "integer"))] +mod test_integer { + use crate::conformance::ListSizeConstraint; + use crate::high_level_api::{generate_keys, ConfigBuilder}; + use crate::prelude::*; + use crate::safe_serialization::{DeserializationConfig, SerializationConfig}; + use crate::shortint::parameters::{ + PARAM_MESSAGE_2_CARRY_2_KS_PBS, PARAM_MESSAGE_3_CARRY_3_KS_PBS, + }; + use crate::{ + set_server_key, CompactCiphertextList, CompactCiphertextListConformanceParams, + CompactPublicKey, FheUint8, + }; + + #[test] + fn safe_deserialization_ct_list() { + let (client_key, sks) = generate_keys(ConfigBuilder::default().build()); + set_server_key(sks); + + let public_key = CompactPublicKey::new(&client_key); + + let msg = [27u8, 10, 3]; + + let ct_list = CompactCiphertextList::builder(&public_key) + .push(27u8) + .push(10u8) + .push(3u8) + .build(); + + let mut buffer = vec![]; + + SerializationConfig::new(1 << 20) + .disable_versioning() + .serialize_into(&ct_list, &mut buffer) + .unwrap(); + + let to_param_set = |list_size_constraint| CompactCiphertextListConformanceParams { + shortint_params: PARAM_MESSAGE_2_CARRY_2_KS_PBS.to_shortint_conformance_param(), + num_elements_constraint: list_size_constraint, + }; + + for param_set in [ + CompactCiphertextListConformanceParams { + shortint_params: PARAM_MESSAGE_3_CARRY_3_KS_PBS.to_shortint_conformance_param(), + num_elements_constraint: ListSizeConstraint::exact_size(3), + }, + to_param_set(ListSizeConstraint::exact_size(2)), + to_param_set(ListSizeConstraint::exact_size(4)), + to_param_set(ListSizeConstraint::try_size_in_range(1, 2).unwrap()), + to_param_set(ListSizeConstraint::try_size_in_range(4, 5).unwrap()), + ] { + assert!(DeserializationConfig::new(1 << 20) + .deserialize_from::(buffer.as_slice(), ¶m_set) + .is_err()); + } + + for len_constraint in [ + ListSizeConstraint::exact_size(3), + ListSizeConstraint::try_size_in_range(2, 3).unwrap(), + ListSizeConstraint::try_size_in_range(3, 4).unwrap(), + ListSizeConstraint::try_size_in_range(2, 4).unwrap(), + ] { + let params = CompactCiphertextListConformanceParams { + shortint_params: PARAM_MESSAGE_2_CARRY_2_KS_PBS.to_shortint_conformance_param(), + num_elements_constraint: len_constraint, + }; + + DeserializationConfig::new(1 << 20) + .deserialize_from::(buffer.as_slice(), ¶ms) + .unwrap(); + } + + let params = CompactCiphertextListConformanceParams { + shortint_params: PARAM_MESSAGE_2_CARRY_2_KS_PBS.to_shortint_conformance_param(), + num_elements_constraint: ListSizeConstraint::exact_size(3), + }; + let ct2 = DeserializationConfig::new(1 << 20) + .deserialize_from::(buffer.as_slice(), ¶ms) + .unwrap(); + + let mut cts = Vec::with_capacity(3); + let expander = ct2.expand().unwrap(); + for i in 0..3 { + cts.push(expander.get::(i).unwrap().unwrap()); + } + + let dec: Vec = cts.iter().map(|a| a.decrypt(&client_key)).collect(); + + assert_eq!(&msg[..], &dec); + } + + #[test] + fn safe_deserialization_ct_list_versioned() { + let (client_key, sks) = generate_keys(ConfigBuilder::default().build()); + set_server_key(sks); + + let public_key = CompactPublicKey::new(&client_key); + + let msg = [27u8, 10, 3]; + + let ct_list = CompactCiphertextList::builder(&public_key) + .push(27u8) + .push(10u8) + .push(3u8) + .build(); + + let mut buffer = vec![]; + + SerializationConfig::new(1 << 20) + .serialize_into(&ct_list, &mut buffer) + .unwrap(); + + let to_param_set = |list_size_constraint| CompactCiphertextListConformanceParams { + shortint_params: PARAM_MESSAGE_2_CARRY_2_KS_PBS.to_shortint_conformance_param(), + num_elements_constraint: list_size_constraint, + }; + + for param_set in [ + CompactCiphertextListConformanceParams { + shortint_params: PARAM_MESSAGE_3_CARRY_3_KS_PBS.to_shortint_conformance_param(), + num_elements_constraint: ListSizeConstraint::exact_size(3), + }, + to_param_set(ListSizeConstraint::exact_size(2)), + to_param_set(ListSizeConstraint::exact_size(4)), + to_param_set(ListSizeConstraint::try_size_in_range(1, 2).unwrap()), + to_param_set(ListSizeConstraint::try_size_in_range(4, 5).unwrap()), + ] { + assert!(DeserializationConfig::new(1 << 20) + .deserialize_from::(buffer.as_slice(), ¶m_set) + .is_err()); + } + + for len_constraint in [ + ListSizeConstraint::exact_size(3), + ListSizeConstraint::try_size_in_range(2, 3).unwrap(), + ListSizeConstraint::try_size_in_range(3, 4).unwrap(), + ListSizeConstraint::try_size_in_range(2, 4).unwrap(), + ] { + let params = CompactCiphertextListConformanceParams { + shortint_params: PARAM_MESSAGE_2_CARRY_2_KS_PBS.to_shortint_conformance_param(), + num_elements_constraint: len_constraint, + }; + + DeserializationConfig::new(1 << 20) + .deserialize_from::(buffer.as_slice(), ¶ms) + .unwrap(); + } + + let params = CompactCiphertextListConformanceParams { + shortint_params: PARAM_MESSAGE_2_CARRY_2_KS_PBS.to_shortint_conformance_param(), + num_elements_constraint: ListSizeConstraint::exact_size(3), + }; + let ct2 = DeserializationConfig::new(1 << 20) + .deserialize_from::(buffer.as_slice(), ¶ms) + .unwrap(); + + let mut cts = Vec::with_capacity(3); + let expander = ct2.expand().unwrap(); + for i in 0..3 { + cts.push(expander.get::(i).unwrap().unwrap()); + } + + let dec: Vec = cts.iter().map(|a| a.decrypt(&client_key)).collect(); + + assert_eq!(&msg[..], &dec); + } +}