Skip to content

Commit

Permalink
derive: Add delegate_to option for ExtensionDispatch
Browse files Browse the repository at this point in the history
This patch adds a delegate_to option to the ExtensionDispatch derive
macro that makes it possible to create an alias for a backend using a
different set of extensions.  To avoid misunderstandings, delegating
backends are forced to use the unit type ().
  • Loading branch information
robin-nitrokey committed Mar 25, 2024
1 parent 9ca8866 commit ec23e99
Show file tree
Hide file tree
Showing 2 changed files with 186 additions and 13 deletions.
26 changes: 25 additions & 1 deletion derive/examples/extension-dispatch.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
use trussed::Error;

mod backends {
use super::extensions::{TestExtension, TestReply, TestRequest};
use super::extensions::{
SampleExtension, SampleReply, SampleRequest, TestExtension, TestReply, TestRequest,
};
use trussed::{
backend::Backend, platform::Platform, serde_extensions::ExtensionImpl,
service::ServiceResources, types::CoreContext, Error,
Expand All @@ -26,6 +28,18 @@ mod backends {
}
}

impl ExtensionImpl<SampleExtension> for ABackend {
fn extension_request<P: Platform>(
&mut self,
_core_ctx: &mut CoreContext,
_backend_ctx: &mut Self::Context,
_request: &SampleRequest,
_resources: &mut ServiceResources<P>,
) -> Result<SampleReply, Error> {
Ok(SampleReply)
}
}

#[derive(Default)]
pub struct BBackend;

Expand Down Expand Up @@ -88,6 +102,7 @@ mod extensions {

enum Backend {
A,
ASample,
B,
}

Expand All @@ -106,6 +121,11 @@ enum Extension {
struct Dispatch {
#[extensions("Test")]
a: backends::ABackend,

#[dispatch(delegate_to = "a")]
#[extensions("Sample")]
a_sample: (),

b: backends::BBackend,
}

Expand Down Expand Up @@ -135,5 +155,9 @@ fn main() {
&[BackendId::Custom(Backend::B)],
Some(Error::RequestNotAvailable),
);
run(
&[BackendId::Custom(Backend::ASample)],
Some(Error::RequestNotAvailable),
);
run(&[BackendId::Custom(Backend::A)], None);
}
173 changes: 161 additions & 12 deletions derive/src/extension_dispatch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ pub struct ExtensionDispatch {
dispatch_attrs: DispatchAttrs,
extension_attrs: ExtensionAttrs,
backends: Vec<Backend>,
delegated_backends: Vec<DelegatedBackend>,
}

impl ExtensionDispatch {
Expand All @@ -27,18 +28,37 @@ impl ExtensionDispatch {
};
let dispatch_attrs = DispatchAttrs::new(&input)?;
let extension_attrs = ExtensionAttrs::new(&input)?;
let backends = data_struct
let raw_backends: Vec<_> = data_struct
.fields
.iter()
.enumerate()
.map(|(i, field)| Backend::new(i, field, &extension_attrs.extensions))
.map(RawBackend::new)
.collect::<Result<_>>()?;
let mut backends = Vec::new();
let mut delegated_backends = Vec::new();
for raw_backend in raw_backends {
if let Some(delegate_to) = raw_backend.delegate_to.clone() {
delegated_backends.push((raw_backend, delegate_to));
} else {
backends.push(Backend::new(
backends.len(),
raw_backend,
&extension_attrs.extensions,
)?);
}
}
let delegated_backends = delegated_backends
.into_iter()
.map(|(raw, delegate_to)| {
DelegatedBackend::new(raw, delegate_to, &backends, &extension_attrs.extensions)
})
.collect::<Result<_>>()?;
Ok(Self {
name: input.ident,
generics: input.generics,
dispatch_attrs,
extension_attrs,
backends,
delegated_backends,
})
}

Expand All @@ -49,7 +69,15 @@ impl ExtensionDispatch {
let (impl_generics, ty_generics, where_clause) = self.generics.split_for_impl();
let context = self.backends.iter().map(Backend::context);
let requests = self.backends.iter().map(Backend::request);
let delegated_requests = self
.delegated_backends
.iter()
.map(DelegatedBackend::request);
let extension_requests = self.backends.iter().map(Backend::extension_request);
let delegated_extension_requests = self
.delegated_backends
.iter()
.map(DelegatedBackend::extension_request);
let extension_impls = self
.extension_attrs
.extensions
Expand All @@ -71,6 +99,7 @@ impl ExtensionDispatch {
) -> ::core::result::Result<::trussed::api::Reply, ::trussed::error::Error> {
match backend {
#(#requests)*
#(#delegated_requests)*
}
}

Expand All @@ -84,6 +113,7 @@ impl ExtensionDispatch {
) -> ::core::result::Result<::trussed::api::reply::SerdeExtension, ::trussed::error::Error> {
match backend {
#(#extension_requests)*
#(#delegated_extension_requests)*
}
}
}
Expand Down Expand Up @@ -165,32 +195,70 @@ impl ExtensionAttrs {
}
}

struct Backend {
struct RawBackend {
id: Ident,
field: Ident,
ty: Type,
index: Index,
extensions: Vec<Extension>,
delegate_to: Option<Ident>,
extensions: Vec<Ident>,
}

impl Backend {
fn new(i: usize, field: &Field, extension_types: &HashMap<Ident, Path>) -> Result<Self> {
impl RawBackend {
fn new(field: &Field) -> Result<Self> {
let ident = field.ident.clone().ok_or_else(|| {
Error::new_spanned(
field,
"ExtensionDispatch can only be derived for a struct with named fields",
)
})?;
let mut delegate_to = None;
for attr in util::get_attrs(&field.attrs, "dispatch") {
attr.parse_nested_meta(|meta| {
if meta.path.is_ident("delegate_to") {
let s: LitStr = meta.value()?.parse()?;
delegate_to = Some(s.parse()?);
Ok(())
} else {
Err(meta.error("unsupported dispatch attribute"))
}
})?;
}
let mut extensions = Vec::new();
for attr in util::get_attrs(&field.attrs, "extensions") {
for s in attr.parse_args_with(Punctuated::<LitStr, Token![,]>::parse_terminated)? {
extensions.push(Extension::new(&s, extension_types)?);
extensions.push(s.parse()?);
}
}
Ok(Self {
id: util::to_camelcase(&ident),
field: ident,
ty: field.ty.clone(),
delegate_to,
extensions,
})
}
}

#[derive(Clone)]
struct Backend {
id: Ident,
field: Ident,
ty: Type,
index: Index,
extensions: Vec<Extension>,
}

impl Backend {
fn new(i: usize, raw: RawBackend, extensions: &HashMap<Ident, Path>) -> Result<Self> {
let extensions = raw
.extensions
.into_iter()
.map(|i| Extension::new(i, extensions))
.collect::<Result<_>>()?;
Ok(Self {
id: raw.id,
field: raw.field,
ty: raw.ty,
index: Index::from(i),
extensions,
})
Expand Down Expand Up @@ -224,17 +292,98 @@ impl Backend {
}
}

struct DelegatedBackend {
id: Ident,
field: Ident,
backend: Backend,
extensions: Vec<Extension>,
}

impl DelegatedBackend {
fn new(
raw: RawBackend,
delegate_to: Ident,
backends: &[Backend],
extensions: &HashMap<Ident, Path>,
) -> Result<Self> {
match raw.ty {
Type::Tuple(tuple) if tuple.elems.is_empty() => (),
_ => {
return Err(Error::new_spanned(
&raw.ty,
"delegated backends must use the unit type ()",
));
}
}

let extensions = raw
.extensions
.into_iter()
.map(|i| Extension::new(i, extensions))
.collect::<Result<_>>()?;
let backend = backends
.iter()
.find(|backend| backend.field == delegate_to)
.ok_or_else(|| Error::new_spanned(delegate_to, "unknown backend"))?
.clone();
Ok(Self {
id: raw.id,
field: raw.field,
backend,
extensions,
})
}

fn request(&self) -> TokenStream {
let Self {
id, backend, field, ..
} = self;
let Backend {
field: delegated_field,
index: delegated_index,
..
} = backend;
quote! {
Self::BackendId::#id => {
let _ = self.#field;
::trussed::backend::Backend::request(
&mut self.#delegated_field, &mut ctx.core, &mut ctx.backends.#delegated_index, request, resources,
)
}
}
}

fn extension_request(&self) -> TokenStream {
let Self {
id,
extensions,
backend,
field,
} = self;
let extension_requests = extensions.iter().map(|e| e.extension_request(backend));
quote! {
Self::BackendId::#id => {
let _ = self.#field;
match extension {
#(#extension_requests)*
_ => Err(::trussed::error::Error::RequestNotAvailable),
}
}
}
}
}

#[derive(Clone)]
struct Extension {
id: Ident,
ty: Path,
}

impl Extension {
fn new(s: &LitStr, extensions: &HashMap<Ident, Path>) -> Result<Self> {
let id = s.parse()?;
fn new(id: Ident, extensions: &HashMap<Ident, Path>) -> Result<Self> {
let ty = extensions
.get(&id)
.ok_or_else(|| Error::new_spanned(s, "unknown extension ID"))?
.ok_or_else(|| Error::new_spanned(&id, "unknown extension ID"))?
.clone();
Ok(Self { id, ty })
}
Expand Down

0 comments on commit ec23e99

Please sign in to comment.