Skip to content

Commit

Permalink
feat(versionable): support version deprecations in the dispatch enum
Browse files Browse the repository at this point in the history
  • Loading branch information
nsarlin-zama committed Oct 21, 2024
1 parent 2c52bd1 commit b92a3cd
Show file tree
Hide file tree
Showing 4 changed files with 343 additions and 3 deletions.
7 changes: 4 additions & 3 deletions utils/tfhe-versionable-derive/src/dispatch_type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,7 @@ impl DispatchType {
fn generate_conversion_constructor_owned(&self, arg_name: &str) -> syn::Result<TokenStream> {
let arg_ident = Ident::new(arg_name, Span::call_site());
let error_ty: Type = parse_const_str(UNVERSIONIZE_ERROR_NAME);
let upgrade_trait: Path = parse_const_str(UPGRADE_TRAIT_NAME);

let match_cases =
self.orig_type
Expand All @@ -354,12 +355,12 @@ impl DispatchType {
// Add chained calls to the upgrade method, with error handling
let upgrades_chain = (0..upgrades_needed).map(|upgrade_idx| {
// Here we can unwrap because src_idx + upgrade_idx < version_count or we wouldn't need to upgrade
let src_type = self.version_type_at(src_idx + upgrade_idx).unwrap();
let src_variant = self.variant_at(src_idx + upgrade_idx).unwrap().ident.to_string();
let dest_variant = self.variant_at(src_idx + upgrade_idx + 1).unwrap().ident.to_string();
quote! {
.and_then(|value| {
value
.upgrade()
.and_then(|value: #src_type| {
#upgrade_trait::upgrade(value)
.map_err(|e|
#error_ty::upgrade(#src_variant, #dest_variant, e)
)
Expand Down
192 changes: 192 additions & 0 deletions utils/tfhe-versionable/examples/deprecation.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
//! Example of a version deprecation, to remove support for types up to a chosen point.
//!
//! In this example, we have an application with 3 versions: v0, v1, v2. We know that v0 and v1 are
//! not used in the wild, so we want to remove backward compatibility with them to be able to
//! clean-up some code. We can use this feature to create a v3 version that will be compatible with
//! v2 but remove support for the previous ones.
use tfhe_versionable::{Unversionize, Versionize};

// The newer version of the app, where you want to cut compatibility with versions that are too old
mod v3 {
use serde::{Deserialize, Serialize};
use tfhe_versionable::Versionize;

use backward_compat::MyStructVersions;

#[derive(Serialize, Deserialize, Versionize)]
#[versionize(MyStructVersions)]
pub struct MyStruct<T> {
pub count: u32,
pub attr: T,
}

mod backward_compat {
use tfhe_versionable::deprecation::{Deprecable, Deprecated};
use tfhe_versionable::VersionsDispatch;

use super::MyStruct;

// The `Deprecation` trait will be used to give meaningful error messages to you users
impl<T> Deprecable for MyStruct<T> {
// The name of the type, as seen by the user
const TYPE_NAME: &'static str = "MyStruct";

// The minimum version of the application/library that we still support. You can include
// the name of your app/library.
const MIN_SUPPORTED_APP_VERSION: &'static str = "app v2";
}

// Replace the deprecation versions with the `Deprecated` type in the dispatch enum
#[derive(VersionsDispatch)]
#[allow(unused)]
pub enum MyStructVersions<T> {
V0(Deprecated<MyStruct<T>>),
V1(Deprecated<MyStruct<T>>),
V2(MyStruct<T>),
}
}
}

fn main() {
// A version that will be deprecated
let v0 = v0::MyStruct(37);

let serialized = bincode::serialize(&v0.versionize()).unwrap();

// We can upgrade it until the last supported version
let v2 = v2::MyStruct::<u64>::unversionize(bincode::deserialize(&serialized).unwrap()).unwrap();

assert_eq!(v0.0, v2.count);
assert_eq!(v2.attr, u64::default());

// But trying to upgrade it into the newer version with dropped support will fail.
let v3_deser: Result<v3::MyStruct<u64>, _> = bincode::deserialize(&serialized);

assert!(v3_deser.is_err());

// However you can still update from the last supported version
let _serialized_v2 = bincode::serialize(&v2.versionize()).unwrap();
}

// Older versions of the application

mod v0 {
use serde::{Deserialize, Serialize};
use tfhe_versionable::Versionize;

use backward_compat::MyStructVersions;

#[derive(Serialize, Deserialize, Versionize)]
#[versionize(MyStructVersions)]
pub struct MyStruct(pub u32);

mod backward_compat {
use tfhe_versionable::VersionsDispatch;

use super::MyStruct;

#[derive(VersionsDispatch)]
#[allow(unused)]
pub enum MyStructVersions {
V0(MyStruct),
}
}
}

mod v1 {
use serde::{Deserialize, Serialize};
use tfhe_versionable::Versionize;

use backward_compat::MyStructVersions;

#[derive(Serialize, Deserialize, Versionize)]
#[versionize(MyStructVersions)]
pub struct MyStruct<T>(pub u32, pub T);

mod backward_compat {
use std::convert::Infallible;

use tfhe_versionable::{Upgrade, Version, VersionsDispatch};

use super::MyStruct;

#[derive(Version)]
pub struct MyStructV0(pub u32);

impl<T: Default> Upgrade<MyStruct<T>> for MyStructV0 {
type Error = Infallible;

fn upgrade(self) -> Result<MyStruct<T>, Self::Error> {
Ok(MyStruct(self.0, T::default()))
}
}

#[derive(VersionsDispatch)]
#[allow(unused)]
pub enum MyStructVersions<T> {
V0(MyStructV0),
V1(MyStruct<T>),
}
}
}

mod v2 {
use serde::{Deserialize, Serialize};
use tfhe_versionable::Versionize;

use backward_compat::MyStructVersions;

#[derive(Serialize, Deserialize, Versionize)]
#[versionize(MyStructVersions)]
pub struct MyStruct<T> {
pub count: u32,
pub attr: T,
}

mod backward_compat {
use std::convert::Infallible;

use tfhe_versionable::{Upgrade, Version, VersionsDispatch};

use super::MyStruct;

#[derive(Version)]
pub struct MyStructV0(pub u32);

impl<T: Default> Upgrade<MyStructV1<T>> for MyStructV0 {
type Error = Infallible;

fn upgrade(self) -> Result<MyStructV1<T>, Self::Error> {
Ok(MyStructV1(self.0, T::default()))
}
}

#[derive(Version)]
pub struct MyStructV1<T>(pub u32, pub T);

impl<T: Default> Upgrade<MyStruct<T>> for MyStructV1<T> {
type Error = Infallible;

fn upgrade(self) -> Result<MyStruct<T>, Self::Error> {
Ok(MyStruct {
count: self.0,
attr: T::default(),
})
}
}

#[derive(VersionsDispatch)]
#[allow(unused)]
pub enum MyStructVersions<T> {
V0(MyStructV0),
V1(MyStructV1<T>),
V2(MyStruct<T>),
}
}
}

#[test]
fn test() {
main()
}
140 changes: 140 additions & 0 deletions utils/tfhe-versionable/src/deprecation.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
//! Handle the deprecation of older versions of some types
use std::error::Error;
use std::fmt::Display;
use std::marker::PhantomData;

use serde::{Deserialize, Serialize};

use crate::{UnversionizeError, Upgrade, Version};

/// This trait should be implemented for types that have deprecated versions. You can then use them
/// inside the dispatch enum by wrapping them into the [`Deprecated`] type.
pub trait Deprecable {
const TYPE_NAME: &'static str;
const MIN_SUPPORTED_APP_VERSION: &'static str;

fn error() -> DeprecatedVersionError {
DeprecatedVersionError {
type_name: Self::TYPE_NAME.to_string(),
min_supported_app_version: Self::MIN_SUPPORTED_APP_VERSION.to_string(),
}
}
}

/// An error returned when trying to interact (unserialize or unversionize) with a deprecated type.
#[derive(Debug)]
pub struct DeprecatedVersionError {
type_name: String,
min_supported_app_version: String,
}

impl Display for DeprecatedVersionError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"Deprecated {} found in serialized data, minimal supported version is {}",
self.type_name, self.min_supported_app_version
)
}
}

impl Error for DeprecatedVersionError {}

/// Wrapper type that can be used inside the dispatch enum for a type to mark a version that has
/// been deprecated.
///
/// For example:
/// ```rust
/// use tfhe_versionable::deprecation::{Deprecable, Deprecated};
/// use tfhe_versionable::{Versionize, VersionsDispatch};
/// #[derive(Versionize)]
/// #[versionize(MyStructVersions)]
/// struct MyStruct;
///
/// impl Deprecable for MyStruct {
/// const TYPE_NAME: &'static str = "MyStruct";
/// const MIN_SUPPORTED_APP_VERSION: &'static str = "my_app v2";
/// }
///
/// #[derive(VersionsDispatch)]
/// #[allow(unused)]
/// pub enum MyStructVersions {
/// V0(Deprecated<MyStruct>),
/// V1(Deprecated<MyStruct>),
/// V2(MyStruct),
/// }
/// ```
pub struct Deprecated<T> {
_phantom: PhantomData<T>,
}

/// This type is used in the [`Version`] trait but should not be manually used.
pub struct DeprecatedVersion<T> {
_phantom: PhantomData<T>,
}

// Manual impl of Serialize/Deserialize to be able to catch them and return a meaningful error to
// the user.

impl<T: Deprecable> Serialize for DeprecatedVersion<T> {
fn serialize<S>(&self, _serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
Err(serde::ser::Error::custom(
"a DeprecatedVersion should never be serialized",
))
}
}

impl<'de, T: Deprecable> Deserialize<'de> for DeprecatedVersion<T> {
fn deserialize<D>(_deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
Err(<D::Error as serde::de::Error>::custom(T::error()))
}
}

impl<T: Deprecable> Version for Deprecated<T> {
// Since the type is a ZST we directly use it without a reference
type Ref<'vers>
= DeprecatedVersion<T>
where
T: 'vers;

type Owned = DeprecatedVersion<T>;
}

impl<T: Deprecable> From<Deprecated<T>> for DeprecatedVersion<T> {
fn from(_value: Deprecated<T>) -> Self {
Self {
_phantom: PhantomData,
}
}
}

impl<T: Deprecable> From<&Deprecated<T>> for DeprecatedVersion<T> {
fn from(_value: &Deprecated<T>) -> Self {
Self {
_phantom: PhantomData,
}
}
}

impl<T: Deprecable> TryFrom<DeprecatedVersion<T>> for Deprecated<T> {
type Error = UnversionizeError;

fn try_from(_value: DeprecatedVersion<T>) -> Result<Self, Self::Error> {
Err(UnversionizeError::DeprecatedVersion(T::error()))
}
}

impl<T: Deprecable, U> Upgrade<U> for Deprecated<T> {
type Error = DeprecatedVersionError;

fn upgrade(self) -> Result<U, Self::Error> {
Err(T::error())
}
}
7 changes: 7 additions & 0 deletions utils/tfhe-versionable/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@
//! that has a variant for each version of the type.
//! These traits can be generated using the [`tfhe_versionable_derive::Versionize`] proc macro.
pub mod deprecation;
pub mod derived_traits;
pub mod upgrade;

use aligned_vec::{ABox, AVec};
use deprecation::DeprecatedVersionError;
use num_complex::Complex;
use std::collections::{HashMap, HashSet};
use std::convert::Infallible;
Expand Down Expand Up @@ -89,6 +91,9 @@ pub enum UnversionizeError {
expected_size: usize,
found_size: usize,
},

/// A deprecated version has been found
DeprecatedVersion(DeprecatedVersionError),
}

impl Display for UnversionizeError {
Expand All @@ -114,6 +119,7 @@ impl Display for UnversionizeError {
"Expected array of size {expected_size}, found array of size {found_size}"
)
}
Self::DeprecatedVersion(deprecation_error) => deprecation_error.fmt(f),
}
}
}
Expand All @@ -124,6 +130,7 @@ impl Error for UnversionizeError {
UnversionizeError::Upgrade { source, .. } => Some(source.as_ref()),
UnversionizeError::Conversion { source, .. } => Some(source.as_ref()),
UnversionizeError::ArrayLength { .. } => None,
UnversionizeError::DeprecatedVersion(_) => None,
}
}
}
Expand Down

0 comments on commit b92a3cd

Please sign in to comment.