diff --git a/derive/examples/extension-dispatch.rs b/derive/examples/extension-dispatch.rs index bcc4862c13b..256dfc13621 100644 --- a/derive/examples/extension-dispatch.rs +++ b/derive/examples/extension-dispatch.rs @@ -1,7 +1,9 @@ use trussed::Error; mod backends { - use super::extensions::{TestExtension, TestReply, TestRequest}; + use super::extensions::{ + SampleExtension, SampleReply, SampleRequest, TestExtension, TestReply, TestRequest, + }; use trussed::{ backend::Backend, platform::Platform, serde_extensions::ExtensionImpl, service::ServiceResources, types::CoreContext, Error, @@ -26,6 +28,18 @@ mod backends { } } + impl ExtensionImpl for ABackend { + fn extension_request( + &mut self, + _core_ctx: &mut CoreContext, + _backend_ctx: &mut Self::Context, + _request: &SampleRequest, + _resources: &mut ServiceResources

, + ) -> Result { + Ok(SampleReply) + } + } + #[derive(Default)] pub struct BBackend; @@ -88,6 +102,7 @@ mod extensions { enum Backend { A, + ASample, B, } @@ -106,6 +121,11 @@ enum Extension { struct Dispatch { #[extensions("Test")] a: backends::ABackend, + + #[dispatch(delegate_to = "a")] + #[extensions("Sample")] + a_sample: (), + b: backends::BBackend, } @@ -135,5 +155,9 @@ fn main() { &[BackendId::Custom(Backend::B)], Some(Error::RequestNotAvailable), ); + run( + &[BackendId::Custom(Backend::ASample)], + Some(Error::RequestNotAvailable), + ); run(&[BackendId::Custom(Backend::A)], None); } diff --git a/derive/src/extension_dispatch.rs b/derive/src/extension_dispatch.rs index 54d0e2a03af..09f073e28d9 100644 --- a/derive/src/extension_dispatch.rs +++ b/derive/src/extension_dispatch.rs @@ -15,6 +15,7 @@ pub struct ExtensionDispatch { dispatch_attrs: DispatchAttrs, extension_attrs: ExtensionAttrs, backends: Vec, + delegated_backends: Vec, } impl ExtensionDispatch { @@ -27,11 +28,29 @@ impl ExtensionDispatch { }; let dispatch_attrs = DispatchAttrs::new(&input)?; let extension_attrs = ExtensionAttrs::new(&input)?; - let backends = data_struct + let raw_backends: Vec<_> = data_struct .fields .iter() - .enumerate() - .map(|(i, field)| Backend::new(i, field, &extension_attrs.extensions)) + .map(RawBackend::new) + .collect::>()?; + let mut backends = Vec::new(); + let mut delegated_backends = Vec::new(); + for raw_backend in raw_backends { + if let Some(delegate_to) = raw_backend.delegate_to.clone() { + delegated_backends.push((raw_backend, delegate_to)); + } else { + backends.push(Backend::new( + backends.len(), + raw_backend, + &extension_attrs.extensions, + )?); + } + } + let delegated_backends = delegated_backends + .into_iter() + .map(|(raw, delegate_to)| { + DelegatedBackend::new(raw, delegate_to, &backends, &extension_attrs.extensions) + }) .collect::>()?; Ok(Self { name: input.ident, @@ -39,6 +58,7 @@ impl ExtensionDispatch { dispatch_attrs, extension_attrs, backends, + delegated_backends, }) } @@ -49,7 +69,15 @@ impl ExtensionDispatch { let (impl_generics, ty_generics, where_clause) = self.generics.split_for_impl(); let context = self.backends.iter().map(Backend::context); let requests = self.backends.iter().map(Backend::request); + let delegated_requests = self + .delegated_backends + .iter() + .map(DelegatedBackend::request); let extension_requests = self.backends.iter().map(Backend::extension_request); + let delegated_extension_requests = self + .delegated_backends + .iter() + .map(DelegatedBackend::extension_request); let extension_impls = self .extension_attrs .extensions @@ -71,6 +99,7 @@ impl ExtensionDispatch { ) -> ::core::result::Result<::trussed::api::Reply, ::trussed::error::Error> { match backend { #(#requests)* + #(#delegated_requests)* } } @@ -84,6 +113,7 @@ impl ExtensionDispatch { ) -> ::core::result::Result<::trussed::api::reply::SerdeExtension, ::trussed::error::Error> { match backend { #(#extension_requests)* + #(#delegated_extension_requests)* } } } @@ -165,32 +195,70 @@ impl ExtensionAttrs { } } -struct Backend { +struct RawBackend { id: Ident, field: Ident, ty: Type, - index: Index, - extensions: Vec, + delegate_to: Option, + extensions: Vec, } -impl Backend { - fn new(i: usize, field: &Field, extension_types: &HashMap) -> Result { +impl RawBackend { + fn new(field: &Field) -> Result { let ident = field.ident.clone().ok_or_else(|| { Error::new_spanned( field, "ExtensionDispatch can only be derived for a struct with named fields", ) })?; + let mut delegate_to = None; + for attr in util::get_attrs(&field.attrs, "dispatch") { + attr.parse_nested_meta(|meta| { + if meta.path.is_ident("delegate_to") { + let s: LitStr = meta.value()?.parse()?; + delegate_to = Some(s.parse()?); + Ok(()) + } else { + Err(meta.error("unsupported dispatch attribute")) + } + })?; + } let mut extensions = Vec::new(); for attr in util::get_attrs(&field.attrs, "extensions") { for s in attr.parse_args_with(Punctuated::::parse_terminated)? { - extensions.push(Extension::new(&s, extension_types)?); + extensions.push(s.parse()?); } } Ok(Self { id: util::to_camelcase(&ident), field: ident, ty: field.ty.clone(), + delegate_to, + extensions, + }) + } +} + +#[derive(Clone)] +struct Backend { + id: Ident, + field: Ident, + ty: Type, + index: Index, + extensions: Vec, +} + +impl Backend { + fn new(i: usize, raw: RawBackend, extensions: &HashMap) -> Result { + let extensions = raw + .extensions + .into_iter() + .map(|i| Extension::new(i, extensions)) + .collect::>()?; + Ok(Self { + id: raw.id, + field: raw.field, + ty: raw.ty, index: Index::from(i), extensions, }) @@ -224,17 +292,98 @@ impl Backend { } } +struct DelegatedBackend { + id: Ident, + field: Ident, + backend: Backend, + extensions: Vec, +} + +impl DelegatedBackend { + fn new( + raw: RawBackend, + delegate_to: Ident, + backends: &[Backend], + extensions: &HashMap, + ) -> Result { + match raw.ty { + Type::Tuple(tuple) if tuple.elems.is_empty() => (), + _ => { + return Err(Error::new_spanned( + &raw.ty, + "delegated backends must use the unit type ()", + )); + } + } + + let extensions = raw + .extensions + .into_iter() + .map(|i| Extension::new(i, extensions)) + .collect::>()?; + let backend = backends + .iter() + .find(|backend| backend.field == delegate_to) + .ok_or_else(|| Error::new_spanned(delegate_to, "unknown backend"))? + .clone(); + Ok(Self { + id: raw.id, + field: raw.field, + backend, + extensions, + }) + } + + fn request(&self) -> TokenStream { + let Self { + id, backend, field, .. + } = self; + let Backend { + field: delegated_field, + index: delegated_index, + .. + } = backend; + quote! { + Self::BackendId::#id => { + let _ = self.#field; + ::trussed::backend::Backend::request( + &mut self.#delegated_field, &mut ctx.core, &mut ctx.backends.#delegated_index, request, resources, + ) + } + } + } + + fn extension_request(&self) -> TokenStream { + let Self { + id, + extensions, + backend, + field, + } = self; + let extension_requests = extensions.iter().map(|e| e.extension_request(backend)); + quote! { + Self::BackendId::#id => { + let _ = self.#field; + match extension { + #(#extension_requests)* + _ => Err(::trussed::error::Error::RequestNotAvailable), + } + } + } + } +} + +#[derive(Clone)] struct Extension { id: Ident, ty: Path, } impl Extension { - fn new(s: &LitStr, extensions: &HashMap) -> Result { - let id = s.parse()?; + fn new(id: Ident, extensions: &HashMap) -> Result { let ty = extensions .get(&id) - .ok_or_else(|| Error::new_spanned(s, "unknown extension ID"))? + .ok_or_else(|| Error::new_spanned(&id, "unknown extension ID"))? .clone(); Ok(Self { id, ty }) }