Skip to content

Commit

Permalink
fix(serialization): serialized_size_limit includes the header
Browse files Browse the repository at this point in the history
  • Loading branch information
nsarlin-zama committed Oct 18, 2024
1 parent f3a1b6b commit f8480c0
Showing 1 changed file with 44 additions and 44 deletions.
88 changes: 44 additions & 44 deletions tfhe/src/safe_serialization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,6 @@ impl SerializationVersioningMode {
}
}

/// `HEADER_LENGTH_LIMIT` is the maximum `SerializationHeader` size which
/// `DeserializationConfig::deserialize_from` is going to try to read (it returns an error if
/// it's too big).
/// It helps prevent an attacker passing a very long header to exhaust memory.
const HEADER_LENGTH_LIMIT: u64 = 1000;

/// Header with global metadata about the serialized object. This help checking that we are not
/// deserializing data that we can't handle.
#[derive(Serialize, Deserialize)]
Expand Down Expand Up @@ -152,7 +146,7 @@ impl SerializationConfig {
/// Creates a new serialization config. The default configuration will serialize the object
/// with versioning information for backward compatibility.
/// `serialized_size_limit` is the size limit (in number of byte) of the serialized object
/// (excluding the header).
/// (including the header).
pub fn new(serialized_size_limit: u64) -> Self {
Self {
versioned: SerializationVersioningMode::versioned(),
Expand Down Expand Up @@ -196,15 +190,6 @@ impl SerializationConfig {
}
}

/// Returns the max length of the serialized header
fn header_length_limit(&self) -> u64 {
if self.serialized_size_limit == 0 {
0
} else {
HEADER_LENGTH_LIMIT
}
}

/// Returns the size the object would take if serialized using the current config
///
/// The size is returned as a u64 to handle the serialization of large buffers under 32b
Expand Down Expand Up @@ -236,21 +221,21 @@ impl SerializationConfig {
object: &T,
mut writer: impl std::io::Write,
) -> bincode::Result<()> {
let options = bincode::DefaultOptions::new()
.with_fixint_encoding()
.with_limit(0);
let options = bincode::DefaultOptions::new().with_fixint_encoding();

let header = self.create_header::<T>();
let header_size = options.serialized_size(&header)?;

options
.with_limit(self.header_length_limit())
.with_limit(self.serialized_size_limit)
.serialize_into(&mut writer, &header)?;

match self.versioned {
SerializationVersioningMode::Versioned { .. } => options
.with_limit(self.serialized_size_limit)
.with_limit(self.serialized_size_limit - header_size)
.serialize_into(&mut writer, &object.versionize())?,
SerializationVersioningMode::Unversioned { .. } => options
.with_limit(self.serialized_size_limit)
.with_limit(self.serialized_size_limit - header_size)
.serialize_into(&mut writer, &object)?,
};

Expand Down Expand Up @@ -284,37 +269,32 @@ impl NonConformantDeserializationConfig {
self,
mut reader: impl std::io::Read,
) -> Result<T, String> {
if self.serialized_size_limit != 0 && self.serialized_size_limit <= HEADER_LENGTH_LIMIT {
return Err(format!(
"The provided size limit is too small, provide a limit of at least \
{HEADER_LENGTH_LIMIT} bytes"
));
}

let options = bincode::DefaultOptions::new()
.with_fixint_encoding()
.with_limit(0);
let options = bincode::DefaultOptions::new().with_fixint_encoding();

let deserialized_header: SerializationHeader = options
.with_limit(self.header_length_limit())
.with_limit(self.serialized_size_limit)
.deserialize_from(&mut reader)
.map_err(|err| err.to_string())?;

let header_size = options
.serialized_size(&deserialized_header)
.map_err(|err| err.to_string())?;

if self.validate_header {
deserialized_header.validate::<T>()?;
}

match deserialized_header.versioning_mode {
SerializationVersioningMode::Versioned { .. } => {
let deser_versioned = options
.with_limit(self.serialized_size_limit - self.header_length_limit())
.with_limit(self.serialized_size_limit - header_size)
.deserialize_from(&mut reader)
.map_err(|err| err.to_string())?;

T::unversionize(deser_versioned).map_err(|e| e.to_string())
}
SerializationVersioningMode::Unversioned { .. } => options
.with_limit(self.serialized_size_limit - self.header_length_limit())
.with_limit(self.serialized_size_limit - header_size)
.deserialize_from(&mut reader)
.map_err(|err| err.to_string()),
}
Expand All @@ -327,14 +307,6 @@ impl NonConformantDeserializationConfig {
validate_header: self.validate_header,
}
}

fn header_length_limit(&self) -> u64 {
if self.serialized_size_limit == 0 {
0
} else {
HEADER_LENGTH_LIMIT
}
}
}

impl DeserializationConfig {
Expand All @@ -343,7 +315,7 @@ impl DeserializationConfig {
/// By default, it will check that the serialization version and the name of the
/// deserialized type are correct.
/// `serialized_size_limit` is the size limit (in number of byte) of the serialized object
/// (excluding version and name serialization).
/// (include the safe serialization header).
///
/// It will also check that the object is conformant with the parameter set given in
/// `conformance_params`. Finally, it will check the compatibility of the loaded data with
Expand Down Expand Up @@ -525,6 +497,34 @@ mod test_shortint {
let dec = ck.decrypt(&ct2);
assert_eq!(msg, dec);
}

#[test]
fn safe_deserialization_size_limit() {
let (ck, _sk) = gen_keys(PARAM_MESSAGE_2_CARRY_2_KS_PBS);

let msg = 2_u64;

let ct = ck.encrypt(msg);

let mut buffer = vec![];

let config = SerializationConfig::new(1 << 20).disable_versioning();

let size = config.serialized_size(&ct).unwrap();
config.serialize_into(&ct, &mut buffer).unwrap();

assert_eq!(size as usize, buffer.len());

let ct2 = DeserializationConfig::new(size)
.deserialize_from::<Ciphertext>(
buffer.as_slice(),
&PARAM_MESSAGE_2_CARRY_2_KS_PBS.to_shortint_conformance_param(),
)
.unwrap();

let dec = ck.decrypt(&ct2);
assert_eq!(msg, dec);
}
}

#[cfg(all(test, feature = "integer"))]
Expand Down

0 comments on commit f8480c0

Please sign in to comment.