Skip to content

Commit

Permalink
Namespace derive attributes
Browse files Browse the repository at this point in the history
  • Loading branch information
djkoloski committed Sep 8, 2024
1 parent d52f9df commit 47faccb
Show file tree
Hide file tree
Showing 4 changed files with 118 additions and 79 deletions.
32 changes: 14 additions & 18 deletions bytecheck/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1092,7 +1092,7 @@ mod tests {
#[test]
fn test_unit_struct() {
#[derive(CheckBytes)]
#[check_bytes(crate)]
#[bytecheck(crate)]
struct Test;

unsafe {
Expand All @@ -1103,7 +1103,7 @@ mod tests {
#[test]
fn test_tuple_struct() {
#[derive(CheckBytes, Debug)]
#[check_bytes(crate)]
#[bytecheck(crate)]
struct Test(u32, bool, CharLE);

let value = Test(42, true, 'x'.into());
Expand Down Expand Up @@ -1168,7 +1168,7 @@ mod tests {
#[test]
fn test_struct() {
#[derive(CheckBytes, Debug)]
#[check_bytes(crate)]
#[bytecheck(crate)]
struct Test {
a: u32,
b: bool,
Expand Down Expand Up @@ -1241,7 +1241,7 @@ mod tests {
#[test]
fn test_generic_struct() {
#[derive(CheckBytes, Debug)]
#[check_bytes(crate)]
#[bytecheck(crate)]
struct Test<T> {
a: u32,
b: T,
Expand Down Expand Up @@ -1277,7 +1277,7 @@ mod tests {
fn test_enum() {
#[allow(dead_code)]
#[derive(CheckBytes, Debug)]
#[check_bytes(crate)]
#[bytecheck(crate)]
#[repr(u8)]
enum Test {
A(u32, bool, CharLE),
Expand Down Expand Up @@ -1346,7 +1346,7 @@ mod tests {
#[test]
fn test_explicit_enum_values() {
#[derive(CheckBytes, Debug)]
#[check_bytes(crate)]
#[bytecheck(crate)]
#[repr(u8)]
enum Test {
A,
Expand Down Expand Up @@ -1403,11 +1403,11 @@ mod tests {

#[allow(dead_code)]
#[derive(CheckBytes)]
#[check_bytes(crate)]
#[bytecheck(crate)]
#[repr(u8)]
enum Node {
Nil,
Cons(#[omit_bounds] MyBox<Node>),
Cons(#[bytecheck(omit_bounds)] MyBox<Node>),
}

unsafe {
Expand All @@ -1427,15 +1427,15 @@ mod tests {
}

#[derive(CheckBytes)]
#[check_bytes(crate = m::bc)]
#[bytecheck(crate = m::bc)]
struct Test;

unsafe {
check_bytes::<_, Infallible>(&Test).unwrap();
}

#[derive(CheckBytes)]
#[check_bytes(crate = crate)]
#[bytecheck(crate = crate)]
struct Test2;

unsafe {
Expand Down Expand Up @@ -1473,8 +1473,7 @@ mod tests {
}

#[derive(CheckBytes)]
#[check_bytes(crate)]
#[check_bytes(verify)]
#[bytecheck(crate, verify)]
struct UnitStruct;

let mut context = FooContext { value: 0 };
Expand All @@ -1499,8 +1498,7 @@ mod tests {
}

#[derive(CheckBytes)]
#[check_bytes(crate)]
#[check_bytes(verify)]
#[bytecheck(crate, verify)]
struct Struct {
value: i32,
}
Expand Down Expand Up @@ -1530,8 +1528,7 @@ mod tests {
}

#[derive(CheckBytes)]
#[check_bytes(crate)]
#[check_bytes(verify)]
#[bytecheck(crate, verify)]
struct TupleStruct(i32);

let mut context = FooContext { value: 0 };
Expand Down Expand Up @@ -1560,8 +1557,7 @@ mod tests {
}

#[derive(CheckBytes)]
#[check_bytes(crate)]
#[check_bytes(verify)]
#[bytecheck(crate, verify)]
#[repr(u8)]
enum Enum {
A,
Expand Down
34 changes: 31 additions & 3 deletions bytecheck_derive/src/attributes.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use quote::ToTokens;
use syn::{
meta::ParseNestedMeta, parenthesized, parse::Parse, parse_quote,
punctuated::Punctuated, AttrStyle, DeriveInput, Error, Path, Token,
punctuated::Punctuated, AttrStyle, DeriveInput, Error, Field, Path, Token,
WherePredicate,
};

Expand Down Expand Up @@ -62,7 +62,7 @@ impl Attributes {

try_set_attribute(&mut self.verify, meta.path, "verify")
} else {
Err(meta.error("unrecognized check_bytes argument"))
Err(meta.error("unrecognized bytecheck argument"))
}
}

Expand All @@ -74,7 +74,7 @@ impl Attributes {
continue;
}

if attr.path().is_ident("check_bytes") {
if attr.path().is_ident("bytecheck") {
attr.parse_nested_meta(|nested| {
result.parse_check_bytes_attributes(nested)
})?;
Expand All @@ -94,3 +94,31 @@ impl Attributes {
.unwrap_or_else(|| parse_quote! { ::bytecheck })
}
}

#[derive(Default)]
pub struct FieldAttributes {
pub omit_bounds: Option<Path>,
}

impl FieldAttributes {
fn parse_meta(&mut self, meta: ParseNestedMeta<'_>) -> Result<(), Error> {
if meta.path.is_ident("omit_bounds") {
self.omit_bounds = Some(meta.path);
Ok(())
} else {
Err(meta.error("unrecognized bytecheck arguments"))
}
}

pub fn parse(input: &Field) -> Result<Self, Error> {
let mut result = Self::default();

for attr in input.attrs.iter() {
if attr.path().is_ident("bytecheck") {
attr.parse_nested_meta(|meta| result.parse_meta(meta))?;
}
}

Ok(result)
}
}
90 changes: 32 additions & 58 deletions bytecheck_derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

mod attributes;
mod repr;
mod util;

use proc_macro2::TokenStream;
use quote::quote;
Expand All @@ -19,32 +20,45 @@ use syn::{
Field, Fields, Ident, Index, Path,
};

use crate::{attributes::Attributes, repr::BaseRepr};
use crate::{
attributes::{Attributes, FieldAttributes},
repr::BaseRepr,
util::iter_fields,
};

/// Derives `CheckBytes` for the labeled type.
///
/// This derive macro automatically adds a type bound `field: CheckBytes<__C>`
/// for each field type. This can cause an overflow while evaluating trait
/// bounds if the structure eventually references its own type, as the
/// implementation of `CheckBytes` for a struct depends on each field type
/// implementing it as well. Adding the attribute `#[omit_bounds]` to a field
/// will suppress this trait bound and allow recursive structures. This may be
/// too coarse for some types, in which case additional type bounds may be
/// required with `bounds(...)`.
/// implementing it as well. Adding the attribute `#[check_bytes(omit_bounds)]`
/// to a field will suppress this trait bound and allow recursive structures.
/// This may be too coarse for some types, in which case additional type bounds
/// may be required with `bounds(...)`.
///
/// # Attributes
///
/// Additional arguments can be specified using attributes.
///
/// `#[check_bytes(...)]` accepts the following attributes:
/// `#[bytecheck(...)]` accepts the following attributes:
///
/// ## Types only
///
/// - `bounds(...)`: Adds additional bounds to the `CheckBytes` implementation.
/// This can be especially useful when dealing with recursive structures,
/// where bounds may need to be omitted to prevent recursive type definitions.
/// In the context of the added bounds, `__C` is the name of the context
/// generic (e.g. `__C: MyContext`).
/// - `crate = ...`: Chooses an alternative crate path to import bytecheck from.
#[proc_macro_derive(CheckBytes, attributes(check_bytes, omit_bounds))]
/// - `verify`: Adds an additional verification step after the validity of each
/// field has been checked. See the `Verify` trait for more information.
///
/// ## Fields only
///
/// - `omit_bounds`: Omits trait bounds for the annotated field in the generated
/// impl.
#[proc_macro_derive(CheckBytes, attributes(bytecheck))]
pub fn check_bytes_derive(
input: proc_macro::TokenStream,
) -> proc_macro::TokenStream {
Expand Down Expand Up @@ -117,6 +131,17 @@ fn derive_check_bytes(mut input: DeriveInput) -> Result<TokenStream, Error> {
None
};

let mut check_where = trait_where_clause.clone();
for field in iter_fields(&input.data) {
let field_attrs = FieldAttributes::parse(field)?;
if field_attrs.omit_bounds.is_none() {
let ty = &field.ty;
check_where.predicates.push(parse_quote! {
#ty: #crate_path::CheckBytes<__C>
});
}
}

// Split trait generics for use later
let (trait_impl_generics, _, trait_where_clause) =
trait_generics.split_for_impl();
Expand All @@ -126,16 +151,6 @@ fn derive_check_bytes(mut input: DeriveInput) -> Result<TokenStream, Error> {
let check_bytes_impl = match input.data {
Data::Struct(ref data) => match data.fields {
Fields::Named(ref fields) => {
let mut check_where = trait_where_clause.clone();
for field in fields.named.iter().filter(|f| {
!f.attrs.iter().any(|a| a.path().is_ident("omit_bounds"))
}) {
let ty = &field.ty;
check_where.predicates.push(
parse_quote! { #ty: #crate_path::CheckBytes<__C> },
);
}

let field_checks = fields.named.iter().map(|f| {
let field = &f.ident;
let ty = &f.ty;
Expand Down Expand Up @@ -183,16 +198,6 @@ fn derive_check_bytes(mut input: DeriveInput) -> Result<TokenStream, Error> {
}
}
Fields::Unnamed(ref fields) => {
let mut check_where = trait_where_clause.clone();
for field in fields.unnamed.iter().filter(|f| {
!f.attrs.iter().any(|a| a.path().is_ident("omit_bounds"))
}) {
let ty = &field.ty;
check_where.predicates.push(parse_quote! {
#ty: #crate_path::CheckBytes<__C>
});
}

let field_checks =
fields.unnamed.iter().enumerate().map(|(i, f)| {
let ty = &f.ty;
Expand Down Expand Up @@ -291,37 +296,6 @@ fn derive_check_bytes(mut input: DeriveInput) -> Result<TokenStream, Error> {
Some((BaseRepr::Int(i), _)) => i,
};

let mut check_where = trait_where_clause.clone();
for v in data.variants.iter() {
match v.fields {
Fields::Named(ref fields) => {
for field in fields.named.iter().filter(|f| {
!f.attrs
.iter()
.any(|a| a.path().is_ident("omit_bounds"))
}) {
let ty = &field.ty;
check_where.predicates.push(parse_quote! {
#ty: #crate_path::CheckBytes<__C>
});
}
}
Fields::Unnamed(ref fields) => {
for field in fields.unnamed.iter().filter(|f| {
!f.attrs
.iter()
.any(|a| a.path().is_ident("omit_bounds"))
}) {
let ty = &field.ty;
check_where.predicates.push(parse_quote! {
#ty: #crate_path::CheckBytes<__C>
});
}
}
Fields::Unit => (),
}
}

let tag_variant_defs = data.variants.iter().map(|v| {
let variant = &v.ident;
if let Some((_, expr)) = &v.discriminant {
Expand Down
41 changes: 41 additions & 0 deletions bytecheck_derive/src/util.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
use core::iter::FlatMap;

use syn::{
punctuated::Iter, Data, DataEnum, DataStruct, DataUnion, Field, Variant,
};

type VariantFieldsFn = fn(&Variant) -> Iter<'_, Field>;

fn variant_fields(variant: &Variant) -> Iter<'_, Field> {
variant.fields.iter()
}

pub enum FieldsIter<'a> {
Struct(Iter<'a, Field>),
Enum(FlatMap<Iter<'a, Variant>, Iter<'a, Field>, VariantFieldsFn>),
}

impl<'a> Iterator for FieldsIter<'a> {
type Item = &'a Field;

fn next(&mut self) -> Option<Self::Item> {
match self {
Self::Struct(iter) => iter.next(),
Self::Enum(iter) => iter.next(),
}
}
}

pub fn iter_fields(data: &Data) -> FieldsIter<'_> {
match data {
Data::Struct(DataStruct { fields, .. }) => {
FieldsIter::Struct(fields.iter())
}
Data::Enum(DataEnum { variants, .. }) => {
FieldsIter::Enum(variants.iter().flat_map(variant_fields))
}
Data::Union(DataUnion { fields, .. }) => {
FieldsIter::Struct(fields.named.iter())
}
}
}

0 comments on commit 47faccb

Please sign in to comment.