Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Local memoization #6

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 28 additions & 3 deletions macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,17 @@ use syn::{parse_quote, Error, Result};
///
/// Furthermore, memoized functions cannot use destructuring patterns in their
/// arguments.
///
/// # Local memoization
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd prefer to move this below the example.

///
/// In the case where you would want to explicitely restrict usage to a single
/// thread, you can use the [`#[memoize(local)]`](macro@memoize) attribute.
/// This will use thread-local storage for the cache and requires a call to the
/// [`local_evict`](comemo::local_evict) function to clear the cache.
///
/// Additionally, if you wish to pass borrowed arguments to a memoized function
/// you will need to manually annotate it with the `'local` lifetime. This allows
/// the macro to specify the lifetime of the tracked value correctly.
///
/// # Example
/// ```
Expand All @@ -71,12 +82,26 @@ use syn::{parse_quote, Error, Result};
/// })
/// .sum()
/// }
///
/// // Evaluate a `.calc` script in a thread-local cache.
/// // /!\ Notice the `'local` lifetime annotation.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what does /!\ do?

/// #[comemo::memoize(local)]
/// fn evaluate(script: &'local str, files: &comemo::Tracked<'local, Files>) -> i32 {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is there an extra & before the Tracked here?

/// script
/// .split('+')
/// .map(str::trim)
/// .map(|part| match part.strip_prefix("eval ") {
/// Some(path) => evaluate(&files.read(path), files),
/// None => part.parse::<i32>().unwrap(),
/// })
/// .sum()
Comment on lines +90 to +97
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And then this can just be ... since it's the same as above.

/// }
/// ```
///
#[proc_macro_attribute]
pub fn memoize(_: BoundaryStream, stream: BoundaryStream) -> BoundaryStream {
let func = syn::parse_macro_input!(stream as syn::Item);
memoize::expand(&func)
pub fn memoize(stream: BoundaryStream, item: BoundaryStream) -> BoundaryStream {
let func = syn::parse_macro_input!(item as syn::Item);
memoize::expand(stream.into(), &func)
.unwrap_or_else(|err| err.to_compile_error())
.into()
}
Expand Down
105 changes: 84 additions & 21 deletions macros/src/memoize.rs
Original file line number Diff line number Diff line change
@@ -1,23 +1,40 @@
use syn::{parse::{Parse, ParseStream}, token::Token};

use super::*;

/// Memoize a function.
pub fn expand(item: &syn::Item) -> Result<proc_macro2::TokenStream> {
pub fn expand(stream: TokenStream, item: &syn::Item) -> Result<proc_macro2::TokenStream> {
let meta: Meta = syn::parse2(stream)?;
let syn::Item::Fn(item) = item else {
bail!(item, "`memoize` can only be applied to functions and methods");
};

// Preprocess and validate the function.
let function = prepare(item)?;
let function = prepare(&meta, item)?;

// Rewrite the function's body to memoize it.
process(&function)
}

/// The `..` in `#[memoize(..)]`.
pub struct Meta {
pub local: bool,
}

impl Parse for Meta {
fn parse(input: ParseStream) -> Result<Self> {
Ok(Self {
local: parse_flag::<kw::local>(input)?,
})
}
}

/// Details about a function that should be memoized.
struct Function {
item: syn::ItemFn,
args: Vec<Argument>,
output: syn::Type,
local: bool,
}

/// An argument to a memoized function.
Expand All @@ -27,7 +44,7 @@ enum Argument {
}

/// Preprocess and validate a function.
fn prepare(function: &syn::ItemFn) -> Result<Function> {
fn prepare(meta: &Meta, function: &syn::ItemFn) -> Result<Function> {
let mut args = vec![];

for input in &function.sig.inputs {
Expand All @@ -39,7 +56,7 @@ fn prepare(function: &syn::ItemFn) -> Result<Function> {
syn::ReturnType::Type(_, ty) => ty.as_ref().clone(),
};

Ok(Function { item: function.clone(), args, output })
Ok(Function { item: function.clone(), args, output, local: meta.local })
}

/// Preprocess a function argument.
Expand Down Expand Up @@ -124,23 +141,69 @@ fn process(function: &Function) -> Result<TokenStream> {
ident.mutability = None;
}

wrapped.block = parse_quote! { {
static __CACHE: ::comemo::internal::Cache<
<::comemo::internal::Args<#arg_ty_tuple> as ::comemo::internal::Input>::Constraint,
#output,
> = ::comemo::internal::Cache::new(|| {
::comemo::internal::register_evictor(|max_age| __CACHE.evict(max_age));
::core::default::Default::default()
});

#(#bounds;)*
::comemo::internal::memoized(
::comemo::internal::Args(#arg_tuple),
&::core::default::Default::default(),
&__CACHE,
#closure,
)
} };
if function.local {
wrapped.block = parse_quote! { {
type __ARGS<'local> = <::comemo::internal::Args<#arg_ty_tuple> as ::comemo::internal::Input>::Constraint;
::std::thread_local! {
static __CACHE: ::comemo::internal::Cache<
__ARGS<'static>,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this 'static result in any restrictions in practice?

#output,
> = ::comemo::internal::Cache::new(|| {
::comemo::internal::register_local_evictor(|max_age| __CACHE.with(|cache| cache.evict(max_age)));
::core::default::Default::default()
});
}

#(#bounds;)*
__CACHE.with(|cache| {
::comemo::internal::memoized(
::comemo::internal::Args(#arg_tuple),
&::core::default::Default::default(),
&cache,
#closure,
)
})
} };
} else {
wrapped.block = parse_quote! { {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The assignment could be moved out of the if/else.

static __CACHE: ::comemo::internal::Cache<
<::comemo::internal::Args<#arg_ty_tuple> as ::comemo::internal::Input>::Constraint,
#output,
> = ::comemo::internal::Cache::new(|| {
::comemo::internal::register_evictor(|max_age| __CACHE.evict(max_age));
::core::default::Default::default()
});

#(#bounds;)*
::comemo::internal::memoized(
::comemo::internal::Args(#arg_tuple),
&::core::default::Default::default(),
&__CACHE,
#closure,
)
} };
}

Ok(quote! { #wrapped })
}

/// Parse a metadata flag that can be present or not.
pub fn parse_flag<K: Token + Default + Parse>(input: ParseStream) -> Result<bool> {
if input.peek(|_| K::default()) {
let _: K = input.parse()?;
eat_comma(input);
return Ok(true);
}
Ok(false)
}

/// Parse a comma if there is one.
fn eat_comma(input: ParseStream) {
if input.peek(syn::Token![,]) {
let _: syn::Token![,] = input.parse().unwrap();
}
}

pub mod kw {
syn::custom_keyword!(local);
}
32 changes: 32 additions & 0 deletions src/cache.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::cell::RefCell;
use std::collections::HashMap;
use std::sync::atomic::{AtomicUsize, Ordering};

Expand All @@ -9,6 +10,11 @@ use crate::accelerate;
use crate::constraint::Join;
use crate::input::Input;

thread_local! {
/// The thread-local list of eviction functions.
static LOCAL_EVICTORS: RefCell<Vec<fn(usize)>> = const { RefCell::new(Vec::new()) };
}

/// The global list of eviction functions.
static EVICTORS: RwLock<Vec<fn(usize)>> = RwLock::new(Vec::new());

Expand Down Expand Up @@ -90,11 +96,37 @@ pub fn evict(max_age: usize) {
accelerate::evict();
}

/// Evict the thread local cache.
///
/// This removes all memoized results from the cache whose age is larger than or
/// equal to `max_age`. The age of a result grows by one during each eviction
/// and is reset to zero when the result produces a cache hit. Set `max_age` to
/// zero to completely clear the cache.
///
/// Comemo's cache is thread-local, meaning that this only evicts this thread's
/// cache.
Comment on lines +99 to +107
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
/// Evict the thread local cache.
///
/// This removes all memoized results from the cache whose age is larger than or
/// equal to `max_age`. The age of a result grows by one during each eviction
/// and is reset to zero when the result produces a cache hit. Set `max_age` to
/// zero to completely clear the cache.
///
/// Comemo's cache is thread-local, meaning that this only evicts this thread's
/// cache.
/// Evict the thread-local cache.
///
/// This only affects functions that are annotated with `#[memoize(local)]`.

I also noticed that the copy-pasted comment from the other function is outdated. The thread-local note can be removed from the other one.

I wonder whether the default evict should also evict the thread local cache. The overhead would be quite low if it's unused. I'm also not sure about evicting the accelerator here.

We should think about what the use case of this is vs. evict. If evict does everything, then even if you only use local comemo, evict would do the job. In which use cases does it make sense to have a second evict_local function?

pub fn local_evict(max_age: usize) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd prefer the name evict_local.

LOCAL_EVICTORS.with_borrow(|cell| {
for subevict in cell.iter() {
subevict(max_age);
}
});

accelerate::evict();
}

/// Register an eviction function in the global list.
pub fn register_evictor(evict: fn(usize)) {
EVICTORS.write().push(evict);
}

/// Register an eviction function in the global list.
pub fn register_local_evictor(evict: fn(usize)) {
LOCAL_EVICTORS.with_borrow_mut(|cell| {
cell.push(evict);
})
}

/// Whether the last call was a hit.
#[cfg(feature = "testing")]
pub fn last_was_hit() -> bool {
Expand Down
6 changes: 4 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ mod input;
mod prehashed;
mod track;

pub use crate::cache::evict;
pub use crate::cache::{evict, local_evict};
pub use crate::prehashed::Prehashed;
pub use crate::track::{Track, Tracked, TrackedMut, Validate};
pub use comemo_macros::{memoize, track};
Expand All @@ -99,7 +99,9 @@ pub use comemo_macros::{memoize, track};
pub mod internal {
pub use parking_lot::RwLock;

pub use crate::cache::{memoized, register_evictor, Cache, CacheData};
pub use crate::cache::{
memoized, register_evictor, register_local_evictor, Cache, CacheData,
};
pub use crate::constraint::{hash, Call, ImmutableConstraint, MutableConstraint};
pub use crate::input::{assert_hashable_or_trackable, Args, Input};
pub use crate::track::{to_parts_mut_mut, to_parts_mut_ref, to_parts_ref, Surfaces};
Expand Down
Loading