diff --git a/utils/tfhe-versionable/src/lib.rs b/utils/tfhe-versionable/src/lib.rs index 8975c88936..00d5327625 100644 --- a/utils/tfhe-versionable/src/lib.rs +++ b/utils/tfhe-versionable/src/lib.rs @@ -222,8 +222,6 @@ macro_rules! impl_scalar_versionize { } impl NotVersioned for $t {} - - impl NotVersioned for Vec<$t> {} }; } @@ -315,7 +313,35 @@ impl Unversionize for Box<[T]> { } } -impl NotVersioned for Box<[T]> {} +impl VersionizeVec for Box<[T]> { + type VersionedVec = Vec; + + fn versionize_vec(vec: Vec) -> Self::VersionedVec { + vec.into_iter() + .map(|inner| inner.versionize_owned()) + .collect() + } +} + +impl VersionizeSlice for Box<[T]> { + type VersionedSlice<'vers> = Vec> where T: 'vers; + + fn versionize_slice(slice: &[Self]) -> Self::VersionedSlice<'_> { + slice + .iter() + .map(|inner| T::versionize_slice(inner)) + .collect() + } +} + +impl UnversionizeVec for Box<[T]> { + fn unversionize_vec(versioned: Self::VersionedVec) -> Result, UnversionizeError> { + versioned + .into_iter() + .map(Box::<[T]>::unversionize) + .collect() + } +} impl Versionize for Vec { type Versioned<'vers> = T::VersionedSlice<'vers> where T: 'vers; @@ -333,6 +359,42 @@ impl VersionizeOwned for Vec { } } +impl Unversionize for Vec { + fn unversionize(versioned: Self::VersionedOwned) -> Result { + T::unversionize_vec(versioned) + } +} + +impl VersionizeVec for Vec { + type VersionedVec = Vec; + + fn versionize_vec(vec: Vec) -> Self::VersionedVec { + vec.into_iter() + .map(|inner| T::versionize_vec(inner)) + .collect() + } +} + +impl VersionizeSlice for Vec { + type VersionedSlice<'vers> = Vec> where T: 'vers; + + fn versionize_slice(slice: &[Self]) -> Self::VersionedSlice<'_> { + slice + .iter() + .map(|inner| T::versionize_slice(inner)) + .collect() + } +} + +impl UnversionizeVec for Vec { + fn unversionize_vec(versioned: Self::VersionedVec) -> Result, UnversionizeError> { + versioned + .into_iter() + .map(|inner| T::unversionize_vec(inner)) + .collect() + } +} + impl Versionize for [T] { type Versioned<'vers> = T::VersionedSlice<'vers> where T: 'vers; @@ -349,9 +411,24 @@ impl VersionizeOwned for &[T] { } } -impl Unversionize for Vec { - fn unversionize(versioned: Self::VersionedOwned) -> Result { - T::unversionize_vec(versioned) +impl VersionizeVec for &[T] { + type VersionedVec = Vec; + + fn versionize_vec(vec: Vec) -> Self::VersionedVec { + vec.into_iter() + .map(|inner| T::versionize_vec(inner.to_vec())) + .collect() + } +} + +impl<'a, T: VersionizeSlice> VersionizeSlice for &'a [T] { + type VersionedSlice<'vers> = Vec> where T: 'vers, 'a: 'vers; + + fn versionize_slice(slice: &[Self]) -> Self::VersionedSlice<'_> { + slice + .iter() + .map(|inner| T::versionize_slice(inner)) + .collect() } } @@ -386,6 +463,33 @@ impl Unversionize for [T; N] { } } +impl VersionizeVec for [T; N] { + type VersionedVec = Vec; + + fn versionize_vec(vec: Vec) -> Self::VersionedVec { + vec.into_iter() + .map(|inner| inner.versionize_owned()) + .collect() + } +} + +impl VersionizeSlice for [T; N] { + type VersionedSlice<'vers> = Vec> where T: 'vers; + + fn versionize_slice(slice: &[Self]) -> Self::VersionedSlice<'_> { + slice + .iter() + .map(|inner| T::versionize_slice(inner)) + .collect() + } +} + +impl UnversionizeVec for [T; N] { + fn unversionize_vec(versioned: Self::VersionedVec) -> Result, UnversionizeError> { + versioned.into_iter().map(<[T; N]>::unversionize).collect() + } +} + impl Versionize for String { type Versioned<'vers> = &'vers str; diff --git a/utils/tfhe-versionable/tests/types.rs b/utils/tfhe-versionable/tests/types.rs index 57760e6b7e..42ab36e00d 100644 --- a/utils/tfhe-versionable/tests/types.rs +++ b/utils/tfhe-versionable/tests/types.rs @@ -10,7 +10,11 @@ use aligned_vec::{ABox, AVec}; use num_complex::Complex; use tfhe_versionable::{Unversionize, Versionize}; -use backward_compat::MyStructVersions; +use backward_compat::{CustomVersions, MyStructVersions}; + +#[derive(PartialEq, Clone, Debug, Versionize)] +#[versionize(CustomVersions)] +struct Custom(u32); #[derive(PartialEq, Clone, Debug, Versionize)] #[versionize(MyStructVersions)] @@ -19,6 +23,9 @@ pub struct MyStruct { base_box: Box, sliced_box: Box<[u16; 50]>, base_vec: Vec, + base_vec_vec: Vec>, + custom_vec_vec: Vec>, + custom_vec_vec_vec: Vec>>, s: String, opt: Option, phantom: PhantomData, @@ -35,6 +42,8 @@ pub struct MyStruct { mod backward_compat { use tfhe_versionable::VersionsDispatch; + use crate::Custom; + use super::MyStruct; #[derive(VersionsDispatch)] @@ -42,6 +51,12 @@ mod backward_compat { pub enum MyStructVersions { V0(MyStruct), } + + #[derive(VersionsDispatch)] + #[allow(unused)] + pub enum CustomVersions { + V0(Custom), + } } #[test] @@ -51,6 +66,28 @@ fn test_types() { base_box: Box::new(42), sliced_box: vec![11; 50].into_boxed_slice().try_into().unwrap(), base_vec: vec![1234, 5678], + base_vec_vec: vec![vec![1234, 5678], vec![9012, 3456]], + custom_vec_vec: vec![ + vec![9876, 5432, 1987, 6543] + .into_iter() + .map(Custom) + .collect(), + vec![1098, 7654, 3210, 9876] + .into_iter() + .map(Custom) + .collect(), + ], + custom_vec_vec_vec: vec![ + vec![ + vec![9876, 5432].into_iter().map(Custom).collect(), + vec![1987, 6543].into_iter().map(Custom).collect(), + ], + vec![ + vec![1098, 7654].into_iter().map(Custom).collect(), + vec![3210, 9876].into_iter().map(Custom).collect(), + ], + ], + s: String::from("test"), opt: Some(0xdeadbeef), phantom: PhantomData,