From 15c8d5bbb8f61546f247faa26ab59c7157f925de Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20d=27Herbais=20de=20Thun=20=28Dherse=29?= Date: Fri, 29 Dec 2023 14:08:05 +0100 Subject: [PATCH] Local memoization --- macros/src/lib.rs | 31 +++++++++-- macros/src/memoize.rs | 105 +++++++++++++++++++++++++++++-------- src/cache.rs | 32 ++++++++++++ src/lib.rs | 6 ++- tests/tests.rs | 117 ++++++++++++++++++++++++++++++++---------- 5 files changed, 237 insertions(+), 54 deletions(-) diff --git a/macros/src/lib.rs b/macros/src/lib.rs index 4ec5633..27c35d6 100644 --- a/macros/src/lib.rs +++ b/macros/src/lib.rs @@ -56,6 +56,17 @@ use syn::{parse_quote, Error, Result}; /// /// Furthermore, memoized functions cannot use destructuring patterns in their /// arguments. +/// +/// # Local memoization +/// +/// In the case where you would want to explicitely restrict usage to a single +/// thread, you can use the [`#[memoize(local)]`](macro@memoize) attribute. +/// This will use thread-local storage for the cache and requires a call to the +/// [`local_evict`](comemo::local_evict) function to clear the cache. +/// +/// Additionally, if you wish to pass borrowed arguments to a memoized function +/// you will need to manually annotate it with the `'local` lifetime. This allows +/// the macro to specify the lifetime of the tracked value correctly. /// /// # Example /// ``` @@ -71,12 +82,26 @@ use syn::{parse_quote, Error, Result}; /// }) /// .sum() /// } +/// +/// // Evaluate a `.calc` script in a thread-local cache. +/// // /!\ Notice the `'local` lifetime annotation. +/// #[comemo::memoize(local)] +/// fn evaluate(script: &'local str, files: &comemo::Tracked<'local, Files>) -> i32 { +/// script +/// .split('+') +/// .map(str::trim) +/// .map(|part| match part.strip_prefix("eval ") { +/// Some(path) => evaluate(&files.read(path), files), +/// None => part.parse::().unwrap(), +/// }) +/// .sum() +/// } /// ``` /// #[proc_macro_attribute] -pub fn memoize(_: BoundaryStream, stream: BoundaryStream) -> BoundaryStream { - let func = syn::parse_macro_input!(stream as syn::Item); - memoize::expand(&func) +pub fn memoize(stream: BoundaryStream, item: BoundaryStream) -> BoundaryStream { + let func = syn::parse_macro_input!(item as syn::Item); + memoize::expand(stream.into(), &func) .unwrap_or_else(|err| err.to_compile_error()) .into() } diff --git a/macros/src/memoize.rs b/macros/src/memoize.rs index 09afbce..db0aa4d 100644 --- a/macros/src/memoize.rs +++ b/macros/src/memoize.rs @@ -1,23 +1,40 @@ +use syn::{parse::{Parse, ParseStream}, token::Token}; + use super::*; /// Memoize a function. -pub fn expand(item: &syn::Item) -> Result { +pub fn expand(stream: TokenStream, item: &syn::Item) -> Result { + let meta: Meta = syn::parse2(stream)?; let syn::Item::Fn(item) = item else { bail!(item, "`memoize` can only be applied to functions and methods"); }; // Preprocess and validate the function. - let function = prepare(item)?; + let function = prepare(&meta, item)?; // Rewrite the function's body to memoize it. process(&function) } +/// The `..` in `#[memoize(..)]`. +pub struct Meta { + pub local: bool, +} + +impl Parse for Meta { + fn parse(input: ParseStream) -> Result { + Ok(Self { + local: parse_flag::(input)?, + }) + } +} + /// Details about a function that should be memoized. struct Function { item: syn::ItemFn, args: Vec, output: syn::Type, + local: bool, } /// An argument to a memoized function. @@ -27,7 +44,7 @@ enum Argument { } /// Preprocess and validate a function. -fn prepare(function: &syn::ItemFn) -> Result { +fn prepare(meta: &Meta, function: &syn::ItemFn) -> Result { let mut args = vec![]; for input in &function.sig.inputs { @@ -39,7 +56,7 @@ fn prepare(function: &syn::ItemFn) -> Result { syn::ReturnType::Type(_, ty) => ty.as_ref().clone(), }; - Ok(Function { item: function.clone(), args, output }) + Ok(Function { item: function.clone(), args, output, local: meta.local }) } /// Preprocess a function argument. @@ -124,23 +141,69 @@ fn process(function: &Function) -> Result { ident.mutability = None; } - wrapped.block = parse_quote! { { - static __CACHE: ::comemo::internal::Cache< - <::comemo::internal::Args<#arg_ty_tuple> as ::comemo::internal::Input>::Constraint, - #output, - > = ::comemo::internal::Cache::new(|| { - ::comemo::internal::register_evictor(|max_age| __CACHE.evict(max_age)); - ::core::default::Default::default() - }); - - #(#bounds;)* - ::comemo::internal::memoized( - ::comemo::internal::Args(#arg_tuple), - &::core::default::Default::default(), - &__CACHE, - #closure, - ) - } }; + if function.local { + wrapped.block = parse_quote! { { + type __ARGS<'local> = <::comemo::internal::Args<#arg_ty_tuple> as ::comemo::internal::Input>::Constraint; + ::std::thread_local! { + static __CACHE: ::comemo::internal::Cache< + __ARGS<'static>, + #output, + > = ::comemo::internal::Cache::new(|| { + ::comemo::internal::register_local_evictor(|max_age| __CACHE.with(|cache| cache.evict(max_age))); + ::core::default::Default::default() + }); + } + + #(#bounds;)* + __CACHE.with(|cache| { + ::comemo::internal::memoized( + ::comemo::internal::Args(#arg_tuple), + &::core::default::Default::default(), + &cache, + #closure, + ) + }) + } }; + } else { + wrapped.block = parse_quote! { { + static __CACHE: ::comemo::internal::Cache< + <::comemo::internal::Args<#arg_ty_tuple> as ::comemo::internal::Input>::Constraint, + #output, + > = ::comemo::internal::Cache::new(|| { + ::comemo::internal::register_evictor(|max_age| __CACHE.evict(max_age)); + ::core::default::Default::default() + }); + + #(#bounds;)* + ::comemo::internal::memoized( + ::comemo::internal::Args(#arg_tuple), + &::core::default::Default::default(), + &__CACHE, + #closure, + ) + } }; + } Ok(quote! { #wrapped }) } + +/// Parse a metadata flag that can be present or not. +pub fn parse_flag(input: ParseStream) -> Result { + if input.peek(|_| K::default()) { + let _: K = input.parse()?; + eat_comma(input); + return Ok(true); + } + Ok(false) +} + +/// Parse a comma if there is one. +fn eat_comma(input: ParseStream) { + if input.peek(syn::Token![,]) { + let _: syn::Token![,] = input.parse().unwrap(); + } +} + +pub mod kw { + syn::custom_keyword!(local); +} \ No newline at end of file diff --git a/src/cache.rs b/src/cache.rs index b8fec40..aeb77bd 100644 --- a/src/cache.rs +++ b/src/cache.rs @@ -1,3 +1,4 @@ +use std::cell::RefCell; use std::collections::HashMap; use std::sync::atomic::{AtomicUsize, Ordering}; @@ -9,6 +10,11 @@ use crate::accelerate; use crate::constraint::Join; use crate::input::Input; +thread_local! { + /// The thread-local list of eviction functions. + static LOCAL_EVICTORS: RefCell> = const { RefCell::new(Vec::new()) }; +} + /// The global list of eviction functions. static EVICTORS: RwLock> = RwLock::new(Vec::new()); @@ -90,11 +96,37 @@ pub fn evict(max_age: usize) { accelerate::evict(); } +/// Evict the thread local cache. +/// +/// This removes all memoized results from the cache whose age is larger than or +/// equal to `max_age`. The age of a result grows by one during each eviction +/// and is reset to zero when the result produces a cache hit. Set `max_age` to +/// zero to completely clear the cache. +/// +/// Comemo's cache is thread-local, meaning that this only evicts this thread's +/// cache. +pub fn local_evict(max_age: usize) { + LOCAL_EVICTORS.with_borrow(|cell| { + for subevict in cell.iter() { + subevict(max_age); + } + }); + + accelerate::evict(); +} + /// Register an eviction function in the global list. pub fn register_evictor(evict: fn(usize)) { EVICTORS.write().push(evict); } +/// Register an eviction function in the global list. +pub fn register_local_evictor(evict: fn(usize)) { + LOCAL_EVICTORS.with_borrow_mut(|cell| { + cell.push(evict); + }) +} + /// Whether the last call was a hit. #[cfg(feature = "testing")] pub fn last_was_hit() -> bool { diff --git a/src/lib.rs b/src/lib.rs index 5921b7a..50570b9 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -89,7 +89,7 @@ mod input; mod prehashed; mod track; -pub use crate::cache::evict; +pub use crate::cache::{evict, local_evict}; pub use crate::prehashed::Prehashed; pub use crate::track::{Track, Tracked, TrackedMut, Validate}; pub use comemo_macros::{memoize, track}; @@ -99,7 +99,9 @@ pub use comemo_macros::{memoize, track}; pub mod internal { pub use parking_lot::RwLock; - pub use crate::cache::{memoized, register_evictor, Cache, CacheData}; + pub use crate::cache::{ + memoized, register_evictor, register_local_evictor, Cache, CacheData, + }; pub use crate::constraint::{hash, Call, ImmutableConstraint, MutableConstraint}; pub use crate::input::{assert_hashable_or_trackable, Args, Input}; pub use crate::track::{to_parts_mut_mut, to_parts_mut_ref, to_parts_ref, Surfaces}; diff --git a/tests/tests.rs b/tests/tests.rs index db6bf31..13ed81c 100644 --- a/tests/tests.rs +++ b/tests/tests.rs @@ -4,7 +4,7 @@ use std::collections::HashMap; use std::hash::Hash; use std::path::{Path, PathBuf}; -use comemo::{evict, memoize, track, Track, Tracked, TrackedMut, Validate}; +use comemo::{evict, local_evict, memoize, track, Track, Tracked, TrackedMut, Validate}; use serial_test::serial; macro_rules! test { @@ -377,35 +377,35 @@ impl<'a> Chain<'a> { } /// Test mutable tracking. - #[test] - #[serial] - #[rustfmt::skip] - fn test_mutable() { - #[comemo::memoize] - fn dump(mut sink: TrackedMut) { - sink.emit("a"); - sink.emit("b"); - let c = sink.len_or_ten().to_string(); - sink.emit(&c); - } - - let mut emitter = Emitter(vec![]); - test!(miss: dump(emitter.track_mut()), ()); - test!(miss: dump(emitter.track_mut()), ()); - test!(miss: dump(emitter.track_mut()), ()); - test!(miss: dump(emitter.track_mut()), ()); - test!(hit: dump(emitter.track_mut()), ()); - test!(hit: dump(emitter.track_mut()), ()); - assert_eq!(emitter.0, [ - "a", "b", "2", - "a", "b", "5", - "a", "b", "8", - "a", "b", "10", - "a", "b", "10", - "a", "b", "10", - ]) +#[test] +#[serial] +#[rustfmt::skip] +fn test_mutable() { + #[comemo::memoize] + fn dump(mut sink: TrackedMut) { + sink.emit("a"); + sink.emit("b"); + let c = sink.len_or_ten().to_string(); + sink.emit(&c); } + let mut emitter = Emitter(vec![]); + test!(miss: dump(emitter.track_mut()), ()); + test!(miss: dump(emitter.track_mut()), ()); + test!(miss: dump(emitter.track_mut()), ()); + test!(miss: dump(emitter.track_mut()), ()); + test!(hit: dump(emitter.track_mut()), ()); + test!(hit: dump(emitter.track_mut()), ()); + assert_eq!(emitter.0, [ + "a", "b", "2", + "a", "b", "5", + "a", "b", "8", + "a", "b", "10", + "a", "b", "10", + "a", "b", "10", + ]) +} + /// A tracked type with a mutable and an immutable method. #[derive(Clone)] struct Emitter(Vec); @@ -421,6 +421,67 @@ impl Emitter { } } +#[test] +fn test_local() { + #[comemo::memoize(local)] + fn add(a: u32, b: u32) -> u32 { + a + b + } + + test!(miss: add(1, 2), 3); + test!(hit: add(1, 2), 3); + test!(miss: add(2, 3), 5); + test!(hit: add(2, 3), 5); +} + +#[test] +fn test_unsend() { + #[comemo::memoize(local)] + fn add(_a: u32, _b: u32) -> *mut () { + std::ptr::null_mut() + } + + test!(miss: add(1, 2), std::ptr::null_mut()); + test!(hit: add(1, 2), std::ptr::null_mut()); + test!(miss: add(2, 3), std::ptr::null_mut()); + test!(hit: add(2, 3), std::ptr::null_mut()); + + local_evict(0); + + test!(miss: add(1, 2), std::ptr::null_mut()); + test!(miss: add(2, 3), std::ptr::null_mut()); +} + +/// Test mutable tracking in a local context. +#[test] +#[serial] +#[rustfmt::skip] +fn test_mutable_local() { + #[comemo::memoize(local)] + fn dump<'local>(mut sink: TrackedMut<'local, Emitter>) { + sink.emit("a"); + sink.emit("b"); + let c = sink.len_or_ten().to_string(); + sink.emit(&c); + } + + let mut emitter = Emitter(vec![]); + test!(miss: dump(emitter.track_mut()), ()); + test!(miss: dump(emitter.track_mut()), ()); + test!(miss: dump(emitter.track_mut()), ()); + test!(miss: dump(emitter.track_mut()), ()); + test!(hit: dump(emitter.track_mut()), ()); + test!(hit: dump(emitter.track_mut()), ()); + assert_eq!(emitter.0, [ + "a", "b", "2", + "a", "b", "5", + "a", "b", "8", + "a", "b", "10", + "a", "b", "10", + "a", "b", "10", + ]) +} + /// A non-copy struct that is passed by value to a tracked method. #[derive(Clone, PartialEq, Hash)] struct Heavy(String);