From 8cfa7b7eea539762c1ae5bdb6656a8f2270eb92c Mon Sep 17 00:00:00 2001 From: Nicolas Sarlin Date: Wed, 2 Oct 2024 15:30:11 +0200 Subject: [PATCH] fix(versionable): compatibility between "convert" and generics --- tfhe-zk-pok/src/curve_api/bls12_381.rs | 40 +- tfhe-zk-pok/src/curve_api/bls12_446.rs | 41 +- .../tfhe-versionable-derive/src/associated.rs | 11 +- utils/tfhe-versionable-derive/src/lib.rs | 136 ++--- .../src/versionize_attribute.rs | 474 ++++++++++++++++-- utils/tfhe-versionable/examples/convert.rs | 33 +- .../tests/convert_with_bounds.rs | 51 ++ .../tests/convert_with_generics.rs | 58 +++ 8 files changed, 642 insertions(+), 202 deletions(-) create mode 100644 utils/tfhe-versionable/tests/convert_with_bounds.rs create mode 100644 utils/tfhe-versionable/tests/convert_with_generics.rs diff --git a/tfhe-zk-pok/src/curve_api/bls12_381.rs b/tfhe-zk-pok/src/curve_api/bls12_381.rs index ccd8f1da18..9a5c979bb1 100644 --- a/tfhe-zk-pok/src/curve_api/bls12_381.rs +++ b/tfhe-zk-pok/src/curve_api/bls12_381.rs @@ -36,18 +36,13 @@ fn bigint_to_le_bytes(x: [u64; 6]) -> [u8; 6 * 8] { mod g1 { use tfhe_versionable::Versionize; - use crate::backward_compatibility::SerializableG1AffineVersions; use crate::serialization::{InvalidSerializedAffineError, SerializableG1Affine}; use super::*; #[derive(Copy, Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Hash, Versionize)] #[serde(try_from = "SerializableG1Affine", into = "SerializableG1Affine")] - #[versionize( - SerializableG1AffineVersions, - try_from = "SerializableG1Affine", - into = "SerializableG1Affine" - )] + #[versionize(try_from = "SerializableG1Affine", into = "SerializableG1Affine")] #[repr(transparent)] pub struct G1Affine { pub(crate) inner: ark_bls12_381::g1::G1Affine, @@ -99,11 +94,7 @@ mod g1 { #[derive(Copy, Clone, PartialEq, Eq, Serialize, Deserialize, Hash, Versionize)] #[serde(try_from = "SerializableG1Affine", into = "SerializableG1Affine")] - #[versionize( - SerializableG1AffineVersions, - try_from = "SerializableG1Affine", - into = "SerializableG1Affine" - )] + #[versionize(try_from = "SerializableG1Affine", into = "SerializableG1Affine")] #[repr(transparent)] pub struct G1 { pub(crate) inner: ark_bls12_381::G1Projective, @@ -264,18 +255,13 @@ mod g1 { mod g2 { use tfhe_versionable::Versionize; - use crate::backward_compatibility::SerializableG2AffineVersions; use crate::serialization::{InvalidSerializedAffineError, SerializableG2Affine}; use super::*; #[derive(Copy, Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Hash, Versionize)] #[serde(try_from = "SerializableG2Affine", into = "SerializableG2Affine")] - #[versionize( - SerializableG2AffineVersions, - try_from = "SerializableG2Affine", - into = "SerializableG2Affine" - )] + #[versionize(try_from = "SerializableG2Affine", into = "SerializableG2Affine")] #[repr(transparent)] pub struct G2Affine { pub(crate) inner: ark_bls12_381::g2::G2Affine, @@ -328,11 +314,7 @@ mod g2 { #[derive(Copy, Clone, PartialEq, Eq, Serialize, Deserialize, Hash, Versionize)] #[serde(try_from = "SerializableG2Affine", into = "SerializableG2Affine")] - #[versionize( - SerializableG2AffineVersions, - try_from = "SerializableG2Affine", - into = "SerializableG2Affine" - )] + #[versionize(try_from = "SerializableG2Affine", into = "SerializableG2Affine")] #[repr(transparent)] pub struct G2 { pub(crate) inner: ark_bls12_381::G2Projective, @@ -539,7 +521,6 @@ mod g2 { } mod gt { - use crate::backward_compatibility::SerializableFp12Versions; use crate::serialization::InvalidArraySizeError; use super::*; @@ -548,11 +529,7 @@ mod gt { #[derive(Copy, Clone, PartialEq, Eq, Serialize, Deserialize, Versionize, Hash)] #[serde(try_from = "SerializableFp12", into = "SerializableFp12")] - #[versionize( - SerializableFp12Versions, - try_from = "SerializableFp12", - into = "SerializableFp12" - )] + #[versionize(try_from = "SerializableFp12", into = "SerializableFp12")] #[repr(transparent)] pub struct Gt { inner: ark_ec::pairing::PairingOutput, @@ -697,7 +674,6 @@ mod gt { } mod zp { - use crate::backward_compatibility::SerializableFpVersions; use crate::serialization::InvalidArraySizeError; use super::*; @@ -741,11 +717,7 @@ mod zp { #[derive(Copy, Clone, PartialEq, Eq, Serialize, Deserialize, Versionize, Hash, Zeroize)] #[serde(try_from = "SerializableFp", into = "SerializableFp")] - #[versionize( - SerializableFpVersions, - try_from = "SerializableFp", - into = "SerializableFp" - )] + #[versionize(try_from = "SerializableFp", into = "SerializableFp")] #[repr(transparent)] pub struct Zp { pub(crate) inner: ark_bls12_381::Fr, diff --git a/tfhe-zk-pok/src/curve_api/bls12_446.rs b/tfhe-zk-pok/src/curve_api/bls12_446.rs index 342ef68f09..53ea960ef0 100644 --- a/tfhe-zk-pok/src/curve_api/bls12_446.rs +++ b/tfhe-zk-pok/src/curve_api/bls12_446.rs @@ -36,18 +36,13 @@ fn bigint_to_le_bytes(x: [u64; 7]) -> [u8; 7 * 8] { mod g1 { use tfhe_versionable::Versionize; - use crate::backward_compatibility::SerializableG1AffineVersions; use crate::serialization::{InvalidSerializedAffineError, SerializableG1Affine}; use super::*; #[derive(Copy, Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Hash, Versionize)] #[serde(try_from = "SerializableG1Affine", into = "SerializableG1Affine")] - #[versionize( - SerializableG1AffineVersions, - try_from = "SerializableG1Affine", - into = "SerializableG1Affine" - )] + #[versionize(try_from = "SerializableG1Affine", into = "SerializableG1Affine")] #[repr(transparent)] pub struct G1Affine { pub(crate) inner: crate::curve_446::g1::G1Affine, @@ -101,11 +96,7 @@ mod g1 { #[derive(Copy, Clone, PartialEq, Eq, Serialize, Deserialize, Hash, Versionize)] #[serde(try_from = "SerializableG1Affine", into = "SerializableG1Affine")] - #[versionize( - SerializableG1AffineVersions, - try_from = "SerializableG1Affine", - into = "SerializableG1Affine" - )] + #[versionize(try_from = "SerializableG1Affine", into = "SerializableG1Affine")] #[repr(transparent)] pub struct G1 { pub(crate) inner: crate::curve_446::g1::G1Projective, @@ -267,7 +258,6 @@ mod g1 { mod g2 { use tfhe_versionable::Versionize; - use crate::backward_compatibility::SerializableG2AffineVersions; use crate::serialization::SerializableG2Affine; use super::*; @@ -275,11 +265,7 @@ mod g2 { #[derive(Copy, Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Hash, Versionize)] #[serde(try_from = "SerializableG2Affine", into = "SerializableG2Affine")] - #[versionize( - SerializableG2AffineVersions, - try_from = "SerializableG2Affine", - into = "SerializableG2Affine" - )] + #[versionize(try_from = "SerializableG2Affine", into = "SerializableG2Affine")] #[repr(transparent)] pub struct G2Affine { pub(crate) inner: crate::curve_446::g2::G2Affine, @@ -423,11 +409,7 @@ mod g2 { #[derive(Copy, Clone, PartialEq, Eq, Serialize, Deserialize, Hash, Versionize)] #[serde(try_from = "SerializableG2Affine", into = "SerializableG2Affine")] - #[versionize( - SerializableG2AffineVersions, - try_from = "SerializableG2Affine", - into = "SerializableG2Affine" - )] + #[versionize(try_from = "SerializableG2Affine", into = "SerializableG2Affine")] #[repr(transparent)] pub struct G2 { pub(crate) inner: crate::curve_446::g2::G2Projective, @@ -633,7 +615,6 @@ mod g2 { } mod gt { - use crate::backward_compatibility::SerializableFp12Versions; use crate::curve_446::{Fq, Fq12, Fq2}; use crate::serialization::InvalidSerializedAffineError; @@ -812,11 +793,7 @@ mod gt { #[derive(Copy, Clone, PartialEq, Eq, Serialize, Deserialize, Versionize, Hash)] #[serde(try_from = "SerializableFp12", into = "SerializableFp12")] - #[versionize( - SerializableFp12Versions, - try_from = "SerializableFp12", - into = "SerializableFp12" - )] + #[versionize(try_from = "SerializableFp12", into = "SerializableFp12")] #[repr(transparent)] pub struct Gt { pub(crate) inner: ark_ec::pairing::PairingOutput, @@ -959,8 +936,6 @@ mod gt { } mod zp { - use crate::backward_compatibility::SerializableFpVersions; - use super::*; use crate::serialization::InvalidArraySizeError; use ark_ff::Fp; @@ -1003,11 +978,7 @@ mod zp { #[derive(Copy, Clone, PartialEq, Eq, Serialize, Deserialize, Versionize, Hash, Zeroize)] #[serde(try_from = "SerializableFp", into = "SerializableFp")] - #[versionize( - SerializableFpVersions, - try_from = "SerializableFp", - into = "SerializableFp" - )] + #[versionize(try_from = "SerializableFp", into = "SerializableFp")] #[repr(transparent)] pub struct Zp { pub inner: crate::curve_446::Fr, diff --git a/utils/tfhe-versionable-derive/src/associated.rs b/utils/tfhe-versionable-derive/src/associated.rs index 8931e44cb2..026d5b200e 100644 --- a/utils/tfhe-versionable-derive/src/associated.rs +++ b/utils/tfhe-versionable-derive/src/associated.rs @@ -6,8 +6,9 @@ use syn::{ }; use crate::{ - add_lifetime_bound, add_trait_where_clause, add_where_lifetime_bound, extend_where_clause, - parse_const_str, DESERIALIZE_TRAIT_NAME, LIFETIME_NAME, SERIALIZE_TRAIT_NAME, + add_lifetime_param, add_trait_where_clause, add_where_lifetime_bound_to_generics, + extend_where_clause, parse_const_str, DESERIALIZE_TRAIT_NAME, LIFETIME_NAME, + SERIALIZE_TRAIT_NAME, }; /// Generates an impl block for the From trait. This will be: @@ -116,7 +117,7 @@ pub(crate) trait AssociatedType: Sized { let mut generics = self.orig_type_generics().clone(); if let AssociatedTypeKind::Ref(opt_lifetime) = &self.kind() { if let Some(lifetime) = opt_lifetime { - add_lifetime_bound(&mut generics, lifetime); + add_lifetime_param(&mut generics, lifetime); } add_trait_where_clause(&mut generics, self.inner_types()?, Self::REF_BOUNDS)?; } else { @@ -214,8 +215,8 @@ impl AssociatingTrait { let mut ref_type_generics = self.ref_type.orig_type_generics().clone(); // If the original type has some generics, we need to add a lifetime bound on them if let Some(lifetime) = self.ref_type.lifetime() { - add_lifetime_bound(&mut ref_type_generics, lifetime); - add_where_lifetime_bound(&mut ref_type_generics, lifetime); + add_lifetime_param(&mut ref_type_generics, lifetime); + add_where_lifetime_bound_to_generics(&mut ref_type_generics, lifetime); } let (impl_generics, orig_generics, where_clause) = generics.split_for_impl(); diff --git a/utils/tfhe-versionable-derive/src/lib.rs b/utils/tfhe-versionable-derive/src/lib.rs index f40374daf0..a64340c7dd 100644 --- a/utils/tfhe-versionable-derive/src/lib.rs +++ b/utils/tfhe-versionable-derive/src/lib.rs @@ -42,6 +42,13 @@ pub(crate) const UNVERSIONIZE_ERROR_NAME: &str = crate_full_path!("UnversionizeE pub(crate) const SERIALIZE_TRAIT_NAME: &str = "::serde::Serialize"; pub(crate) const DESERIALIZE_TRAIT_NAME: &str = "::serde::Deserialize"; +pub(crate) const FROM_TRAIT_NAME: &str = "::core::convert::From"; +pub(crate) const TRY_INTO_TRAIT_NAME: &str = "::core::convert::TryInto"; +pub(crate) const INTO_TRAIT_NAME: &str = "::core::convert::Into"; +pub(crate) const ERROR_TRAIT_NAME: &str = "::core::error::Error"; +pub(crate) const SYNC_TRAIT_NAME: &str = "::core::marker::Sync"; +pub(crate) const SEND_TRAIT_NAME: &str = "::core::marker::Send"; +pub(crate) const STATIC_LIFETIME_NAME: &str = "'static"; use associated::AssociatingTrait; @@ -140,47 +147,7 @@ pub fn derive_versionize(input: TokenStream) -> TokenStream { Some(impl_version_trait(&input)) }; - let dispatch_enum_path = attributes.dispatch_enum(); - let dispatch_target = attributes.dispatch_target(); - let input_ident = &input.ident; - let mut ref_generics = input.generics.clone(); - let mut trait_generics = input.generics.clone(); - let (_, ty_generics, owned_where_clause) = input.generics.split_for_impl(); - - // If the original type has some generics, we need to add bounds on them for - // the impl - let lifetime = Lifetime::new(LIFETIME_NAME, Span::call_site()); - add_where_lifetime_bound(&mut ref_generics, &lifetime); - - // The versionize method takes a ref. We need to own the input type in the conversion case - // to apply `From for Target`. This adds a `Clone` bound to have a better error message - // if the input type is not Clone. - if attributes.needs_conversion() { - syn_unwrap!(add_trait_where_clause( - &mut trait_generics, - [&parse_quote! { Self }], - &["Clone"] - )); - }; - - let dispatch_generics = if attributes.needs_conversion() { - None - } else { - Some(&ty_generics) - }; - - let dispatch_trait: Path = parse_const_str(DISPATCH_TRAIT_NAME); - - syn_unwrap!(add_trait_where_clause( - &mut trait_generics, - [&parse_quote!(#dispatch_enum_path #dispatch_generics)], - &[format!( - "{}<{}>", - DISPATCH_TRAIT_NAME, - dispatch_target.to_token_stream() - )] - )); - + // Parse the name of the traits that we will implement let versionize_trait: Path = parse_const_str(VERSIONIZE_TRAIT_NAME); let versionize_owned_trait: Path = parse_const_str(VERSIONIZE_OWNED_TRAIT_NAME); let unversionize_trait: Path = parse_const_str(UNVERSIONIZE_TRAIT_NAME); @@ -188,19 +155,33 @@ pub fn derive_versionize(input: TokenStream) -> TokenStream { let versionize_slice_trait: Path = parse_const_str(VERSIONIZE_SLICE_TRAIT_NAME); let unversionize_vec_trait: Path = parse_const_str(UNVERSIONIZE_VEC_TRAIT_NAME); + let input_ident = &input.ident; + let lifetime = Lifetime::new(LIFETIME_NAME, Span::call_site()); + // split generics so they can be used inside the generated code - let (_, _, ref_where_clause) = ref_generics.split_for_impl(); - let (trait_impl_generics, _, trait_where_clause) = trait_generics.split_for_impl(); + let (_, ty_generics, _) = input.generics.split_for_impl(); - // If we want to apply a conversion before the call to versionize we need to use the "owned" - // alternative of the dispatch enum to be able to store the conversion result. - let versioned_type_kind = if attributes.needs_conversion() { - quote! { Owned #owned_where_clause } - } else { - quote! { Ref<#lifetime> #ref_where_clause } - }; + // Generates the associated types required by the traits + let versioned_type = attributes.versioned_type(&lifetime, &input.generics); + let versioned_owned_type = attributes.versioned_owned_type(&input.generics); + let versioned_type_where_clause = + attributes.versioned_type_where_clause(&lifetime, &input.generics); + let versioned_owned_type_where_clause = + attributes.versioned_owned_type_where_clause(&input.generics); + + // If the original type has some generics, we need to add bounds on them for + // the traits impl + let versionize_trait_where_clause = + syn_unwrap!(attributes.versionize_trait_where_clause(&input.generics)); + let versionize_owned_trait_where_clause = + syn_unwrap!(attributes.versionize_owned_trait_where_clause(&input.generics)); + let unversionize_trait_where_clause = + syn_unwrap!(attributes.unversionize_trait_where_clause(&input.generics)); + + let trait_impl_generics = input.generics.split_for_impl().0; let versionize_body = attributes.versionize_method_body(); + let versionize_owned_body = attributes.versionize_owned_method_body(); let unversionize_arg_name = Ident::new("versioned", Span::call_site()); let unversionize_body = attributes.unversionize_method_body(&unversionize_arg_name); let unversionize_error: Path = parse_const_str(UNVERSIONIZE_ERROR_NAME); @@ -210,11 +191,9 @@ pub fn derive_versionize(input: TokenStream) -> TokenStream { #[automatically_derived] impl #trait_impl_generics #versionize_trait for #input_ident #ty_generics - #trait_where_clause + #versionize_trait_where_clause { - type Versioned<#lifetime> = - <#dispatch_enum_path #dispatch_generics as - #dispatch_trait<#dispatch_target>>::#versioned_type_kind; + type Versioned<#lifetime> = #versioned_type #versioned_type_where_clause; fn versionize(&self) -> Self::Versioned<'_> { #versionize_body @@ -223,20 +202,18 @@ pub fn derive_versionize(input: TokenStream) -> TokenStream { #[automatically_derived] impl #trait_impl_generics #versionize_owned_trait for #input_ident #ty_generics - #trait_where_clause + #versionize_owned_trait_where_clause { - type VersionedOwned = - <#dispatch_enum_path #dispatch_generics as - #dispatch_trait<#dispatch_target>>::Owned #owned_where_clause; + type VersionedOwned = #versioned_owned_type #versioned_owned_type_where_clause; fn versionize_owned(self) -> Self::VersionedOwned { - #versionize_body + #versionize_owned_body } } #[automatically_derived] impl #trait_impl_generics #unversionize_trait for #input_ident #ty_generics - #trait_where_clause + #unversionize_trait_where_clause { fn unversionize(#unversionize_arg_name: Self::VersionedOwned) -> Result { #unversionize_body @@ -245,20 +222,21 @@ pub fn derive_versionize(input: TokenStream) -> TokenStream { #[automatically_derived] impl #trait_impl_generics #versionize_slice_trait for #input_ident #ty_generics - #trait_where_clause + #versionize_trait_where_clause { - type VersionedSlice<#lifetime> = Vec<::Versioned<#lifetime>> #ref_where_clause; + type VersionedSlice<#lifetime> = Vec<::Versioned<#lifetime>> #versioned_type_where_clause; fn versionize_slice(slice: &[Self]) -> Self::VersionedSlice<'_> { slice.iter().map(|val| #versionize_trait::versionize(val)).collect() } } + #[automatically_derived] impl #trait_impl_generics #versionize_vec_trait for #input_ident #ty_generics - #trait_where_clause + #versionize_owned_trait_where_clause { - type VersionedVec = Vec<::VersionedOwned> #owned_where_clause; + type VersionedVec = Vec<::VersionedOwned> #versioned_owned_type_where_clause; fn versionize_vec(vec: Vec) -> Self::VersionedVec { vec.into_iter().map(|val| #versionize_owned_trait::versionize_owned(val)).collect() @@ -267,7 +245,8 @@ pub fn derive_versionize(input: TokenStream) -> TokenStream { #[automatically_derived] impl #trait_impl_generics #unversionize_vec_trait for #input_ident #ty_generics - #trait_where_clause { + #unversionize_trait_where_clause + { fn unversionize_vec(versioned: Self::VersionedVec) -> Result, #unversionize_error> { versioned .into_iter() @@ -335,7 +314,7 @@ pub fn derive_not_versioned(input: TokenStream) -> TokenStream { } /// Adds a where clause with a lifetime bound on all the generic types and lifetimes in `generics` -fn add_where_lifetime_bound(generics: &mut Generics, lifetime: &Lifetime) { +fn add_where_lifetime_bound_to_generics(generics: &mut Generics, lifetime: &Lifetime) { let mut params = Vec::new(); for param in generics.params.iter() { let param_ident = match param { @@ -359,8 +338,8 @@ fn add_where_lifetime_bound(generics: &mut Generics, lifetime: &Lifetime) { } } -/// Adds a lifetime bound for all the generic types in `generics` -fn add_lifetime_bound(generics: &mut Generics, lifetime: &Lifetime) { +/// Adds a new lifetime param with a bound for all the generic types in `generics` +fn add_lifetime_param(generics: &mut Generics, lifetime: &Lifetime) { generics .params .push(GenericParam::Lifetime(LifetimeParam::new(lifetime.clone()))); @@ -398,6 +377,27 @@ fn add_trait_where_clause<'a, S: AsRef, I: IntoIterator>( Ok(()) } +/// Adds a "where clause" bound for all the input types with all the input lifetimes +fn add_lifetime_where_clause<'a, S: AsRef, I: IntoIterator>( + generics: &mut Generics, + types: I, + lifetimes: &[S], +) -> syn::Result<()> { + let preds = &mut generics.make_where_clause().predicates; + + if !lifetimes.is_empty() { + let bounds: Vec = lifetimes + .iter() + .map(|lifetime| syn::parse_str(lifetime.as_ref())) + .collect::>()?; + for ty in types { + preds.push(parse_quote! { #ty: #(#bounds)+* }); + } + } + + Ok(()) +} + /// Extends a where clause with predicates from another one, filtering duplicates fn extend_where_clause(base_clause: &mut WhereClause, extension_clause: &WhereClause) { for extend_predicate in &extension_clause.predicates { diff --git a/utils/tfhe-versionable-derive/src/versionize_attribute.rs b/utils/tfhe-versionable-derive/src/versionize_attribute.rs index 0a0334ddb7..a950dc6cef 100644 --- a/utils/tfhe-versionable-derive/src/versionize_attribute.rs +++ b/utils/tfhe-versionable-derive/src/versionize_attribute.rs @@ -2,47 +2,146 @@ use proc_macro2::Span; use quote::{quote, ToTokens}; use syn::punctuated::Punctuated; use syn::spanned::Spanned; -use syn::{Attribute, Expr, Ident, Lit, Meta, Path, Token, TraitBound, Type}; +use syn::{ + parse_quote, Attribute, Expr, GenericArgument, GenericParam, Generics, Ident, Lifetime, Lit, + Meta, Path, PathArguments, Token, TraitBound, Type, TypeParam, WhereClause, +}; -use crate::{parse_const_str, UNVERSIONIZE_ERROR_NAME, VERSIONIZE_OWNED_TRAIT_NAME}; +use crate::{ + add_lifetime_where_clause, add_trait_where_clause, add_where_lifetime_bound_to_generics, + parse_const_str, DISPATCH_TRAIT_NAME, ERROR_TRAIT_NAME, FROM_TRAIT_NAME, INTO_TRAIT_NAME, + SEND_TRAIT_NAME, STATIC_LIFETIME_NAME, SYNC_TRAIT_NAME, TRY_INTO_TRAIT_NAME, + UNVERSIONIZE_ERROR_NAME, UNVERSIONIZE_TRAIT_NAME, VERSIONIZE_OWNED_TRAIT_NAME, +}; /// Name of the attribute used to give arguments to the `Versionize` macro const VERSIONIZE_ATTR_NAME: &str = "versionize"; -pub(crate) struct VersionizeAttribute { +pub(crate) struct ClassicVersionizeAttribute { dispatch_enum: Path, - from: Option, - try_from: Option, - into: Option, +} + +pub(crate) enum ConversionType { + Direct, + Try, +} + +pub(crate) struct ConvertVersionizeAttribute { + conversion_target: Path, + conversion_type: ConversionType, +} + +pub(crate) enum VersionizeAttribute { + Classic(ClassicVersionizeAttribute), + Convert(ConvertVersionizeAttribute), } #[derive(Default)] struct VersionizeAttributeBuilder { dispatch_enum: Option, + convert: Option, + try_convert: Option, from: Option, try_from: Option, into: Option, } impl VersionizeAttributeBuilder { - fn build(self) -> Option { - // These attributes are mutually exclusive - if self.from.is_some() && self.try_from.is_some() { - return None; + fn build(self, base_span: &Span) -> syn::Result { + let convert_is_try = self.try_convert.is_some() || self.try_from.is_some(); + // User should not use `from` and `try_from` at the same time + let from_target = match (self.from, self.try_from) { + (None, None) => None, + (Some(_), Some(try_from)) => { + return Err(syn::Error::new( + try_from.span(), + "'try_from' and 'from' attributes are mutually exclusive", + )) + } + (None, Some(try_from)) => Some(try_from), + (Some(from), None) => Some(from), + }; + + // Same with `convert`/`try_convert` + let convert_target = match (self.convert, self.try_convert) { + (None, None) => None, + (Some(_), Some(try_convert)) => { + return Err(syn::Error::new( + try_convert.span(), + "'try_convert' and 'convert' attributes are mutually exclusive", + )) + } + (None, Some(try_convert)) => Some(try_convert), + (Some(convert), None) => Some(convert), + }; + + // from/into are here for similarity with serde, but we don't actually support having + // different target inside. So we check this to warn the user + let from_target = + match (from_target, self.into) { + (None, None) => None, + (None, Some(into)) => return Err(syn::Error::new( + into.span(), + "unidirectional conversions are not handled, please add a 'from'/'try_from' \ +attribute or use the 'convert'/'try_convert' attribute instead", + )), + (Some(from), None) => return Err(syn::Error::new( + from.span(), + "unidirectional conversions are not handled, please add a 'into' attribute or \ +use the 'convert'/'try_convert' attribute instead", + )), + (Some(from), Some(into)) => { + if format!("{}", from.to_token_stream()) + != format!("{}", into.to_token_stream()) + { + return Err(syn::Error::new( + from.span(), + "unidirectional conversions are not handled, 'from' and 'into' parameters \ +should have the same value", + )); + } else { + Some(from) + } + } + }; + + // Finally, checks that the user doesn't use both from/into and convert + let conversion_target = match (from_target, convert_target) { + (None, None) => None, + (Some(_), Some(convert)) => { + return Err(syn::Error::new( + convert.span(), + "'convert' and 'from'/'into' attributes are mutually exclusive", + )) + } + (None, Some(convert)) => Some(convert), + (Some(from), None) => Some(from), + }; + + if let Some(conversion_target) = conversion_target { + Ok(VersionizeAttribute::Convert(ConvertVersionizeAttribute { + conversion_target, + conversion_type: if convert_is_try { + ConversionType::Try + } else { + ConversionType::Direct + }, + })) + } else { + Ok(VersionizeAttribute::Classic(ClassicVersionizeAttribute { + dispatch_enum: self.dispatch_enum.ok_or(syn::Error::new( + *base_span, + "Missing dispatch enum argument", + ))?, + })) } - Some(VersionizeAttribute { - dispatch_enum: self.dispatch_enum?, - from: self.from, - try_from: self.try_from, - into: self.into, - }) } } impl VersionizeAttribute { /// Find and parse an attribute with the form `#[versionize(DispatchType)]`, where /// `DispatchType` is the name of the type holding the dispatch enum. - /// Returns an error if no `versionize` attribute has been found, if multiple attributes are + /// Return an error if no `versionize` attribute has been found, if multiple attributes are /// present on the same struct or if the attribute is malformed. pub(crate) fn parse_from_attributes_list( attributes: &[Attribute], @@ -82,8 +181,24 @@ impl VersionizeAttribute { } } Meta::NameValue(name_value) => { + // parse versionize(convert = "TypeConvert") + if name_value.path.is_ident("convert") { + if attribute_builder.convert.is_some() { + return Err(Self::default_error(meta.span())); + } else { + attribute_builder.convert = + Some(parse_path_ignore_quotes(&name_value.value)?); + } + // parse versionize(try_convert = "TypeTryConvert") + } else if name_value.path.is_ident("try_convert") { + if attribute_builder.try_convert.is_some() { + return Err(Self::default_error(meta.span())); + } else { + attribute_builder.try_convert = + Some(parse_path_ignore_quotes(&name_value.value)?); + } // parse versionize(from = "TypeFrom") - if name_value.path.is_ident("from") { + } else if name_value.path.is_ident("from") { if attribute_builder.from.is_some() { return Err(Self::default_error(meta.span())); } else { @@ -122,60 +237,289 @@ impl VersionizeAttribute { } } - attribute_builder - .build() - .ok_or_else(|| Self::default_error(attribute.span())) + attribute_builder.build(&attribute.span()) } - pub(crate) fn dispatch_enum(&self) -> &Path { - &self.dispatch_enum + pub(crate) fn needs_conversion(&self) -> bool { + match self { + VersionizeAttribute::Classic(_) => false, + VersionizeAttribute::Convert(_) => true, + } } - pub(crate) fn needs_conversion(&self) -> bool { - self.try_from.is_some() || self.from.is_some() + /// Return the associated type used in the `Versionize` trait: `MyType::Versioned<'vers>` + /// + /// If the type is directly versioned, this will be a type generated by the `VersionDispatch`. + /// + /// If we have a conversion before the versioning, we re-use the versioned_owned type of the + /// conversion target. The versioned_owned is needed because the conversion will create a new + /// value, so we can't just use a reference. + pub(crate) fn versioned_type( + &self, + lifetime: &Lifetime, + input_generics: &Generics, + ) -> proc_macro2::TokenStream { + match self { + VersionizeAttribute::Classic(attr) => { + let (_, ty_generics, _) = input_generics.split_for_impl(); + + let dispatch_trait: Path = parse_const_str(DISPATCH_TRAIT_NAME); + let dispatch_enum_path = &attr.dispatch_enum; + quote! { + <#dispatch_enum_path #ty_generics as + #dispatch_trait>::Ref<#lifetime> + } + } + VersionizeAttribute::Convert(_) => { + // If we want to apply a conversion before the call to versionize we need to use the + // "owned" alternative of the dispatch enum to be able to store the + // conversion result. + self.versioned_owned_type(input_generics) + } + } + } + + /// Return the where clause for `MyType::Versioned<'vers>`. if `MyType` has generics, this means + /// adding a 'vers lifetime bound on them. + pub(crate) fn versioned_type_where_clause( + &self, + lifetime: &Lifetime, + input_generics: &Generics, + ) -> Option { + let mut generics = input_generics.clone(); + + add_where_lifetime_bound_to_generics(&mut generics, lifetime); + let (_, _, where_clause) = generics.split_for_impl(); + where_clause.cloned() + } + + /// Return the associated type used in the `VersionizeOwned` trait: `MyType::VersionedOwned` + /// + /// If the type is directly versioned, this will be a type generated by the `VersionDispatch`. + /// + /// If we have a conversion before the versioning, we re-use the versioned_owned type of the + /// conversion target. + pub(crate) fn versioned_owned_type( + &self, + input_generics: &Generics, + ) -> proc_macro2::TokenStream { + let (_, ty_generics, _) = input_generics.split_for_impl(); + match self { + VersionizeAttribute::Classic(attr) => { + let dispatch_trait: Path = parse_const_str(DISPATCH_TRAIT_NAME); + let dispatch_enum_path = &attr.dispatch_enum; + quote! { + <#dispatch_enum_path #ty_generics as + #dispatch_trait>::Owned + } + } + VersionizeAttribute::Convert(convert_attr) => { + let convert_type_path = &convert_attr.conversion_target; + let versionize_owned_trait: Path = parse_const_str(VERSIONIZE_OWNED_TRAIT_NAME); + + quote! { + <#convert_type_path as #versionize_owned_trait>::VersionedOwned + } + } + } + } + + /// Return the where clause for `MyType::VersionedOwned`. + /// + /// This is simply the where clause of the input type. + pub(crate) fn versioned_owned_type_where_clause( + &self, + input_generics: &Generics, + ) -> Option { + match self { + VersionizeAttribute::Classic(_) => input_generics.split_for_impl().2.cloned(), + VersionizeAttribute::Convert(convert_attr) => { + extract_generics(&convert_attr.conversion_target) + .split_for_impl() + .2 + .cloned() + } + } } - pub(crate) fn dispatch_target(&self) -> Path { - self.from - .as_ref() - .or(self.try_from.as_ref()) - .map(|target| target.to_owned()) - .unwrap_or_else(|| { - syn::parse_str("Self").expect("Parsing of const value should never fail") - }) + /// Return the where clause needed to implement the Versionize trait. + /// + /// This is the same as the one for the VersionizeOwned, with an additional "Clone" bound in the + /// case where we need to perform a conversion before the versioning. + pub(crate) fn versionize_trait_where_clause( + &self, + input_generics: &Generics, + ) -> syn::Result> { + // The base bounds for the owned traits are also used for the ref traits + let mut generics = input_generics.clone(); + if self.needs_conversion() { + // The versionize method takes a ref. We need to own the input type in the conversion + // case to apply `From for Target`. This adds a `Clone` bound to have + // a better error message if the input type is not Clone. + add_trait_where_clause(&mut generics, [&parse_quote! { Self }], &["Clone"])?; + } + + self.versionize_owned_trait_where_clause(&generics) + } + + /// Return the where clause needed to implement the VersionizeOwned trait. + /// + /// If the type is directly versioned, the bound states that the argument points to a valid + /// DispatchEnum for this type. This is done by adding a bound on this argument to + /// `VersionsDisaptch`. + /// + /// If there is a conversion, the target of the conversion should implement `VersionizeOwned` + /// and `From`. + pub(crate) fn versionize_owned_trait_where_clause( + &self, + input_generics: &Generics, + ) -> syn::Result> { + let mut generics = input_generics.clone(); + match self { + VersionizeAttribute::Classic(attr) => { + let dispatch_generics = generics.clone(); + let dispatch_ty_generics = dispatch_generics.split_for_impl().1; + let dispatch_enum_path = &attr.dispatch_enum; + + add_trait_where_clause( + &mut generics, + [&parse_quote!(#dispatch_enum_path #dispatch_ty_generics)], + &[format!("{}", DISPATCH_TRAIT_NAME,)], + )?; + } + VersionizeAttribute::Convert(convert_attr) => { + let convert_type_path = &convert_attr.conversion_target; + add_trait_where_clause( + &mut generics, + [&parse_quote!(#convert_type_path)], + &[ + VERSIONIZE_OWNED_TRAIT_NAME, + &format!("{}", FROM_TRAIT_NAME), + ], + )?; + } + } + + Ok(generics.split_for_impl().2.cloned()) } + /// Return the where clause for the `Unversionize` trait. + /// + /// If the versioning is direct, this is the same bound as the one used for `VersionizeOwned`. + /// + /// If there is a conversion, the target of the conversion need to implement `Unversionize` and + /// `Into` or `TryInto`, with `E: Error + Send + Sync + 'static` + pub(crate) fn unversionize_trait_where_clause( + &self, + input_generics: &Generics, + ) -> syn::Result> { + match self { + VersionizeAttribute::Classic(_) => { + self.versionize_owned_trait_where_clause(input_generics) + } + VersionizeAttribute::Convert(convert_attr) => { + let mut generics = input_generics.clone(); + let convert_type_path = &convert_attr.conversion_target; + let into_trait = match convert_attr.conversion_type { + ConversionType::Direct => format!("{}", INTO_TRAIT_NAME), + ConversionType::Try => { + // Doing a TryFrom requires that the error + // impl Error + Send + Sync + 'static + let try_into_trait: Path = parse_const_str(TRY_INTO_TRAIT_NAME); + add_trait_where_clause( + &mut generics, + [&parse_quote!(<#convert_type_path as #try_into_trait>::Error)], + &[ERROR_TRAIT_NAME, SYNC_TRAIT_NAME, SEND_TRAIT_NAME], + )?; + add_lifetime_where_clause( + &mut generics, + [&parse_quote!(<#convert_type_path as #try_into_trait>::Error)], + &[STATIC_LIFETIME_NAME], + )?; + + format!("{}", TRY_INTO_TRAIT_NAME) + } + }; + add_trait_where_clause( + &mut generics, + [&parse_quote!(#convert_type_path)], + &[ + UNVERSIONIZE_TRAIT_NAME, + &format!("{}", FROM_TRAIT_NAME), + &into_trait, + ], + )?; + + Ok(generics.split_for_impl().2.cloned()) + } + } + } + + /// Return the body of the versionize method. pub(crate) fn versionize_method_body(&self) -> proc_macro2::TokenStream { let versionize_owned_trait: TraitBound = parse_const_str(VERSIONIZE_OWNED_TRAIT_NAME); - self.into - .as_ref() - .map(|target| { + + match self { + VersionizeAttribute::Classic(_) => { quote! { - #versionize_owned_trait::versionize_owned(#target::from(self.to_owned())) + self.into() } - }) - .unwrap_or_else(|| { + } + VersionizeAttribute::Convert(convert_attr) => { + let convert_type_path = with_turbofish(&convert_attr.conversion_target); + quote! { + #versionize_owned_trait::versionize_owned(#convert_type_path::from(self.to_owned())) + } + } + } + } + + /// Return the body of the versionize_owned method. + pub(crate) fn versionize_owned_method_body(&self) -> proc_macro2::TokenStream { + let versionize_owned_trait: TraitBound = parse_const_str(VERSIONIZE_OWNED_TRAIT_NAME); + + match self { + VersionizeAttribute::Classic(_) => { quote! { self.into() } - }) + } + VersionizeAttribute::Convert(convert_attr) => { + let convert_type_path = with_turbofish(&convert_attr.conversion_target); + quote! { + #versionize_owned_trait::versionize_owned(#convert_type_path::from(self)) + } + } + } } + /// Return the body of the unversionize method. pub(crate) fn unversionize_method_body(&self, arg_name: &Ident) -> proc_macro2::TokenStream { let error: Type = parse_const_str(UNVERSIONIZE_ERROR_NAME); - if let Some(target) = &self.from { - quote! { #target::unversionize(#arg_name).map(|value| value.into()) } - } else if let Some(target) = &self.try_from { - let target_name = format!("{}", target.to_token_stream()); - quote! { #target::unversionize(#arg_name).and_then(|value| TryInto::::try_into(value) - .map_err(|e| #error::conversion(#target_name, e))) + match self { + VersionizeAttribute::Classic(_) => { + quote! { #arg_name.try_into() } + } + VersionizeAttribute::Convert(convert_attr) => { + let target = with_turbofish(&convert_attr.conversion_target); + match convert_attr.conversion_type { + ConversionType::Direct => { + quote! { #target::unversionize(#arg_name).map(|value| value.into()) } + } + ConversionType::Try => { + let target_name = format!("{}", target.to_token_stream()); + quote! { #target::unversionize(#arg_name).and_then(|value| TryInto::::try_into(value) + .map_err(|e| #error::conversion(#target_name, e))) + } + } + } } - } else { - quote! { #arg_name.try_into() } } } } +/// Allow the user to give type arguments as `#[versionize(MyType)]` as well as +/// `#[versionize("MyType")]` fn parse_path_ignore_quotes(value: &Expr) -> syn::Result { match &value { Expr::Path(expr_path) => Ok(expr_path.path.clone()), @@ -192,3 +536,37 @@ fn parse_path_ignore_quotes(value: &Expr) -> syn::Result { )), } } + +/// Return the same type but with generics that use the turbofish syntax. Converts +/// `MyStruct` into `MyStruct::` +fn with_turbofish(path: &Path) -> Path { + let mut with_turbo = path.clone(); + + for segment in with_turbo.segments.iter_mut() { + if let PathArguments::AngleBracketed(generics) = &mut segment.arguments { + generics.colon2_token = Some(Token![::](generics.span())); + } + } + + with_turbo +} + +/// Extract the generics inside a type +fn extract_generics(path: &Path) -> Generics { + let mut generics = Generics::default(); + + if let Some(last_segment) = path.segments.last() { + if let PathArguments::AngleBracketed(args) = &last_segment.arguments { + for arg in &args.args { + if let GenericArgument::Type(Type::Path(type_path)) = arg { + if let Some(ident) = type_path.path.get_ident() { + let param = TypeParam::from(ident.clone()); + generics.params.push(GenericParam::Type(param)); + } + } + } + } + } + + generics +} diff --git a/utils/tfhe-versionable/examples/convert.rs b/utils/tfhe-versionable/examples/convert.rs index 230fe02c53..92029c7101 100644 --- a/utils/tfhe-versionable/examples/convert.rs +++ b/utils/tfhe-versionable/examples/convert.rs @@ -3,48 +3,57 @@ use tfhe_versionable::{Unversionize, Versionize, VersionsDispatch}; #[derive(Clone, Versionize)] -#[versionize(SerializableMyStructVersions, from = SerializableMyStruct, into = SerializableMyStruct)] -struct MyStruct { +// To mimic serde parameters, this can also be expressed as +// "#[versionize(from = SerializableMyStruct, into = SerializableMyStruct)]" +#[versionize(convert = "SerializableMyStruct")] +struct MyStruct { val: u64, + generics: T, } #[derive(Versionize)] #[versionize(SerializableMyStructVersions)] -struct SerializableMyStruct { +struct SerializableMyStruct { high: u32, low: u32, + generics: T, } #[derive(VersionsDispatch)] #[allow(unused)] -enum SerializableMyStructVersions { - V0(SerializableMyStruct), +enum SerializableMyStructVersions { + V0(SerializableMyStruct), } -impl From for SerializableMyStruct { - fn from(value: MyStruct) -> Self { - println!("{}", value.val); +impl From> for SerializableMyStruct { + fn from(value: MyStruct) -> Self { Self { high: (value.val >> 32) as u32, low: (value.val & 0xffffffff) as u32, + generics: value.generics, } } } -impl From for MyStruct { - fn from(value: SerializableMyStruct) -> Self { +impl From> for MyStruct { + fn from(value: SerializableMyStruct) -> Self { Self { val: ((value.high as u64) << 32) | (value.low as u64), + generics: value.generics, } } } fn main() { - let stru = MyStruct { val: 37 }; + let stru = MyStruct { + val: 37, + generics: 90, + }; let serialized = bincode::serialize(&stru.versionize()).unwrap(); - let stru_decoded = MyStruct::unversionize(bincode::deserialize(&serialized).unwrap()).unwrap(); + let stru_decoded: MyStruct = + MyStruct::unversionize(bincode::deserialize(&serialized).unwrap()).unwrap(); assert_eq!(stru.val, stru_decoded.val) } diff --git a/utils/tfhe-versionable/tests/convert_with_bounds.rs b/utils/tfhe-versionable/tests/convert_with_bounds.rs new file mode 100644 index 0000000000..7ed3947338 --- /dev/null +++ b/utils/tfhe-versionable/tests/convert_with_bounds.rs @@ -0,0 +1,51 @@ +//! Checks compatibility between the "convert" feature and bounds on the From/Into trait + +use tfhe_versionable::{Unversionize, Versionize, VersionsDispatch}; + +#[derive(Clone, Versionize)] +#[versionize(try_convert = "SerializableMyStruct")] +struct MyStruct { + generics: T, +} + +#[derive(Versionize)] +#[versionize(SerializableMyStructVersions)] +struct SerializableMyStruct { + concrete: u64, +} + +#[derive(VersionsDispatch)] +#[allow(unused)] +enum SerializableMyStructVersions { + V0(SerializableMyStruct), +} + +impl> From> for SerializableMyStruct { + fn from(value: MyStruct) -> Self { + Self { + concrete: value.generics.into(), + } + } +} + +impl> TryFrom for MyStruct { + fn try_from(value: SerializableMyStruct) -> Result { + Ok(Self { + generics: value.concrete.try_into()?, + }) + } + + type Error = T::Error; +} + +#[test] +fn test() { + let stru = MyStruct { generics: 90u32 }; + + let serialized = bincode::serialize(&stru.versionize()).unwrap(); + + let stru_decoded: MyStruct = + MyStruct::unversionize(bincode::deserialize(&serialized).unwrap()).unwrap(); + + assert_eq!(stru.generics, stru_decoded.generics) +} diff --git a/utils/tfhe-versionable/tests/convert_with_generics.rs b/utils/tfhe-versionable/tests/convert_with_generics.rs new file mode 100644 index 0000000000..ac1359e8bf --- /dev/null +++ b/utils/tfhe-versionable/tests/convert_with_generics.rs @@ -0,0 +1,58 @@ +//! Checks compatibility between the "convert" feature and generics + +use tfhe_versionable::{Unversionize, Versionize, VersionsDispatch}; + +#[derive(Clone, Versionize)] +#[versionize(convert = "SerializableMyStruct")] +struct MyStruct { + val: u64, + generics: T, +} + +#[derive(Versionize)] +#[versionize(SerializableMyStructVersions)] +struct SerializableMyStruct { + high: u32, + low: u32, + generics: T, +} + +#[derive(VersionsDispatch)] +#[allow(unused)] +enum SerializableMyStructVersions { + V0(SerializableMyStruct), +} + +impl From> for SerializableMyStruct { + fn from(value: MyStruct) -> Self { + Self { + high: (value.val >> 32) as u32, + low: (value.val & 0xffffffff) as u32, + generics: value.generics, + } + } +} + +impl From> for MyStruct { + fn from(value: SerializableMyStruct) -> Self { + Self { + val: ((value.high as u64) << 32) | (value.low as u64), + generics: value.generics, + } + } +} + +#[test] +fn test() { + let stru = MyStruct { + val: 37, + generics: 90, + }; + + let serialized = bincode::serialize(&stru.versionize()).unwrap(); + + let stru_decoded: MyStruct = + MyStruct::unversionize(bincode::deserialize(&serialized).unwrap()).unwrap(); + + assert_eq!(stru.val, stru_decoded.val) +}