From 972e300ce5a20e3a3580c761def76c8e27dba814 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20d=27Herbais=20de=20Thun?= Date: Mon, 8 Jul 2024 11:07:01 +0200 Subject: [PATCH] Added `enabled` attribute (#8) --- macros/src/lib.rs | 30 ++++++++++++++++++++++++++++-- macros/src/memoize.rs | 38 +++++++++++++++++++++++++++++++++----- macros/src/utils.rs | 28 ++++++++++++++++++++++++++++ src/cache.rs | 31 +++++++++++++++++++++++++++++++ tests/tests.rs | 39 +++++++++++++++++++++++++++------------ 5 files changed, 147 insertions(+), 19 deletions(-) create mode 100644 macros/src/utils.rs diff --git a/macros/src/lib.rs b/macros/src/lib.rs index cb0a559..4fa6bfb 100644 --- a/macros/src/lib.rs +++ b/macros/src/lib.rs @@ -12,6 +12,7 @@ macro_rules! bail { mod memoize; mod track; +mod utils; use proc_macro::TokenStream as BoundaryStream; use proc_macro2::TokenStream; @@ -76,10 +77,35 @@ use syn::{parse_quote, Error, Result}; /// } /// ``` /// +/// # Disabling memoization conditionally +/// If you want to enable or disable memoization for a function conditionally, +/// you can use the `enabled` attribute. This is useful for cheap function calls +/// where dealing with the caching is more expensive than recomputing the +/// result. This allows you to bypass hashing and constraint validation while +/// still dealing with the same function signature. And allows saving memory and +/// time. +/// +/// By default, all functions are unconditionally memoized. To disable +/// memoization conditionally, you must specify an `enabled = ` attribute. +/// The expression can use the parameters and must evaluate to a boolean value. +/// If the expression is `false`, the function will be executed without hashing +/// and caching. +/// +/// ## Example +/// ``` +/// /// Compute the sum of a slice of floats, but only memoize if the slice is +/// /// longer than 1024 elements. +/// #[comemo::memoize(enabled = add.len() > 1024)] +/// fn evaluate(add: &[f32]) -> f32 { +/// add.iter().copied().sum() +/// } +/// ``` +/// #[proc_macro_attribute] -pub fn memoize(_: BoundaryStream, stream: BoundaryStream) -> BoundaryStream { +pub fn memoize(args: BoundaryStream, stream: BoundaryStream) -> BoundaryStream { + let args = syn::parse_macro_input!(args as TokenStream); let func = syn::parse_macro_input!(stream as syn::Item); - memoize::expand(&func) + memoize::expand(args, &func) .unwrap_or_else(|err| err.to_compile_error()) .into() } diff --git a/macros/src/memoize.rs b/macros/src/memoize.rs index 09afbce..a077459 100644 --- a/macros/src/memoize.rs +++ b/macros/src/memoize.rs @@ -1,13 +1,15 @@ +use utils::parse_key_value; + use super::*; /// Memoize a function. -pub fn expand(item: &syn::Item) -> Result { +pub fn expand(attrs: TokenStream, item: &syn::Item) -> Result { let syn::Item::Fn(item) = item else { bail!(item, "`memoize` can only be applied to functions and methods"); }; // Preprocess and validate the function. - let function = prepare(item)?; + let function = prepare(attrs, item)?; // Rewrite the function's body to memoize it. process(&function) @@ -18,6 +20,18 @@ struct Function { item: syn::ItemFn, args: Vec, output: syn::Type, + enabled: Option, +} + +/// Additional metadata for a memoized function. +struct Meta { + enabled: Option, +} + +impl syn::parse::Parse for Meta { + fn parse(input: syn::parse::ParseStream) -> Result { + Ok(Self { enabled: parse_key_value::(input)? }) + } } /// An argument to a memoized function. @@ -27,9 +41,10 @@ enum Argument { } /// Preprocess and validate a function. -fn prepare(function: &syn::ItemFn) -> Result { - let mut args = vec![]; +fn prepare(attrs: TokenStream, function: &syn::ItemFn) -> Result { + let meta = syn::parse2::(attrs.clone())?; + let mut args = vec![]; for input in &function.sig.inputs { args.push(prepare_arg(input)?); } @@ -39,7 +54,12 @@ fn prepare(function: &syn::ItemFn) -> Result { syn::ReturnType::Type(_, ty) => ty.as_ref().clone(), }; - Ok(Function { item: function.clone(), args, output }) + Ok(Function { + item: function.clone(), + args, + output, + enabled: meta.enabled, + }) } /// Preprocess a function argument. @@ -124,6 +144,8 @@ fn process(function: &Function) -> Result { ident.mutability = None; } + let enabled = function.enabled.clone().unwrap_or(parse_quote! { true }); + wrapped.block = parse_quote! { { static __CACHE: ::comemo::internal::Cache< <::comemo::internal::Args<#arg_ty_tuple> as ::comemo::internal::Input>::Constraint, @@ -134,13 +156,19 @@ fn process(function: &Function) -> Result { }); #(#bounds;)* + ::comemo::internal::memoized( ::comemo::internal::Args(#arg_tuple), &::core::default::Default::default(), &__CACHE, + #enabled, #closure, ) } }; Ok(quote! { #wrapped }) } + +pub mod kw { + syn::custom_keyword!(enabled); +} diff --git a/macros/src/utils.rs b/macros/src/utils.rs new file mode 100644 index 0000000..bb13d42 --- /dev/null +++ b/macros/src/utils.rs @@ -0,0 +1,28 @@ +use syn::{ + parse::{Parse, ParseStream}, + token::Token, +}; + +use super::*; + +/// Parse a metadata key-value pair, separated by `=`. +pub fn parse_key_value( + input: ParseStream, +) -> Result> { + if !input.peek(|_| K::default()) { + return Ok(None); + } + + let _: K = input.parse()?; + let _: syn::Token![=] = input.parse()?; + let value: V = input.parse::()?; + eat_comma(input); + Ok(Some(value)) +} + +/// Parse a comma if there is one. +pub fn eat_comma(input: ParseStream) { + if input.peek(syn::Token![,]) { + let _: syn::Token![,] = input.parse().unwrap(); + } +} diff --git a/src/cache.rs b/src/cache.rs index 1ff0719..d0a5d28 100644 --- a/src/cache.rs +++ b/src/cache.rs @@ -23,6 +23,7 @@ pub fn memoized<'c, In, Out, F>( mut input: In, constraint: &'c In::Constraint, cache: &Cache, + enabled: bool, func: F, ) -> Out where @@ -30,6 +31,12 @@ where Out: Clone + 'static, F: FnOnce(In::Tracked<'c>) -> Out, { + // Early bypass if memoization is disabled. + // Hopefully the compiler will optimize this away, if the condition is constant. + if !enabled { + return memoized_disabled(input, constraint, func); + } + // Compute the hash of the input's key part. let key = { let mut state = SipHasher13::new(); @@ -73,6 +80,30 @@ where output } +fn memoized_disabled<'c, In, Out, F>( + input: In, + constraint: &'c In::Constraint, + func: F, +) -> Out +where + In: Input + 'c, + Out: Clone + 'static, + F: FnOnce(In::Tracked<'c>) -> Out, +{ + // Execute the function with the new constraints hooked in. + let (input, outer) = input.retrack(constraint); + let output = func(input); + + // Add the new constraints to the outer ones. + outer.join(constraint); + + // Ensure that the last call was a miss during testing. + #[cfg(feature = "testing")] + LAST_WAS_HIT.with(|cell| cell.set(false)); + + output +} + /// Evict the global cache. /// /// This removes all memoized results from the cache whose age is larger than or diff --git a/tests/tests.rs b/tests/tests.rs index db6bf31..fd17310 100644 --- a/tests/tests.rs +++ b/tests/tests.rs @@ -73,22 +73,21 @@ fn test_basic() { test!(hit: sum_iter(1000), 499500); } +#[memoize] +fn evaluate(script: &str, files: Tracked) -> i32 { + script + .split('+') + .map(str::trim) + .map(|part| match part.strip_prefix("eval ") { + Some(path) => evaluate(&files.read(path), files), + None => part.parse::().unwrap(), + }) + .sum() +} /// Test the calc language. #[test] #[serial] fn test_calc() { - #[memoize] - fn evaluate(script: &str, files: Tracked) -> i32 { - script - .split('+') - .map(str::trim) - .map(|part| match part.strip_prefix("eval ") { - Some(path) => evaluate(&files.read(path), files), - None => part.parse::().unwrap(), - }) - .sum() - } - let mut files = Files(HashMap::new()); files.write("alpha.calc", "2 + eval beta.calc"); files.write("beta.calc", "2 + 3"); @@ -452,3 +451,19 @@ impl Impure { VAL.fetch_add(1, Ordering::SeqCst) } } + +#[test] +#[serial] +#[cfg(debug_assertions)] +fn test_with_disabled() { + #[comemo::memoize(enabled = size >= 1000)] + fn disabled(size: usize) -> usize { + size + } + + test!(miss: disabled(0), 0); + test!(miss: disabled(0), 0); + + test!(miss: disabled(2000), 2000); + test!(hit: disabled(2000), 2000); +}