From 750787380dee26d2228cc3857e9d3f30ff2aafd4 Mon Sep 17 00:00:00 2001 From: Nicolas Sarlin Date: Thu, 17 Oct 2024 17:02:45 +0200 Subject: [PATCH] fix(serialization): safe_serialization with unlimited size --- tfhe/src/safe_serialization.rs | 161 +++++++++++++++++++++++++-------- 1 file changed, 124 insertions(+), 37 deletions(-) diff --git a/tfhe/src/safe_serialization.rs b/tfhe/src/safe_serialization.rs index e62a0a3d74..61e9b0232b 100644 --- a/tfhe/src/safe_serialization.rs +++ b/tfhe/src/safe_serialization.rs @@ -139,7 +139,7 @@ Please use the versioned serialization mode for backward compatibility.", #[derive(Clone)] pub struct SerializationConfig { versioned: SerializationVersioningMode, - serialized_size_limit: u64, + serialized_size_limit: Option, } impl SerializationConfig { @@ -150,7 +150,7 @@ impl SerializationConfig { pub fn new(serialized_size_limit: u64) -> Self { Self { versioned: SerializationVersioningMode::versioned(), - serialized_size_limit, + serialized_size_limit: Some(serialized_size_limit), } } @@ -158,14 +158,14 @@ impl SerializationConfig { pub fn new_with_unlimited_size() -> Self { Self { versioned: SerializationVersioningMode::versioned(), - serialized_size_limit: 0, + serialized_size_limit: None, } } /// Disables the size limit for serialized objects pub fn disable_size_limit(self) -> Self { Self { - serialized_size_limit: 0, + serialized_size_limit: None, ..self } } @@ -178,6 +178,14 @@ impl SerializationConfig { } } + /// Sets the size limit for this serialization config + pub fn with_size_limit(self, size: u64) -> Self { + Self { + serialized_size_limit: Some(size), + ..self + } + } + /// Create a serialization header based on the current config fn create_header(&self) -> SerializationHeader { match self.versioned { @@ -226,17 +234,30 @@ impl SerializationConfig { let header = self.create_header::(); let header_size = options.serialized_size(&header)?; - options - .with_limit(self.serialized_size_limit) - .serialize_into(&mut writer, &header)?; + if let Some(size_limit) = self.serialized_size_limit { + options + .with_limit(size_limit) + .serialize_into(&mut writer, &header)?; + + match self.versioned { + SerializationVersioningMode::Versioned { .. } => options + .with_limit(size_limit - header_size) + .serialize_into(&mut writer, &object.versionize())?, + SerializationVersioningMode::Unversioned { .. } => options + .with_limit(size_limit - header_size) + .serialize_into(&mut writer, &object)?, + }; + } else { + options.serialize_into(&mut writer, &header)?; - match self.versioned { - SerializationVersioningMode::Versioned { .. } => options - .with_limit(self.serialized_size_limit - header_size) - .serialize_into(&mut writer, &object.versionize())?, - SerializationVersioningMode::Unversioned { .. } => options - .with_limit(self.serialized_size_limit - header_size) - .serialize_into(&mut writer, &object)?, + match self.versioned { + SerializationVersioningMode::Versioned { .. } => { + options.serialize_into(&mut writer, &object.versionize())? + } + SerializationVersioningMode::Unversioned { .. } => { + options.serialize_into(&mut writer, &object)? + } + }; }; Ok(()) @@ -247,7 +268,7 @@ impl SerializationConfig { /// the various sanity checks that will be performed during deserialization. #[derive(Copy, Clone)] pub struct DeserializationConfig { - serialized_size_limit: u64, + serialized_size_limit: Option, validate_header: bool, } @@ -257,11 +278,30 @@ pub struct DeserializationConfig { /// This type should be created with [`DeserializationConfig::disable_conformance`] #[derive(Copy, Clone)] pub struct NonConformantDeserializationConfig { - serialized_size_limit: u64, + serialized_size_limit: Option, validate_header: bool, } impl NonConformantDeserializationConfig { + /// Deserialize a header using the current config + fn deserialize_header( + &self, + reader: &mut impl std::io::Read, + ) -> Result { + let options = bincode::DefaultOptions::new().with_fixint_encoding(); + + if let Some(size_limit) = self.serialized_size_limit { + options + .with_limit(size_limit) + .deserialize_from(reader) + .map_err(|err| err.to_string()) + } else { + options + .deserialize_from(reader) + .map_err(|err| err.to_string()) + } + } + /// Deserializes an object serialized by [`SerializationConfig::serialize_into`] from a /// [reader](std::io::Read). Performs various sanity checks based on the deserialization config, /// but skips conformance checks. @@ -271,10 +311,7 @@ impl NonConformantDeserializationConfig { ) -> Result { let options = bincode::DefaultOptions::new().with_fixint_encoding(); - let deserialized_header: SerializationHeader = options - .with_limit(self.serialized_size_limit) - .deserialize_from(&mut reader) - .map_err(|err| err.to_string())?; + let deserialized_header: SerializationHeader = self.deserialize_header(&mut reader)?; let header_size = options .serialized_size(&deserialized_header) @@ -284,19 +321,33 @@ impl NonConformantDeserializationConfig { deserialized_header.validate::()?; } - match deserialized_header.versioning_mode { - SerializationVersioningMode::Versioned { .. } => { - let deser_versioned = options - .with_limit(self.serialized_size_limit - header_size) - .deserialize_from(&mut reader) - .map_err(|err| err.to_string())?; + if let Some(size_limit) = self.serialized_size_limit { + let options = options.with_limit(size_limit - header_size); + match deserialized_header.versioning_mode { + SerializationVersioningMode::Versioned { .. } => { + let deser_versioned = options + .deserialize_from(&mut reader) + .map_err(|err| err.to_string())?; - T::unversionize(deser_versioned).map_err(|e| e.to_string()) + T::unversionize(deser_versioned).map_err(|e| e.to_string()) + } + SerializationVersioningMode::Unversioned { .. } => options + .deserialize_from(&mut reader) + .map_err(|err| err.to_string()), + } + } else { + match deserialized_header.versioning_mode { + SerializationVersioningMode::Versioned { .. } => { + let deser_versioned = options + .deserialize_from(&mut reader) + .map_err(|err| err.to_string())?; + + T::unversionize(deser_versioned).map_err(|e| e.to_string()) + } + SerializationVersioningMode::Unversioned { .. } => options + .deserialize_from(&mut reader) + .map_err(|err| err.to_string()), } - SerializationVersioningMode::Unversioned { .. } => options - .with_limit(self.serialized_size_limit - header_size) - .deserialize_from(&mut reader) - .map_err(|err| err.to_string()), } } @@ -322,7 +373,7 @@ impl DeserializationConfig { /// the current *TFHE-rs* version. pub fn new(serialized_size_limit: u64) -> Self { Self { - serialized_size_limit, + serialized_size_limit: Some(serialized_size_limit), validate_header: true, } } @@ -330,7 +381,7 @@ impl DeserializationConfig { /// Creates a new config without any size limit for the deserialized objects. pub fn new_with_unlimited_size() -> Self { Self { - serialized_size_limit: 0, + serialized_size_limit: None, validate_header: true, } } @@ -338,7 +389,15 @@ impl DeserializationConfig { /// Disables the size limit for the serialized objects. pub fn disable_size_limit(self) -> Self { Self { - serialized_size_limit: 0, + serialized_size_limit: None, + ..self + } + } + + /// Sets the size limit for this deserialization config + pub fn with_size_limit(self, size: u64) -> Self { + Self { + serialized_size_limit: Some(size), ..self } } @@ -429,7 +488,7 @@ mod test_shortint { use crate::shortint::{gen_keys, Ciphertext}; #[test] - fn safe_deserialization_ct() { + fn safe_deserialization_ct_unversioned() { let (ck, _sk) = gen_keys(PARAM_MESSAGE_2_CARRY_2_KS_PBS); let msg = 2_u64; @@ -464,7 +523,7 @@ mod test_shortint { } #[test] - fn safe_deserialization_ct_versioned() { + fn safe_deserialization_ct() { let (ck, _sk) = gen_keys(PARAM_MESSAGE_2_CARRY_2_KS_PBS); let msg = 2_u64; @@ -498,6 +557,34 @@ mod test_shortint { assert_eq!(msg, dec); } + #[test] + fn safe_deserialization_ct_unlimited_size() { + let (ck, _sk) = gen_keys(PARAM_MESSAGE_2_CARRY_2_KS_PBS); + + let msg = 2_u64; + + let ct = ck.encrypt(msg); + + let mut buffer = vec![]; + + let config = SerializationConfig::new_with_unlimited_size(); + + let size = config.serialized_size(&ct).unwrap(); + config.serialize_into(&ct, &mut buffer).unwrap(); + + assert_eq!(size as usize, buffer.len()); + + let ct2 = DeserializationConfig::new_with_unlimited_size() + .deserialize_from::( + buffer.as_slice(), + &PARAM_MESSAGE_2_CARRY_2_KS_PBS.to_shortint_conformance_param(), + ) + .unwrap(); + + let dec = ck.decrypt(&ct2); + assert_eq!(msg, dec); + } + #[test] fn safe_deserialization_size_limit() { let (ck, _sk) = gen_keys(PARAM_MESSAGE_2_CARRY_2_KS_PBS); @@ -508,7 +595,7 @@ mod test_shortint { let mut buffer = vec![]; - let config = SerializationConfig::new(1 << 20).disable_versioning(); + let config = SerializationConfig::new_with_unlimited_size().disable_versioning(); let size = config.serialized_size(&ct).unwrap(); config.serialize_into(&ct, &mut buffer).unwrap();