From b92a3cd565dc22f56ff70c42b3b356e417ce8a61 Mon Sep 17 00:00:00 2001 From: Nicolas Sarlin Date: Wed, 16 Oct 2024 17:48:16 +0200 Subject: [PATCH] feat(versionable): support version deprecations in the dispatch enum --- .../src/dispatch_type.rs | 7 +- .../tfhe-versionable/examples/deprecation.rs | 192 ++++++++++++++++++ utils/tfhe-versionable/src/deprecation.rs | 140 +++++++++++++ utils/tfhe-versionable/src/lib.rs | 7 + 4 files changed, 343 insertions(+), 3 deletions(-) create mode 100644 utils/tfhe-versionable/examples/deprecation.rs create mode 100644 utils/tfhe-versionable/src/deprecation.rs diff --git a/utils/tfhe-versionable-derive/src/dispatch_type.rs b/utils/tfhe-versionable-derive/src/dispatch_type.rs index c39aad57ad..b2f9363625 100644 --- a/utils/tfhe-versionable-derive/src/dispatch_type.rs +++ b/utils/tfhe-versionable-derive/src/dispatch_type.rs @@ -336,6 +336,7 @@ impl DispatchType { fn generate_conversion_constructor_owned(&self, arg_name: &str) -> syn::Result { let arg_ident = Ident::new(arg_name, Span::call_site()); let error_ty: Type = parse_const_str(UNVERSIONIZE_ERROR_NAME); + let upgrade_trait: Path = parse_const_str(UPGRADE_TRAIT_NAME); let match_cases = self.orig_type @@ -354,12 +355,12 @@ impl DispatchType { // Add chained calls to the upgrade method, with error handling let upgrades_chain = (0..upgrades_needed).map(|upgrade_idx| { // Here we can unwrap because src_idx + upgrade_idx < version_count or we wouldn't need to upgrade + let src_type = self.version_type_at(src_idx + upgrade_idx).unwrap(); let src_variant = self.variant_at(src_idx + upgrade_idx).unwrap().ident.to_string(); let dest_variant = self.variant_at(src_idx + upgrade_idx + 1).unwrap().ident.to_string(); quote! { - .and_then(|value| { - value - .upgrade() + .and_then(|value: #src_type| { + #upgrade_trait::upgrade(value) .map_err(|e| #error_ty::upgrade(#src_variant, #dest_variant, e) ) diff --git a/utils/tfhe-versionable/examples/deprecation.rs b/utils/tfhe-versionable/examples/deprecation.rs new file mode 100644 index 0000000000..2fd05a9acd --- /dev/null +++ b/utils/tfhe-versionable/examples/deprecation.rs @@ -0,0 +1,192 @@ +//! Example of a version deprecation, to remove support for types up to a chosen point. +//! +//! In this example, we have an application with 3 versions: v0, v1, v2. We know that v0 and v1 are +//! not used in the wild, so we want to remove backward compatibility with them to be able to +//! clean-up some code. We can use this feature to create a v3 version that will be compatible with +//! v2 but remove support for the previous ones. + +use tfhe_versionable::{Unversionize, Versionize}; + +// The newer version of the app, where you want to cut compatibility with versions that are too old +mod v3 { + use serde::{Deserialize, Serialize}; + use tfhe_versionable::Versionize; + + use backward_compat::MyStructVersions; + + #[derive(Serialize, Deserialize, Versionize)] + #[versionize(MyStructVersions)] + pub struct MyStruct { + pub count: u32, + pub attr: T, + } + + mod backward_compat { + use tfhe_versionable::deprecation::{Deprecable, Deprecated}; + use tfhe_versionable::VersionsDispatch; + + use super::MyStruct; + + // The `Deprecation` trait will be used to give meaningful error messages to you users + impl Deprecable for MyStruct { + // The name of the type, as seen by the user + const TYPE_NAME: &'static str = "MyStruct"; + + // The minimum version of the application/library that we still support. You can include + // the name of your app/library. + const MIN_SUPPORTED_APP_VERSION: &'static str = "app v2"; + } + + // Replace the deprecation versions with the `Deprecated` type in the dispatch enum + #[derive(VersionsDispatch)] + #[allow(unused)] + pub enum MyStructVersions { + V0(Deprecated>), + V1(Deprecated>), + V2(MyStruct), + } + } +} + +fn main() { + // A version that will be deprecated + let v0 = v0::MyStruct(37); + + let serialized = bincode::serialize(&v0.versionize()).unwrap(); + + // We can upgrade it until the last supported version + let v2 = v2::MyStruct::::unversionize(bincode::deserialize(&serialized).unwrap()).unwrap(); + + assert_eq!(v0.0, v2.count); + assert_eq!(v2.attr, u64::default()); + + // But trying to upgrade it into the newer version with dropped support will fail. + let v3_deser: Result, _> = bincode::deserialize(&serialized); + + assert!(v3_deser.is_err()); + + // However you can still update from the last supported version + let _serialized_v2 = bincode::serialize(&v2.versionize()).unwrap(); +} + +// Older versions of the application + +mod v0 { + use serde::{Deserialize, Serialize}; + use tfhe_versionable::Versionize; + + use backward_compat::MyStructVersions; + + #[derive(Serialize, Deserialize, Versionize)] + #[versionize(MyStructVersions)] + pub struct MyStruct(pub u32); + + mod backward_compat { + use tfhe_versionable::VersionsDispatch; + + use super::MyStruct; + + #[derive(VersionsDispatch)] + #[allow(unused)] + pub enum MyStructVersions { + V0(MyStruct), + } + } +} + +mod v1 { + use serde::{Deserialize, Serialize}; + use tfhe_versionable::Versionize; + + use backward_compat::MyStructVersions; + + #[derive(Serialize, Deserialize, Versionize)] + #[versionize(MyStructVersions)] + pub struct MyStruct(pub u32, pub T); + + mod backward_compat { + use std::convert::Infallible; + + use tfhe_versionable::{Upgrade, Version, VersionsDispatch}; + + use super::MyStruct; + + #[derive(Version)] + pub struct MyStructV0(pub u32); + + impl Upgrade> for MyStructV0 { + type Error = Infallible; + + fn upgrade(self) -> Result, Self::Error> { + Ok(MyStruct(self.0, T::default())) + } + } + + #[derive(VersionsDispatch)] + #[allow(unused)] + pub enum MyStructVersions { + V0(MyStructV0), + V1(MyStruct), + } + } +} + +mod v2 { + use serde::{Deserialize, Serialize}; + use tfhe_versionable::Versionize; + + use backward_compat::MyStructVersions; + + #[derive(Serialize, Deserialize, Versionize)] + #[versionize(MyStructVersions)] + pub struct MyStruct { + pub count: u32, + pub attr: T, + } + + mod backward_compat { + use std::convert::Infallible; + + use tfhe_versionable::{Upgrade, Version, VersionsDispatch}; + + use super::MyStruct; + + #[derive(Version)] + pub struct MyStructV0(pub u32); + + impl Upgrade> for MyStructV0 { + type Error = Infallible; + + fn upgrade(self) -> Result, Self::Error> { + Ok(MyStructV1(self.0, T::default())) + } + } + + #[derive(Version)] + pub struct MyStructV1(pub u32, pub T); + + impl Upgrade> for MyStructV1 { + type Error = Infallible; + + fn upgrade(self) -> Result, Self::Error> { + Ok(MyStruct { + count: self.0, + attr: T::default(), + }) + } + } + + #[derive(VersionsDispatch)] + #[allow(unused)] + pub enum MyStructVersions { + V0(MyStructV0), + V1(MyStructV1), + V2(MyStruct), + } + } +} + +#[test] +fn test() { + main() +} diff --git a/utils/tfhe-versionable/src/deprecation.rs b/utils/tfhe-versionable/src/deprecation.rs new file mode 100644 index 0000000000..d5831d41db --- /dev/null +++ b/utils/tfhe-versionable/src/deprecation.rs @@ -0,0 +1,140 @@ +//! Handle the deprecation of older versions of some types + +use std::error::Error; +use std::fmt::Display; +use std::marker::PhantomData; + +use serde::{Deserialize, Serialize}; + +use crate::{UnversionizeError, Upgrade, Version}; + +/// This trait should be implemented for types that have deprecated versions. You can then use them +/// inside the dispatch enum by wrapping them into the [`Deprecated`] type. +pub trait Deprecable { + const TYPE_NAME: &'static str; + const MIN_SUPPORTED_APP_VERSION: &'static str; + + fn error() -> DeprecatedVersionError { + DeprecatedVersionError { + type_name: Self::TYPE_NAME.to_string(), + min_supported_app_version: Self::MIN_SUPPORTED_APP_VERSION.to_string(), + } + } +} + +/// An error returned when trying to interact (unserialize or unversionize) with a deprecated type. +#[derive(Debug)] +pub struct DeprecatedVersionError { + type_name: String, + min_supported_app_version: String, +} + +impl Display for DeprecatedVersionError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "Deprecated {} found in serialized data, minimal supported version is {}", + self.type_name, self.min_supported_app_version + ) + } +} + +impl Error for DeprecatedVersionError {} + +/// Wrapper type that can be used inside the dispatch enum for a type to mark a version that has +/// been deprecated. +/// +/// For example: +/// ```rust +/// use tfhe_versionable::deprecation::{Deprecable, Deprecated}; +/// use tfhe_versionable::{Versionize, VersionsDispatch}; +/// #[derive(Versionize)] +/// #[versionize(MyStructVersions)] +/// struct MyStruct; +/// +/// impl Deprecable for MyStruct { +/// const TYPE_NAME: &'static str = "MyStruct"; +/// const MIN_SUPPORTED_APP_VERSION: &'static str = "my_app v2"; +/// } +/// +/// #[derive(VersionsDispatch)] +/// #[allow(unused)] +/// pub enum MyStructVersions { +/// V0(Deprecated), +/// V1(Deprecated), +/// V2(MyStruct), +/// } +/// ``` +pub struct Deprecated { + _phantom: PhantomData, +} + +/// This type is used in the [`Version`] trait but should not be manually used. +pub struct DeprecatedVersion { + _phantom: PhantomData, +} + +// Manual impl of Serialize/Deserialize to be able to catch them and return a meaningful error to +// the user. + +impl Serialize for DeprecatedVersion { + fn serialize(&self, _serializer: S) -> Result + where + S: serde::Serializer, + { + Err(serde::ser::Error::custom( + "a DeprecatedVersion should never be serialized", + )) + } +} + +impl<'de, T: Deprecable> Deserialize<'de> for DeprecatedVersion { + fn deserialize(_deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + Err(::custom(T::error())) + } +} + +impl Version for Deprecated { + // Since the type is a ZST we directly use it without a reference + type Ref<'vers> + = DeprecatedVersion + where + T: 'vers; + + type Owned = DeprecatedVersion; +} + +impl From> for DeprecatedVersion { + fn from(_value: Deprecated) -> Self { + Self { + _phantom: PhantomData, + } + } +} + +impl From<&Deprecated> for DeprecatedVersion { + fn from(_value: &Deprecated) -> Self { + Self { + _phantom: PhantomData, + } + } +} + +impl TryFrom> for Deprecated { + type Error = UnversionizeError; + + fn try_from(_value: DeprecatedVersion) -> Result { + Err(UnversionizeError::DeprecatedVersion(T::error())) + } +} + +impl Upgrade for Deprecated { + type Error = DeprecatedVersionError; + + fn upgrade(self) -> Result { + Err(T::error()) + } +} diff --git a/utils/tfhe-versionable/src/lib.rs b/utils/tfhe-versionable/src/lib.rs index 2613536df7..e7ee2d3a38 100644 --- a/utils/tfhe-versionable/src/lib.rs +++ b/utils/tfhe-versionable/src/lib.rs @@ -6,10 +6,12 @@ //! that has a variant for each version of the type. //! These traits can be generated using the [`tfhe_versionable_derive::Versionize`] proc macro. +pub mod deprecation; pub mod derived_traits; pub mod upgrade; use aligned_vec::{ABox, AVec}; +use deprecation::DeprecatedVersionError; use num_complex::Complex; use std::collections::{HashMap, HashSet}; use std::convert::Infallible; @@ -89,6 +91,9 @@ pub enum UnversionizeError { expected_size: usize, found_size: usize, }, + + /// A deprecated version has been found + DeprecatedVersion(DeprecatedVersionError), } impl Display for UnversionizeError { @@ -114,6 +119,7 @@ impl Display for UnversionizeError { "Expected array of size {expected_size}, found array of size {found_size}" ) } + Self::DeprecatedVersion(deprecation_error) => deprecation_error.fmt(f), } } } @@ -124,6 +130,7 @@ impl Error for UnversionizeError { UnversionizeError::Upgrade { source, .. } => Some(source.as_ref()), UnversionizeError::Conversion { source, .. } => Some(source.as_ref()), UnversionizeError::ArrayLength { .. } => None, + UnversionizeError::DeprecatedVersion(_) => None, } } }