Skip to content

Commit

Permalink
Added enabled attribute (#8)
Browse files Browse the repository at this point in the history
  • Loading branch information
Dherse authored Jul 8, 2024
1 parent 1275982 commit 972e300
Show file tree
Hide file tree
Showing 5 changed files with 147 additions and 19 deletions.
30 changes: 28 additions & 2 deletions macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ macro_rules! bail {

mod memoize;
mod track;
mod utils;

use proc_macro::TokenStream as BoundaryStream;
use proc_macro2::TokenStream;
Expand Down Expand Up @@ -76,10 +77,35 @@ use syn::{parse_quote, Error, Result};
/// }
/// ```
///
/// # Disabling memoization conditionally
/// If you want to enable or disable memoization for a function conditionally,
/// you can use the `enabled` attribute. This is useful for cheap function calls
/// where dealing with the caching is more expensive than recomputing the
/// result. This allows you to bypass hashing and constraint validation while
/// still dealing with the same function signature. And allows saving memory and
/// time.
///
/// By default, all functions are unconditionally memoized. To disable
/// memoization conditionally, you must specify an `enabled = <expr>` attribute.
/// The expression can use the parameters and must evaluate to a boolean value.
/// If the expression is `false`, the function will be executed without hashing
/// and caching.
///
/// ## Example
/// ```
/// /// Compute the sum of a slice of floats, but only memoize if the slice is
/// /// longer than 1024 elements.
/// #[comemo::memoize(enabled = add.len() > 1024)]
/// fn evaluate(add: &[f32]) -> f32 {
/// add.iter().copied().sum()
/// }
/// ```
///
#[proc_macro_attribute]
pub fn memoize(_: BoundaryStream, stream: BoundaryStream) -> BoundaryStream {
pub fn memoize(args: BoundaryStream, stream: BoundaryStream) -> BoundaryStream {
let args = syn::parse_macro_input!(args as TokenStream);
let func = syn::parse_macro_input!(stream as syn::Item);
memoize::expand(&func)
memoize::expand(args, &func)
.unwrap_or_else(|err| err.to_compile_error())
.into()
}
Expand Down
38 changes: 33 additions & 5 deletions macros/src/memoize.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
use utils::parse_key_value;

use super::*;

/// Memoize a function.
pub fn expand(item: &syn::Item) -> Result<proc_macro2::TokenStream> {
pub fn expand(attrs: TokenStream, item: &syn::Item) -> Result<proc_macro2::TokenStream> {
let syn::Item::Fn(item) = item else {
bail!(item, "`memoize` can only be applied to functions and methods");
};

// Preprocess and validate the function.
let function = prepare(item)?;
let function = prepare(attrs, item)?;

// Rewrite the function's body to memoize it.
process(&function)
Expand All @@ -18,6 +20,18 @@ struct Function {
item: syn::ItemFn,
args: Vec<Argument>,
output: syn::Type,
enabled: Option<syn::Expr>,
}

/// Additional metadata for a memoized function.
struct Meta {
enabled: Option<syn::Expr>,
}

impl syn::parse::Parse for Meta {
fn parse(input: syn::parse::ParseStream) -> Result<Self> {
Ok(Self { enabled: parse_key_value::<kw::enabled, _>(input)? })
}
}

/// An argument to a memoized function.
Expand All @@ -27,9 +41,10 @@ enum Argument {
}

/// Preprocess and validate a function.
fn prepare(function: &syn::ItemFn) -> Result<Function> {
let mut args = vec![];
fn prepare(attrs: TokenStream, function: &syn::ItemFn) -> Result<Function> {
let meta = syn::parse2::<Meta>(attrs.clone())?;

let mut args = vec![];
for input in &function.sig.inputs {
args.push(prepare_arg(input)?);
}
Expand All @@ -39,7 +54,12 @@ fn prepare(function: &syn::ItemFn) -> Result<Function> {
syn::ReturnType::Type(_, ty) => ty.as_ref().clone(),
};

Ok(Function { item: function.clone(), args, output })
Ok(Function {
item: function.clone(),
args,
output,
enabled: meta.enabled,
})
}

/// Preprocess a function argument.
Expand Down Expand Up @@ -124,6 +144,8 @@ fn process(function: &Function) -> Result<TokenStream> {
ident.mutability = None;
}

let enabled = function.enabled.clone().unwrap_or(parse_quote! { true });

wrapped.block = parse_quote! { {
static __CACHE: ::comemo::internal::Cache<
<::comemo::internal::Args<#arg_ty_tuple> as ::comemo::internal::Input>::Constraint,
Expand All @@ -134,13 +156,19 @@ fn process(function: &Function) -> Result<TokenStream> {
});

#(#bounds;)*

::comemo::internal::memoized(
::comemo::internal::Args(#arg_tuple),
&::core::default::Default::default(),
&__CACHE,
#enabled,
#closure,
)
} };

Ok(quote! { #wrapped })
}

pub mod kw {
syn::custom_keyword!(enabled);
}
28 changes: 28 additions & 0 deletions macros/src/utils.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
use syn::{
parse::{Parse, ParseStream},
token::Token,
};

use super::*;

/// Parse a metadata key-value pair, separated by `=`.
pub fn parse_key_value<K: Token + Default + Parse, V: Parse>(
input: ParseStream,
) -> Result<Option<V>> {
if !input.peek(|_| K::default()) {
return Ok(None);
}

let _: K = input.parse()?;
let _: syn::Token![=] = input.parse()?;
let value: V = input.parse::<V>()?;
eat_comma(input);

This comment has been minimized.

Copy link
@JohnMeyerhoff

JohnMeyerhoff Jul 11, 2024

Where would this comma even come from in our input?

Ok(Some(value))
}

/// Parse a comma if there is one.
pub fn eat_comma(input: ParseStream) {
if input.peek(syn::Token![,]) {
let _: syn::Token![,] = input.parse().unwrap();
}
}
31 changes: 31 additions & 0 deletions src/cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,20 @@ pub fn memoized<'c, In, Out, F>(
mut input: In,
constraint: &'c In::Constraint,
cache: &Cache<In::Constraint, Out>,
enabled: bool,
func: F,
) -> Out
where
In: Input + 'c,
Out: Clone + 'static,
F: FnOnce(In::Tracked<'c>) -> Out,
{
// Early bypass if memoization is disabled.
// Hopefully the compiler will optimize this away, if the condition is constant.
if !enabled {
return memoized_disabled(input, constraint, func);
}

// Compute the hash of the input's key part.
let key = {
let mut state = SipHasher13::new();
Expand Down Expand Up @@ -73,6 +80,30 @@ where
output
}

fn memoized_disabled<'c, In, Out, F>(
input: In,
constraint: &'c In::Constraint,
func: F,
) -> Out
where
In: Input + 'c,
Out: Clone + 'static,
F: FnOnce(In::Tracked<'c>) -> Out,
{
// Execute the function with the new constraints hooked in.
let (input, outer) = input.retrack(constraint);
let output = func(input);

// Add the new constraints to the outer ones.
outer.join(constraint);

// Ensure that the last call was a miss during testing.
#[cfg(feature = "testing")]
LAST_WAS_HIT.with(|cell| cell.set(false));

output
}

/// Evict the global cache.
///
/// This removes all memoized results from the cache whose age is larger than or
Expand Down
39 changes: 27 additions & 12 deletions tests/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,22 +73,21 @@ fn test_basic() {
test!(hit: sum_iter(1000), 499500);
}

#[memoize]
fn evaluate(script: &str, files: Tracked<Files>) -> i32 {
script
.split('+')
.map(str::trim)
.map(|part| match part.strip_prefix("eval ") {
Some(path) => evaluate(&files.read(path), files),
None => part.parse::<i32>().unwrap(),
})
.sum()
}
/// Test the calc language.
#[test]
#[serial]
fn test_calc() {
#[memoize]
fn evaluate(script: &str, files: Tracked<Files>) -> i32 {
script
.split('+')
.map(str::trim)
.map(|part| match part.strip_prefix("eval ") {
Some(path) => evaluate(&files.read(path), files),
None => part.parse::<i32>().unwrap(),
})
.sum()
}

let mut files = Files(HashMap::new());
files.write("alpha.calc", "2 + eval beta.calc");
files.write("beta.calc", "2 + 3");
Expand Down Expand Up @@ -452,3 +451,19 @@ impl Impure {
VAL.fetch_add(1, Ordering::SeqCst)
}
}

#[test]
#[serial]
#[cfg(debug_assertions)]
fn test_with_disabled() {
#[comemo::memoize(enabled = size >= 1000)]
fn disabled(size: usize) -> usize {
size
}

test!(miss: disabled(0), 0);
test!(miss: disabled(0), 0);

test!(miss: disabled(2000), 2000);
test!(hit: disabled(2000), 2000);
}

0 comments on commit 972e300

Please sign in to comment.