From ae82700cd61d774d80d1fadedc40d7ae6181ffa2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20d=27Herbais=20de=20Thun=20=28Dherse=29?= Date: Mon, 4 Dec 2023 18:22:06 +0100 Subject: [PATCH 01/28] Working on par comemo --- Cargo.toml | 3 + macros/src/memoize.rs | 37 +++- macros/src/track.rs | 7 +- src/cache.rs | 410 +++++++++++++++++++++++++++--------------- src/lib.rs | 8 +- tests/tests.rs | 19 -- 6 files changed, 315 insertions(+), 169 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 2bad481..d0a3117 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,4 +12,7 @@ keywords = ["incremental", "memoization", "tracked", "constraints"] [dependencies] comemo-macros = { version = "0.3.1", path = "macros" } +indexmap = "2.1.0" +once_cell = "1.18.0" +rustc-hash = "1.1.0" siphasher = "1" diff --git a/macros/src/memoize.rs b/macros/src/memoize.rs index 0e24f82..00044ef 100644 --- a/macros/src/memoize.rs +++ b/macros/src/memoize.rs @@ -23,7 +23,7 @@ struct Function { /// An argument to a memoized function. enum Argument { Receiver(syn::Token![self]), - Ident(Option, syn::Ident), + Ident(Box, Option, syn::Ident), } /// Preprocess and validate a function. @@ -71,7 +71,7 @@ fn prepare_arg(input: &syn::FnArg) -> Result { bail!(typed.ty, "memoized functions cannot have mutable parameters") } - Argument::Ident(mutability.clone(), ident.clone()) + Argument::Ident(typed.ty.clone(), mutability.clone(), ident.clone()) } }) } @@ -82,7 +82,7 @@ fn process(function: &Function) -> Result { let bounds = function.args.iter().map(|arg| { let val = match arg { Argument::Receiver(token) => quote! { #token }, - Argument::Ident(_, ident) => quote! { #ident }, + Argument::Ident(_, _, ident) => quote! { #ident }, }; quote_spanned! { function.item.span() => ::comemo::internal::assert_hashable_or_trackable(&#val); @@ -94,14 +94,20 @@ fn process(function: &Function) -> Result { Argument::Receiver(token) => quote! { ::comemo::internal::hash(&#token) }, - Argument::Ident(_, ident) => quote! { #ident }, + Argument::Ident(_, _, ident) => quote! { #ident }, }); let arg_tuple = quote! { (#(#args,)*) }; + let arg_tys = function.args.iter().map(|arg| match arg { + Argument::Receiver(_) => quote! { () }, + Argument::Ident(ty, _, _) => quote! { #ty }, + }); + let arg_ty_tuple = quote! { (#(#arg_tys,)*) }; + // Construct a tuple for all parameters. let params = function.args.iter().map(|arg| match arg { Argument::Receiver(_) => quote! { _ }, - Argument::Ident(mutability, ident) => quote! { #mutability #ident }, + Argument::Ident(_, mutability, ident) => quote! { #mutability #ident }, }); let param_tuple = quote! { (#(#params,)*) }; @@ -120,12 +126,31 @@ fn process(function: &Function) -> Result { let unique = quote! { __ComemoUnique }; wrapped.block = parse_quote! { { + static __CACHE: ::comemo::internal::Lazy< + ::std::sync::RwLock< + ::comemo::internal::Cache< + <::comemo::internal::Args<#arg_ty_tuple> as ::comemo::internal::Input>::Constraint, + #output, + > + > + > = + ::comemo::internal::Lazy::new( + || { + ::comemo::internal::register_cache(evict); + ::std::sync::RwLock::new(::comemo::internal::Cache::new()) + } + ); + + fn evict(max_age: usize) { + __CACHE.write().unwrap().evict(max_age); + } + struct #unique; #(#bounds;)* ::comemo::internal::memoized( - ::core::any::TypeId::of::<#unique>(), ::comemo::internal::Args(#arg_tuple), &::core::default::Default::default(), + &*__CACHE, #closure, ) } }; diff --git a/macros/src/track.rs b/macros/src/track.rs index 60e80ec..cefa926 100644 --- a/macros/src/track.rs +++ b/macros/src/track.rs @@ -244,11 +244,16 @@ fn create( .map(|m| create_wrapper(m, false)); let wrapper_methods_mut = methods.iter().map(|m| create_wrapper(m, true)); + let constraint = if methods.iter().all(|m| !m.mutable) { + quote! { ImmutableConstraint } + } else { + quote! { Constraint } + }; Ok(quote! { impl #impl_params ::comemo::Track for #ty #where_clause {} impl #impl_params ::comemo::Validate for #ty #where_clause { - type Constraint = ::comemo::internal::Constraint<__ComemoCall>; + type Constraint = ::comemo::internal::#constraint<__ComemoCall>; #[inline] fn validate(&self, constraint: &Self::Constraint) -> bool { diff --git a/src/cache.rs b/src/cache.rs index df7e702..058787f 100644 --- a/src/cache.rs +++ b/src/cache.rs @@ -1,31 +1,41 @@ -use std::any::{Any, TypeId}; -use std::cell::{Cell, RefCell}; +use std::borrow::Cow; +use std::cell::Cell; use std::collections::HashMap; -use std::hash::Hash; +use std::hash::{Hash, BuildHasherDefault}; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::{Mutex, RwLock}; +use once_cell::sync::Lazy; +use rustc_hash::{FxHashMap, FxHasher}; use siphasher::sip128::{Hasher128, SipHasher13}; use crate::input::Input; -thread_local! { - /// The global, dynamic cache shared by all memoized functions. - static CACHE: RefCell = RefCell::new(Cache::default()); +type FxIndexMap = indexmap::IndexMap>; + +static CACHES: RwLock> = RwLock::new(Vec::new()); - /// The global ID counter for tracked values. Each tracked value gets a - /// unqiue ID based on which its validations are cached in the accelerator. - /// IDs may only be reused upon eviction of the accelerator. - static ID: Cell = const { Cell::new(0) }; +static ACCELERATOR: Lazy>> = + Lazy::new(|| Mutex::new(FxHashMap::default())); + +pub fn register_cache(fun: fn(usize)) { + CACHES.write().unwrap().push(fun); +} - /// The global, dynamic accelerator shared by all cached values. - static ACCELERATOR: RefCell> - = RefCell::new(HashMap::default()); +thread_local! { + static LAST_WAS_HIT: Cell = const { Cell::new(false) }; } +/// The global ID counter for tracked values. Each tracked value gets a +/// unqiue ID based on which its validations are cached in the accelerator. +/// IDs may only be reused upon eviction of the accelerator. +static ID: AtomicUsize = AtomicUsize::new(0); + /// Execute a function or use a cached result for it. pub fn memoized<'c, In, Out, F>( - id: TypeId, mut input: In, constraint: &'c In::Constraint, + cache: &RwLock>, func: F, ) -> Out where @@ -33,61 +43,54 @@ where Out: Clone + 'static, F: FnOnce(In::Tracked<'c>) -> Out, { - CACHE.with(|cache| { - // Compute the hash of the input's key part. - let key = { - let mut state = SipHasher13::new(); - input.key(&mut state); - let hash = state.finish128().as_u128(); - (id, hash) - }; - - // Check if there is a cached output. - let mut borrow = cache.borrow_mut(); - if let Some(constrained) = borrow.lookup::(key, &input) { - // Replay the mutations. - input.replay(&constrained.constraint); - - // Add the cached constraints to the outer ones. - input.retrack(constraint).1.join(&constrained.constraint); - - let value = constrained.output.clone(); - borrow.last_was_hit = true; - return value; - } + // Compute the hash of the input's key part. + let key = { + let mut state = SipHasher13::new(); + input.key(&mut state); + state.finish128().as_u128() + }; + + // Check if there is a cached output. + let mut borrow = cache.write().unwrap(); + if let Some(constrained) = borrow.lookup::(key, &input) { + // Replay the mutations. + input.replay(&constrained.constraint); + + // Add the cached constraints to the outer ones. + input.retrack(constraint).1.join(&constrained.constraint); + + let value = constrained.output.clone(); + LAST_WAS_HIT.with(|cell| cell.set(true)); + return value; + } - // Release the borrow so that nested memoized calls can access the - // cache without panicking. - drop(borrow); + // Release the borrow so that nested memoized calls can access the + // cache without panicking. + drop(borrow); - // Execute the function with the new constraints hooked in. - let (input, outer) = input.retrack(constraint); - let output = func(input); + // Execute the function with the new constraints hooked in. + let (input, outer) = input.retrack(constraint); + let output = func(input); - // Add the new constraints to the outer ones. - outer.join(constraint); + // Add the new constraints to the outer ones. + outer.join(constraint); - // Insert the result into the cache. - borrow = cache.borrow_mut(); - borrow.insert::(key, constraint.take(), output.clone()); - borrow.last_was_hit = false; + // Insert the result into the cache. + borrow = cache.write().unwrap(); + borrow.insert::(key, constraint.take(), output.clone()); + LAST_WAS_HIT.with(|cell| cell.set(false)); - output - }) + output } /// Whether the last call was a hit. pub fn last_was_hit() -> bool { - CACHE.with(|cache| cache.borrow().last_was_hit) + LAST_WAS_HIT.with(|cell| cell.get()) } /// Get the next ID. pub fn id() -> usize { - ID.with(|cell| { - let current = cell.get(); - cell.set(current.wrapping_add(1)); - current - }) + ID.fetch_add(1, Ordering::SeqCst) } /// Evict the cache. @@ -100,69 +103,71 @@ pub fn id() -> usize { /// Comemo's cache is thread-local, meaning that this only evicts this thread's /// cache. pub fn evict(max_age: usize) { - CACHE.with(|cache| { - let mut cache = cache.borrow_mut(); - cache.map.retain(|_, entries| { + CACHES.read().unwrap().iter().for_each(|fun| fun(max_age)); + ACCELERATOR.lock().unwrap().clear(); +} + +/// The global cache. +pub struct Cache { + /// Maps from hashes to memoized results. + entries: HashMap>>, +} + +impl Default for Cache { + fn default() -> Self { + Self { entries: HashMap::new() } + } +} + +impl Cache { + pub fn new() -> Self { + Self::default() + } + + pub fn evict(&mut self, max_age: usize) { + self.entries.retain(|_, entries| { entries.retain_mut(|entry| { entry.age += 1; entry.age <= max_age }); !entries.is_empty() }); - }); - ACCELERATOR.with(|accelerator| accelerator.borrow_mut().clear()); -} - -/// The global cache. -#[derive(Default)] -struct Cache { - /// Maps from function IDs + hashes to memoized results. - map: HashMap<(TypeId, u128), Vec>, - /// Whether the last call was a hit. - last_was_hit: bool, -} + } -impl Cache { /// Look for a matching entry in the cache. - fn lookup( + fn lookup( &mut self, - key: (TypeId, u128), + key: u128, input: &In, ) -> Option<&Constrained> where - In: Input, - Out: Clone + 'static, + In: Input, { - self.map + self.entries .get_mut(&key)? .iter_mut() .rev() - .find_map(|entry| entry.lookup::(input)) + .find_map(|entry| entry.lookup::(input)) } /// Insert an entry into the cache. - fn insert( - &mut self, - key: (TypeId, u128), - constraint: In::Constraint, - output: Out, - ) where - In: Input, - Out: 'static, + fn insert(&mut self, key: u128, constraint: In::Constraint, output: Out) + where + In: Input, { - self.map + self.entries .entry(key) .or_default() - .push(CacheEntry::new::(constraint, output)); + .push(CacheEntry::new::(constraint, output)); } } /// A memoized result. -struct CacheEntry { +struct CacheEntry { /// The memoized function's constrained output. /// /// This is of type `Constrained`. - constrained: Box, + constrained: Constrained, /// How many evictions have passed since the entry has been last used. age: usize, } @@ -175,38 +180,52 @@ struct Constrained { output: T, } -impl CacheEntry { +impl CacheEntry { /// Create a new entry. - fn new(constraint: In::Constraint, output: Out) -> Self + fn new(constraint: In::Constraint, output: Out) -> Self where - In: Input, - Out: 'static, + In: Input, { Self { - constrained: Box::new(Constrained { constraint, output }), + constrained: Constrained { constraint, output }, age: 0, } } /// Return the entry's output if it is valid for the given input. - fn lookup(&mut self, input: &In) -> Option<&Constrained> + fn lookup(&mut self, input: &In) -> Option<&Constrained> where - In: Input, - Out: Clone + 'static, + In: Input, { - let constrained: &Constrained = - self.constrained.downcast_ref().expect("wrong entry type"); - - input.validate(&constrained.constraint).then(|| { + input.validate(&self.constrained.constraint).then(|| { self.age = 0; - constrained + &self.constrained }) } } /// Defines a constraint for a tracked type. -#[derive(Clone)] -pub struct Constraint(RefCell>>); +pub struct Constraint(RwLock>); + +struct Inner { + calls: Vec>, + immutable: FxHashMap, +} + +impl Clone for Constraint { + fn clone(&self) -> Self { + Self(RwLock::new(self.0.read().unwrap().clone())) + } +} + +impl Clone for Inner { + fn clone(&self) -> Self { + Self { + calls: self.calls.clone(), + immutable: self.immutable.clone(), + } + } +} /// A call entry. #[derive(Clone)] @@ -217,7 +236,34 @@ struct Call { mutable: bool, } -impl Constraint { +impl Inner { + /// Enter a constraint for a call to an immutable function. + #[inline] + fn push_inner(&mut self, call: Cow>) { + if !call.mutable { + if let Some(_prev) = self.immutable.get(&call.both) { + #[cfg(debug_assertions)] + { + let prev = &self.calls[*_prev]; + if prev.args == call.args { + check(prev.ret, call.ret); + } + } + + return; + } + } + + if call.mutable { + self.immutable.clear(); + } else { + self.immutable.insert(call.both, self.calls.len()); + } + self.calls.push(call.into_owned()); + } +} + +impl Constraint { /// Create empty constraints. pub fn new() -> Self { Self::default() @@ -227,32 +273,100 @@ impl Constraint { #[inline] pub fn push(&self, args: T, ret: u128, mutable: bool) { let both = hash(&(&args, ret)); - self.push_inner(Call { args, ret, both, mutable }); + self.0.write().unwrap().push_inner(Cow::Owned(Call { args, ret, both, mutable })); } - /// Enter a constraint for a call to an immutable function. + /// Whether the method satisfies as all input-output pairs. #[inline] - fn push_inner(&self, call: Call) { - let mut calls = self.0.borrow_mut(); + pub fn validate(&self, mut f: F) -> bool + where + F: FnMut(&T) -> u128, + { + self.0 + .read() + .unwrap() + .calls + .iter() + .all(|entry| f(&entry.args) == entry.ret) + } - if !call.mutable { - for prev in calls.iter().rev() { - if prev.mutable { - break; - } + /// Whether the method satisfies as all input-output pairs. + #[inline] + pub fn validate_with_id(&self, mut f: F, id: usize) -> bool + where + F: FnMut(&T) -> u128, + { + let inner = self.0.read().unwrap(); + let mut map = ACCELERATOR.lock().unwrap(); + inner.calls.iter().all(|entry| { + *map.entry((id, entry.both)).or_insert_with(|| f(&entry.args)) == entry.ret + }) + } - #[cfg(debug_assertions)] - if prev.args == call.args { - check(prev.ret, call.ret); - } + /// Replay all input-output pairs. + #[inline] + pub fn replay(&self, mut f: F) + where + F: FnMut(&T), + { + for entry in self.0.read().unwrap().calls.iter() { + if entry.mutable { + f(&entry.args); + } + } + } +} - if prev.both == call.both { - return; - } +impl Default for Constraint { + fn default() -> Self { + Self(RwLock::new(Inner { calls: Vec::new(), immutable: FxHashMap::default() })) + } +} + +impl Default for Inner { + fn default() -> Self { + Self { calls: Vec::new(), immutable: FxHashMap::default() } + } +} + +/// Defines a constraint for a tracked type. +pub struct ImmutableConstraint(RwLock>>); + +impl Clone for ImmutableConstraint { + fn clone(&self) -> Self { + Self(RwLock::new(self.0.read().unwrap().clone())) + } +} + +impl ImmutableConstraint { + /// Create empty constraints. + pub fn new() -> Self { + Self::default() + } + + /// Enter a constraint for a call to an immutable function. + #[inline] + pub fn push(&self, args: T, ret: u128, mutable: bool) { + let both = hash(&(&args, ret)); + self.push_inner(Cow::Owned(Call { args, ret, both, mutable })); + } + + /// Enter a constraint for a call to an immutable function. + #[inline] + fn push_inner(&self, call: Cow>) { + let mut calls = self.0.write().unwrap(); + debug_assert!(!call.mutable); + + if let Some(_prev) = calls.get(&call.both) { + #[cfg(debug_assertions)] + if _prev.args == call.args { + check(_prev.ret, call.ret); } + + return; } - calls.push(call); + calls.insert(call.both, call.into_owned()); } /// Whether the method satisfies as all input-output pairs. @@ -261,7 +375,11 @@ impl Constraint { where F: FnMut(&T) -> u128, { - self.0.borrow().iter().all(|entry| f(&entry.args) == entry.ret) + self.0 + .read() + .unwrap() + .values() + .all(|entry| f(&entry.args) == entry.ret) } /// Whether the method satisfies as all input-output pairs. @@ -270,33 +388,28 @@ impl Constraint { where F: FnMut(&T) -> u128, { - let calls = self.0.borrow(); - ACCELERATOR.with(|accelerator| { - let mut map = accelerator.borrow_mut(); - calls.iter().all(|entry| { - *map.entry((id, entry.both)).or_insert_with(|| f(&entry.args)) - == entry.ret - }) + let calls = self.0.read().unwrap(); + let mut map = ACCELERATOR.lock().unwrap(); + calls.values().all(|entry| { + *map.entry((id, entry.both)).or_insert_with(|| f(&entry.args)) == entry.ret }) } /// Replay all input-output pairs. #[inline] - pub fn replay(&self, mut f: F) + pub fn replay(&self, _: F) where F: FnMut(&T), { - for entry in self.0.borrow().iter() { - if entry.mutable { - f(&entry.args); - } + for entry in self.0.read().unwrap().values() { + debug_assert!(!entry.mutable); } } } -impl Default for Constraint { +impl Default for ImmutableConstraint { fn default() -> Self { - Self(RefCell::new(vec![])) + Self(RwLock::new(FxIndexMap::with_hasher(Default::default()))) } } @@ -326,14 +439,29 @@ impl Join for Option<&T> { impl Join for Constraint { #[inline] fn join(&self, inner: &Self) { - for call in inner.0.borrow().iter() { - self.push_inner(call.clone()); + let mut this = self.0.write().unwrap(); + for call in inner.0.read().unwrap().calls.iter() { + this.push_inner(Cow::Borrowed(call)); + } + } + + #[inline] + fn take(&self) -> Self { + Self(RwLock::new(std::mem::take(&mut *self.0.write().unwrap()))) + } +} + +impl Join for ImmutableConstraint { + #[inline] + fn join(&self, inner: &Self) { + for call in inner.0.read().unwrap().values() { + self.push_inner(Cow::Borrowed(call)); } } #[inline] fn take(&self) -> Self { - Self(RefCell::new(std::mem::take(&mut *self.0.borrow_mut()))) + Self(RwLock::new(std::mem::take(&mut *self.0.write().unwrap()))) } } diff --git a/src/lib.rs b/src/lib.rs index 9c5b77b..16968ff 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -95,7 +95,11 @@ pub use comemo_macros::{memoize, track}; /// These are implementation details. Do not rely on them! #[doc(hidden)] pub mod internal { - pub use crate::cache::{hash, last_was_hit, memoized, Constraint}; - pub use crate::input::{assert_hashable_or_trackable, Args}; + pub use crate::cache::{ + hash, last_was_hit, memoized, register_cache, Cache, Constraint, + ImmutableConstraint, + }; + pub use crate::input::{assert_hashable_or_trackable, Args, Input}; pub use crate::track::{to_parts_mut_mut, to_parts_mut_ref, to_parts_ref, Surfaces}; + pub use once_cell::sync::Lazy; } diff --git a/tests/tests.rs b/tests/tests.rs index 9674272..6463f91 100644 --- a/tests/tests.rs +++ b/tests/tests.rs @@ -212,27 +212,12 @@ fn test_kinds() { } } - #[memoize] - fn generic(tester: Tracky, name: T) -> String - where - T: AsRef + Hash, - { - tester.double_ref(name.as_ref()).to_string() - } - - #[memoize] - fn ignorant(tester: Tracky, name: impl AsRef + Hash) -> String { - tester.arg_ref(name.as_ref()).to_string() - } - let mut tester = Tester { data: "Hi".to_string() }; let tracky = tester.track(); test!(miss: selfie(tracky), "Hi"); test!(miss: unconditional(tracky), "Short"); test!(hit: unconditional(tracky), "Short"); - test!(miss: generic(tracky, "World"), "World"); - test!(miss: ignorant(tracky, "Ignorant"), "Ignorant"); test!(hit: selfie(tracky), "Hi"); tester.data.push('!'); @@ -240,15 +225,11 @@ fn test_kinds() { let tracky = tester.track(); test!(miss: selfie(tracky), "Hi!"); test!(miss: unconditional(tracky), "Short"); - test!(hit: generic(tracky, "World"), "World"); - test!(hit: ignorant(tracky, "Ignorant"), "Ignorant"); tester.data.push_str(" Let's go."); let tracky = tester.track(); test!(miss: unconditional(tracky), "Long"); - test!(miss: generic(tracky, "World"), "Hi! Let's go."); - test!(hit: ignorant(tracky, "Ignorant"), "Ignorant"); } /// Test with type alias. From 852c933d041cbb9ecc92a7150cb7117088246a96 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20d=27Herbais=20de=20Thun=20=28Dherse=29?= Date: Mon, 4 Dec 2023 18:23:57 +0100 Subject: [PATCH 02/28] fmt --- src/cache.rs | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/cache.rs b/src/cache.rs index 058787f..0fcfcc8 100644 --- a/src/cache.rs +++ b/src/cache.rs @@ -1,7 +1,7 @@ use std::borrow::Cow; use std::cell::Cell; use std::collections::HashMap; -use std::hash::{Hash, BuildHasherDefault}; +use std::hash::{BuildHasherDefault, Hash}; use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::{Mutex, RwLock}; @@ -273,7 +273,10 @@ impl Constraint { #[inline] pub fn push(&self, args: T, ret: u128, mutable: bool) { let both = hash(&(&args, ret)); - self.0.write().unwrap().push_inner(Cow::Owned(Call { args, ret, both, mutable })); + self.0 + .write() + .unwrap() + .push_inner(Cow::Owned(Call { args, ret, both, mutable })); } /// Whether the method satisfies as all input-output pairs. From 500508bea8df33767583c74f28f5e0c91a06bbe4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20d=27Herbais=20de=20Thun=20=28Dherse=29?= Date: Fri, 8 Dec 2023 20:11:45 +0100 Subject: [PATCH 03/28] Cleanup & improvements --- Cargo.toml | 7 +- macros/src/track.rs | 3 +- src/cache.rs | 155 +++++++++++++++++++++++--------------------- src/lib.rs | 7 +- 4 files changed, 92 insertions(+), 80 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index d0a3117..ab21719 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,9 +10,12 @@ license = "MIT OR Apache-2.0" categories = ["caching"] keywords = ["incremental", "memoization", "tracked", "constraints"] +[features] +default = [ "last_was_hit" ] +last_was_hit = [] + [dependencies] comemo-macros = { version = "0.3.1", path = "macros" } -indexmap = "2.1.0" +hashbrown = "0.14.3" once_cell = "1.18.0" -rustc-hash = "1.1.0" siphasher = "1" diff --git a/macros/src/track.rs b/macros/src/track.rs index cefa926..5ca29ed 100644 --- a/macros/src/track.rs +++ b/macros/src/track.rs @@ -229,8 +229,9 @@ fn create( }; // Prepare replying. + let immutable = methods.iter().all(|m| !m.mutable); let replays = methods.iter().map(create_replay); - let replay = methods.iter().any(|m| m.mutable).then(|| { + let replay = (!immutable).then(|| { quote! { constraint.replay(|call| match &call.0 { #(#replays,)* }); } diff --git a/src/cache.rs b/src/cache.rs index 0fcfcc8..228f73f 100644 --- a/src/cache.rs +++ b/src/cache.rs @@ -1,29 +1,30 @@ use std::borrow::Cow; -use std::cell::Cell; -use std::collections::HashMap; -use std::hash::{BuildHasherDefault, Hash}; +use std::hash::Hash; use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::{Mutex, RwLock}; +use hashbrown::HashMap; use once_cell::sync::Lazy; -use rustc_hash::{FxHashMap, FxHasher}; use siphasher::sip128::{Hasher128, SipHasher13}; use crate::input::Input; -type FxIndexMap = indexmap::IndexMap>; - +/// The global list of caches. static CACHES: RwLock> = RwLock::new(Vec::new()); -static ACCELERATOR: Lazy>> = - Lazy::new(|| Mutex::new(FxHashMap::default())); +/// The global accelerator. +static ACCELERATOR: Lazy>> = + Lazy::new(|| Mutex::new(HashMap::default())); +/// Register a cache in the global list. pub fn register_cache(fun: fn(usize)) { CACHES.write().unwrap().push(fun); } +#[cfg(feature = "last_was_hit")] thread_local! { - static LAST_WAS_HIT: Cell = const { Cell::new(false) }; + /// Whether the last call was a hit. + static LAST_WAS_HIT: std::cell::Cell = const { std::cell::Cell::new(false) }; } /// The global ID counter for tracked values. Each tracked value gets a @@ -52,16 +53,16 @@ where // Check if there is a cached output. let mut borrow = cache.write().unwrap(); - if let Some(constrained) = borrow.lookup::(key, &input) { + if let Some((constrained, value)) = borrow.lookup::(key, &input) { // Replay the mutations. - input.replay(&constrained.constraint); + input.replay(constrained); // Add the cached constraints to the outer ones. - input.retrack(constraint).1.join(&constrained.constraint); + input.retrack(constraint).1.join(constrained); - let value = constrained.output.clone(); + #[cfg(feature = "last_was_hit")] LAST_WAS_HIT.with(|cell| cell.set(true)); - return value; + return value.clone(); } // Release the borrow so that nested memoized calls can access the @@ -78,12 +79,14 @@ where // Insert the result into the cache. borrow = cache.write().unwrap(); borrow.insert::(key, constraint.take(), output.clone()); + #[cfg(feature = "last_was_hit")] LAST_WAS_HIT.with(|cell| cell.set(false)); output } /// Whether the last call was a hit. +#[cfg(feature = "last_was_hit")] pub fn last_was_hit() -> bool { LAST_WAS_HIT.with(|cell| cell.get()) } @@ -135,11 +138,7 @@ impl Cache { } /// Look for a matching entry in the cache. - fn lookup( - &mut self, - key: u128, - input: &In, - ) -> Option<&Constrained> + fn lookup(&mut self, key: u128, input: &In) -> Option<(&In::Constraint, &Out)> where In: Input, { @@ -164,20 +163,12 @@ impl Cache { /// A memoized result. struct CacheEntry { - /// The memoized function's constrained output. - /// - /// This is of type `Constrained`. - constrained: Constrained, - /// How many evictions have passed since the entry has been last used. - age: usize, -} - -/// A value with a constraint. -struct Constrained { - /// The constraint which must be fulfilled for the output to be used. + /// The memoized function's constraint. constraint: C, /// The memoized function's output. - output: T, + output: Out, + /// How many evictions have passed since the entry has been last used. + age: usize, } impl CacheEntry { @@ -186,30 +177,37 @@ impl CacheEntry { where In: Input, { - Self { - constrained: Constrained { constraint, output }, - age: 0, - } + Self { constraint, output, age: 0 } } /// Return the entry's output if it is valid for the given input. - fn lookup(&mut self, input: &In) -> Option<&Constrained> + fn lookup(&mut self, input: &In) -> Option<(&In::Constraint, &Out)> where In: Input, { - input.validate(&self.constrained.constraint).then(|| { + input.validate(&self.constraint).then(|| { self.age = 0; - &self.constrained + (&self.constraint, &self.output) }) } } +/// A call entry. +#[derive(Clone)] +struct Call { + args: T, + args_hash: u128, + ret: u128, + both: u128, + mutable: bool, +} + /// Defines a constraint for a tracked type. pub struct Constraint(RwLock>); struct Inner { calls: Vec>, - immutable: FxHashMap, + immutable: HashMap, } impl Clone for Constraint { @@ -227,21 +225,14 @@ impl Clone for Inner { } } -/// A call entry. -#[derive(Clone)] -struct Call { - args: T, - ret: u128, - both: u128, - mutable: bool, -} - impl Inner { /// Enter a constraint for a call to an immutable function. #[inline] fn push_inner(&mut self, call: Cow>) { + // If the call is not mutable check whether we already have a call + // with the same arguments and return value. if !call.mutable { - if let Some(_prev) = self.immutable.get(&call.both) { + if let Some(_prev) = self.immutable.get(&call.args_hash) { #[cfg(debug_assertions)] { let prev = &self.calls[*_prev]; @@ -255,10 +246,14 @@ impl Inner { } if call.mutable { + // If the call is mutable, clear all immutable calls. self.immutable.clear(); } else { - self.immutable.insert(call.both, self.calls.len()); + // Otherwise, insert the call into the immutable map. + self.immutable.insert(call.args_hash, self.calls.len()); } + + // Insert the call into the call list. self.calls.push(call.into_owned()); } } @@ -272,11 +267,15 @@ impl Constraint { /// Enter a constraint for a call to an immutable function. #[inline] pub fn push(&self, args: T, ret: u128, mutable: bool) { - let both = hash(&(&args, ret)); - self.0 - .write() - .unwrap() - .push_inner(Cow::Owned(Call { args, ret, both, mutable })); + let args_hash = hash(&args); + let both = hash(&(args_hash, ret)); + self.0.write().unwrap().push_inner(Cow::Owned(Call { + args, + args_hash, + ret, + both, + mutable, + })); } /// Whether the method satisfies as all input-output pairs. @@ -312,34 +311,32 @@ impl Constraint { where F: FnMut(&T), { - for entry in self.0.read().unwrap().calls.iter() { - if entry.mutable { - f(&entry.args); - } - } + self.0 + .read() + .unwrap() + .calls + .iter() + .filter(|call| call.mutable) + .for_each(|call| { + f(&call.args); + }); } } impl Default for Constraint { fn default() -> Self { - Self(RwLock::new(Inner { calls: Vec::new(), immutable: FxHashMap::default() })) + Self(RwLock::new(Inner { calls: Vec::new(), immutable: HashMap::default() })) } } impl Default for Inner { fn default() -> Self { - Self { calls: Vec::new(), immutable: FxHashMap::default() } + Self { calls: Vec::new(), immutable: HashMap::default() } } } /// Defines a constraint for a tracked type. -pub struct ImmutableConstraint(RwLock>>); - -impl Clone for ImmutableConstraint { - fn clone(&self) -> Self { - Self(RwLock::new(self.0.read().unwrap().clone())) - } -} +pub struct ImmutableConstraint(RwLock>>); impl ImmutableConstraint { /// Create empty constraints. @@ -350,8 +347,9 @@ impl ImmutableConstraint { /// Enter a constraint for a call to an immutable function. #[inline] pub fn push(&self, args: T, ret: u128, mutable: bool) { - let both = hash(&(&args, ret)); - self.push_inner(Cow::Owned(Call { args, ret, both, mutable })); + let args_hash = hash(&args); + let both = hash(&(args_hash, ret)); + self.push_inner(Cow::Owned(Call { args, args_hash, ret, both, mutable })); } /// Enter a constraint for a call to an immutable function. @@ -360,7 +358,7 @@ impl ImmutableConstraint { let mut calls = self.0.write().unwrap(); debug_assert!(!call.mutable); - if let Some(_prev) = calls.get(&call.both) { + if let Some(_prev) = calls.get(&call.args_hash) { #[cfg(debug_assertions)] if _prev.args == call.args { check(_prev.ret, call.ret); @@ -369,7 +367,7 @@ impl ImmutableConstraint { return; } - calls.insert(call.both, call.into_owned()); + calls.insert(call.args_hash, call.into_owned()); } /// Whether the method satisfies as all input-output pairs. @@ -404,15 +402,22 @@ impl ImmutableConstraint { where F: FnMut(&T), { + #[cfg(debug_assertions)] for entry in self.0.read().unwrap().values() { - debug_assert!(!entry.mutable); + assert!(!entry.mutable); } } } +impl Clone for ImmutableConstraint { + fn clone(&self) -> Self { + Self(RwLock::new(self.0.read().unwrap().clone())) + } +} + impl Default for ImmutableConstraint { fn default() -> Self { - Self(RwLock::new(FxIndexMap::with_hasher(Default::default()))) + Self(RwLock::new(HashMap::with_hasher(Default::default()))) } } diff --git a/src/lib.rs b/src/lib.rs index 16968ff..71e7c80 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -96,10 +96,13 @@ pub use comemo_macros::{memoize, track}; #[doc(hidden)] pub mod internal { pub use crate::cache::{ - hash, last_was_hit, memoized, register_cache, Cache, Constraint, - ImmutableConstraint, + hash, memoized, register_cache, Cache, Constraint, ImmutableConstraint, }; + pub use crate::input::{assert_hashable_or_trackable, Args, Input}; pub use crate::track::{to_parts_mut_mut, to_parts_mut_ref, to_parts_ref, Surfaces}; pub use once_cell::sync::Lazy; + + #[cfg(feature = "last_was_hit")] + pub use crate::cache::last_was_hit; } From e5129d0879615eb7ecfe43e9587506e32ea47f93 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20d=27Herbais=20de=20Thun=20=28Dherse=29?= Date: Fri, 8 Dec 2023 20:12:04 +0100 Subject: [PATCH 04/28] Fix comment --- src/cache.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/cache.rs b/src/cache.rs index 228f73f..3c6989f 100644 --- a/src/cache.rs +++ b/src/cache.rs @@ -66,7 +66,7 @@ where } // Release the borrow so that nested memoized calls can access the - // cache without panicking. + // cache without dead locking. drop(borrow); // Execute the function with the new constraints hooked in. From 3fbec115de38c6c9a6fa292adebc6b90f3cf1f6b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20d=27Herbais=20de=20Thun=20=28Dherse=29?= Date: Fri, 8 Dec 2023 20:14:55 +0100 Subject: [PATCH 05/28] Added comments --- src/cache.rs | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/src/cache.rs b/src/cache.rs index 3c6989f..efb5c26 100644 --- a/src/cache.rs +++ b/src/cache.rs @@ -123,10 +123,12 @@ impl Default for Cache { } impl Cache { + /// Create an empty cache. pub fn new() -> Self { Self::default() } + /// Evict all entries whose age is larger than or equal to `max_age`. pub fn evict(&mut self, max_age: usize) { self.entries.retain(|_, entries| { entries.retain_mut(|entry| { @@ -206,7 +208,13 @@ struct Call { pub struct Constraint(RwLock>); struct Inner { + /// The list of calls. + /// + /// Order matters here, as those are mutable & immutable calls. calls: Vec>, + /// The hash of the arguments and index of the call. + /// + /// Order does not matter here, as those are immutable calls. immutable: HashMap, } @@ -236,9 +244,7 @@ impl Inner { #[cfg(debug_assertions)] { let prev = &self.calls[*_prev]; - if prev.args == call.args { - check(prev.ret, call.ret); - } + check(prev.ret, call.ret); } return; From 2e8eb8f78f43d2c1845fa4ddc956240e228cf9a2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20d=27Herbais=20de=20Thun=20=28Dherse=29?= Date: Fri, 8 Dec 2023 20:19:55 +0100 Subject: [PATCH 06/28] Further test & debugging --- src/cache.rs | 27 +++++++++++++++++---------- 1 file changed, 17 insertions(+), 10 deletions(-) diff --git a/src/cache.rs b/src/cache.rs index efb5c26..d6c4690 100644 --- a/src/cache.rs +++ b/src/cache.rs @@ -237,15 +237,12 @@ impl Inner { /// Enter a constraint for a call to an immutable function. #[inline] fn push_inner(&mut self, call: Cow>) { - // If the call is not mutable check whether we already have a call + // If the call is immutable check whether we already have a call // with the same arguments and return value. if !call.mutable { if let Some(_prev) = self.immutable.get(&call.args_hash) { #[cfg(debug_assertions)] - { - let prev = &self.calls[*_prev]; - check(prev.ret, call.ret); - } + check(&self.calls[*_prev], &call); return; } @@ -366,9 +363,7 @@ impl ImmutableConstraint { if let Some(_prev) = calls.get(&call.args_hash) { #[cfg(debug_assertions)] - if _prev.args == call.args { - check(_prev.ret, call.ret); - } + check(_prev, &call); return; } @@ -491,11 +486,23 @@ pub fn hash(value: &T) -> u128 { #[inline] #[track_caller] #[allow(dead_code)] -fn check(left_hash: u128, right_hash: u128) { - if left_hash != right_hash { +fn check(lhs: &Call, rhs: &Call) { + if lhs.ret != rhs.ret { panic!( "comemo: found conflicting constraints. \ is this tracked function pure?" ) } + + // Additional checks for debugging. + if lhs.args_hash != rhs.args_hash + || lhs.args != rhs.args + || lhs.both != rhs.both + || lhs.mutable != rhs.mutable + { + panic!( + "comemo: found conflicting arguments | + this is a bug in comemo" + ) + } } From 3f9cb30f7c93d2774f07f10ab656e0ebad7144b1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20d=27Herbais=20de=20Thun=20=28Dherse=29?= Date: Fri, 8 Dec 2023 21:09:28 +0100 Subject: [PATCH 07/28] Cleanup & test fixes --- Cargo.toml | 3 +++ src/cache.rs | 12 ++---------- tests/tests.rs | 12 ++++++++++++ 3 files changed, 17 insertions(+), 10 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index ab21719..cd3bb87 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,3 +19,6 @@ comemo-macros = { version = "0.3.1", path = "macros" } hashbrown = "0.14.3" once_cell = "1.18.0" siphasher = "1" + +[dev-dependencies] +serial_test = "2.0.0" diff --git a/src/cache.rs b/src/cache.rs index d6c4690..f175571 100644 --- a/src/cache.rs +++ b/src/cache.rs @@ -207,6 +207,7 @@ struct Call { /// Defines a constraint for a tracked type. pub struct Constraint(RwLock>); +#[derive(Clone)] struct Inner { /// The list of calls. /// @@ -224,15 +225,6 @@ impl Clone for Constraint { } } -impl Clone for Inner { - fn clone(&self) -> Self { - Self { - calls: self.calls.clone(), - immutable: self.immutable.clone(), - } - } -} - impl Inner { /// Enter a constraint for a call to an immutable function. #[inline] @@ -418,7 +410,7 @@ impl Clone for ImmutableConstraint { impl Default for ImmutableConstraint { fn default() -> Self { - Self(RwLock::new(HashMap::with_hasher(Default::default()))) + Self(RwLock::new(HashMap::default())) } } diff --git a/tests/tests.rs b/tests/tests.rs index 6463f91..5cf3975 100644 --- a/tests/tests.rs +++ b/tests/tests.rs @@ -3,6 +3,7 @@ use std::hash::Hash; use std::path::{Path, PathBuf}; use comemo::{evict, memoize, track, Track, Tracked, TrackedMut, Validate}; +use serial_test::serial; macro_rules! test { (miss: $call:expr, $result:expr) => {{ @@ -17,6 +18,7 @@ macro_rules! test { /// Test basic memoization. #[test] +#[serial] fn test_basic() { #[memoize] fn empty() -> String { @@ -71,6 +73,7 @@ fn test_basic() { /// Test the calc language. #[test] +#[serial] fn test_calc() { #[memoize] fn evaluate(script: &str, files: Tracked) -> i32 { @@ -116,6 +119,7 @@ impl Files { /// Test cache eviction. #[test] +#[serial] fn test_evict() { #[memoize] fn null() -> u8 { @@ -141,6 +145,7 @@ fn test_evict() { /// Test tracking a trait object. #[test] +#[serial] fn test_tracked_trait() { #[memoize] fn traity(loader: Tracked, path: &Path) -> Vec { @@ -172,6 +177,7 @@ impl Loader for StaticLoader { /// Test memoized methods. #[test] +#[serial] fn test_memoized_methods() { #[derive(Hash)] struct Taker(String); @@ -197,6 +203,7 @@ fn test_memoized_methods() { /// Test different kinds of arguments. #[test] +#[serial] fn test_kinds() { #[memoize] fn selfie(tester: Tracky) -> String { @@ -277,6 +284,7 @@ impl Empty {} /// Test tracking a type with a lifetime. #[test] +#[serial] fn test_lifetime() { #[comemo::memoize] fn contains_hello(lifeful: Tracked) -> bool { @@ -304,6 +312,7 @@ impl<'a> Lifeful<'a> { /// Test tracking a type with a chain of tracked values. #[test] +#[serial] fn test_chain() { #[comemo::memoize] fn process(chain: Tracked, value: u32) -> bool { @@ -329,6 +338,7 @@ fn test_chain() { /// Test that `Tracked` is covariant over `T`. #[test] +#[serial] #[allow(unused, clippy::needless_lifetimes)] fn test_variance() { fn foo<'a>(_: Tracked<'a, Chain<'a>>) {} @@ -366,6 +376,7 @@ impl<'a> Chain<'a> { /// Test mutable tracking. #[test] +#[serial] #[rustfmt::skip] fn test_mutable() { #[comemo::memoize] @@ -414,6 +425,7 @@ struct Heavy(String); /// Test a tracked method that is impure. #[test] +#[serial] #[cfg(debug_assertions)] #[should_panic( expected = "comemo: found conflicting constraints. is this tracked function pure?" From d61b0b5daee879d936f8a1413ff4a225e9dd30d2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20d=27Herbais=20de=20Thun=20=28Dherse=29?= Date: Fri, 8 Dec 2023 23:25:07 +0100 Subject: [PATCH 08/28] Remove `unique` from `memoize` --- macros/src/memoize.rs | 2 -- 1 file changed, 2 deletions(-) diff --git a/macros/src/memoize.rs b/macros/src/memoize.rs index 00044ef..2329825 100644 --- a/macros/src/memoize.rs +++ b/macros/src/memoize.rs @@ -124,7 +124,6 @@ fn process(function: &Function) -> Result { ident.mutability = None; } - let unique = quote! { __ComemoUnique }; wrapped.block = parse_quote! { { static __CACHE: ::comemo::internal::Lazy< ::std::sync::RwLock< @@ -145,7 +144,6 @@ fn process(function: &Function) -> Result { __CACHE.write().unwrap().evict(max_age); } - struct #unique; #(#bounds;)* ::comemo::internal::memoized( ::comemo::internal::Args(#arg_tuple), From 5dcab10b6d55bc8cffda7eb0d01823a00ff67c2c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20d=27Herbais=20de=20Thun=20=28Dherse=29?= Date: Fri, 8 Dec 2023 23:26:23 +0100 Subject: [PATCH 09/28] Fixed pential `dyn Trait` issue --- macros/src/track.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/macros/src/track.rs b/macros/src/track.rs index 5ca29ed..e8e22bf 100644 --- a/macros/src/track.rs +++ b/macros/src/track.rs @@ -36,7 +36,7 @@ pub fn expand(item: &syn::Item) -> Result { } let name = &item.ident; - let ty = parse_quote! { dyn #name + '__comemo_dynamic }; + let ty = parse_quote! { dyn #name + Send + Sync + '__comemo_dynamic }; (ty, &item.generics, Some(name.clone())) } _ => bail!(item, "`track` can only be applied to impl blocks and traits"), From 7796af4eaa24d7a087de5e153e5e6b0a634e46f3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20d=27Herbais=20de=20Thun=20=28Dherse=29?= Date: Fri, 8 Dec 2023 23:28:41 +0100 Subject: [PATCH 10/28] Fix dyn test --- tests/tests.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/tests.rs b/tests/tests.rs index 5cf3975..57e5b1d 100644 --- a/tests/tests.rs +++ b/tests/tests.rs @@ -148,15 +148,15 @@ fn test_evict() { #[serial] fn test_tracked_trait() { #[memoize] - fn traity(loader: Tracked, path: &Path) -> Vec { + fn traity(loader: Tracked, path: &Path) -> Vec { loader.load(path).unwrap() } - fn wrapper(loader: &dyn Loader, path: &Path) -> Vec { + fn wrapper(loader: &(dyn Loader + Send + Sync), path: &Path) -> Vec { traity(loader.track(), path) } - let loader: &dyn Loader = &StaticLoader; + let loader: &(dyn Loader + Send + Sync) = &StaticLoader; test!(miss: traity(loader.track(), Path::new("hi.rs")), [1, 2, 3]); test!(hit: traity(loader.track(), Path::new("hi.rs")), [1, 2, 3]); test!(miss: traity(loader.track(), Path::new("bye.rs")), [1, 2, 3]); @@ -164,7 +164,7 @@ fn test_tracked_trait() { } #[track] -trait Loader { +trait Loader: Send + Sync { fn load(&self, path: &Path) -> Result, String>; } From 3968f1060f74d2ddd016b7826c9b6aa9b06a6547 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20d=27Herbais=20de=20Thun=20=28Dherse=29?= Date: Fri, 8 Dec 2023 23:37:30 +0100 Subject: [PATCH 11/28] Moved to `parking_lot` (no poison) --- Cargo.toml | 1 + macros/src/memoize.rs | 6 ++--- src/cache.rs | 54 ++++++++++++++++++------------------------- src/lib.rs | 2 ++ 4 files changed, 28 insertions(+), 35 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index cd3bb87..a2b596f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,6 +18,7 @@ last_was_hit = [] comemo-macros = { version = "0.3.1", path = "macros" } hashbrown = "0.14.3" once_cell = "1.18.0" +parking_lot = "0.12.1" siphasher = "1" [dev-dependencies] diff --git a/macros/src/memoize.rs b/macros/src/memoize.rs index 2329825..a517d8f 100644 --- a/macros/src/memoize.rs +++ b/macros/src/memoize.rs @@ -126,7 +126,7 @@ fn process(function: &Function) -> Result { wrapped.block = parse_quote! { { static __CACHE: ::comemo::internal::Lazy< - ::std::sync::RwLock< + ::comemo::internal::RwLock< ::comemo::internal::Cache< <::comemo::internal::Args<#arg_ty_tuple> as ::comemo::internal::Input>::Constraint, #output, @@ -136,12 +136,12 @@ fn process(function: &Function) -> Result { ::comemo::internal::Lazy::new( || { ::comemo::internal::register_cache(evict); - ::std::sync::RwLock::new(::comemo::internal::Cache::new()) + ::comemo::internal::RwLock::new(::comemo::internal::Cache::new()) } ); fn evict(max_age: usize) { - __CACHE.write().unwrap().evict(max_age); + __CACHE.write().evict(max_age); } #(#bounds;)* diff --git a/src/cache.rs b/src/cache.rs index f175571..35d5798 100644 --- a/src/cache.rs +++ b/src/cache.rs @@ -1,10 +1,10 @@ use std::borrow::Cow; use std::hash::Hash; use std::sync::atomic::{AtomicUsize, Ordering}; -use std::sync::{Mutex, RwLock}; use hashbrown::HashMap; use once_cell::sync::Lazy; +use parking_lot::{Mutex, RwLock}; use siphasher::sip128::{Hasher128, SipHasher13}; use crate::input::Input; @@ -18,7 +18,7 @@ static ACCELERATOR: Lazy>> = /// Register a cache in the global list. pub fn register_cache(fun: fn(usize)) { - CACHES.write().unwrap().push(fun); + CACHES.write().push(fun); } #[cfg(feature = "last_was_hit")] @@ -52,7 +52,7 @@ where }; // Check if there is a cached output. - let mut borrow = cache.write().unwrap(); + let mut borrow = cache.write(); if let Some((constrained, value)) = borrow.lookup::(key, &input) { // Replay the mutations. input.replay(constrained); @@ -77,7 +77,7 @@ where outer.join(constraint); // Insert the result into the cache. - borrow = cache.write().unwrap(); + borrow = cache.write(); borrow.insert::(key, constraint.take(), output.clone()); #[cfg(feature = "last_was_hit")] LAST_WAS_HIT.with(|cell| cell.set(false)); @@ -106,8 +106,8 @@ pub fn id() -> usize { /// Comemo's cache is thread-local, meaning that this only evicts this thread's /// cache. pub fn evict(max_age: usize) { - CACHES.read().unwrap().iter().for_each(|fun| fun(max_age)); - ACCELERATOR.lock().unwrap().clear(); + CACHES.read().iter().for_each(|fun| fun(max_age)); + ACCELERATOR.lock().clear(); } /// The global cache. @@ -221,7 +221,7 @@ struct Inner { impl Clone for Constraint { fn clone(&self) -> Self { - Self(RwLock::new(self.0.read().unwrap().clone())) + Self(RwLock::new(self.0.read().clone())) } } @@ -264,7 +264,7 @@ impl Constraint { pub fn push(&self, args: T, ret: u128, mutable: bool) { let args_hash = hash(&args); let both = hash(&(args_hash, ret)); - self.0.write().unwrap().push_inner(Cow::Owned(Call { + self.0.write().push_inner(Cow::Owned(Call { args, args_hash, ret, @@ -279,12 +279,7 @@ impl Constraint { where F: FnMut(&T) -> u128, { - self.0 - .read() - .unwrap() - .calls - .iter() - .all(|entry| f(&entry.args) == entry.ret) + self.0.read().calls.iter().all(|entry| f(&entry.args) == entry.ret) } /// Whether the method satisfies as all input-output pairs. @@ -293,8 +288,8 @@ impl Constraint { where F: FnMut(&T) -> u128, { - let inner = self.0.read().unwrap(); - let mut map = ACCELERATOR.lock().unwrap(); + let inner = self.0.read(); + let mut map = ACCELERATOR.lock(); inner.calls.iter().all(|entry| { *map.entry((id, entry.both)).or_insert_with(|| f(&entry.args)) == entry.ret }) @@ -308,7 +303,6 @@ impl Constraint { { self.0 .read() - .unwrap() .calls .iter() .filter(|call| call.mutable) @@ -350,7 +344,7 @@ impl ImmutableConstraint { /// Enter a constraint for a call to an immutable function. #[inline] fn push_inner(&self, call: Cow>) { - let mut calls = self.0.write().unwrap(); + let mut calls = self.0.write(); debug_assert!(!call.mutable); if let Some(_prev) = calls.get(&call.args_hash) { @@ -369,11 +363,7 @@ impl ImmutableConstraint { where F: FnMut(&T) -> u128, { - self.0 - .read() - .unwrap() - .values() - .all(|entry| f(&entry.args) == entry.ret) + self.0.read().values().all(|entry| f(&entry.args) == entry.ret) } /// Whether the method satisfies as all input-output pairs. @@ -382,8 +372,8 @@ impl ImmutableConstraint { where F: FnMut(&T) -> u128, { - let calls = self.0.read().unwrap(); - let mut map = ACCELERATOR.lock().unwrap(); + let calls = self.0.read(); + let mut map = ACCELERATOR.lock(); calls.values().all(|entry| { *map.entry((id, entry.both)).or_insert_with(|| f(&entry.args)) == entry.ret }) @@ -396,7 +386,7 @@ impl ImmutableConstraint { F: FnMut(&T), { #[cfg(debug_assertions)] - for entry in self.0.read().unwrap().values() { + for entry in self.0.read().values() { assert!(!entry.mutable); } } @@ -404,7 +394,7 @@ impl ImmutableConstraint { impl Clone for ImmutableConstraint { fn clone(&self) -> Self { - Self(RwLock::new(self.0.read().unwrap().clone())) + Self(RwLock::new(self.0.read().clone())) } } @@ -440,29 +430,29 @@ impl Join for Option<&T> { impl Join for Constraint { #[inline] fn join(&self, inner: &Self) { - let mut this = self.0.write().unwrap(); - for call in inner.0.read().unwrap().calls.iter() { + let mut this = self.0.write(); + for call in inner.0.read().calls.iter() { this.push_inner(Cow::Borrowed(call)); } } #[inline] fn take(&self) -> Self { - Self(RwLock::new(std::mem::take(&mut *self.0.write().unwrap()))) + Self(RwLock::new(std::mem::take(&mut *self.0.write()))) } } impl Join for ImmutableConstraint { #[inline] fn join(&self, inner: &Self) { - for call in inner.0.read().unwrap().values() { + for call in inner.0.read().values() { self.push_inner(Cow::Borrowed(call)); } } #[inline] fn take(&self) -> Self { - Self(RwLock::new(std::mem::take(&mut *self.0.write().unwrap()))) + Self(RwLock::new(std::mem::take(&mut *self.0.write()))) } } diff --git a/src/lib.rs b/src/lib.rs index 71e7c80..a1dc4cf 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -95,6 +95,8 @@ pub use comemo_macros::{memoize, track}; /// These are implementation details. Do not rely on them! #[doc(hidden)] pub mod internal { + pub use parking_lot::RwLock; + pub use crate::cache::{ hash, memoized, register_cache, Cache, Constraint, ImmutableConstraint, }; From c8322f5ca08d39023a01cebb06914394ef7bccf9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20d=27Herbais=20de=20Thun=20=28Dherse=29?= Date: Sat, 9 Dec 2023 11:29:42 +0100 Subject: [PATCH 12/28] Local accelerator (optimization) --- examples/calc.rs | 12 +++---- macros/src/track.rs | 8 ++--- src/cache.rs | 39 ++++++++++------------- src/input.rs | 76 +++++++++++++++++++++++++++++++++++++++++++-- src/lib.rs | 3 +- src/track.rs | 35 ++++++++++++++------- tests/tests.rs | 27 ++++++++-------- 7 files changed, 139 insertions(+), 61 deletions(-) diff --git a/examples/calc.rs b/examples/calc.rs index 2ec6c4c..76c464c 100644 --- a/examples/calc.rs +++ b/examples/calc.rs @@ -16,33 +16,33 @@ fn main() { files.write("gamma.calc", "8 + 3"); // [Miss] The cache is empty. - assert_eq!(evaluate("eval alpha.calc", files.track()), 7); + assert_eq!(evaluate("eval alpha.calc", &files.track()), 7); // [Miss] This is not a top-level hit because this exact string was never // passed to `evaluate`, but this does not compute "2 + 3" again. - assert_eq!(evaluate("eval beta.calc", files.track()), 5); + assert_eq!(evaluate("eval beta.calc", &files.track()), 5); // Modify the gamma file. files.write("gamma.calc", "42"); // [Hit] This is a hit because `gamma.calc` isn't referenced by `alpha.calc`. - assert_eq!(evaluate("eval alpha.calc", files.track()), 7); + assert_eq!(evaluate("eval alpha.calc", &files.track()), 7); // Modify the beta file. files.write("beta.calc", "4 + eval gamma.calc"); // [Miss] This is a miss because `beta.calc` changed. - assert_eq!(evaluate("eval alpha.calc", files.track()), 48); + assert_eq!(evaluate("eval alpha.calc", &files.track()), 48); } /// Evaluate a `.calc` script. #[memoize] -fn evaluate(script: &str, files: Tracked) -> i32 { +fn evaluate(script: &str, files: &Tracked) -> i32 { script .split('+') .map(str::trim) .map(|part| match part.strip_prefix("eval ") { - Some(path) => evaluate(&files.read(path), files), + Some(path) => evaluate(&files.read(path), &files), None => part.parse::().unwrap(), }) .sum() diff --git a/macros/src/track.rs b/macros/src/track.rs index e8e22bf..ad1adc7 100644 --- a/macros/src/track.rs +++ b/macros/src/track.rs @@ -219,9 +219,9 @@ fn create( let validate_with_id = if !methods.is_empty() { quote! { let mut this = #maybe_cloned; - constraint.validate_with_id( + constraint.validate_with_accelerator( |call| match &call.0 { #(#validations,)* }, - id, + accelerator, ) } } else { @@ -262,7 +262,7 @@ fn create( } #[inline] - fn validate_with_id(&self, constraint: &Self::Constraint, id: usize) -> bool { + fn validate_with_accelerator(&self, constraint: &Self::Constraint, accelerator: &::comemo::internal::Accelerator) -> bool { #validate_with_id } @@ -378,7 +378,7 @@ fn create_wrapper(method: &Method, tracked_mut: bool) -> TokenStream { let args = &method.args; let mutable = method.mutable; let to_parts = if !tracked_mut { - quote! { to_parts_ref(self.0) } + quote! { to_parts_ref(&self.0) } } else if !mutable { quote! { to_parts_mut_ref(&self.0) } } else { diff --git a/src/cache.rs b/src/cache.rs index 35d5798..2aaf26f 100644 --- a/src/cache.rs +++ b/src/cache.rs @@ -1,21 +1,17 @@ use std::borrow::Cow; use std::hash::Hash; -use std::sync::atomic::{AtomicUsize, Ordering}; use hashbrown::HashMap; -use once_cell::sync::Lazy; use parking_lot::{Mutex, RwLock}; use siphasher::sip128::{Hasher128, SipHasher13}; use crate::input::Input; +pub type Accelerator = Mutex>; + /// The global list of caches. static CACHES: RwLock> = RwLock::new(Vec::new()); -/// The global accelerator. -static ACCELERATOR: Lazy>> = - Lazy::new(|| Mutex::new(HashMap::default())); - /// Register a cache in the global list. pub fn register_cache(fun: fn(usize)) { CACHES.write().push(fun); @@ -27,11 +23,6 @@ thread_local! { static LAST_WAS_HIT: std::cell::Cell = const { std::cell::Cell::new(false) }; } -/// The global ID counter for tracked values. Each tracked value gets a -/// unqiue ID based on which its validations are cached in the accelerator. -/// IDs may only be reused upon eviction of the accelerator. -static ID: AtomicUsize = AtomicUsize::new(0); - /// Execute a function or use a cached result for it. pub fn memoized<'c, In, Out, F>( mut input: In, @@ -91,11 +82,6 @@ pub fn last_was_hit() -> bool { LAST_WAS_HIT.with(|cell| cell.get()) } -/// Get the next ID. -pub fn id() -> usize { - ID.fetch_add(1, Ordering::SeqCst) -} - /// Evict the cache. /// /// This removes all memoized results from the cache whose age is larger than or @@ -107,7 +93,6 @@ pub fn id() -> usize { /// cache. pub fn evict(max_age: usize) { CACHES.read().iter().for_each(|fun| fun(max_age)); - ACCELERATOR.lock().clear(); } /// The global cache. @@ -284,14 +269,18 @@ impl Constraint { /// Whether the method satisfies as all input-output pairs. #[inline] - pub fn validate_with_id(&self, mut f: F, id: usize) -> bool + pub fn validate_with_accelerator( + &self, + mut f: F, + accelerator: &Accelerator, + ) -> bool where F: FnMut(&T) -> u128, { let inner = self.0.read(); - let mut map = ACCELERATOR.lock(); + let mut map = accelerator.lock(); inner.calls.iter().all(|entry| { - *map.entry((id, entry.both)).or_insert_with(|| f(&entry.args)) == entry.ret + *map.entry(entry.both).or_insert_with(|| f(&entry.args)) == entry.ret }) } @@ -368,14 +357,18 @@ impl ImmutableConstraint { /// Whether the method satisfies as all input-output pairs. #[inline] - pub fn validate_with_id(&self, mut f: F, id: usize) -> bool + pub fn validate_with_accelerator( + &self, + mut f: F, + accelerator: &Accelerator, + ) -> bool where F: FnMut(&T) -> u128, { let calls = self.0.read(); - let mut map = ACCELERATOR.lock(); + let mut map = accelerator.lock(); calls.values().all(|entry| { - *map.entry((id, entry.both)).or_insert_with(|| f(&entry.args)) == entry.ret + *map.entry(entry.both).or_insert_with(|| f(&entry.args)) == entry.ret }) } diff --git a/src/input.rs b/src/input.rs index d4185ea..8c3f8c5 100644 --- a/src/input.rs +++ b/src/input.rs @@ -84,7 +84,7 @@ where #[inline] fn validate(&self, constraint: &Self::Constraint) -> bool { - self.value.validate_with_id(constraint, self.id) + self.value.validate_with_accelerator(constraint, &self.accelerator) } #[inline] @@ -101,7 +101,44 @@ where let tracked = Tracked { value: self.value, constraint: Some(constraint), - id: self.id, + accelerator: self.accelerator, + }; + (tracked, self.constraint) + } +} + +impl<'a: 'b, 'b, T> Input for &'b Tracked<'a, T> +where + T: Track + ?Sized, +{ + // Forward constraint from `Trackable` implementation. + type Constraint = ::Constraint; + type Tracked<'r> = Tracked<'r, T> where Self: 'r; + type Outer = Option<&'a Self::Constraint>; + + #[inline] + fn key(&self, _: &mut H) {} + + #[inline] + fn validate(&self, constraint: &Self::Constraint) -> bool { + self.value.validate_with_accelerator(constraint, &self.accelerator) + } + + #[inline] + fn replay(&mut self, _: &Self::Constraint) {} + + #[inline] + fn retrack<'r>( + self, + constraint: &'r Self::Constraint, + ) -> (Self::Tracked<'r>, Self::Outer) + where + Self: 'r, + { + let tracked = Tracked { + value: self.value, + constraint: Some(constraint), + accelerator: self.accelerator.clone(), }; (tracked, self.constraint) } @@ -142,6 +179,41 @@ where } } +impl<'a: 'b, 'b, T> Input for &'b mut TrackedMut<'a, T> +where + T: Track + ?Sized, +{ + // Forward constraint from `Trackable` implementation. + type Constraint = T::Constraint; + type Tracked<'r> = TrackedMut<'r, T> where Self: 'r; + type Outer = Option<&'a Self::Constraint>; + + #[inline] + fn key(&self, _: &mut H) {} + + #[inline] + fn validate(&self, constraint: &Self::Constraint) -> bool { + self.value.validate(constraint) + } + + #[inline] + fn replay(&mut self, constraint: &Self::Constraint) { + self.value.replay(constraint); + } + + #[inline] + fn retrack<'r>( + self, + constraint: &'r Self::Constraint, + ) -> (Self::Tracked<'r>, Self::Outer) + where + Self: 'r, + { + let tracked = TrackedMut { value: self.value, constraint: Some(constraint) }; + (tracked, self.constraint) + } +} + /// Wrapper for multiple inputs. pub struct Args(pub T); diff --git a/src/lib.rs b/src/lib.rs index a1dc4cf..c31a4bf 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -98,7 +98,8 @@ pub mod internal { pub use parking_lot::RwLock; pub use crate::cache::{ - hash, memoized, register_cache, Cache, Constraint, ImmutableConstraint, + hash, memoized, register_cache, Accelerator, Cache, Constraint, + ImmutableConstraint, }; pub use crate::input::{assert_hashable_or_trackable, Args, Input}; diff --git a/src/track.rs b/src/track.rs index e707057..7df4a0c 100644 --- a/src/track.rs +++ b/src/track.rs @@ -1,7 +1,8 @@ use std::fmt::{self, Debug, Formatter}; use std::ops::{Deref, DerefMut}; +use std::sync::Arc; -use crate::cache::{id, Join}; +use crate::cache::{Accelerator, Join}; /// A trackable type. /// @@ -12,7 +13,11 @@ pub trait Track: Validate + Surfaces { /// Start tracking all accesses to a value. #[inline] fn track(&self) -> Tracked { - Tracked { value: self, constraint: None, id: id() } + Tracked { + value: self, + constraint: None, + accelerator: Arc::new(Accelerator::default()), + } } /// Start tracking all accesses and mutations to a value. @@ -27,7 +32,7 @@ pub trait Track: Validate + Surfaces { Tracked { value: self, constraint: Some(constraint), - id: id(), + accelerator: Arc::new(Accelerator::default()), } } @@ -66,7 +71,11 @@ pub trait Validate { /// equal constraints against the same value. If given the same `id` twice, /// `self` must also be identical, unless [`evict`](crate::evict) has been /// called in between. - fn validate_with_id(&self, constraint: &Self::Constraint, id: usize) -> bool; + fn validate_with_accelerator( + &self, + constraint: &Self::Constraint, + accelerator: &Accelerator, + ) -> bool; /// Replay recorded mutations to the value. fn replay(&mut self, constraint: &Self::Constraint); @@ -151,8 +160,8 @@ where /// Starts out as `None` and is set to a stack-stored constraint in the /// preamble of memoized functions. pub(crate) constraint: Option<&'a C>, - /// A unique ID for validation acceleration. - pub(crate) id: usize, + /// A reference to the local accelerator. + pub(crate) accelerator: Arc, } // The type `Tracked` automatically dereferences to T's generated surface @@ -180,15 +189,17 @@ where } } -impl<'a, T> Copy for Tracked<'a, T> where T: Track + ?Sized {} - impl<'a, T> Clone for Tracked<'a, T> where T: Track + ?Sized, { #[inline] fn clone(&self) -> Self { - *self + Self { + value: self.value, + constraint: self.constraint, + accelerator: Arc::clone(&self.accelerator), + } } } @@ -227,7 +238,7 @@ where Tracked { value: this.value, constraint: this.constraint, - id: id(), + accelerator: Arc::new(Accelerator::default()), } } @@ -240,7 +251,7 @@ where Tracked { value: this.value, constraint: this.constraint, - id: id(), + accelerator: Arc::new(Accelerator::default()), } } @@ -288,7 +299,7 @@ where /// Destructure a `Tracked<_>` into its parts. #[inline] -pub fn to_parts_ref(tracked: Tracked) -> (&T, Option<&T::Constraint>) +pub fn to_parts_ref<'a, T>(tracked: &Tracked<'a, T>) -> (&'a T, Option<&'a T::Constraint>) where T: Track + ?Sized, { diff --git a/tests/tests.rs b/tests/tests.rs index 57e5b1d..ffbd87c 100644 --- a/tests/tests.rs +++ b/tests/tests.rs @@ -76,12 +76,12 @@ fn test_basic() { #[serial] fn test_calc() { #[memoize] - fn evaluate(script: &str, files: Tracked) -> i32 { + fn evaluate(script: &str, files: &Tracked) -> i32 { script .split('+') .map(str::trim) .map(|part| match part.strip_prefix("eval ") { - Some(path) => evaluate(&files.read(path), files), + Some(path) => evaluate(&files.read(path), &files), None => part.parse::().unwrap(), }) .sum() @@ -91,15 +91,15 @@ fn test_calc() { files.write("alpha.calc", "2 + eval beta.calc"); files.write("beta.calc", "2 + 3"); files.write("gamma.calc", "8 + 3"); - test!(miss: evaluate("eval alpha.calc", files.track()), 7); - test!(miss: evaluate("eval beta.calc", files.track()), 5); + test!(miss: evaluate("eval alpha.calc", &files.track()), 7); + test!(miss: evaluate("eval beta.calc", &files.track()), 5); files.write("gamma.calc", "42"); - test!(hit: evaluate("eval alpha.calc", files.track()), 7); + test!(hit: evaluate("eval alpha.calc", &files.track()), 7); files.write("beta.calc", "4 + eval gamma.calc"); - test!(miss: evaluate("eval beta.calc", files.track()), 46); - test!(miss: evaluate("eval alpha.calc", files.track()), 48); + test!(miss: evaluate("eval beta.calc", &files.track()), 46); + test!(miss: evaluate("eval alpha.calc", &files.track()), 48); files.write("gamma.calc", "80"); - test!(miss: evaluate("eval alpha.calc", files.track()), 86); + test!(miss: evaluate("eval alpha.calc", &files.track()), 86); } struct Files(HashMap); @@ -222,15 +222,15 @@ fn test_kinds() { let mut tester = Tester { data: "Hi".to_string() }; let tracky = tester.track(); - test!(miss: selfie(tracky), "Hi"); - test!(miss: unconditional(tracky), "Short"); - test!(hit: unconditional(tracky), "Short"); + test!(miss: selfie(tracky.clone()), "Hi"); + test!(miss: unconditional(tracky.clone()), "Short"); + test!(hit: unconditional(tracky.clone()), "Short"); test!(hit: selfie(tracky), "Hi"); tester.data.push('!'); let tracky = tester.track(); - test!(miss: selfie(tracky), "Hi!"); + test!(miss: selfie(tracky.clone()), "Hi!"); test!(miss: unconditional(tracky), "Short"); tester.data.push_str(" Let's go."); @@ -370,7 +370,8 @@ impl<'a> Chain<'a> { #[track] impl<'a> Chain<'a> { fn contains(&self, value: u32) -> bool { - self.value == value || self.outer.map_or(false, |outer| outer.contains(value)) + self.value == value + || self.outer.as_ref().map_or(false, |outer| outer.contains(value)) } } From 47ca59be1c600baa43843b26dc6554a12fa0c5fe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20d=27Herbais=20de=20Thun=20=28Dherse=29?= Date: Sat, 9 Dec 2023 11:44:46 +0100 Subject: [PATCH 13/28] Atomic age (faster) --- src/cache.rs | 27 ++++++++++++++------------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/src/cache.rs b/src/cache.rs index 2aaf26f..0877467 100644 --- a/src/cache.rs +++ b/src/cache.rs @@ -1,5 +1,6 @@ -use std::borrow::Cow; use std::hash::Hash; +use std::sync::atomic::Ordering; +use std::{borrow::Cow, sync::atomic::AtomicUsize}; use hashbrown::HashMap; use parking_lot::{Mutex, RwLock}; @@ -43,7 +44,7 @@ where }; // Check if there is a cached output. - let mut borrow = cache.write(); + let borrow = cache.read(); if let Some((constrained, value)) = borrow.lookup::(key, &input) { // Replay the mutations. input.replay(constrained); @@ -68,7 +69,7 @@ where outer.join(constraint); // Insert the result into the cache. - borrow = cache.write(); + let mut borrow = cache.write(); borrow.insert::(key, constraint.take(), output.clone()); #[cfg(feature = "last_was_hit")] LAST_WAS_HIT.with(|cell| cell.set(false)); @@ -116,22 +117,22 @@ impl Cache { /// Evict all entries whose age is larger than or equal to `max_age`. pub fn evict(&mut self, max_age: usize) { self.entries.retain(|_, entries| { - entries.retain_mut(|entry| { - entry.age += 1; - entry.age <= max_age + entries.retain(|entry| { + let age = entry.age.fetch_add(1, Ordering::Acquire); + (age + 1) <= max_age }); !entries.is_empty() }); } /// Look for a matching entry in the cache. - fn lookup(&mut self, key: u128, input: &In) -> Option<(&In::Constraint, &Out)> + fn lookup(&self, key: u128, input: &In) -> Option<(&In::Constraint, &Out)> where In: Input, { self.entries - .get_mut(&key)? - .iter_mut() + .get(&key)? + .iter() .rev() .find_map(|entry| entry.lookup::(input)) } @@ -155,7 +156,7 @@ struct CacheEntry { /// The memoized function's output. output: Out, /// How many evictions have passed since the entry has been last used. - age: usize, + age: AtomicUsize, } impl CacheEntry { @@ -164,16 +165,16 @@ impl CacheEntry { where In: Input, { - Self { constraint, output, age: 0 } + Self { constraint, output, age: AtomicUsize::new(0) } } /// Return the entry's output if it is valid for the given input. - fn lookup(&mut self, input: &In) -> Option<(&In::Constraint, &Out)> + fn lookup(&self, input: &In) -> Option<(&In::Constraint, &Out)> where In: Input, { input.validate(&self.constraint).then(|| { - self.age = 0; + self.age.store(0, Ordering::Release); (&self.constraint, &self.output) }) } From 67631d7f445ca80911cfe070aa83152f3b642e16 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20d=27Herbais=20de=20Thun=20=28Dherse=29?= Date: Sat, 9 Dec 2023 12:02:18 +0100 Subject: [PATCH 14/28] Made `Tracked` copy again Using a sharded list of accelerators. --- macros/src/track.rs | 6 ++-- src/cache.rs | 82 ++++++++++++++++++++++++++++++++++----------- src/input.rs | 8 ++--- src/track.rs | 33 ++++++------------ 4 files changed, 80 insertions(+), 49 deletions(-) diff --git a/macros/src/track.rs b/macros/src/track.rs index ad1adc7..cea697c 100644 --- a/macros/src/track.rs +++ b/macros/src/track.rs @@ -219,9 +219,9 @@ fn create( let validate_with_id = if !methods.is_empty() { quote! { let mut this = #maybe_cloned; - constraint.validate_with_accelerator( + constraint.validate_with_id( |call| match &call.0 { #(#validations,)* }, - accelerator, + id, ) } } else { @@ -262,7 +262,7 @@ fn create( } #[inline] - fn validate_with_accelerator(&self, constraint: &Self::Constraint, accelerator: &::comemo::internal::Accelerator) -> bool { + fn validate_with_id(&self, constraint: &Self::Constraint, id: usize) -> bool { #validate_with_id } diff --git a/src/cache.rs b/src/cache.rs index 0877467..53b6282 100644 --- a/src/cache.rs +++ b/src/cache.rs @@ -1,5 +1,6 @@ use std::hash::Hash; use std::sync::atomic::Ordering; +use std::sync::Arc; use std::{borrow::Cow, sync::atomic::AtomicUsize}; use hashbrown::HashMap; @@ -8,16 +9,50 @@ use siphasher::sip128::{Hasher128, SipHasher13}; use crate::input::Input; -pub type Accelerator = Mutex>; +pub type Accelerator = Arc>>; /// The global list of caches. static CACHES: RwLock> = RwLock::new(Vec::new()); +/// The global list of currently alive accelerators. +static ACCELERATORS: RwLock<(usize, Vec)> = RwLock::new((0, Vec::new())); + +/// The current ID of the accelerator. +static ID: AtomicUsize = AtomicUsize::new(0); + /// Register a cache in the global list. pub fn register_cache(fun: fn(usize)) { CACHES.write().push(fun); } +/// Generate a new accelerator. +pub fn id() -> usize { + ID.fetch_add(1, Ordering::Release) +} + +/// Get an accelerator by ID. +fn accelerator(id: usize) -> Option { + // We always lock the accelerators, as we need to make sure that the + // accelerator is not removed while we are reading it. + let accelerators = ACCELERATORS.read(); + + // Force the ID to be loaded. + ID.load(Ordering::Acquire); + + // because + let offset = accelerators.0; + if id < offset { + return None; + } + + let i = id - offset; + if i >= accelerators.1.len() { + return None; + } + + Some(accelerators.1[i].clone()) +} + #[cfg(feature = "last_was_hit")] thread_local! { /// Whether the last call was a hit. @@ -94,6 +129,11 @@ pub fn last_was_hit() -> bool { /// cache. pub fn evict(max_age: usize) { CACHES.read().iter().for_each(|fun| fun(max_age)); + + // Evict all accelerators. + let mut accelerators = ACCELERATORS.write(); + accelerators.0 += accelerators.1.len(); + accelerators.1.clear(); } /// The global cache. @@ -270,19 +310,20 @@ impl Constraint { /// Whether the method satisfies as all input-output pairs. #[inline] - pub fn validate_with_accelerator( - &self, - mut f: F, - accelerator: &Accelerator, - ) -> bool + pub fn validate_with_id(&self, mut f: F, id: usize) -> bool where F: FnMut(&T) -> u128, { + let accelerator = accelerator(id); let inner = self.0.read(); - let mut map = accelerator.lock(); - inner.calls.iter().all(|entry| { - *map.entry(entry.both).or_insert_with(|| f(&entry.args)) == entry.ret - }) + if let Some(accelerator) = accelerator { + let mut map = accelerator.lock(); + inner.calls.iter().all(|entry| { + *map.entry(entry.both).or_insert_with(|| f(&entry.args)) == entry.ret + }) + } else { + inner.calls.iter().all(|entry| f(&entry.args) == entry.ret) + } } /// Replay all input-output pairs. @@ -358,19 +399,20 @@ impl ImmutableConstraint { /// Whether the method satisfies as all input-output pairs. #[inline] - pub fn validate_with_accelerator( - &self, - mut f: F, - accelerator: &Accelerator, - ) -> bool + pub fn validate_with_id(&self, mut f: F, id: usize) -> bool where F: FnMut(&T) -> u128, { - let calls = self.0.read(); - let mut map = accelerator.lock(); - calls.values().all(|entry| { - *map.entry(entry.both).or_insert_with(|| f(&entry.args)) == entry.ret - }) + let accelerator = accelerator(id); + let inner = self.0.read(); + if let Some(accelerator) = accelerator { + let mut map = accelerator.lock(); + inner.values().all(|entry| { + *map.entry(entry.both).or_insert_with(|| f(&entry.args)) == entry.ret + }) + } else { + inner.values().all(|entry| f(&entry.args) == entry.ret) + } } /// Replay all input-output pairs. diff --git a/src/input.rs b/src/input.rs index 8c3f8c5..7b52ef4 100644 --- a/src/input.rs +++ b/src/input.rs @@ -84,7 +84,7 @@ where #[inline] fn validate(&self, constraint: &Self::Constraint) -> bool { - self.value.validate_with_accelerator(constraint, &self.accelerator) + self.value.validate_with_id(constraint, self.id) } #[inline] @@ -101,7 +101,7 @@ where let tracked = Tracked { value: self.value, constraint: Some(constraint), - accelerator: self.accelerator, + id: self.id, }; (tracked, self.constraint) } @@ -121,7 +121,7 @@ where #[inline] fn validate(&self, constraint: &Self::Constraint) -> bool { - self.value.validate_with_accelerator(constraint, &self.accelerator) + self.value.validate_with_id(constraint, self.id) } #[inline] @@ -138,7 +138,7 @@ where let tracked = Tracked { value: self.value, constraint: Some(constraint), - accelerator: self.accelerator.clone(), + id: self.id, }; (tracked, self.constraint) } diff --git a/src/track.rs b/src/track.rs index 7df4a0c..9703703 100644 --- a/src/track.rs +++ b/src/track.rs @@ -1,8 +1,7 @@ use std::fmt::{self, Debug, Formatter}; use std::ops::{Deref, DerefMut}; -use std::sync::Arc; -use crate::cache::{Accelerator, Join}; +use crate::cache::{id, Join}; /// A trackable type. /// @@ -13,11 +12,7 @@ pub trait Track: Validate + Surfaces { /// Start tracking all accesses to a value. #[inline] fn track(&self) -> Tracked { - Tracked { - value: self, - constraint: None, - accelerator: Arc::new(Accelerator::default()), - } + Tracked { value: self, constraint: None, id: id() } } /// Start tracking all accesses and mutations to a value. @@ -32,7 +27,7 @@ pub trait Track: Validate + Surfaces { Tracked { value: self, constraint: Some(constraint), - accelerator: Arc::new(Accelerator::default()), + id: id(), } } @@ -71,11 +66,7 @@ pub trait Validate { /// equal constraints against the same value. If given the same `id` twice, /// `self` must also be identical, unless [`evict`](crate::evict) has been /// called in between. - fn validate_with_accelerator( - &self, - constraint: &Self::Constraint, - accelerator: &Accelerator, - ) -> bool; + fn validate_with_id(&self, constraint: &Self::Constraint, id: usize) -> bool; /// Replay recorded mutations to the value. fn replay(&mut self, constraint: &Self::Constraint); @@ -160,8 +151,8 @@ where /// Starts out as `None` and is set to a stack-stored constraint in the /// preamble of memoized functions. pub(crate) constraint: Option<&'a C>, - /// A reference to the local accelerator. - pub(crate) accelerator: Arc, + /// The ID of the tracked value. + pub(crate) id: usize, } // The type `Tracked` automatically dereferences to T's generated surface @@ -195,14 +186,12 @@ where { #[inline] fn clone(&self) -> Self { - Self { - value: self.value, - constraint: self.constraint, - accelerator: Arc::clone(&self.accelerator), - } + *self } } +impl<'a, T> Copy for Tracked<'a, T> where T: Track + ?Sized {} + /// Tracks accesses and mutations to a value. /// /// Encapsulates a mutable reference to a value and tracks all accesses to it. @@ -238,7 +227,7 @@ where Tracked { value: this.value, constraint: this.constraint, - accelerator: Arc::new(Accelerator::default()), + id: id(), } } @@ -251,7 +240,7 @@ where Tracked { value: this.value, constraint: this.constraint, - accelerator: Arc::new(Accelerator::default()), + id: id(), } } From 0d38bc0c181795e1f8b823e2880035f535678a49 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20d=27Herbais=20de=20Thun=20=28Dherse=29?= Date: Sat, 9 Dec 2023 12:48:18 +0100 Subject: [PATCH 15/28] Actually allocate the accelerators --- src/cache.rs | 72 +++++++++++++++++++++++++++++++++++++++++----------- 1 file changed, 57 insertions(+), 15 deletions(-) diff --git a/src/cache.rs b/src/cache.rs index 53b6282..c36e26d 100644 --- a/src/cache.rs +++ b/src/cache.rs @@ -1,21 +1,25 @@ use std::hash::Hash; use std::sync::atomic::Ordering; -use std::sync::Arc; use std::{borrow::Cow, sync::atomic::AtomicUsize}; use hashbrown::HashMap; -use parking_lot::{Mutex, RwLock}; +use parking_lot::{ + MappedRwLockReadGuard, Mutex, RwLock, RwLockReadGuard, +}; use siphasher::sip128::{Hasher128, SipHasher13}; use crate::input::Input; -pub type Accelerator = Arc>>; +pub type Accelerator = Mutex>; /// The global list of caches. static CACHES: RwLock> = RwLock::new(Vec::new()); /// The global list of currently alive accelerators. -static ACCELERATORS: RwLock<(usize, Vec)> = RwLock::new((0, Vec::new())); +static ACCELERATORS: RwLock> = RwLock::new(Vec::new()); + +static OFFSET: AtomicUsize = AtomicUsize::new(0); +static CAPACITY: AtomicUsize = AtomicUsize::new(0); /// The current ID of the accelerator. static ID: AtomicUsize = AtomicUsize::new(0); @@ -25,32 +29,61 @@ pub fn register_cache(fun: fn(usize)) { CACHES.write().push(fun); } +fn offset() -> usize { + OFFSET.load(Ordering::Acquire) +} + /// Generate a new accelerator. pub fn id() -> usize { - ID.fetch_add(1, Ordering::Release) + #[cold] + fn allocate_accelerator(min_len: usize) { + // Allocate a new accelerator. + let mut list = ACCELERATORS.write(); + + let len = (ID.load(Ordering::Acquire) - offset()).max(min_len); + + // If it was grown by another thread, we can just return. + if list.len() >= len { + return; + } + + list.resize_with(len, || Mutex::new(HashMap::new())); + CAPACITY.store(len, Ordering::SeqCst); + } + + // Get the next ID. + let id = ID.fetch_add(1, Ordering::SeqCst); + + // Make sure that the accelerator list is long enough. + if CAPACITY.load(Ordering::SeqCst) <= id - offset() { + allocate_accelerator(id - offset() + 1); + } + + id } /// Get an accelerator by ID. -fn accelerator(id: usize) -> Option { +fn accelerator(id: usize) -> Option> { + // We always lock the accelerators, as we need to make sure that the // accelerator is not removed while we are reading it. let accelerators = ACCELERATORS.read(); - // Force the ID to be loaded. - ID.load(Ordering::Acquire); - // because - let offset = accelerators.0; + let offset = offset(); if id < offset { return None; } - let i = id - offset; - if i >= accelerators.1.len() { + let i: usize = id - offset; + if i >= accelerators.len() { return None; } - Some(accelerators.1[i].clone()) + Some(RwLockReadGuard::map( + accelerators, + move |accelerators| &accelerators[i], + )) } #[cfg(feature = "last_was_hit")] @@ -132,8 +165,17 @@ pub fn evict(max_age: usize) { // Evict all accelerators. let mut accelerators = ACCELERATORS.write(); - accelerators.0 += accelerators.1.len(); - accelerators.1.clear(); + + // Force the ID to be loaded. + let id = ID.load(Ordering::SeqCst); + + // Update the offset. + OFFSET.store(id, Ordering::SeqCst); + + // Clear all accelerators while keeping the memory allocated. + accelerators.iter_mut().for_each(|accelerator| { + accelerator.lock().clear(); + }) } /// The global cache. From 941861c0b08e3b732bf7aaaaf3f2f61f460a7613 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20d=27Herbais=20de=20Thun=20=28Dherse=29?= Date: Sat, 9 Dec 2023 12:50:30 +0100 Subject: [PATCH 16/28] cargo fmt & small opt --- src/cache.rs | 31 ++++++++++++------------------- 1 file changed, 12 insertions(+), 19 deletions(-) diff --git a/src/cache.rs b/src/cache.rs index c36e26d..e241237 100644 --- a/src/cache.rs +++ b/src/cache.rs @@ -3,9 +3,7 @@ use std::sync::atomic::Ordering; use std::{borrow::Cow, sync::atomic::AtomicUsize}; use hashbrown::HashMap; -use parking_lot::{ - MappedRwLockReadGuard, Mutex, RwLock, RwLockReadGuard, -}; +use parking_lot::{MappedRwLockReadGuard, Mutex, RwLock, RwLockReadGuard}; use siphasher::sip128::{Hasher128, SipHasher13}; use crate::input::Input; @@ -48,14 +46,14 @@ pub fn id() -> usize { } list.resize_with(len, || Mutex::new(HashMap::new())); - CAPACITY.store(len, Ordering::SeqCst); + CAPACITY.store(len, Ordering::Release); } // Get the next ID. - let id = ID.fetch_add(1, Ordering::SeqCst); + let id = ID.fetch_add(1, Ordering::AcqRel); // Make sure that the accelerator list is long enough. - if CAPACITY.load(Ordering::SeqCst) <= id - offset() { + if CAPACITY.load(Ordering::Acquire) <= id - offset() { allocate_accelerator(id - offset() + 1); } @@ -64,26 +62,21 @@ pub fn id() -> usize { /// Get an accelerator by ID. fn accelerator(id: usize) -> Option> { - - // We always lock the accelerators, as we need to make sure that the - // accelerator is not removed while we are reading it. - let accelerators = ACCELERATORS.read(); - - // because let offset = offset(); if id < offset { return None; } - let i: usize = id - offset; + // We always lock the accelerators, as we need to make sure that the + // accelerator is not removed while we are reading it. + let accelerators = ACCELERATORS.read(); + + let i = id - offset; if i >= accelerators.len() { return None; } - Some(RwLockReadGuard::map( - accelerators, - move |accelerators| &accelerators[i], - )) + Some(RwLockReadGuard::map(accelerators, move |accelerators| &accelerators[i])) } #[cfg(feature = "last_was_hit")] @@ -167,10 +160,10 @@ pub fn evict(max_age: usize) { let mut accelerators = ACCELERATORS.write(); // Force the ID to be loaded. - let id = ID.load(Ordering::SeqCst); + let id = ID.load(Ordering::Acquire); // Update the offset. - OFFSET.store(id, Ordering::SeqCst); + OFFSET.store(id, Ordering::Release); // Clear all accelerators while keeping the memory allocated. accelerators.iter_mut().for_each(|accelerator| { From 66ea8ddbc29ca1ae5a26f7215b187a6b77a4c480 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20d=27Herbais=20de=20Thun=20=28Dherse=29?= Date: Sat, 9 Dec 2023 13:44:56 +0100 Subject: [PATCH 17/28] Cleanup & attempt at opt --- examples/calc.rs | 12 ++++---- src/cache.rs | 44 ++++++++++++++--------------- src/input.rs | 72 ------------------------------------------------ tests/tests.rs | 16 +++++------ 4 files changed, 36 insertions(+), 108 deletions(-) diff --git a/examples/calc.rs b/examples/calc.rs index 76c464c..2ec6c4c 100644 --- a/examples/calc.rs +++ b/examples/calc.rs @@ -16,33 +16,33 @@ fn main() { files.write("gamma.calc", "8 + 3"); // [Miss] The cache is empty. - assert_eq!(evaluate("eval alpha.calc", &files.track()), 7); + assert_eq!(evaluate("eval alpha.calc", files.track()), 7); // [Miss] This is not a top-level hit because this exact string was never // passed to `evaluate`, but this does not compute "2 + 3" again. - assert_eq!(evaluate("eval beta.calc", &files.track()), 5); + assert_eq!(evaluate("eval beta.calc", files.track()), 5); // Modify the gamma file. files.write("gamma.calc", "42"); // [Hit] This is a hit because `gamma.calc` isn't referenced by `alpha.calc`. - assert_eq!(evaluate("eval alpha.calc", &files.track()), 7); + assert_eq!(evaluate("eval alpha.calc", files.track()), 7); // Modify the beta file. files.write("beta.calc", "4 + eval gamma.calc"); // [Miss] This is a miss because `beta.calc` changed. - assert_eq!(evaluate("eval alpha.calc", &files.track()), 48); + assert_eq!(evaluate("eval alpha.calc", files.track()), 48); } /// Evaluate a `.calc` script. #[memoize] -fn evaluate(script: &str, files: &Tracked) -> i32 { +fn evaluate(script: &str, files: Tracked) -> i32 { script .split('+') .map(str::trim) .map(|part| match part.strip_prefix("eval ") { - Some(path) => evaluate(&files.read(path), &files), + Some(path) => evaluate(&files.read(path), files), None => part.parse::().unwrap(), }) .sum() diff --git a/src/cache.rs b/src/cache.rs index e241237..d3bc1d6 100644 --- a/src/cache.rs +++ b/src/cache.rs @@ -3,7 +3,9 @@ use std::sync::atomic::Ordering; use std::{borrow::Cow, sync::atomic::AtomicUsize}; use hashbrown::HashMap; -use parking_lot::{MappedRwLockReadGuard, Mutex, RwLock, RwLockReadGuard}; +use parking_lot::{ + MappedRwLockReadGuard, Mutex, RwLock, RwLockReadGuard, +}; use siphasher::sip128::{Hasher128, SipHasher13}; use crate::input::Input; @@ -15,12 +17,11 @@ static CACHES: RwLock> = RwLock::new(Vec::new()); /// The global list of currently alive accelerators. static ACCELERATORS: RwLock> = RwLock::new(Vec::new()); - -static OFFSET: AtomicUsize = AtomicUsize::new(0); static CAPACITY: AtomicUsize = AtomicUsize::new(0); /// The current ID of the accelerator. static ID: AtomicUsize = AtomicUsize::new(0); +static OFFSET: AtomicUsize = AtomicUsize::new(0); /// Register a cache in the global list. pub fn register_cache(fun: fn(usize)) { @@ -34,27 +35,27 @@ fn offset() -> usize { /// Generate a new accelerator. pub fn id() -> usize { #[cold] - fn allocate_accelerator(min_len: usize) { - // Allocate a new accelerator. - let mut list = ACCELERATORS.write(); - - let len = (ID.load(Ordering::Acquire) - offset()).max(min_len); - + fn allocate_accelerator( + len: usize, + ) { + let mut accelerators = ACCELERATORS.write(); // If it was grown by another thread, we can just return. - if list.len() >= len { + if accelerators.len() >= len { return; } - list.resize_with(len, || Mutex::new(HashMap::new())); - CAPACITY.store(len, Ordering::Release); + // We allocate exponentially to avoid too many reallocations. + accelerators.resize_with(len * 2, || Mutex::new(HashMap::new())); + CAPACITY.store(len * 2, Ordering::Release); } // Get the next ID. let id = ID.fetch_add(1, Ordering::AcqRel); // Make sure that the accelerator list is long enough. - if CAPACITY.load(Ordering::Acquire) <= id - offset() { - allocate_accelerator(id - offset() + 1); + let i = id - offset(); + if CAPACITY.load(Ordering::Acquire) <= i { + allocate_accelerator(i + 1); } id @@ -62,21 +63,20 @@ pub fn id() -> usize { /// Get an accelerator by ID. fn accelerator(id: usize) -> Option> { - let offset = offset(); - if id < offset { - return None; - } - // We always lock the accelerators, as we need to make sure that the // accelerator is not removed while we are reading it. let accelerators = ACCELERATORS.read(); - let i = id - offset; - if i >= accelerators.len() { + let offset = offset(); + if id < offset { return None; } - Some(RwLockReadGuard::map(accelerators, move |accelerators| &accelerators[i])) + let i = id - offset; + Some(RwLockReadGuard::map( + accelerators, + move |accelerators| &accelerators[i], + )) } #[cfg(feature = "last_was_hit")] diff --git a/src/input.rs b/src/input.rs index 7b52ef4..d4185ea 100644 --- a/src/input.rs +++ b/src/input.rs @@ -107,43 +107,6 @@ where } } -impl<'a: 'b, 'b, T> Input for &'b Tracked<'a, T> -where - T: Track + ?Sized, -{ - // Forward constraint from `Trackable` implementation. - type Constraint = ::Constraint; - type Tracked<'r> = Tracked<'r, T> where Self: 'r; - type Outer = Option<&'a Self::Constraint>; - - #[inline] - fn key(&self, _: &mut H) {} - - #[inline] - fn validate(&self, constraint: &Self::Constraint) -> bool { - self.value.validate_with_id(constraint, self.id) - } - - #[inline] - fn replay(&mut self, _: &Self::Constraint) {} - - #[inline] - fn retrack<'r>( - self, - constraint: &'r Self::Constraint, - ) -> (Self::Tracked<'r>, Self::Outer) - where - Self: 'r, - { - let tracked = Tracked { - value: self.value, - constraint: Some(constraint), - id: self.id, - }; - (tracked, self.constraint) - } -} - impl<'a, T> Input for TrackedMut<'a, T> where T: Track + ?Sized, @@ -179,41 +142,6 @@ where } } -impl<'a: 'b, 'b, T> Input for &'b mut TrackedMut<'a, T> -where - T: Track + ?Sized, -{ - // Forward constraint from `Trackable` implementation. - type Constraint = T::Constraint; - type Tracked<'r> = TrackedMut<'r, T> where Self: 'r; - type Outer = Option<&'a Self::Constraint>; - - #[inline] - fn key(&self, _: &mut H) {} - - #[inline] - fn validate(&self, constraint: &Self::Constraint) -> bool { - self.value.validate(constraint) - } - - #[inline] - fn replay(&mut self, constraint: &Self::Constraint) { - self.value.replay(constraint); - } - - #[inline] - fn retrack<'r>( - self, - constraint: &'r Self::Constraint, - ) -> (Self::Tracked<'r>, Self::Outer) - where - Self: 'r, - { - let tracked = TrackedMut { value: self.value, constraint: Some(constraint) }; - (tracked, self.constraint) - } -} - /// Wrapper for multiple inputs. pub struct Args(pub T); diff --git a/tests/tests.rs b/tests/tests.rs index ffbd87c..d5072ae 100644 --- a/tests/tests.rs +++ b/tests/tests.rs @@ -76,12 +76,12 @@ fn test_basic() { #[serial] fn test_calc() { #[memoize] - fn evaluate(script: &str, files: &Tracked) -> i32 { + fn evaluate(script: &str, files: Tracked) -> i32 { script .split('+') .map(str::trim) .map(|part| match part.strip_prefix("eval ") { - Some(path) => evaluate(&files.read(path), &files), + Some(path) => evaluate(&files.read(path), files), None => part.parse::().unwrap(), }) .sum() @@ -91,15 +91,15 @@ fn test_calc() { files.write("alpha.calc", "2 + eval beta.calc"); files.write("beta.calc", "2 + 3"); files.write("gamma.calc", "8 + 3"); - test!(miss: evaluate("eval alpha.calc", &files.track()), 7); - test!(miss: evaluate("eval beta.calc", &files.track()), 5); + test!(miss: evaluate("eval alpha.calc", files.track()), 7); + test!(miss: evaluate("eval beta.calc", files.track()), 5); files.write("gamma.calc", "42"); - test!(hit: evaluate("eval alpha.calc", &files.track()), 7); + test!(hit: evaluate("eval alpha.calc", files.track()), 7); files.write("beta.calc", "4 + eval gamma.calc"); - test!(miss: evaluate("eval beta.calc", &files.track()), 46); - test!(miss: evaluate("eval alpha.calc", &files.track()), 48); + test!(miss: evaluate("eval beta.calc", files.track()), 46); + test!(miss: evaluate("eval alpha.calc", files.track()), 48); files.write("gamma.calc", "80"); - test!(miss: evaluate("eval alpha.calc", &files.track()), 86); + test!(miss: evaluate("eval alpha.calc", files.track()), 86); } struct Files(HashMap); From cc0d37a44b3fd8c90367f368ad83feb7a62b5a14 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20d=27Herbais=20de=20Thun=20=28Dherse=29?= Date: Sat, 9 Dec 2023 13:45:02 +0100 Subject: [PATCH 18/28] cargo fmt --- src/cache.rs | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/src/cache.rs b/src/cache.rs index d3bc1d6..17da3fc 100644 --- a/src/cache.rs +++ b/src/cache.rs @@ -3,9 +3,7 @@ use std::sync::atomic::Ordering; use std::{borrow::Cow, sync::atomic::AtomicUsize}; use hashbrown::HashMap; -use parking_lot::{ - MappedRwLockReadGuard, Mutex, RwLock, RwLockReadGuard, -}; +use parking_lot::{MappedRwLockReadGuard, Mutex, RwLock, RwLockReadGuard}; use siphasher::sip128::{Hasher128, SipHasher13}; use crate::input::Input; @@ -35,9 +33,7 @@ fn offset() -> usize { /// Generate a new accelerator. pub fn id() -> usize { #[cold] - fn allocate_accelerator( - len: usize, - ) { + fn allocate_accelerator(len: usize) { let mut accelerators = ACCELERATORS.write(); // If it was grown by another thread, we can just return. if accelerators.len() >= len { @@ -73,10 +69,7 @@ fn accelerator(id: usize) -> Option> } let i = id - offset; - Some(RwLockReadGuard::map( - accelerators, - move |accelerators| &accelerators[i], - )) + Some(RwLockReadGuard::map(accelerators, move |accelerators| &accelerators[i])) } #[cfg(feature = "last_was_hit")] From d189b5676fe49805eff610e3977a83825baf455c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20d=27Herbais=20de=20Thun=20=28Dherse=29?= Date: Sat, 9 Dec 2023 13:46:10 +0100 Subject: [PATCH 19/28] revert small misc changes --- src/track.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/track.rs b/src/track.rs index 9703703..0fac13c 100644 --- a/src/track.rs +++ b/src/track.rs @@ -151,7 +151,7 @@ where /// Starts out as `None` and is set to a stack-stored constraint in the /// preamble of memoized functions. pub(crate) constraint: Option<&'a C>, - /// The ID of the tracked value. + /// A unique ID for validation acceleration. pub(crate) id: usize, } @@ -180,6 +180,8 @@ where } } +impl<'a, T> Copy for Tracked<'a, T> where T: Track + ?Sized {} + impl<'a, T> Clone for Tracked<'a, T> where T: Track + ?Sized, @@ -190,8 +192,6 @@ where } } -impl<'a, T> Copy for Tracked<'a, T> where T: Track + ?Sized {} - /// Tracks accesses and mutations to a value. /// /// Encapsulates a mutable reference to a value and tracks all accesses to it. From eaedff969f34258b55da2ad72ff2d23bfa60aa5a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20d=27Herbais=20de=20Thun=20=28Dherse=29?= Date: Sat, 9 Dec 2023 14:17:03 +0100 Subject: [PATCH 20/28] Removed exponential growth --- src/cache.rs | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/cache.rs b/src/cache.rs index 17da3fc..43a6175 100644 --- a/src/cache.rs +++ b/src/cache.rs @@ -19,6 +19,7 @@ static CAPACITY: AtomicUsize = AtomicUsize::new(0); /// The current ID of the accelerator. static ID: AtomicUsize = AtomicUsize::new(0); +/// The current offset of the accelerator IDs. static OFFSET: AtomicUsize = AtomicUsize::new(0); /// Register a cache in the global list. @@ -26,11 +27,15 @@ pub fn register_cache(fun: fn(usize)) { CACHES.write().push(fun); } +/// Get the current offset. fn offset() -> usize { OFFSET.load(Ordering::Acquire) } /// Generate a new accelerator. +/// +/// Will allocate a new accelerator if the ID is larger than the current +/// capacity. pub fn id() -> usize { #[cold] fn allocate_accelerator(len: usize) { @@ -41,8 +46,8 @@ pub fn id() -> usize { } // We allocate exponentially to avoid too many reallocations. - accelerators.resize_with(len * 2, || Mutex::new(HashMap::new())); - CAPACITY.store(len * 2, Ordering::Release); + accelerators.resize_with(len, || Mutex::new(HashMap::new())); + CAPACITY.store(len, Ordering::Release); } // Get the next ID. From 976ebef6c5aa45399ffca1387959140e57c29442 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20d=27Herbais=20de=20Thun=20=28Dherse=29?= Date: Sat, 9 Dec 2023 15:48:37 +0100 Subject: [PATCH 21/28] More opts --- src/cache.rs | 56 +++++++++++++++++++++------------------------------- 1 file changed, 22 insertions(+), 34 deletions(-) diff --git a/src/cache.rs b/src/cache.rs index 43a6175..c8cfba7 100644 --- a/src/cache.rs +++ b/src/cache.rs @@ -14,67 +14,55 @@ pub type Accelerator = Mutex>; static CACHES: RwLock> = RwLock::new(Vec::new()); /// The global list of currently alive accelerators. -static ACCELERATORS: RwLock> = RwLock::new(Vec::new()); -static CAPACITY: AtomicUsize = AtomicUsize::new(0); +static ACCELERATORS: RwLock<(usize, Vec)> = RwLock::new((0, Vec::new())); /// The current ID of the accelerator. static ID: AtomicUsize = AtomicUsize::new(0); -/// The current offset of the accelerator IDs. -static OFFSET: AtomicUsize = AtomicUsize::new(0); /// Register a cache in the global list. pub fn register_cache(fun: fn(usize)) { CACHES.write().push(fun); } -/// Get the current offset. -fn offset() -> usize { - OFFSET.load(Ordering::Acquire) -} - /// Generate a new accelerator. /// /// Will allocate a new accelerator if the ID is larger than the current /// capacity. pub fn id() -> usize { + // Get the next ID. + ID.fetch_add(1, Ordering::AcqRel) +} + +/// Get an accelerator by ID. +fn accelerator(id: usize) -> Option> { #[cold] - fn allocate_accelerator(len: usize) { + fn resize_accelerators(len: usize) { let mut accelerators = ACCELERATORS.write(); - // If it was grown by another thread, we can just return. - if accelerators.len() >= len { + + if len <= accelerators.1.len() { return; } - // We allocate exponentially to avoid too many reallocations. - accelerators.resize_with(len, || Mutex::new(HashMap::new())); - CAPACITY.store(len, Ordering::Release); - } - - // Get the next ID. - let id = ID.fetch_add(1, Ordering::AcqRel); - - // Make sure that the accelerator list is long enough. - let i = id - offset(); - if CAPACITY.load(Ordering::Acquire) <= i { - allocate_accelerator(i + 1); + accelerators.1.resize_with(len, || Mutex::new(HashMap::new())); } - id -} - -/// Get an accelerator by ID. -fn accelerator(id: usize) -> Option> { // We always lock the accelerators, as we need to make sure that the // accelerator is not removed while we are reading it. - let accelerators = ACCELERATORS.read(); + let mut accelerators = ACCELERATORS.read(); - let offset = offset(); + let offset = accelerators.0; if id < offset { return None; } + if id - offset >= accelerators.1.len() { + drop(accelerators); + resize_accelerators(id - offset + 1); + accelerators = ACCELERATORS.read(); + } + let i = id - offset; - Some(RwLockReadGuard::map(accelerators, move |accelerators| &accelerators[i])) + Some(RwLockReadGuard::map(accelerators, move |accelerators| &accelerators.1[i])) } #[cfg(feature = "last_was_hit")] @@ -161,10 +149,10 @@ pub fn evict(max_age: usize) { let id = ID.load(Ordering::Acquire); // Update the offset. - OFFSET.store(id, Ordering::Release); + accelerators.0 = id; // Clear all accelerators while keeping the memory allocated. - accelerators.iter_mut().for_each(|accelerator| { + accelerators.1.iter_mut().for_each(|accelerator| { accelerator.lock().clear(); }) } From e0a83bad06b4241296dc8da0e5c10468a9a91d48 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20d=27Herbais=20de=20Thun=20=28Dherse=29?= Date: Sat, 9 Dec 2023 16:55:49 +0100 Subject: [PATCH 22/28] Small comment cleanup --- src/cache.rs | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/cache.rs b/src/cache.rs index c8cfba7..32712cf 100644 --- a/src/cache.rs +++ b/src/cache.rs @@ -25,9 +25,7 @@ pub fn register_cache(fun: fn(usize)) { } /// Generate a new accelerator. -/// -/// Will allocate a new accelerator if the ID is larger than the current -/// capacity. +/// Will allocate a new accelerator if the ID is larger than the current capacity. pub fn id() -> usize { // Get the next ID. ID.fetch_add(1, Ordering::AcqRel) From 162030ae4f12d441e7030900568f001a6b10b66d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20d=27Herbais=20de=20Thun=20=28Dherse=29?= Date: Thu, 14 Dec 2023 15:56:53 +0100 Subject: [PATCH 23/28] Fix code review items --- .gitignore | 1 + Cargo.toml | 3 +- macros/src/memoize.rs | 27 +- macros/src/track.rs | 115 +++++-- src/cache.rs | 305 +++++++++-------- src/lib.rs | 5 +- src/track.rs | 2 +- tests/tests.rs | 743 +++++++++++++++++++++--------------------- 8 files changed, 632 insertions(+), 569 deletions(-) diff --git a/.gitignore b/.gitignore index 360ab70..dbb723c 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ .vscode .DS_Store /target +macros/target Cargo.lock diff --git a/Cargo.toml b/Cargo.toml index a2b596f..db2cb4b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,12 +11,11 @@ categories = ["caching"] keywords = ["incremental", "memoization", "tracked", "constraints"] [features] -default = [ "last_was_hit" ] +default = [ ] last_was_hit = [] [dependencies] comemo-macros = { version = "0.3.1", path = "macros" } -hashbrown = "0.14.3" once_cell = "1.18.0" parking_lot = "0.12.1" siphasher = "1" diff --git a/macros/src/memoize.rs b/macros/src/memoize.rs index a517d8f..a26768c 100644 --- a/macros/src/memoize.rs +++ b/macros/src/memoize.rs @@ -7,7 +7,7 @@ pub fn expand(item: &syn::Item) -> Result { }; // Preprocess and validate the function. - let function = prepare(&item)?; + let function = prepare(item)?; // Rewrite the function's body to memoize it. process(&function) @@ -71,7 +71,7 @@ fn prepare_arg(input: &syn::FnArg) -> Result { bail!(typed.ty, "memoized functions cannot have mutable parameters") } - Argument::Ident(typed.ty.clone(), mutability.clone(), ident.clone()) + Argument::Ident(typed.ty.clone(), *mutability, ident.clone()) } }) } @@ -125,20 +125,13 @@ fn process(function: &Function) -> Result { } wrapped.block = parse_quote! { { - static __CACHE: ::comemo::internal::Lazy< - ::comemo::internal::RwLock< - ::comemo::internal::Cache< - <::comemo::internal::Args<#arg_ty_tuple> as ::comemo::internal::Input>::Constraint, - #output, - > - > - > = - ::comemo::internal::Lazy::new( - || { - ::comemo::internal::register_cache(evict); - ::comemo::internal::RwLock::new(::comemo::internal::Cache::new()) - } - ); + static __CACHE: ::comemo::internal::Cache< + <::comemo::internal::Args<#arg_ty_tuple> as ::comemo::internal::Input>::Constraint, + #output, + > = ::comemo::internal::Cache::new(|| { + ::comemo::internal::register_cache(evict); + ::comemo::internal::RwLock::new(::comemo::internal::CacheData::new()) + }); fn evict(max_age: usize) { __CACHE.write().evict(max_age); @@ -148,7 +141,7 @@ fn process(function: &Function) -> Result { ::comemo::internal::memoized( ::comemo::internal::Args(#arg_tuple), &::core::default::Default::default(), - &*__CACHE, + &__CACHE, #closure, ) } }; diff --git a/macros/src/track.rs b/macros/src/track.rs index cea697c..a78e53e 100644 --- a/macros/src/track.rs +++ b/macros/src/track.rs @@ -5,7 +5,7 @@ pub fn expand(item: &syn::Item) -> Result { // Preprocess and validate the methods. let mut methods = vec![]; - let (ty, generics, trait_) = match item { + let (ty, generics, trait_, prefix) = match item { syn::Item::Impl(item) => { for param in item.generics.params.iter() { match param { @@ -20,34 +20,50 @@ pub fn expand(item: &syn::Item) -> Result { } for item in &item.items { - methods.push(prepare_impl_method(&item)?); + methods.push(prepare_impl_method(item)?); } let ty = item.self_ty.as_ref().clone(); - (ty, &item.generics, None) + (ty, &item.generics, None, None) } syn::Item::Trait(item) => { - for param in item.generics.params.iter() { - bail!(param, "tracked traits cannot be generic") + if let Some(first) = item.generics.params.first() { + bail!(first, "tracked traits cannot be generic") } for item in &item.items { - methods.push(prepare_trait_method(&item)?); + methods.push(prepare_trait_method(item)?); } let name = &item.ident; + let ty_send_sync = parse_quote! { dyn #name + '__comemo_dynamic }; let ty = parse_quote! { dyn #name + Send + Sync + '__comemo_dynamic }; - (ty, &item.generics, Some(name.clone())) + + // Produce the necessary item for the non-Send + Sync version of the trait. + let prefix = create( + &ty, + Some(quote::format_ident!("__ComemoSurfaceUnsync")), + &item.generics, + Some(item.ident.clone()), + &methods, + )?; + + (ty_send_sync, &item.generics, Some(item.ident.clone()), Some(prefix)) } _ => bail!(item, "`track` can only be applied to impl blocks and traits"), }; // Produce the necessary items for the type to become trackable. - let scope = create(&ty, generics, trait_, &methods)?; + let variants = create_variants(&methods); + let scope = create(&ty, None, generics, trait_, &methods)?; Ok(quote! { #item - const _: () = { #scope }; + const _: () = { + #variants + #prefix + #scope + }; }) } @@ -175,9 +191,48 @@ fn prepare_method(vis: syn::Visibility, sig: &syn::Signature) -> Result }) } +/// Produces the variants for the constraint. +fn create_variants(methods: &[Method]) -> TokenStream { + let variants = methods.iter().map(create_variant); + let is_mutable_variants = methods.iter().map(|m| { + let name = &m.sig.ident; + let mutable = m.mutable; + quote! { __ComemoVariant::#name(..) => #mutable } + }); + + let is_mutable = (!methods.is_empty()) + .then(|| { + quote! { + match &self.0 { + #(#is_mutable_variants),* + } + } + }) + .unwrap_or_else(|| quote! { false }); + + quote! { + #[derive(Clone, PartialEq, Hash)] + pub struct __ComemoCall(__ComemoVariant); + + impl ::comemo::internal::Call for __ComemoCall { + fn is_mutable(&self) -> bool { + #is_mutable + } + } + + #[derive(Clone, PartialEq, Hash)] + #[allow(non_camel_case_types)] + enum __ComemoVariant { + #(#variants,)* + } + + } +} + /// Produce the necessary items for a type to become trackable. fn create( ty: &syn::Type, + surface: Option, generics: &syn::Generics, trait_: Option, methods: &[Method], @@ -238,22 +293,26 @@ fn create( }); // Prepare variants and wrapper methods. - let variants = methods.iter().map(create_variant); let wrapper_methods = methods .iter() .filter(|m| !m.mutable) .map(|m| create_wrapper(m, false)); let wrapper_methods_mut = methods.iter().map(|m| create_wrapper(m, true)); - let constraint = if methods.iter().all(|m| !m.mutable) { + let constraint = if immutable { quote! { ImmutableConstraint } } else { - quote! { Constraint } + quote! { MutableConstraint } }; + let surface_mut = surface + .clone() + .map(|s| quote::format_ident!("{s}Mut")) + .unwrap_or_else(|| parse_quote! { __ComemoSurfaceMut }); + let surface = surface.unwrap_or_else(|| parse_quote! { __ComemoSurface }); Ok(quote! { - impl #impl_params ::comemo::Track for #ty #where_clause {} + impl #impl_params ::comemo::Track for #ty #where_clause {} - impl #impl_params ::comemo::Validate for #ty #where_clause { + impl #impl_params ::comemo::Validate for #ty #where_clause { type Constraint = ::comemo::internal::#constraint<__ComemoCall>; #[inline] @@ -273,19 +332,10 @@ fn create( } } - #[derive(Clone, PartialEq, Hash)] - pub struct __ComemoCall(__ComemoVariant); - - #[derive(Clone, PartialEq, Hash)] - #[allow(non_camel_case_types)] - enum __ComemoVariant { - #(#variants,)* - } - #[doc(hidden)] impl #impl_params ::comemo::internal::Surfaces for #ty #where_clause { - type Surface<#t> = __ComemoSurface #type_params_t where Self: #t; - type SurfaceMut<#t> = __ComemoSurfaceMut #type_params_t where Self: #t; + type Surface<#t> = #surface #type_params_t where Self: #t; + type SurfaceMut<#t> = #surface_mut #type_params_t where Self: #t; #[inline] fn surface_ref<#t, #r>( @@ -313,23 +363,22 @@ fn create( } #[repr(transparent)] - pub struct __ComemoSurface #impl_params_t(::comemo::Tracked<#t, #ty>) + pub struct #surface #impl_params_t(::comemo::Tracked<#t, #ty>) #where_clause; #[allow(dead_code)] - impl #impl_params_t #prefix __ComemoSurface #type_params_t { + impl #impl_params_t #prefix #surface #type_params_t { #(#wrapper_methods)* } #[repr(transparent)] - pub struct __ComemoSurfaceMut #impl_params_t(::comemo::TrackedMut<#t, #ty>) + pub struct #surface_mut #impl_params_t(::comemo::TrackedMut<#t, #ty>) #where_clause; #[allow(dead_code)] - impl #impl_params_t #prefix __ComemoSurfaceMut #type_params_t { + impl #impl_params_t #prefix #surface_mut #type_params_t { #(#wrapper_methods_mut)* } - }) } @@ -376,10 +425,9 @@ fn create_wrapper(method: &Method, tracked_mut: bool) -> TokenStream { let vis = &method.vis; let sig = &method.sig; let args = &method.args; - let mutable = method.mutable; let to_parts = if !tracked_mut { - quote! { to_parts_ref(&self.0) } - } else if !mutable { + quote! { to_parts_ref(self.0) } + } else if !method.mutable { quote! { to_parts_mut_ref(&self.0) } } else { quote! { to_parts_mut_mut(&mut self.0) } @@ -395,7 +443,6 @@ fn create_wrapper(method: &Method, tracked_mut: bool) -> TokenStream { constraint.push( __ComemoCall(__comemo_variant), ::comemo::internal::hash(&output), - #mutable, ); } output diff --git a/src/cache.rs b/src/cache.rs index 32712cf..c8345ac 100644 --- a/src/cache.rs +++ b/src/cache.rs @@ -1,9 +1,13 @@ +use std::borrow::Cow; +use std::collections::HashMap; use std::hash::Hash; +use std::sync::atomic::AtomicUsize; use std::sync::atomic::Ordering; -use std::{borrow::Cow, sync::atomic::AtomicUsize}; -use hashbrown::HashMap; -use parking_lot::{MappedRwLockReadGuard, Mutex, RwLock, RwLockReadGuard}; +use once_cell::sync::Lazy; +use parking_lot::{ + MappedRwLockReadGuard, Mutex, RwLock, RwLockReadGuard, RwLockWriteGuard, +}; use siphasher::sip128::{Hasher128, SipHasher13}; use crate::input::Input; @@ -28,7 +32,7 @@ pub fn register_cache(fun: fn(usize)) { /// Will allocate a new accelerator if the ID is larger than the current capacity. pub fn id() -> usize { // Get the next ID. - ID.fetch_add(1, Ordering::AcqRel) + ID.fetch_add(1, Ordering::SeqCst) } /// Get an accelerator by ID. @@ -59,7 +63,14 @@ fn accelerator(id: usize) -> Option> accelerators = ACCELERATORS.read(); } - let i = id - offset; + // Because we release the lock before resizing the accelerator, + // we need to check again whether the ID is still valid because + // another thread might evicted the cache. + let i = id - accelerators.0; + if id < offset { + return None; + } + Some(RwLockReadGuard::map(accelerators, move |accelerators| &accelerators.1[i])) } @@ -73,7 +84,7 @@ thread_local! { pub fn memoized<'c, In, Out, F>( mut input: In, constraint: &'c In::Constraint, - cache: &RwLock>, + cache: &Cache, func: F, ) -> Out where @@ -99,6 +110,7 @@ where #[cfg(feature = "last_was_hit")] LAST_WAS_HIT.with(|cell| cell.set(true)); + return value.clone(); } @@ -116,6 +128,7 @@ where // Insert the result into the cache. let mut borrow = cache.write(); borrow.insert::(key, constraint.take(), output.clone()); + #[cfg(feature = "last_was_hit")] LAST_WAS_HIT.with(|cell| cell.set(false)); @@ -138,16 +151,15 @@ pub fn last_was_hit() -> bool { /// Comemo's cache is thread-local, meaning that this only evicts this thread's /// cache. pub fn evict(max_age: usize) { - CACHES.read().iter().for_each(|fun| fun(max_age)); + for subevict in CACHES.read().iter() { + subevict(max_age); + } // Evict all accelerators. let mut accelerators = ACCELERATORS.write(); - // Force the ID to be loaded. - let id = ID.load(Ordering::Acquire); - // Update the offset. - accelerators.0 = id; + accelerators.0 = ID.load(Ordering::SeqCst); // Clear all accelerators while keeping the memory allocated. accelerators.1.iter_mut().for_each(|accelerator| { @@ -155,19 +167,42 @@ pub fn evict(max_age: usize) { }) } +pub struct Cache(Lazy>>); + +impl Cache { + /// Create an empty cache. + /// + /// It must take an initialization function because the `evict` fn + /// pointer cannot be passed as an argument otherwise the function + /// passed to `Lazy::new` is a closure and not a function pointer. + pub const fn new(init: fn() -> RwLock>) -> Self { + Self(Lazy::new(init)) + } + + /// Write to the inner cache. + pub fn write(&self) -> RwLockWriteGuard<'_, CacheData> { + self.0.write() + } + + /// Read from the inner cache. + fn read(&self) -> RwLockReadGuard<'_, CacheData> { + self.0.read() + } +} + /// The global cache. -pub struct Cache { +pub struct CacheData { /// Maps from hashes to memoized results. entries: HashMap>>, } -impl Default for Cache { +impl Default for CacheData { fn default() -> Self { Self { entries: HashMap::new() } } } -impl Cache { +impl CacheData { /// Create an empty cache. pub fn new() -> Self { Self::default() @@ -176,9 +211,10 @@ impl Cache { /// Evict all entries whose age is larger than or equal to `max_age`. pub fn evict(&mut self, max_age: usize) { self.entries.retain(|_, entries| { - entries.retain(|entry| { - let age = entry.age.fetch_add(1, Ordering::Acquire); - (age + 1) <= max_age + entries.retain_mut(|entry| { + let age = entry.age.get_mut(); + *age += 1; + *age <= max_age }); !entries.is_empty() }); @@ -233,89 +269,56 @@ impl CacheEntry { In: Input, { input.validate(&self.constraint).then(|| { - self.age.store(0, Ordering::Release); + self.age.store(0, Ordering::SeqCst); (&self.constraint, &self.output) }) } } +/// A call to a tracked function. +pub trait Call { + /// Whether the call is mutable. + fn is_mutable(&self) -> bool; +} + /// A call entry. #[derive(Clone)] -struct Call { +struct ConstraintEntry { args: T, args_hash: u128, ret: u128, - both: u128, - mutable: bool, } /// Defines a constraint for a tracked type. -pub struct Constraint(RwLock>); +pub struct ImmutableConstraint(RwLock>>); -#[derive(Clone)] -struct Inner { - /// The list of calls. - /// - /// Order matters here, as those are mutable & immutable calls. - calls: Vec>, - /// The hash of the arguments and index of the call. - /// - /// Order does not matter here, as those are immutable calls. - immutable: HashMap, -} +impl ImmutableConstraint { + /// Create empty constraints. + pub fn new() -> Self { + Self::default() + } -impl Clone for Constraint { - fn clone(&self) -> Self { - Self(RwLock::new(self.0.read().clone())) + /// Enter a constraint for a call to an immutable function. + #[inline] + pub fn push(&self, args: T, ret: u128) { + let args_hash = hash(&args); + self.push_inner(Cow::Owned(ConstraintEntry { args, args_hash, ret })); } -} -impl Inner { /// Enter a constraint for a call to an immutable function. #[inline] - fn push_inner(&mut self, call: Cow>) { - // If the call is immutable check whether we already have a call - // with the same arguments and return value. - if !call.mutable { - if let Some(_prev) = self.immutable.get(&call.args_hash) { - #[cfg(debug_assertions)] - check(&self.calls[*_prev], &call); + fn push_inner(&self, call: Cow>) { + let mut calls = self.0.write(); + debug_assert!(!call.args.is_mutable()); - return; - } - } + if let Some(_prev) = calls.get(&call.args_hash) { + #[cfg(debug_assertions)] + check(_prev, &call); - if call.mutable { - // If the call is mutable, clear all immutable calls. - self.immutable.clear(); - } else { - // Otherwise, insert the call into the immutable map. - self.immutable.insert(call.args_hash, self.calls.len()); + return; } - // Insert the call into the call list. - self.calls.push(call.into_owned()); - } -} - -impl Constraint { - /// Create empty constraints. - pub fn new() -> Self { - Self::default() - } - - /// Enter a constraint for a call to an immutable function. - #[inline] - pub fn push(&self, args: T, ret: u128, mutable: bool) { - let args_hash = hash(&args); - let both = hash(&(args_hash, ret)); - self.0.write().push_inner(Cow::Owned(Call { - args, - args_hash, - ret, - both, - mutable, - })); + calls.insert(call.args_hash, call.into_owned()); } /// Whether the method satisfies as all input-output pairs. @@ -324,7 +327,7 @@ impl Constraint { where F: FnMut(&T) -> u128, { - self.0.read().calls.iter().all(|entry| f(&entry.args) == entry.ret) + self.0.read().values().all(|entry| f(&entry.args) == entry.ret) } /// Whether the method satisfies as all input-output pairs. @@ -337,47 +340,43 @@ impl Constraint { let inner = self.0.read(); if let Some(accelerator) = accelerator { let mut map = accelerator.lock(); - inner.calls.iter().all(|entry| { - *map.entry(entry.both).or_insert_with(|| f(&entry.args)) == entry.ret + inner.values().all(|entry| { + *map.entry(entry.args_hash).or_insert_with(|| f(&entry.args)) == entry.ret }) } else { - inner.calls.iter().all(|entry| f(&entry.args) == entry.ret) + inner.values().all(|entry| f(&entry.args) == entry.ret) } } /// Replay all input-output pairs. #[inline] - pub fn replay(&self, mut f: F) + pub fn replay(&self, _: F) where F: FnMut(&T), { - self.0 - .read() - .calls - .iter() - .filter(|call| call.mutable) - .for_each(|call| { - f(&call.args); - }); + #[cfg(debug_assertions)] + for entry in self.0.read().values() { + assert!(!entry.args.is_mutable()); + } } } -impl Default for Constraint { - fn default() -> Self { - Self(RwLock::new(Inner { calls: Vec::new(), immutable: HashMap::default() })) +impl Clone for ImmutableConstraint { + fn clone(&self) -> Self { + Self(RwLock::new(self.0.read().clone())) } } -impl Default for Inner { +impl Default for ImmutableConstraint { fn default() -> Self { - Self { calls: Vec::new(), immutable: HashMap::default() } + Self(RwLock::new(HashMap::default())) } } /// Defines a constraint for a tracked type. -pub struct ImmutableConstraint(RwLock>>); +pub struct MutableConstraint(RwLock>); -impl ImmutableConstraint { +impl MutableConstraint { /// Create empty constraints. pub fn new() -> Self { Self::default() @@ -385,26 +384,11 @@ impl ImmutableConstraint { /// Enter a constraint for a call to an immutable function. #[inline] - pub fn push(&self, args: T, ret: u128, mutable: bool) { + pub fn push(&self, args: T, ret: u128) { let args_hash = hash(&args); - let both = hash(&(args_hash, ret)); - self.push_inner(Cow::Owned(Call { args, args_hash, ret, both, mutable })); - } - - /// Enter a constraint for a call to an immutable function. - #[inline] - fn push_inner(&self, call: Cow>) { - let mut calls = self.0.write(); - debug_assert!(!call.mutable); - - if let Some(_prev) = calls.get(&call.args_hash) { - #[cfg(debug_assertions)] - check(_prev, &call); - - return; - } - - calls.insert(call.args_hash, call.into_owned()); + self.0 + .write() + .push_inner(Cow::Owned(ConstraintEntry { args, args_hash, ret })); } /// Whether the method satisfies as all input-output pairs. @@ -413,49 +397,88 @@ impl ImmutableConstraint { where F: FnMut(&T) -> u128, { - self.0.read().values().all(|entry| f(&entry.args) == entry.ret) + self.0.read().calls.iter().all(|entry| f(&entry.args) == entry.ret) } /// Whether the method satisfies as all input-output pairs. + /// + /// On mutable tracked types, this does not use an accelerator as it is + /// rarely, if ever used. Therefore, it is not worth the overhead. #[inline] - pub fn validate_with_id(&self, mut f: F, id: usize) -> bool + pub fn validate_with_id(&self, mut f: F, _: usize) -> bool where F: FnMut(&T) -> u128, { - let accelerator = accelerator(id); let inner = self.0.read(); - if let Some(accelerator) = accelerator { - let mut map = accelerator.lock(); - inner.values().all(|entry| { - *map.entry(entry.both).or_insert_with(|| f(&entry.args)) == entry.ret - }) - } else { - inner.values().all(|entry| f(&entry.args) == entry.ret) - } + inner.calls.iter().all(|entry| f(&entry.args) == entry.ret) } /// Replay all input-output pairs. #[inline] - pub fn replay(&self, _: F) + pub fn replay(&self, mut f: F) where F: FnMut(&T), { - #[cfg(debug_assertions)] - for entry in self.0.read().values() { - assert!(!entry.mutable); + for call in self.0.read().calls.iter().filter(|call| call.args.is_mutable()) { + f(&call.args); } } } -impl Clone for ImmutableConstraint { +impl Clone for MutableConstraint { fn clone(&self) -> Self { Self(RwLock::new(self.0.read().clone())) } } -impl Default for ImmutableConstraint { +impl Default for MutableConstraint { fn default() -> Self { - Self(RwLock::new(HashMap::default())) + Self(RwLock::new(Inner { calls: Vec::new() })) + } +} + +#[derive(Clone)] +struct Inner { + /// The list of calls. + /// + /// Order matters here, as those are mutable & immutable calls. + calls: Vec>, +} + +impl Inner { + /// Enter a constraint for a call to a function. + /// + /// If the function is immutable, it uses a fast-path based on a + /// `HashMap` to perform deduplication. Otherwise, it always + /// pushes the call to the list. + #[inline] + fn push_inner(&mut self, call: Cow>) { + // If the call is immutable check whether we already have a call + // with the same arguments and return value. + let mutable = call.args.is_mutable(); + if !mutable { + for entry in self.calls.iter().rev() { + if call.args.is_mutable() { + break; + } + + if call.args_hash == entry.args_hash && call.ret == entry.ret { + #[cfg(debug_assertions)] + check(&call, entry); + + return; + } + } + } + + // Insert the call into the call list. + self.calls.push(call.into_owned()); + } +} + +impl Default for Inner { + fn default() -> Self { + Self { calls: Vec::new() } } } @@ -482,7 +505,7 @@ impl Join for Option<&T> { } } -impl Join for Constraint { +impl Join for MutableConstraint { #[inline] fn join(&self, inner: &Self) { let mut this = self.0.write(); @@ -497,7 +520,7 @@ impl Join for Constraint { } } -impl Join for ImmutableConstraint { +impl Join for ImmutableConstraint { #[inline] fn join(&self, inner: &Self) { for call in inner.0.read().values() { @@ -523,7 +546,7 @@ pub fn hash(value: &T) -> u128 { #[inline] #[track_caller] #[allow(dead_code)] -fn check(lhs: &Call, rhs: &Call) { +fn check(lhs: &ConstraintEntry, rhs: &ConstraintEntry) { if lhs.ret != rhs.ret { panic!( "comemo: found conflicting constraints. \ @@ -532,13 +555,9 @@ fn check(lhs: &Call, rhs: &Call) { } // Additional checks for debugging. - if lhs.args_hash != rhs.args_hash - || lhs.args != rhs.args - || lhs.both != rhs.both - || lhs.mutable != rhs.mutable - { + if lhs.args_hash != rhs.args_hash || lhs.args != rhs.args { panic!( - "comemo: found conflicting arguments | + "comemo: found conflicting `check` arguments. \ this is a bug in comemo" ) } diff --git a/src/lib.rs b/src/lib.rs index c31a4bf..37e8237 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -98,13 +98,12 @@ pub mod internal { pub use parking_lot::RwLock; pub use crate::cache::{ - hash, memoized, register_cache, Accelerator, Cache, Constraint, - ImmutableConstraint, + hash, memoized, register_cache, Accelerator, Cache, CacheData, Call, + ImmutableConstraint, MutableConstraint, }; pub use crate::input::{assert_hashable_or_trackable, Args, Input}; pub use crate::track::{to_parts_mut_mut, to_parts_mut_ref, to_parts_ref, Surfaces}; - pub use once_cell::sync::Lazy; #[cfg(feature = "last_was_hit")] pub use crate::cache::last_was_hit; diff --git a/src/track.rs b/src/track.rs index 0fac13c..e707057 100644 --- a/src/track.rs +++ b/src/track.rs @@ -288,7 +288,7 @@ where /// Destructure a `Tracked<_>` into its parts. #[inline] -pub fn to_parts_ref<'a, T>(tracked: &Tracked<'a, T>) -> (&'a T, Option<&'a T::Constraint>) +pub fn to_parts_ref(tracked: Tracked) -> (&T, Option<&T::Constraint>) where T: Track + ?Sized, { diff --git a/tests/tests.rs b/tests/tests.rs index d5072ae..f4a92a5 100644 --- a/tests/tests.rs +++ b/tests/tests.rs @@ -1,453 +1,458 @@ -use std::collections::HashMap; -use std::hash::Hash; -use std::path::{Path, PathBuf}; - -use comemo::{evict, memoize, track, Track, Tracked, TrackedMut, Validate}; -use serial_test::serial; - -macro_rules! test { - (miss: $call:expr, $result:expr) => {{ - assert_eq!($call, $result); - assert!(!comemo::internal::last_was_hit()); - }}; - (hit: $call:expr, $result:expr) => {{ - assert_eq!($call, $result); - assert!(comemo::internal::last_was_hit()); - }}; -} - -/// Test basic memoization. -#[test] -#[serial] -fn test_basic() { - #[memoize] - fn empty() -> String { - format!("The world is {}", "big") - } - - #[memoize] - fn double(x: u32) -> u32 { - 2 * x - } - - #[memoize] - fn sum(a: u32, b: u32) -> u32 { - a + b +#[cfg(feature = "last_was_hit")] +mod tests { + use std::collections::HashMap; + use std::hash::Hash; + use std::path::{Path, PathBuf}; + + use comemo::{evict, memoize, track, Track, Tracked, TrackedMut, Validate}; + use serial_test::serial; + + macro_rules! test { + (miss: $call:expr, $result:expr) => {{ + assert_eq!($call, $result); + assert!(!comemo::internal::last_was_hit()); + }}; + (hit: $call:expr, $result:expr) => {{ + assert_eq!($call, $result); + assert!(comemo::internal::last_was_hit()); + }}; } - #[memoize] - fn fib(n: u32) -> u32 { - if n <= 2 { - 1 - } else { - fib(n - 1) + fib(n - 2) + /// Test basic memoization. + #[test] + #[serial] + fn test_basic() { + #[memoize] + fn empty() -> String { + format!("The world is {}", "big") } - } - #[memoize] - fn sum_iter(n: u32) -> u32 { - (0..n).sum() - } - - test!(miss: empty(), "The world is big"); - test!(hit: empty(), "The world is big"); - test!(hit: empty(), "The world is big"); + #[memoize] + fn double(x: u32) -> u32 { + 2 * x + } - test!(miss: double(2), 4); - test!(miss: double(4), 8); - test!(hit: double(2), 4); + #[memoize] + fn sum(a: u32, b: u32) -> u32 { + a + b + } - test!(miss: sum(2, 4), 6); - test!(miss: sum(2, 3), 5); - test!(hit: sum(2, 3), 5); - test!(miss: sum(4, 2), 6); + #[memoize] + fn fib(n: u32) -> u32 { + if n <= 2 { + 1 + } else { + fib(n - 1) + fib(n - 2) + } + } - test!(miss: fib(5), 5); - test!(hit: fib(3), 2); - test!(miss: fib(8), 21); - test!(hit: fib(7), 13); + #[memoize] + fn sum_iter(n: u32) -> u32 { + (0..n).sum() + } - test!(miss: sum_iter(1000), 499500); - test!(hit: sum_iter(1000), 499500); -} + test!(miss: empty(), "The world is big"); + test!(hit: empty(), "The world is big"); + test!(hit: empty(), "The world is big"); -/// Test the calc language. -#[test] -#[serial] -fn test_calc() { - #[memoize] - fn evaluate(script: &str, files: Tracked) -> i32 { - script - .split('+') - .map(str::trim) - .map(|part| match part.strip_prefix("eval ") { - Some(path) => evaluate(&files.read(path), files), - None => part.parse::().unwrap(), - }) - .sum() - } + test!(miss: double(2), 4); + test!(miss: double(4), 8); + test!(hit: double(2), 4); - let mut files = Files(HashMap::new()); - files.write("alpha.calc", "2 + eval beta.calc"); - files.write("beta.calc", "2 + 3"); - files.write("gamma.calc", "8 + 3"); - test!(miss: evaluate("eval alpha.calc", files.track()), 7); - test!(miss: evaluate("eval beta.calc", files.track()), 5); - files.write("gamma.calc", "42"); - test!(hit: evaluate("eval alpha.calc", files.track()), 7); - files.write("beta.calc", "4 + eval gamma.calc"); - test!(miss: evaluate("eval beta.calc", files.track()), 46); - test!(miss: evaluate("eval alpha.calc", files.track()), 48); - files.write("gamma.calc", "80"); - test!(miss: evaluate("eval alpha.calc", files.track()), 86); -} + test!(miss: sum(2, 4), 6); + test!(miss: sum(2, 3), 5); + test!(hit: sum(2, 3), 5); + test!(miss: sum(4, 2), 6); -struct Files(HashMap); + test!(miss: fib(5), 5); + test!(hit: fib(3), 2); + test!(miss: fib(8), 21); + test!(hit: fib(7), 13); -#[track] -impl Files { - fn read(&self, path: &str) -> String { - self.0.get(Path::new(path)).cloned().unwrap_or_default() + test!(miss: sum_iter(1000), 499500); + test!(hit: sum_iter(1000), 499500); } -} -impl Files { - fn write(&mut self, path: &str, text: &str) { - self.0.insert(path.into(), text.into()); - } -} + /// Test the calc language. + #[test] + #[serial] + fn test_calc() { + #[memoize] + fn evaluate(script: &str, files: Tracked) -> i32 { + script + .split('+') + .map(str::trim) + .map(|part| match part.strip_prefix("eval ") { + Some(path) => evaluate(&files.read(path), files), + None => part.parse::().unwrap(), + }) + .sum() + } -/// Test cache eviction. -#[test] -#[serial] -fn test_evict() { - #[memoize] - fn null() -> u8 { - 0 + let mut files = Files(HashMap::new()); + files.write("alpha.calc", "2 + eval beta.calc"); + files.write("beta.calc", "2 + 3"); + files.write("gamma.calc", "8 + 3"); + test!(miss: evaluate("eval alpha.calc", files.track()), 7); + test!(miss: evaluate("eval beta.calc", files.track()), 5); + files.write("gamma.calc", "42"); + test!(hit: evaluate("eval alpha.calc", files.track()), 7); + files.write("beta.calc", "4 + eval gamma.calc"); + test!(miss: evaluate("eval beta.calc", files.track()), 46); + test!(miss: evaluate("eval alpha.calc", files.track()), 48); + files.write("gamma.calc", "80"); + test!(miss: evaluate("eval alpha.calc", files.track()), 86); } - test!(miss: null(), 0); - test!(hit: null(), 0); - evict(2); - test!(hit: null(), 0); - evict(2); - evict(2); - test!(hit: null(), 0); - evict(2); - evict(2); - evict(2); - test!(miss: null(), 0); - test!(hit: null(), 0); - evict(0); - test!(miss: null(), 0); - test!(hit: null(), 0); -} + struct Files(HashMap); -/// Test tracking a trait object. -#[test] -#[serial] -fn test_tracked_trait() { - #[memoize] - fn traity(loader: Tracked, path: &Path) -> Vec { - loader.load(path).unwrap() + #[track] + impl Files { + fn read(&self, path: &str) -> String { + self.0.get(Path::new(path)).cloned().unwrap_or_default() + } } - fn wrapper(loader: &(dyn Loader + Send + Sync), path: &Path) -> Vec { - traity(loader.track(), path) + impl Files { + fn write(&mut self, path: &str, text: &str) { + self.0.insert(path.into(), text.into()); + } } - let loader: &(dyn Loader + Send + Sync) = &StaticLoader; - test!(miss: traity(loader.track(), Path::new("hi.rs")), [1, 2, 3]); - test!(hit: traity(loader.track(), Path::new("hi.rs")), [1, 2, 3]); - test!(miss: traity(loader.track(), Path::new("bye.rs")), [1, 2, 3]); - wrapper(loader, Path::new("hi.rs")); -} - -#[track] -trait Loader: Send + Sync { - fn load(&self, path: &Path) -> Result, String>; -} + /// Test cache eviction. + #[test] + #[serial] + fn test_evict() { + #[memoize] + fn null() -> u8 { + 0 + } -struct StaticLoader; -impl Loader for StaticLoader { - fn load(&self, _: &Path) -> Result, String> { - Ok(vec![1, 2, 3]) + test!(miss: null(), 0); + test!(hit: null(), 0); + evict(2); + test!(hit: null(), 0); + evict(2); + evict(2); + test!(hit: null(), 0); + evict(2); + evict(2); + evict(2); + test!(miss: null(), 0); + test!(hit: null(), 0); + evict(0); + test!(miss: null(), 0); + test!(hit: null(), 0); } -} - -/// Test memoized methods. -#[test] -#[serial] -fn test_memoized_methods() { - #[derive(Hash)] - struct Taker(String); - /// Has memoized methods. - impl Taker { + /// Test tracking a trait object. + #[test] + #[serial] + fn test_tracked_trait() { #[memoize] - fn copy(&self) -> String { - self.0.clone() + fn traity( + loader: Tracked, + path: &Path, + ) -> Vec { + loader.load(path).unwrap() } - #[memoize] - fn take(self) -> String { - self.0 + fn wrapper(loader: &(dyn Loader + Send + Sync), path: &Path) -> Vec { + traity(loader.track(), path) } - } - test!(miss: Taker("Hello".into()).take(), "Hello"); - test!(miss: Taker("Hello".into()).copy(), "Hello"); - test!(miss: Taker("World".into()).take(), "World"); - test!(hit: Taker("Hello".into()).take(), "Hello"); -} + let loader: &(dyn Loader + Send + Sync) = &StaticLoader; + test!(miss: traity(loader.track(), Path::new("hi.rs")), [1, 2, 3]); + test!(hit: traity(loader.track(), Path::new("hi.rs")), [1, 2, 3]); + test!(miss: traity(loader.track(), Path::new("bye.rs")), [1, 2, 3]); + wrapper(loader, Path::new("hi.rs")); + } -/// Test different kinds of arguments. -#[test] -#[serial] -fn test_kinds() { - #[memoize] - fn selfie(tester: Tracky) -> String { - tester.self_ref().into() + #[track] + trait Loader: Send + Sync { + fn load(&self, path: &Path) -> Result, String>; } - #[memoize] - fn unconditional(tester: Tracky) -> &'static str { - if tester.by_value(Heavy("HEAVY".into())) > 10 { - "Long" - } else { - "Short" + struct StaticLoader; + impl Loader for StaticLoader { + fn load(&self, _: &Path) -> Result, String> { + Ok(vec![1, 2, 3]) } } - let mut tester = Tester { data: "Hi".to_string() }; + /// Test memoized methods. + #[test] + #[serial] + fn test_memoized_methods() { + #[derive(Hash)] + struct Taker(String); + + /// Has memoized methods. + impl Taker { + #[memoize] + fn copy(&self) -> String { + self.0.clone() + } + + #[memoize] + fn take(self) -> String { + self.0 + } + } - let tracky = tester.track(); - test!(miss: selfie(tracky.clone()), "Hi"); - test!(miss: unconditional(tracky.clone()), "Short"); - test!(hit: unconditional(tracky.clone()), "Short"); - test!(hit: selfie(tracky), "Hi"); + test!(miss: Taker("Hello".into()).take(), "Hello"); + test!(miss: Taker("Hello".into()).copy(), "Hello"); + test!(miss: Taker("World".into()).take(), "World"); + test!(hit: Taker("Hello".into()).take(), "Hello"); + } - tester.data.push('!'); + /// Test different kinds of arguments. + #[test] + #[serial] + fn test_kinds() { + #[memoize] + fn selfie(tester: Tracky) -> String { + tester.self_ref().into() + } - let tracky = tester.track(); - test!(miss: selfie(tracky.clone()), "Hi!"); - test!(miss: unconditional(tracky), "Short"); + #[memoize] + fn unconditional(tester: Tracky) -> &'static str { + if tester.by_value(Heavy("HEAVY".into())) > 10 { + "Long" + } else { + "Short" + } + } - tester.data.push_str(" Let's go."); + let mut tester = Tester { data: "Hi".to_string() }; - let tracky = tester.track(); - test!(miss: unconditional(tracky), "Long"); -} + let tracky = tester.track(); + test!(miss: selfie(tracky), "Hi"); + test!(miss: unconditional(tracky), "Short"); + test!(hit: unconditional(tracky), "Short"); + test!(hit: selfie(tracky), "Hi"); -/// Test with type alias. -type Tracky<'a> = comemo::Tracked<'a, Tester>; + tester.data.push('!'); -/// A struct with some data. -struct Tester { - data: String, -} + let tracky = tester.track(); + test!(miss: selfie(tracky), "Hi!"); + test!(miss: unconditional(tracky), "Short"); + + tester.data.push_str(" Let's go."); -/// Tests different kinds of arguments. -#[track] -impl Tester { - /// Return value can borrow from self. - #[allow(clippy::needless_lifetimes)] - fn self_ref<'a>(&'a self) -> &'a str { - &self.data + let tracky = tester.track(); + test!(miss: unconditional(tracky), "Long"); } - /// Return value can borrow from argument. - fn arg_ref<'a>(&self, name: &'a str) -> &'a str { - name + /// Test with type alias. + type Tracky<'a> = comemo::Tracked<'a, Tester>; + + /// A struct with some data. + struct Tester { + data: String, } - /// Return value can borrow from both. - fn double_ref<'a>(&'a self, name: &'a str) -> &'a str { - if name.len() > self.data.len() { - name - } else { + /// Tests different kinds of arguments. + #[track] + impl Tester { + /// Return value can borrow from self. + #[allow(clippy::needless_lifetimes)] + fn self_ref<'a>(&'a self) -> &'a str { &self.data } - } - - /// Normal method with owned argument. - fn by_value(&self, heavy: Heavy) -> usize { - self.data.len() + heavy.0.len() - } -} -/// Test empty type without methods. -struct Empty; + /// Return value can borrow from argument. + fn arg_ref<'a>(&self, name: &'a str) -> &'a str { + name + } -#[track] -impl Empty {} + /// Return value can borrow from both. + fn double_ref<'a>(&'a self, name: &'a str) -> &'a str { + if name.len() > self.data.len() { + name + } else { + &self.data + } + } -/// Test tracking a type with a lifetime. -#[test] -#[serial] -fn test_lifetime() { - #[comemo::memoize] - fn contains_hello(lifeful: Tracked) -> bool { - lifeful.contains("hello") + /// Normal method with owned argument. + fn by_value(&self, heavy: Heavy) -> usize { + self.data.len() + heavy.0.len() + } } - let lifeful = Lifeful("hey"); - test!(miss: contains_hello(lifeful.track()), false); - test!(hit: contains_hello(lifeful.track()), false); + /// Test empty type without methods. + struct Empty; - let lifeful = Lifeful("hello"); - test!(miss: contains_hello(lifeful.track()), true); - test!(hit: contains_hello(lifeful.track()), true); -} + #[track] + impl Empty {} -/// Test tracked with lifetime. -struct Lifeful<'a>(&'a str); + /// Test tracking a type with a lifetime. + #[test] + #[serial] + fn test_lifetime() { + #[comemo::memoize] + fn contains_hello(lifeful: Tracked) -> bool { + lifeful.contains("hello") + } -#[track] -impl<'a> Lifeful<'a> { - fn contains(&self, text: &str) -> bool { - self.0 == text + let lifeful = Lifeful("hey"); + test!(miss: contains_hello(lifeful.track()), false); + test!(hit: contains_hello(lifeful.track()), false); + + let lifeful = Lifeful("hello"); + test!(miss: contains_hello(lifeful.track()), true); + test!(hit: contains_hello(lifeful.track()), true); } -} -/// Test tracking a type with a chain of tracked values. -#[test] -#[serial] -fn test_chain() { - #[comemo::memoize] - fn process(chain: Tracked, value: u32) -> bool { - chain.contains(value) + /// Test tracked with lifetime. + struct Lifeful<'a>(&'a str); + + #[track] + impl<'a> Lifeful<'a> { + fn contains(&self, text: &str) -> bool { + self.0 == text + } } - let chain1 = Chain::new(1); - let chain3 = Chain::new(3); - let chain12 = Chain::insert(chain1.track(), 2); - let chain123 = Chain::insert(chain12.track(), 3); - let chain124 = Chain::insert(chain12.track(), 4); - let chain1245 = Chain::insert(chain124.track(), 5); - - test!(miss: process(chain1.track(), 0), false); - test!(miss: process(chain1.track(), 1), true); - test!(miss: process(chain123.track(), 2), true); - test!(hit: process(chain124.track(), 2), true); - test!(hit: process(chain12.track(), 2), true); - test!(hit: process(chain1245.track(), 2), true); - test!(miss: process(chain1.track(), 2), false); - test!(hit: process(chain3.track(), 2), false); -} + /// Test tracking a type with a chain of tracked values. + #[test] + #[serial] + fn test_chain() { + #[comemo::memoize] + fn process(chain: Tracked, value: u32) -> bool { + chain.contains(value) + } -/// Test that `Tracked` is covariant over `T`. -#[test] -#[serial] -#[allow(unused, clippy::needless_lifetimes)] -fn test_variance() { - fn foo<'a>(_: Tracked<'a, Chain<'a>>) {} - fn bar<'a>(chain: Tracked<'a, Chain<'static>>) { - foo(chain); + let chain1 = Chain::new(1); + let chain3 = Chain::new(3); + let chain12 = Chain::insert(chain1.track(), 2); + let chain123 = Chain::insert(chain12.track(), 3); + let chain124 = Chain::insert(chain12.track(), 4); + let chain1245 = Chain::insert(chain124.track(), 5); + + test!(miss: process(chain1.track(), 0), false); + test!(miss: process(chain1.track(), 1), true); + test!(miss: process(chain123.track(), 2), true); + test!(hit: process(chain124.track(), 2), true); + test!(hit: process(chain12.track(), 2), true); + test!(hit: process(chain1245.track(), 2), true); + test!(miss: process(chain1.track(), 2), false); + test!(hit: process(chain3.track(), 2), false); } -} -/// Test tracked with lifetime. -struct Chain<'a> { - // Need to override the lifetime here so that a `Tracked` is covariant over - // `Chain`. - outer: Option as Validate>::Constraint>>, - value: u32, -} + /// Test that `Tracked` is covariant over `T`. + #[test] + #[serial] + #[allow(unused, clippy::needless_lifetimes)] + fn test_variance() { + fn foo<'a>(_: Tracked<'a, Chain<'a>>) {} + fn bar<'a>(chain: Tracked<'a, Chain<'static>>) { + foo(chain); + } + } -impl<'a> Chain<'a> { - /// Create a new chain entry point. - fn new(value: u32) -> Self { - Self { outer: None, value } + /// Test tracked with lifetime. + struct Chain<'a> { + // Need to override the lifetime here so that a `Tracked` is covariant over + // `Chain`. + outer: Option as Validate>::Constraint>>, + value: u32, } - /// Insert a link into the chain. - fn insert(outer: Tracked<'a, Self>, value: u32) -> Self { - Chain { outer: Some(outer), value } + impl<'a> Chain<'a> { + /// Create a new chain entry point. + fn new(value: u32) -> Self { + Self { outer: None, value } + } + + /// Insert a link into the chain. + fn insert(outer: Tracked<'a, Self>, value: u32) -> Self { + Chain { outer: Some(outer), value } + } } -} -#[track] -impl<'a> Chain<'a> { - fn contains(&self, value: u32) -> bool { - self.value == value - || self.outer.as_ref().map_or(false, |outer| outer.contains(value)) + #[track] + impl<'a> Chain<'a> { + fn contains(&self, value: u32) -> bool { + self.value == value || self.outer.map_or(false, |outer| outer.contains(value)) + } } -} -/// Test mutable tracking. -#[test] -#[serial] -#[rustfmt::skip] -fn test_mutable() { - #[comemo::memoize] - fn dump(mut sink: TrackedMut) { - sink.emit("a"); - sink.emit("b"); - let c = sink.len_or_ten().to_string(); - sink.emit(&c); + /// Test mutable tracking. + #[test] + #[serial] + #[rustfmt::skip] + fn test_mutable() { + #[comemo::memoize] + fn dump(mut sink: TrackedMut) { + sink.emit("a"); + sink.emit("b"); + let c = sink.len_or_ten().to_string(); + sink.emit(&c); + } + + let mut emitter = Emitter(vec![]); + test!(miss: dump(emitter.track_mut()), ()); + test!(miss: dump(emitter.track_mut()), ()); + test!(miss: dump(emitter.track_mut()), ()); + test!(miss: dump(emitter.track_mut()), ()); + test!(hit: dump(emitter.track_mut()), ()); + test!(hit: dump(emitter.track_mut()), ()); + assert_eq!(emitter.0, [ + "a", "b", "2", + "a", "b", "5", + "a", "b", "8", + "a", "b", "10", + "a", "b", "10", + "a", "b", "10", + ]) } - let mut emitter = Emitter(vec![]); - test!(miss: dump(emitter.track_mut()), ()); - test!(miss: dump(emitter.track_mut()), ()); - test!(miss: dump(emitter.track_mut()), ()); - test!(miss: dump(emitter.track_mut()), ()); - test!(hit: dump(emitter.track_mut()), ()); - test!(hit: dump(emitter.track_mut()), ()); - assert_eq!(emitter.0, [ - "a", "b", "2", - "a", "b", "5", - "a", "b", "8", - "a", "b", "10", - "a", "b", "10", - "a", "b", "10", - ]) -} + /// A tracked type with a mutable and an immutable method. + #[derive(Clone)] + struct Emitter(Vec); -/// A tracked type with a mutable and an immutable method. -#[derive(Clone)] -struct Emitter(Vec); + #[track] + impl Emitter { + fn emit(&mut self, msg: &str) { + self.0.push(msg.into()); + } -#[track] -impl Emitter { - fn emit(&mut self, msg: &str) { - self.0.push(msg.into()); + fn len_or_ten(&self) -> usize { + self.0.len().min(10) + } } - fn len_or_ten(&self) -> usize { - self.0.len().min(10) - } -} + /// A non-copy struct that is passed by value to a tracked method. + #[derive(Clone, PartialEq, Hash)] + struct Heavy(String); + + /// Test a tracked method that is impure. + #[test] + #[serial] + #[cfg(debug_assertions)] + #[should_panic( + expected = "comemo: found conflicting constraints. is this tracked function pure?" + )] + fn test_impure_tracked_method() { + #[comemo::memoize] + fn call(impure: Tracked) -> u32 { + impure.impure(); + impure.impure() + } -/// A non-copy struct that is passed by value to a tracked method. -#[derive(Clone, PartialEq, Hash)] -struct Heavy(String); - -/// Test a tracked method that is impure. -#[test] -#[serial] -#[cfg(debug_assertions)] -#[should_panic( - expected = "comemo: found conflicting constraints. is this tracked function pure?" -)] -fn test_impure_tracked_method() { - #[comemo::memoize] - fn call(impure: Tracked) -> u32 { - impure.impure(); - impure.impure() + call(Impure.track()); } - call(Impure.track()); -} + struct Impure; -struct Impure; - -#[track] -impl Impure { - fn impure(&self) -> u32 { - use std::sync::atomic::{AtomicU32, Ordering}; - static VAL: AtomicU32 = AtomicU32::new(0); - VAL.fetch_add(1, Ordering::SeqCst) + #[track] + impl Impure { + fn impure(&self) -> u32 { + use std::sync::atomic::{AtomicU32, Ordering}; + static VAL: AtomicU32 = AtomicU32::new(0); + VAL.fetch_add(1, Ordering::SeqCst) + } } } From 95e14ffb0e698a554e00fddd515152c1e6b40f8b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20d=27Herbais=20de=20Thun=20=28Dherse=29?= Date: Thu, 14 Dec 2023 18:23:14 +0100 Subject: [PATCH 24/28] Simplified trait requirements --- macros/src/track.rs | 40 +++++++++++----------------------------- tests/tests.rs | 9 +++------ 2 files changed, 14 insertions(+), 35 deletions(-) diff --git a/macros/src/track.rs b/macros/src/track.rs index a78e53e..e0db944 100644 --- a/macros/src/track.rs +++ b/macros/src/track.rs @@ -5,7 +5,7 @@ pub fn expand(item: &syn::Item) -> Result { // Preprocess and validate the methods. let mut methods = vec![]; - let (ty, generics, trait_, prefix) = match item { + let (ty, generics, trait_) = match item { syn::Item::Impl(item) => { for param in item.generics.params.iter() { match param { @@ -24,7 +24,7 @@ pub fn expand(item: &syn::Item) -> Result { } let ty = item.self_ty.as_ref().clone(); - (ty, &item.generics, None, None) + (ty, &item.generics, None) } syn::Item::Trait(item) => { if let Some(first) = item.generics.params.first() { @@ -36,32 +36,20 @@ pub fn expand(item: &syn::Item) -> Result { } let name = &item.ident; - let ty_send_sync = parse_quote! { dyn #name + '__comemo_dynamic }; - let ty = parse_quote! { dyn #name + Send + Sync + '__comemo_dynamic }; - - // Produce the necessary item for the non-Send + Sync version of the trait. - let prefix = create( - &ty, - Some(quote::format_ident!("__ComemoSurfaceUnsync")), - &item.generics, - Some(item.ident.clone()), - &methods, - )?; - - (ty_send_sync, &item.generics, Some(item.ident.clone()), Some(prefix)) + let ty = parse_quote! { dyn #name + '__comemo_dynamic }; + (ty, &item.generics, Some(item.ident.clone())) } _ => bail!(item, "`track` can only be applied to impl blocks and traits"), }; // Produce the necessary items for the type to become trackable. let variants = create_variants(&methods); - let scope = create(&ty, None, generics, trait_, &methods)?; + let scope = create(&ty, generics, trait_, &methods)?; Ok(quote! { #item const _: () = { #variants - #prefix #scope }; }) @@ -232,7 +220,6 @@ fn create_variants(methods: &[Method]) -> TokenStream { /// Produce the necessary items for a type to become trackable. fn create( ty: &syn::Type, - surface: Option, generics: &syn::Generics, trait_: Option, methods: &[Method], @@ -304,11 +291,6 @@ fn create( } else { quote! { MutableConstraint } }; - let surface_mut = surface - .clone() - .map(|s| quote::format_ident!("{s}Mut")) - .unwrap_or_else(|| parse_quote! { __ComemoSurfaceMut }); - let surface = surface.unwrap_or_else(|| parse_quote! { __ComemoSurface }); Ok(quote! { impl #impl_params ::comemo::Track for #ty #where_clause {} @@ -334,8 +316,8 @@ fn create( #[doc(hidden)] impl #impl_params ::comemo::internal::Surfaces for #ty #where_clause { - type Surface<#t> = #surface #type_params_t where Self: #t; - type SurfaceMut<#t> = #surface_mut #type_params_t where Self: #t; + type Surface<#t> = __ComemoSurface #type_params_t where Self: #t; + type SurfaceMut<#t> = __ComemoSurfaceMut #type_params_t where Self: #t; #[inline] fn surface_ref<#t, #r>( @@ -363,20 +345,20 @@ fn create( } #[repr(transparent)] - pub struct #surface #impl_params_t(::comemo::Tracked<#t, #ty>) + pub struct __ComemoSurface #impl_params_t(::comemo::Tracked<#t, #ty>) #where_clause; #[allow(dead_code)] - impl #impl_params_t #prefix #surface #type_params_t { + impl #impl_params_t #prefix __ComemoSurface #type_params_t { #(#wrapper_methods)* } #[repr(transparent)] - pub struct #surface_mut #impl_params_t(::comemo::TrackedMut<#t, #ty>) + pub struct __ComemoSurfaceMut #impl_params_t(::comemo::TrackedMut<#t, #ty>) #where_clause; #[allow(dead_code)] - impl #impl_params_t #prefix #surface_mut #type_params_t { + impl #impl_params_t #prefix __ComemoSurfaceMut #type_params_t { #(#wrapper_methods_mut)* } }) diff --git a/tests/tests.rs b/tests/tests.rs index f4a92a5..d19d129 100644 --- a/tests/tests.rs +++ b/tests/tests.rs @@ -150,18 +150,15 @@ mod tests { #[serial] fn test_tracked_trait() { #[memoize] - fn traity( - loader: Tracked, - path: &Path, - ) -> Vec { + fn traity(loader: Tracked, path: &Path) -> Vec { loader.load(path).unwrap() } - fn wrapper(loader: &(dyn Loader + Send + Sync), path: &Path) -> Vec { + fn wrapper(loader: &(dyn Loader), path: &Path) -> Vec { traity(loader.track(), path) } - let loader: &(dyn Loader + Send + Sync) = &StaticLoader; + let loader: &(dyn Loader) = &StaticLoader; test!(miss: traity(loader.track(), Path::new("hi.rs")), [1, 2, 3]); test!(hit: traity(loader.track(), Path::new("hi.rs")), [1, 2, 3]); test!(miss: traity(loader.track(), Path::new("bye.rs")), [1, 2, 3]); From c5ce6440bb89fcfd52ef29faa2a984ef23be960f Mon Sep 17 00:00:00 2001 From: Laurenz Date: Thu, 14 Dec 2023 23:58:38 +0100 Subject: [PATCH 25/28] Adjust test setup --- .github/workflows/ci.yml | 4 +- Cargo.toml | 13 +- tests/tests.rs | 691 +++++++++++++++++++-------------------- 3 files changed, 356 insertions(+), 352 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 9e6f0a2..98b71ca 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -7,5 +7,5 @@ jobs: steps: - uses: actions/checkout@v3 - uses: dtolnay/rust-toolchain@stable - - run: cargo build - - run: cargo test + - run: cargo build --all-features + - run: cargo test --all-features diff --git a/Cargo.toml b/Cargo.toml index db2cb4b..4cd252f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,14 +11,19 @@ categories = ["caching"] keywords = ["incremental", "memoization", "tracked", "constraints"] [features] -default = [ ] -last_was_hit = [] +default = [] +testing = [] [dependencies] comemo-macros = { version = "0.3.1", path = "macros" } -once_cell = "1.18.0" -parking_lot = "0.12.1" +once_cell = "1.18" +parking_lot = "0.12" siphasher = "1" [dev-dependencies] serial_test = "2.0.0" + +[[test]] +name = "tests" +path = "tests/tests.rs" +required-features = ["testing"] diff --git a/tests/tests.rs b/tests/tests.rs index d19d129..db6bf31 100644 --- a/tests/tests.rs +++ b/tests/tests.rs @@ -1,382 +1,382 @@ -#[cfg(feature = "last_was_hit")] -mod tests { - use std::collections::HashMap; - use std::hash::Hash; - use std::path::{Path, PathBuf}; - - use comemo::{evict, memoize, track, Track, Tracked, TrackedMut, Validate}; - use serial_test::serial; - - macro_rules! test { - (miss: $call:expr, $result:expr) => {{ - assert_eq!($call, $result); - assert!(!comemo::internal::last_was_hit()); - }}; - (hit: $call:expr, $result:expr) => {{ - assert_eq!($call, $result); - assert!(comemo::internal::last_was_hit()); - }}; +//! Run with `cargo test --all-features`. + +use std::collections::HashMap; +use std::hash::Hash; +use std::path::{Path, PathBuf}; + +use comemo::{evict, memoize, track, Track, Tracked, TrackedMut, Validate}; +use serial_test::serial; + +macro_rules! test { + (miss: $call:expr, $result:expr) => {{ + assert_eq!($call, $result); + assert!(!comemo::internal::last_was_hit()); + }}; + (hit: $call:expr, $result:expr) => {{ + assert_eq!($call, $result); + assert!(comemo::internal::last_was_hit()); + }}; +} + +/// Test basic memoization. +#[test] +#[serial] +fn test_basic() { + #[memoize] + fn empty() -> String { + format!("The world is {}", "big") } - /// Test basic memoization. - #[test] - #[serial] - fn test_basic() { - #[memoize] - fn empty() -> String { - format!("The world is {}", "big") - } + #[memoize] + fn double(x: u32) -> u32 { + 2 * x + } - #[memoize] - fn double(x: u32) -> u32 { - 2 * x - } + #[memoize] + fn sum(a: u32, b: u32) -> u32 { + a + b + } - #[memoize] - fn sum(a: u32, b: u32) -> u32 { - a + b + #[memoize] + fn fib(n: u32) -> u32 { + if n <= 2 { + 1 + } else { + fib(n - 1) + fib(n - 2) } + } - #[memoize] - fn fib(n: u32) -> u32 { - if n <= 2 { - 1 - } else { - fib(n - 1) + fib(n - 2) - } - } + #[memoize] + fn sum_iter(n: u32) -> u32 { + (0..n).sum() + } - #[memoize] - fn sum_iter(n: u32) -> u32 { - (0..n).sum() - } + test!(miss: empty(), "The world is big"); + test!(hit: empty(), "The world is big"); + test!(hit: empty(), "The world is big"); - test!(miss: empty(), "The world is big"); - test!(hit: empty(), "The world is big"); - test!(hit: empty(), "The world is big"); + test!(miss: double(2), 4); + test!(miss: double(4), 8); + test!(hit: double(2), 4); - test!(miss: double(2), 4); - test!(miss: double(4), 8); - test!(hit: double(2), 4); + test!(miss: sum(2, 4), 6); + test!(miss: sum(2, 3), 5); + test!(hit: sum(2, 3), 5); + test!(miss: sum(4, 2), 6); - test!(miss: sum(2, 4), 6); - test!(miss: sum(2, 3), 5); - test!(hit: sum(2, 3), 5); - test!(miss: sum(4, 2), 6); + test!(miss: fib(5), 5); + test!(hit: fib(3), 2); + test!(miss: fib(8), 21); + test!(hit: fib(7), 13); - test!(miss: fib(5), 5); - test!(hit: fib(3), 2); - test!(miss: fib(8), 21); - test!(hit: fib(7), 13); + test!(miss: sum_iter(1000), 499500); + test!(hit: sum_iter(1000), 499500); +} - test!(miss: sum_iter(1000), 499500); - test!(hit: sum_iter(1000), 499500); +/// Test the calc language. +#[test] +#[serial] +fn test_calc() { + #[memoize] + fn evaluate(script: &str, files: Tracked) -> i32 { + script + .split('+') + .map(str::trim) + .map(|part| match part.strip_prefix("eval ") { + Some(path) => evaluate(&files.read(path), files), + None => part.parse::().unwrap(), + }) + .sum() } - /// Test the calc language. - #[test] - #[serial] - fn test_calc() { - #[memoize] - fn evaluate(script: &str, files: Tracked) -> i32 { - script - .split('+') - .map(str::trim) - .map(|part| match part.strip_prefix("eval ") { - Some(path) => evaluate(&files.read(path), files), - None => part.parse::().unwrap(), - }) - .sum() - } + let mut files = Files(HashMap::new()); + files.write("alpha.calc", "2 + eval beta.calc"); + files.write("beta.calc", "2 + 3"); + files.write("gamma.calc", "8 + 3"); + test!(miss: evaluate("eval alpha.calc", files.track()), 7); + test!(miss: evaluate("eval beta.calc", files.track()), 5); + files.write("gamma.calc", "42"); + test!(hit: evaluate("eval alpha.calc", files.track()), 7); + files.write("beta.calc", "4 + eval gamma.calc"); + test!(miss: evaluate("eval beta.calc", files.track()), 46); + test!(miss: evaluate("eval alpha.calc", files.track()), 48); + files.write("gamma.calc", "80"); + test!(miss: evaluate("eval alpha.calc", files.track()), 86); +} - let mut files = Files(HashMap::new()); - files.write("alpha.calc", "2 + eval beta.calc"); - files.write("beta.calc", "2 + 3"); - files.write("gamma.calc", "8 + 3"); - test!(miss: evaluate("eval alpha.calc", files.track()), 7); - test!(miss: evaluate("eval beta.calc", files.track()), 5); - files.write("gamma.calc", "42"); - test!(hit: evaluate("eval alpha.calc", files.track()), 7); - files.write("beta.calc", "4 + eval gamma.calc"); - test!(miss: evaluate("eval beta.calc", files.track()), 46); - test!(miss: evaluate("eval alpha.calc", files.track()), 48); - files.write("gamma.calc", "80"); - test!(miss: evaluate("eval alpha.calc", files.track()), 86); +struct Files(HashMap); + +#[track] +impl Files { + fn read(&self, path: &str) -> String { + self.0.get(Path::new(path)).cloned().unwrap_or_default() } +} - struct Files(HashMap); +impl Files { + fn write(&mut self, path: &str, text: &str) { + self.0.insert(path.into(), text.into()); + } +} - #[track] - impl Files { - fn read(&self, path: &str) -> String { - self.0.get(Path::new(path)).cloned().unwrap_or_default() - } +/// Test cache eviction. +#[test] +#[serial] +fn test_evict() { + #[memoize] + fn null() -> u8 { + 0 } - impl Files { - fn write(&mut self, path: &str, text: &str) { - self.0.insert(path.into(), text.into()); - } + test!(miss: null(), 0); + test!(hit: null(), 0); + evict(2); + test!(hit: null(), 0); + evict(2); + evict(2); + test!(hit: null(), 0); + evict(2); + evict(2); + evict(2); + test!(miss: null(), 0); + test!(hit: null(), 0); + evict(0); + test!(miss: null(), 0); + test!(hit: null(), 0); +} + +/// Test tracking a trait object. +#[test] +#[serial] +fn test_tracked_trait() { + #[memoize] + fn traity(loader: Tracked, path: &Path) -> Vec { + loader.load(path).unwrap() } - /// Test cache eviction. - #[test] - #[serial] - fn test_evict() { - #[memoize] - fn null() -> u8 { - 0 - } + fn wrapper(loader: &(dyn Loader), path: &Path) -> Vec { + traity(loader.track(), path) + } + + let loader: &(dyn Loader) = &StaticLoader; + test!(miss: traity(loader.track(), Path::new("hi.rs")), [1, 2, 3]); + test!(hit: traity(loader.track(), Path::new("hi.rs")), [1, 2, 3]); + test!(miss: traity(loader.track(), Path::new("bye.rs")), [1, 2, 3]); + wrapper(loader, Path::new("hi.rs")); +} - test!(miss: null(), 0); - test!(hit: null(), 0); - evict(2); - test!(hit: null(), 0); - evict(2); - evict(2); - test!(hit: null(), 0); - evict(2); - evict(2); - evict(2); - test!(miss: null(), 0); - test!(hit: null(), 0); - evict(0); - test!(miss: null(), 0); - test!(hit: null(), 0); +#[track] +trait Loader: Send + Sync { + fn load(&self, path: &Path) -> Result, String>; +} + +struct StaticLoader; +impl Loader for StaticLoader { + fn load(&self, _: &Path) -> Result, String> { + Ok(vec![1, 2, 3]) } +} - /// Test tracking a trait object. - #[test] - #[serial] - fn test_tracked_trait() { +/// Test memoized methods. +#[test] +#[serial] +fn test_memoized_methods() { + #[derive(Hash)] + struct Taker(String); + + /// Has memoized methods. + impl Taker { #[memoize] - fn traity(loader: Tracked, path: &Path) -> Vec { - loader.load(path).unwrap() + fn copy(&self) -> String { + self.0.clone() } - fn wrapper(loader: &(dyn Loader), path: &Path) -> Vec { - traity(loader.track(), path) + #[memoize] + fn take(self) -> String { + self.0 } - - let loader: &(dyn Loader) = &StaticLoader; - test!(miss: traity(loader.track(), Path::new("hi.rs")), [1, 2, 3]); - test!(hit: traity(loader.track(), Path::new("hi.rs")), [1, 2, 3]); - test!(miss: traity(loader.track(), Path::new("bye.rs")), [1, 2, 3]); - wrapper(loader, Path::new("hi.rs")); } - #[track] - trait Loader: Send + Sync { - fn load(&self, path: &Path) -> Result, String>; - } + test!(miss: Taker("Hello".into()).take(), "Hello"); + test!(miss: Taker("Hello".into()).copy(), "Hello"); + test!(miss: Taker("World".into()).take(), "World"); + test!(hit: Taker("Hello".into()).take(), "Hello"); +} - struct StaticLoader; - impl Loader for StaticLoader { - fn load(&self, _: &Path) -> Result, String> { - Ok(vec![1, 2, 3]) - } +/// Test different kinds of arguments. +#[test] +#[serial] +fn test_kinds() { + #[memoize] + fn selfie(tester: Tracky) -> String { + tester.self_ref().into() } - /// Test memoized methods. - #[test] - #[serial] - fn test_memoized_methods() { - #[derive(Hash)] - struct Taker(String); - - /// Has memoized methods. - impl Taker { - #[memoize] - fn copy(&self) -> String { - self.0.clone() - } - - #[memoize] - fn take(self) -> String { - self.0 - } + #[memoize] + fn unconditional(tester: Tracky) -> &'static str { + if tester.by_value(Heavy("HEAVY".into())) > 10 { + "Long" + } else { + "Short" } - - test!(miss: Taker("Hello".into()).take(), "Hello"); - test!(miss: Taker("Hello".into()).copy(), "Hello"); - test!(miss: Taker("World".into()).take(), "World"); - test!(hit: Taker("Hello".into()).take(), "Hello"); } - /// Test different kinds of arguments. - #[test] - #[serial] - fn test_kinds() { - #[memoize] - fn selfie(tester: Tracky) -> String { - tester.self_ref().into() - } + let mut tester = Tester { data: "Hi".to_string() }; - #[memoize] - fn unconditional(tester: Tracky) -> &'static str { - if tester.by_value(Heavy("HEAVY".into())) > 10 { - "Long" - } else { - "Short" - } - } + let tracky = tester.track(); + test!(miss: selfie(tracky), "Hi"); + test!(miss: unconditional(tracky), "Short"); + test!(hit: unconditional(tracky), "Short"); + test!(hit: selfie(tracky), "Hi"); - let mut tester = Tester { data: "Hi".to_string() }; + tester.data.push('!'); - let tracky = tester.track(); - test!(miss: selfie(tracky), "Hi"); - test!(miss: unconditional(tracky), "Short"); - test!(hit: unconditional(tracky), "Short"); - test!(hit: selfie(tracky), "Hi"); + let tracky = tester.track(); + test!(miss: selfie(tracky), "Hi!"); + test!(miss: unconditional(tracky), "Short"); - tester.data.push('!'); + tester.data.push_str(" Let's go."); - let tracky = tester.track(); - test!(miss: selfie(tracky), "Hi!"); - test!(miss: unconditional(tracky), "Short"); + let tracky = tester.track(); + test!(miss: unconditional(tracky), "Long"); +} - tester.data.push_str(" Let's go."); +/// Test with type alias. +type Tracky<'a> = comemo::Tracked<'a, Tester>; - let tracky = tester.track(); - test!(miss: unconditional(tracky), "Long"); - } +/// A struct with some data. +struct Tester { + data: String, +} - /// Test with type alias. - type Tracky<'a> = comemo::Tracked<'a, Tester>; +/// Tests different kinds of arguments. +#[track] +impl Tester { + /// Return value can borrow from self. + #[allow(clippy::needless_lifetimes)] + fn self_ref<'a>(&'a self) -> &'a str { + &self.data + } - /// A struct with some data. - struct Tester { - data: String, + /// Return value can borrow from argument. + fn arg_ref<'a>(&self, name: &'a str) -> &'a str { + name } - /// Tests different kinds of arguments. - #[track] - impl Tester { - /// Return value can borrow from self. - #[allow(clippy::needless_lifetimes)] - fn self_ref<'a>(&'a self) -> &'a str { + /// Return value can borrow from both. + fn double_ref<'a>(&'a self, name: &'a str) -> &'a str { + if name.len() > self.data.len() { + name + } else { &self.data } + } - /// Return value can borrow from argument. - fn arg_ref<'a>(&self, name: &'a str) -> &'a str { - name - } + /// Normal method with owned argument. + fn by_value(&self, heavy: Heavy) -> usize { + self.data.len() + heavy.0.len() + } +} - /// Return value can borrow from both. - fn double_ref<'a>(&'a self, name: &'a str) -> &'a str { - if name.len() > self.data.len() { - name - } else { - &self.data - } - } +/// Test empty type without methods. +struct Empty; - /// Normal method with owned argument. - fn by_value(&self, heavy: Heavy) -> usize { - self.data.len() + heavy.0.len() - } - } +#[track] +impl Empty {} - /// Test empty type without methods. - struct Empty; +/// Test tracking a type with a lifetime. +#[test] +#[serial] +fn test_lifetime() { + #[comemo::memoize] + fn contains_hello(lifeful: Tracked) -> bool { + lifeful.contains("hello") + } - #[track] - impl Empty {} + let lifeful = Lifeful("hey"); + test!(miss: contains_hello(lifeful.track()), false); + test!(hit: contains_hello(lifeful.track()), false); - /// Test tracking a type with a lifetime. - #[test] - #[serial] - fn test_lifetime() { - #[comemo::memoize] - fn contains_hello(lifeful: Tracked) -> bool { - lifeful.contains("hello") - } + let lifeful = Lifeful("hello"); + test!(miss: contains_hello(lifeful.track()), true); + test!(hit: contains_hello(lifeful.track()), true); +} - let lifeful = Lifeful("hey"); - test!(miss: contains_hello(lifeful.track()), false); - test!(hit: contains_hello(lifeful.track()), false); +/// Test tracked with lifetime. +struct Lifeful<'a>(&'a str); - let lifeful = Lifeful("hello"); - test!(miss: contains_hello(lifeful.track()), true); - test!(hit: contains_hello(lifeful.track()), true); +#[track] +impl<'a> Lifeful<'a> { + fn contains(&self, text: &str) -> bool { + self.0 == text } +} - /// Test tracked with lifetime. - struct Lifeful<'a>(&'a str); - - #[track] - impl<'a> Lifeful<'a> { - fn contains(&self, text: &str) -> bool { - self.0 == text - } +/// Test tracking a type with a chain of tracked values. +#[test] +#[serial] +fn test_chain() { + #[comemo::memoize] + fn process(chain: Tracked, value: u32) -> bool { + chain.contains(value) } - /// Test tracking a type with a chain of tracked values. - #[test] - #[serial] - fn test_chain() { - #[comemo::memoize] - fn process(chain: Tracked, value: u32) -> bool { - chain.contains(value) - } + let chain1 = Chain::new(1); + let chain3 = Chain::new(3); + let chain12 = Chain::insert(chain1.track(), 2); + let chain123 = Chain::insert(chain12.track(), 3); + let chain124 = Chain::insert(chain12.track(), 4); + let chain1245 = Chain::insert(chain124.track(), 5); + + test!(miss: process(chain1.track(), 0), false); + test!(miss: process(chain1.track(), 1), true); + test!(miss: process(chain123.track(), 2), true); + test!(hit: process(chain124.track(), 2), true); + test!(hit: process(chain12.track(), 2), true); + test!(hit: process(chain1245.track(), 2), true); + test!(miss: process(chain1.track(), 2), false); + test!(hit: process(chain3.track(), 2), false); +} - let chain1 = Chain::new(1); - let chain3 = Chain::new(3); - let chain12 = Chain::insert(chain1.track(), 2); - let chain123 = Chain::insert(chain12.track(), 3); - let chain124 = Chain::insert(chain12.track(), 4); - let chain1245 = Chain::insert(chain124.track(), 5); - - test!(miss: process(chain1.track(), 0), false); - test!(miss: process(chain1.track(), 1), true); - test!(miss: process(chain123.track(), 2), true); - test!(hit: process(chain124.track(), 2), true); - test!(hit: process(chain12.track(), 2), true); - test!(hit: process(chain1245.track(), 2), true); - test!(miss: process(chain1.track(), 2), false); - test!(hit: process(chain3.track(), 2), false); +/// Test that `Tracked` is covariant over `T`. +#[test] +#[serial] +#[allow(unused, clippy::needless_lifetimes)] +fn test_variance() { + fn foo<'a>(_: Tracked<'a, Chain<'a>>) {} + fn bar<'a>(chain: Tracked<'a, Chain<'static>>) { + foo(chain); } +} - /// Test that `Tracked` is covariant over `T`. - #[test] - #[serial] - #[allow(unused, clippy::needless_lifetimes)] - fn test_variance() { - fn foo<'a>(_: Tracked<'a, Chain<'a>>) {} - fn bar<'a>(chain: Tracked<'a, Chain<'static>>) { - foo(chain); - } - } +/// Test tracked with lifetime. +struct Chain<'a> { + // Need to override the lifetime here so that a `Tracked` is covariant over + // `Chain`. + outer: Option as Validate>::Constraint>>, + value: u32, +} - /// Test tracked with lifetime. - struct Chain<'a> { - // Need to override the lifetime here so that a `Tracked` is covariant over - // `Chain`. - outer: Option as Validate>::Constraint>>, - value: u32, +impl<'a> Chain<'a> { + /// Create a new chain entry point. + fn new(value: u32) -> Self { + Self { outer: None, value } } - impl<'a> Chain<'a> { - /// Create a new chain entry point. - fn new(value: u32) -> Self { - Self { outer: None, value } - } - - /// Insert a link into the chain. - fn insert(outer: Tracked<'a, Self>, value: u32) -> Self { - Chain { outer: Some(outer), value } - } + /// Insert a link into the chain. + fn insert(outer: Tracked<'a, Self>, value: u32) -> Self { + Chain { outer: Some(outer), value } } +} - #[track] - impl<'a> Chain<'a> { - fn contains(&self, value: u32) -> bool { - self.value == value || self.outer.map_or(false, |outer| outer.contains(value)) - } +#[track] +impl<'a> Chain<'a> { + fn contains(&self, value: u32) -> bool { + self.value == value || self.outer.map_or(false, |outer| outer.contains(value)) } +} - /// Test mutable tracking. +/// Test mutable tracking. #[test] #[serial] #[rustfmt::skip] @@ -388,7 +388,7 @@ mod tests { let c = sink.len_or_ten().to_string(); sink.emit(&c); } - + let mut emitter = Emitter(vec![]); test!(miss: dump(emitter.track_mut()), ()); test!(miss: dump(emitter.track_mut()), ()); @@ -406,50 +406,49 @@ mod tests { ]) } - /// A tracked type with a mutable and an immutable method. - #[derive(Clone)] - struct Emitter(Vec); +/// A tracked type with a mutable and an immutable method. +#[derive(Clone)] +struct Emitter(Vec); - #[track] - impl Emitter { - fn emit(&mut self, msg: &str) { - self.0.push(msg.into()); - } - - fn len_or_ten(&self) -> usize { - self.0.len().min(10) - } +#[track] +impl Emitter { + fn emit(&mut self, msg: &str) { + self.0.push(msg.into()); } - /// A non-copy struct that is passed by value to a tracked method. - #[derive(Clone, PartialEq, Hash)] - struct Heavy(String); - - /// Test a tracked method that is impure. - #[test] - #[serial] - #[cfg(debug_assertions)] - #[should_panic( - expected = "comemo: found conflicting constraints. is this tracked function pure?" - )] - fn test_impure_tracked_method() { - #[comemo::memoize] - fn call(impure: Tracked) -> u32 { - impure.impure(); - impure.impure() - } + fn len_or_ten(&self) -> usize { + self.0.len().min(10) + } +} - call(Impure.track()); +/// A non-copy struct that is passed by value to a tracked method. +#[derive(Clone, PartialEq, Hash)] +struct Heavy(String); + +/// Test a tracked method that is impure. +#[test] +#[serial] +#[cfg(debug_assertions)] +#[should_panic( + expected = "comemo: found conflicting constraints. is this tracked function pure?" +)] +fn test_impure_tracked_method() { + #[comemo::memoize] + fn call(impure: Tracked) -> u32 { + impure.impure(); + impure.impure() } - struct Impure; + call(Impure.track()); +} - #[track] - impl Impure { - fn impure(&self) -> u32 { - use std::sync::atomic::{AtomicU32, Ordering}; - static VAL: AtomicU32 = AtomicU32::new(0); - VAL.fetch_add(1, Ordering::SeqCst) - } +struct Impure; + +#[track] +impl Impure { + fn impure(&self) -> u32 { + use std::sync::atomic::{AtomicU32, Ordering}; + static VAL: AtomicU32 = AtomicU32::new(0); + VAL.fetch_add(1, Ordering::SeqCst) } } From af711c9cda29370a8d50586081ea3b851495e62a Mon Sep 17 00:00:00 2001 From: Laurenz Date: Fri, 15 Dec 2023 00:06:19 +0100 Subject: [PATCH 26/28] Refactor constraints and accelerator --- macros/src/memoize.rs | 8 +- src/accelerate.rs | 63 ++++++ src/cache.rs | 431 ++++-------------------------------------- src/constraint.rs | 301 +++++++++++++++++++++++++++++ src/input.rs | 2 +- src/lib.rs | 11 +- src/track.rs | 15 +- 7 files changed, 416 insertions(+), 415 deletions(-) create mode 100644 src/accelerate.rs create mode 100644 src/constraint.rs diff --git a/macros/src/memoize.rs b/macros/src/memoize.rs index a26768c..d6c0b10 100644 --- a/macros/src/memoize.rs +++ b/macros/src/memoize.rs @@ -129,14 +129,10 @@ fn process(function: &Function) -> Result { <::comemo::internal::Args<#arg_ty_tuple> as ::comemo::internal::Input>::Constraint, #output, > = ::comemo::internal::Cache::new(|| { - ::comemo::internal::register_cache(evict); - ::comemo::internal::RwLock::new(::comemo::internal::CacheData::new()) + ::comemo::internal::register_evictor(|max_age| __CACHE.evict(max_age)); + ::std::default::Default::default() }); - fn evict(max_age: usize) { - __CACHE.write().evict(max_age); - } - #(#bounds;)* ::comemo::internal::memoized( ::comemo::internal::Args(#arg_tuple), diff --git a/src/accelerate.rs b/src/accelerate.rs new file mode 100644 index 0000000..611d3a0 --- /dev/null +++ b/src/accelerate.rs @@ -0,0 +1,63 @@ +use std::collections::HashMap; +use std::sync::atomic::{AtomicUsize, Ordering}; + +use parking_lot::{MappedRwLockReadGuard, Mutex, RwLock, RwLockReadGuard}; + +/// The global list of currently alive accelerators. +static ACCELERATORS: RwLock<(usize, Vec)> = RwLock::new((0, Vec::new())); + +/// The current ID of the accelerator. +static ID: AtomicUsize = AtomicUsize::new(0); + +/// The type of each individual accelerator. +/// +/// Maps from call hashes to return hashes. +type Accelerator = Mutex>; + +/// Generate a new accelerator. +pub fn id() -> usize { + // Get the next ID. + ID.fetch_add(1, Ordering::SeqCst) +} + +/// Evict the accelerators. +pub fn evict() { + let mut accelerators = ACCELERATORS.write(); + let (offset, vec) = &mut *accelerators; + + // Update the offset. + *offset = ID.load(Ordering::SeqCst); + + // Clear all accelerators while keeping the memory allocated. + vec.iter_mut().for_each(|accelerator| accelerator.lock().clear()) +} + +/// Get an accelerator by ID. +pub fn get(id: usize) -> Option> { + // We always lock the accelerators, as we need to make sure that the + // accelerator is not removed while we are reading it. + let mut accelerators = ACCELERATORS.read(); + + let mut i = id.checked_sub(accelerators.0)?; + if i >= accelerators.1.len() { + drop(accelerators); + resize(i + 1); + accelerators = ACCELERATORS.read(); + + // Because we release the lock before resizing the accelerator, we need + // to check again whether the ID is still valid because another thread + // might evicted the cache. + i = id.checked_sub(accelerators.0)?; + } + + Some(RwLockReadGuard::map(accelerators, move |(_, vec)| &vec[i])) +} + +/// Adjusts the amount of accelerators. +#[cold] +fn resize(len: usize) { + let mut pair = ACCELERATORS.write(); + if len > pair.1.len() { + pair.1.resize_with(len, || Mutex::new(HashMap::new())); + } +} diff --git a/src/cache.rs b/src/cache.rs index c8345ac..b8fec40 100644 --- a/src/cache.rs +++ b/src/cache.rs @@ -1,80 +1,18 @@ -use std::borrow::Cow; use std::collections::HashMap; -use std::hash::Hash; -use std::sync::atomic::AtomicUsize; -use std::sync::atomic::Ordering; +use std::sync::atomic::{AtomicUsize, Ordering}; use once_cell::sync::Lazy; -use parking_lot::{ - MappedRwLockReadGuard, Mutex, RwLock, RwLockReadGuard, RwLockWriteGuard, -}; +use parking_lot::RwLock; use siphasher::sip128::{Hasher128, SipHasher13}; +use crate::accelerate; +use crate::constraint::Join; use crate::input::Input; -pub type Accelerator = Mutex>; +/// The global list of eviction functions. +static EVICTORS: RwLock> = RwLock::new(Vec::new()); -/// The global list of caches. -static CACHES: RwLock> = RwLock::new(Vec::new()); - -/// The global list of currently alive accelerators. -static ACCELERATORS: RwLock<(usize, Vec)> = RwLock::new((0, Vec::new())); - -/// The current ID of the accelerator. -static ID: AtomicUsize = AtomicUsize::new(0); - -/// Register a cache in the global list. -pub fn register_cache(fun: fn(usize)) { - CACHES.write().push(fun); -} - -/// Generate a new accelerator. -/// Will allocate a new accelerator if the ID is larger than the current capacity. -pub fn id() -> usize { - // Get the next ID. - ID.fetch_add(1, Ordering::SeqCst) -} - -/// Get an accelerator by ID. -fn accelerator(id: usize) -> Option> { - #[cold] - fn resize_accelerators(len: usize) { - let mut accelerators = ACCELERATORS.write(); - - if len <= accelerators.1.len() { - return; - } - - accelerators.1.resize_with(len, || Mutex::new(HashMap::new())); - } - - // We always lock the accelerators, as we need to make sure that the - // accelerator is not removed while we are reading it. - let mut accelerators = ACCELERATORS.read(); - - let offset = accelerators.0; - if id < offset { - return None; - } - - if id - offset >= accelerators.1.len() { - drop(accelerators); - resize_accelerators(id - offset + 1); - accelerators = ACCELERATORS.read(); - } - - // Because we release the lock before resizing the accelerator, - // we need to check again whether the ID is still valid because - // another thread might evicted the cache. - let i = id - accelerators.0; - if id < offset { - return None; - } - - Some(RwLockReadGuard::map(accelerators, move |accelerators| &accelerators.1[i])) -} - -#[cfg(feature = "last_was_hit")] +#[cfg(feature = "testing")] thread_local! { /// Whether the last call was a hit. static LAST_WAS_HIT: std::cell::Cell = const { std::cell::Cell::new(false) }; @@ -100,7 +38,7 @@ where }; // Check if there is a cached output. - let borrow = cache.read(); + let borrow = cache.0.read(); if let Some((constrained, value)) = borrow.lookup::(key, &input) { // Replay the mutations. input.replay(constrained); @@ -108,7 +46,7 @@ where // Add the cached constraints to the outer ones. input.retrack(constraint).1.join(constrained); - #[cfg(feature = "last_was_hit")] + #[cfg(feature = "testing")] LAST_WAS_HIT.with(|cell| cell.set(true)); return value.clone(); @@ -126,21 +64,15 @@ where outer.join(constraint); // Insert the result into the cache. - let mut borrow = cache.write(); + let mut borrow = cache.0.write(); borrow.insert::(key, constraint.take(), output.clone()); - #[cfg(feature = "last_was_hit")] + #[cfg(feature = "testing")] LAST_WAS_HIT.with(|cell| cell.set(false)); output } -/// Whether the last call was a hit. -#[cfg(feature = "last_was_hit")] -pub fn last_was_hit() -> bool { - LAST_WAS_HIT.with(|cell| cell.get()) -} - /// Evict the cache. /// /// This removes all memoized results from the cache whose age is larger than or @@ -151,22 +83,25 @@ pub fn last_was_hit() -> bool { /// Comemo's cache is thread-local, meaning that this only evicts this thread's /// cache. pub fn evict(max_age: usize) { - for subevict in CACHES.read().iter() { + for subevict in EVICTORS.read().iter() { subevict(max_age); } - // Evict all accelerators. - let mut accelerators = ACCELERATORS.write(); + accelerate::evict(); +} - // Update the offset. - accelerators.0 = ID.load(Ordering::SeqCst); +/// Register an eviction function in the global list. +pub fn register_evictor(evict: fn(usize)) { + EVICTORS.write().push(evict); +} - // Clear all accelerators while keeping the memory allocated. - accelerators.1.iter_mut().for_each(|accelerator| { - accelerator.lock().clear(); - }) +/// Whether the last call was a hit. +#[cfg(feature = "testing")] +pub fn last_was_hit() -> bool { + LAST_WAS_HIT.with(|cell| cell.get()) } +/// A cache for a single memoized function. pub struct Cache(Lazy>>); impl Cache { @@ -179,37 +114,21 @@ impl Cache { Self(Lazy::new(init)) } - /// Write to the inner cache. - pub fn write(&self) -> RwLockWriteGuard<'_, CacheData> { - self.0.write() - } - - /// Read from the inner cache. - fn read(&self) -> RwLockReadGuard<'_, CacheData> { - self.0.read() + /// Evict all entries whose age is larger than or equal to `max_age`. + pub fn evict(&self, max_age: usize) { + self.0.write().evict(max_age) } } -/// The global cache. +/// The internal data for a cache. pub struct CacheData { /// Maps from hashes to memoized results. entries: HashMap>>, } -impl Default for CacheData { - fn default() -> Self { - Self { entries: HashMap::new() } - } -} - impl CacheData { - /// Create an empty cache. - pub fn new() -> Self { - Self::default() - } - /// Evict all entries whose age is larger than or equal to `max_age`. - pub fn evict(&mut self, max_age: usize) { + fn evict(&mut self, max_age: usize) { self.entries.retain(|_, entries| { entries.retain_mut(|entry| { let age = entry.age.get_mut(); @@ -244,6 +163,12 @@ impl CacheData { } } +impl Default for CacheData { + fn default() -> Self { + Self { entries: HashMap::new() } + } +} + /// A memoized result. struct CacheEntry { /// The memoized function's constraint. @@ -274,291 +199,3 @@ impl CacheEntry { }) } } - -/// A call to a tracked function. -pub trait Call { - /// Whether the call is mutable. - fn is_mutable(&self) -> bool; -} - -/// A call entry. -#[derive(Clone)] -struct ConstraintEntry { - args: T, - args_hash: u128, - ret: u128, -} - -/// Defines a constraint for a tracked type. -pub struct ImmutableConstraint(RwLock>>); - -impl ImmutableConstraint { - /// Create empty constraints. - pub fn new() -> Self { - Self::default() - } - - /// Enter a constraint for a call to an immutable function. - #[inline] - pub fn push(&self, args: T, ret: u128) { - let args_hash = hash(&args); - self.push_inner(Cow::Owned(ConstraintEntry { args, args_hash, ret })); - } - - /// Enter a constraint for a call to an immutable function. - #[inline] - fn push_inner(&self, call: Cow>) { - let mut calls = self.0.write(); - debug_assert!(!call.args.is_mutable()); - - if let Some(_prev) = calls.get(&call.args_hash) { - #[cfg(debug_assertions)] - check(_prev, &call); - - return; - } - - calls.insert(call.args_hash, call.into_owned()); - } - - /// Whether the method satisfies as all input-output pairs. - #[inline] - pub fn validate(&self, mut f: F) -> bool - where - F: FnMut(&T) -> u128, - { - self.0.read().values().all(|entry| f(&entry.args) == entry.ret) - } - - /// Whether the method satisfies as all input-output pairs. - #[inline] - pub fn validate_with_id(&self, mut f: F, id: usize) -> bool - where - F: FnMut(&T) -> u128, - { - let accelerator = accelerator(id); - let inner = self.0.read(); - if let Some(accelerator) = accelerator { - let mut map = accelerator.lock(); - inner.values().all(|entry| { - *map.entry(entry.args_hash).or_insert_with(|| f(&entry.args)) == entry.ret - }) - } else { - inner.values().all(|entry| f(&entry.args) == entry.ret) - } - } - - /// Replay all input-output pairs. - #[inline] - pub fn replay(&self, _: F) - where - F: FnMut(&T), - { - #[cfg(debug_assertions)] - for entry in self.0.read().values() { - assert!(!entry.args.is_mutable()); - } - } -} - -impl Clone for ImmutableConstraint { - fn clone(&self) -> Self { - Self(RwLock::new(self.0.read().clone())) - } -} - -impl Default for ImmutableConstraint { - fn default() -> Self { - Self(RwLock::new(HashMap::default())) - } -} - -/// Defines a constraint for a tracked type. -pub struct MutableConstraint(RwLock>); - -impl MutableConstraint { - /// Create empty constraints. - pub fn new() -> Self { - Self::default() - } - - /// Enter a constraint for a call to an immutable function. - #[inline] - pub fn push(&self, args: T, ret: u128) { - let args_hash = hash(&args); - self.0 - .write() - .push_inner(Cow::Owned(ConstraintEntry { args, args_hash, ret })); - } - - /// Whether the method satisfies as all input-output pairs. - #[inline] - pub fn validate(&self, mut f: F) -> bool - where - F: FnMut(&T) -> u128, - { - self.0.read().calls.iter().all(|entry| f(&entry.args) == entry.ret) - } - - /// Whether the method satisfies as all input-output pairs. - /// - /// On mutable tracked types, this does not use an accelerator as it is - /// rarely, if ever used. Therefore, it is not worth the overhead. - #[inline] - pub fn validate_with_id(&self, mut f: F, _: usize) -> bool - where - F: FnMut(&T) -> u128, - { - let inner = self.0.read(); - inner.calls.iter().all(|entry| f(&entry.args) == entry.ret) - } - - /// Replay all input-output pairs. - #[inline] - pub fn replay(&self, mut f: F) - where - F: FnMut(&T), - { - for call in self.0.read().calls.iter().filter(|call| call.args.is_mutable()) { - f(&call.args); - } - } -} - -impl Clone for MutableConstraint { - fn clone(&self) -> Self { - Self(RwLock::new(self.0.read().clone())) - } -} - -impl Default for MutableConstraint { - fn default() -> Self { - Self(RwLock::new(Inner { calls: Vec::new() })) - } -} - -#[derive(Clone)] -struct Inner { - /// The list of calls. - /// - /// Order matters here, as those are mutable & immutable calls. - calls: Vec>, -} - -impl Inner { - /// Enter a constraint for a call to a function. - /// - /// If the function is immutable, it uses a fast-path based on a - /// `HashMap` to perform deduplication. Otherwise, it always - /// pushes the call to the list. - #[inline] - fn push_inner(&mut self, call: Cow>) { - // If the call is immutable check whether we already have a call - // with the same arguments and return value. - let mutable = call.args.is_mutable(); - if !mutable { - for entry in self.calls.iter().rev() { - if call.args.is_mutable() { - break; - } - - if call.args_hash == entry.args_hash && call.ret == entry.ret { - #[cfg(debug_assertions)] - check(&call, entry); - - return; - } - } - } - - // Insert the call into the call list. - self.calls.push(call.into_owned()); - } -} - -impl Default for Inner { - fn default() -> Self { - Self { calls: Vec::new() } - } -} - -/// Extend an outer constraint by an inner one. -pub trait Join { - /// Join this constraint with the `inner` one. - fn join(&self, inner: &T); - - /// Take out the constraint. - fn take(&self) -> Self; -} - -impl Join for Option<&T> { - #[inline] - fn join(&self, inner: &T) { - if let Some(outer) = self { - outer.join(inner); - } - } - - #[inline] - fn take(&self) -> Self { - unimplemented!("cannot call `Join::take` on optional constraint") - } -} - -impl Join for MutableConstraint { - #[inline] - fn join(&self, inner: &Self) { - let mut this = self.0.write(); - for call in inner.0.read().calls.iter() { - this.push_inner(Cow::Borrowed(call)); - } - } - - #[inline] - fn take(&self) -> Self { - Self(RwLock::new(std::mem::take(&mut *self.0.write()))) - } -} - -impl Join for ImmutableConstraint { - #[inline] - fn join(&self, inner: &Self) { - for call in inner.0.read().values() { - self.push_inner(Cow::Borrowed(call)); - } - } - - #[inline] - fn take(&self) -> Self { - Self(RwLock::new(std::mem::take(&mut *self.0.write()))) - } -} - -/// Produce a 128-bit hash of a value. -#[inline] -pub fn hash(value: &T) -> u128 { - let mut state = SipHasher13::new(); - value.hash(&mut state); - state.finish128().as_u128() -} - -/// Check for a constraint violation. -#[inline] -#[track_caller] -#[allow(dead_code)] -fn check(lhs: &ConstraintEntry, rhs: &ConstraintEntry) { - if lhs.ret != rhs.ret { - panic!( - "comemo: found conflicting constraints. \ - is this tracked function pure?" - ) - } - - // Additional checks for debugging. - if lhs.args_hash != rhs.args_hash || lhs.args != rhs.args { - panic!( - "comemo: found conflicting `check` arguments. \ - this is a bug in comemo" - ) - } -} diff --git a/src/constraint.rs b/src/constraint.rs new file mode 100644 index 0000000..bec98aa --- /dev/null +++ b/src/constraint.rs @@ -0,0 +1,301 @@ +use std::borrow::Cow; +use std::collections::hash_map::Entry; +use std::collections::HashMap; +use std::hash::Hash; + +use parking_lot::RwLock; +use siphasher::sip128::{Hasher128, SipHasher13}; + +use crate::accelerate; + +/// A call to a tracked function. +pub trait Call: Hash + PartialEq + Clone { + /// Whether the call is mutable. + fn is_mutable(&self) -> bool; +} + +/// A constraint entry for a single call. +#[derive(Clone)] +struct ConstraintEntry { + call: T, + call_hash: u128, + ret_hash: u128, +} + +/// Defines a constraint for an immutably tracked type. +pub struct ImmutableConstraint(RwLock>); + +impl ImmutableConstraint { + /// Create an empty constraint. + pub fn new() -> Self { + Self::default() + } + + /// Enter a constraint for a call to an immutable function. + #[inline] + pub fn push(&self, call: T, ret_hash: u128) { + let call_hash = hash(&call); + let entry = ConstraintEntry { call, call_hash, ret_hash }; + self.0.write().push_inner(Cow::Owned(entry)); + } + + /// Whether the method satisfies as all input-output pairs. + #[inline] + pub fn validate(&self, mut f: F) -> bool + where + F: FnMut(&T) -> u128, + { + self.0.read().0.values().all(|entry| f(&entry.call) == entry.ret_hash) + } + + /// Whether the method satisfies as all input-output pairs. + #[inline] + pub fn validate_with_id(&self, mut f: F, id: usize) -> bool + where + F: FnMut(&T) -> u128, + { + let guard = self.0.read(); + if let Some(accelerator) = accelerate::get(id) { + let mut map = accelerator.lock(); + guard.0.values().all(|entry| { + *map.entry(entry.call_hash).or_insert_with(|| f(&entry.call)) + == entry.ret_hash + }) + } else { + guard.0.values().all(|entry| f(&entry.call) == entry.ret_hash) + } + } + + /// Replay all input-output pairs. + #[inline] + pub fn replay(&self, _: F) + where + F: FnMut(&T), + { + #[cfg(debug_assertions)] + for entry in self.0.read().0.values() { + assert!(!entry.call.is_mutable()); + } + } +} + +impl Clone for ImmutableConstraint { + fn clone(&self) -> Self { + Self(RwLock::new(self.0.read().clone())) + } +} + +impl Default for ImmutableConstraint { + fn default() -> Self { + Self(RwLock::new(EntryMap::default())) + } +} + +/// Defines a constraint for a mutably tracked type. +pub struct MutableConstraint(RwLock>); + +impl MutableConstraint { + /// Create an empty constraint. + pub fn new() -> Self { + Self::default() + } + + /// Enter a constraint for a call to a mutable function. + #[inline] + pub fn push(&self, call: T, ret_hash: u128) { + let call_hash = hash(&call); + let entry = ConstraintEntry { call, call_hash, ret_hash }; + self.0.write().push_inner(Cow::Owned(entry)); + } + + /// Whether the method satisfies as all input-output pairs. + #[inline] + pub fn validate(&self, mut f: F) -> bool + where + F: FnMut(&T) -> u128, + { + self.0.read().0.iter().all(|entry| f(&entry.call) == entry.ret_hash) + } + + /// Whether the method satisfies as all input-output pairs. + /// + /// On mutable tracked types, this does not use an accelerator as it is + /// rarely, if ever used. Therefore, it is not worth the overhead. + #[inline] + pub fn validate_with_id(&self, f: F, _: usize) -> bool + where + F: FnMut(&T) -> u128, + { + self.validate(f) + } + + /// Replay all input-output pairs. + #[inline] + pub fn replay(&self, mut f: F) + where + F: FnMut(&T), + { + for call in &self.0.read().0 { + if call.call.is_mutable() { + f(&call.call); + } + } + } +} + +impl Clone for MutableConstraint { + fn clone(&self) -> Self { + Self(RwLock::new(self.0.read().clone())) + } +} + +impl Default for MutableConstraint { + fn default() -> Self { + Self(RwLock::new(EntryVec::default())) + } +} + +/// A map of calls. +#[derive(Clone)] +struct EntryMap(HashMap>); + +impl EntryMap { + /// Enter a constraint for a call to a function. + #[inline] + fn push_inner(&mut self, entry: Cow>) { + match self.0.entry(entry.call_hash) { + Entry::Occupied(occupied) => { + #[cfg(debug_assertions)] + check(occupied.get(), &entry); + } + Entry::Vacant(vacant) => { + vacant.insert(entry.into_owned()); + } + } + } +} + +impl Default for EntryMap { + fn default() -> Self { + Self(HashMap::new()) + } +} + +/// A list of calls. +/// +/// Order matters here, as those are mutable & immutable calls. +#[derive(Clone)] +struct EntryVec(Vec>); + +impl EntryVec { + /// Enter a constraint for a call to a function. + #[inline] + fn push_inner(&mut self, entry: Cow>) { + // If the call is immutable check whether we already have a call + // with the same arguments and return value. + if !entry.call.is_mutable() { + for prev in self.0.iter().rev() { + if entry.call.is_mutable() { + break; + } + + if entry.call_hash == prev.call_hash && entry.ret_hash == prev.ret_hash { + #[cfg(debug_assertions)] + check(&entry, prev); + return; + } + } + } + + // Insert the call into the call list. + self.0.push(entry.into_owned()); + } +} + +impl Default for EntryVec { + fn default() -> Self { + Self(Vec::new()) + } +} + +/// Extend an outer constraint by an inner one. +pub trait Join { + /// Join this constraint with the `inner` one. + fn join(&self, inner: &T); + + /// Take out the constraint. + fn take(&self) -> Self; +} + +impl Join for Option<&T> { + #[inline] + fn join(&self, inner: &T) { + if let Some(outer) = self { + outer.join(inner); + } + } + + #[inline] + fn take(&self) -> Self { + unimplemented!("cannot call `Join::take` on optional constraint") + } +} + +impl Join for ImmutableConstraint { + #[inline] + fn join(&self, inner: &Self) { + let mut this = self.0.write(); + for call in inner.0.read().0.values() { + this.push_inner(Cow::Borrowed(call)); + } + } + + #[inline] + fn take(&self) -> Self { + Self(RwLock::new(std::mem::take(&mut *self.0.write()))) + } +} + +impl Join for MutableConstraint { + #[inline] + fn join(&self, inner: &Self) { + let mut this = self.0.write(); + for call in inner.0.read().0.iter() { + this.push_inner(Cow::Borrowed(call)); + } + } + + #[inline] + fn take(&self) -> Self { + Self(RwLock::new(std::mem::take(&mut *self.0.write()))) + } +} + +/// Produce a 128-bit hash of a value. +#[inline] +pub fn hash(value: &T) -> u128 { + let mut state = SipHasher13::new(); + value.hash(&mut state); + state.finish128().as_u128() +} + +/// Check for a constraint violation. +#[inline] +#[track_caller] +#[allow(dead_code)] +fn check(lhs: &ConstraintEntry, rhs: &ConstraintEntry) { + if lhs.ret_hash != rhs.ret_hash { + panic!( + "comemo: found conflicting constraints. \ + is this tracked function pure?" + ) + } + + // Additional checks for debugging. + if lhs.call_hash != rhs.call_hash || lhs.call != rhs.call { + panic!( + "comemo: found conflicting `check` arguments. \ + this is a bug in comemo" + ) + } +} diff --git a/src/input.rs b/src/input.rs index d4185ea..a535ef2 100644 --- a/src/input.rs +++ b/src/input.rs @@ -1,6 +1,6 @@ use std::hash::{Hash, Hasher}; -use crate::cache::Join; +use crate::constraint::Join; use crate::track::{Track, Tracked, TrackedMut, Validate}; /// Ensure a type is suitable as input. diff --git a/src/lib.rs b/src/lib.rs index 37e8237..5921b7a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -82,7 +82,9 @@ For the full example see [`examples/calc.rs`][calc]. [calc]: https://github.com/typst/comemo/blob/main/examples/calc.rs */ +mod accelerate; mod cache; +mod constraint; mod input; mod prehashed; mod track; @@ -97,14 +99,11 @@ pub use comemo_macros::{memoize, track}; pub mod internal { pub use parking_lot::RwLock; - pub use crate::cache::{ - hash, memoized, register_cache, Accelerator, Cache, CacheData, Call, - ImmutableConstraint, MutableConstraint, - }; - + pub use crate::cache::{memoized, register_evictor, Cache, CacheData}; + pub use crate::constraint::{hash, Call, ImmutableConstraint, MutableConstraint}; pub use crate::input::{assert_hashable_or_trackable, Args, Input}; pub use crate::track::{to_parts_mut_mut, to_parts_mut_ref, to_parts_ref, Surfaces}; - #[cfg(feature = "last_was_hit")] + #[cfg(feature = "testing")] pub use crate::cache::last_was_hit; } diff --git a/src/track.rs b/src/track.rs index e707057..8c7e340 100644 --- a/src/track.rs +++ b/src/track.rs @@ -1,7 +1,8 @@ use std::fmt::{self, Debug, Formatter}; use std::ops::{Deref, DerefMut}; -use crate::cache::{id, Join}; +use crate::accelerate; +use crate::constraint::Join; /// A trackable type. /// @@ -12,7 +13,11 @@ pub trait Track: Validate + Surfaces { /// Start tracking all accesses to a value. #[inline] fn track(&self) -> Tracked { - Tracked { value: self, constraint: None, id: id() } + Tracked { + value: self, + constraint: None, + id: accelerate::id(), + } } /// Start tracking all accesses and mutations to a value. @@ -27,7 +32,7 @@ pub trait Track: Validate + Surfaces { Tracked { value: self, constraint: Some(constraint), - id: id(), + id: accelerate::id(), } } @@ -227,7 +232,7 @@ where Tracked { value: this.value, constraint: this.constraint, - id: id(), + id: accelerate::id(), } } @@ -240,7 +245,7 @@ where Tracked { value: this.value, constraint: this.constraint, - id: id(), + id: accelerate::id(), } } From f1ed1e520516aff440bfb8870f46d2179956472b Mon Sep 17 00:00:00 2001 From: Laurenz Date: Fri, 15 Dec 2023 00:11:23 +0100 Subject: [PATCH 27/28] Minor style changes --- macros/src/memoize.rs | 2 +- macros/src/track.rs | 2 +- src/constraint.rs | 14 +++++++------- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/macros/src/memoize.rs b/macros/src/memoize.rs index d6c0b10..09afbce 100644 --- a/macros/src/memoize.rs +++ b/macros/src/memoize.rs @@ -130,7 +130,7 @@ fn process(function: &Function) -> Result { #output, > = ::comemo::internal::Cache::new(|| { ::comemo::internal::register_evictor(|max_age| __CACHE.evict(max_age)); - ::std::default::Default::default() + ::core::default::Default::default() }); #(#bounds;)* diff --git a/macros/src/track.rs b/macros/src/track.rs index e0db944..bfc90be 100644 --- a/macros/src/track.rs +++ b/macros/src/track.rs @@ -213,7 +213,6 @@ fn create_variants(methods: &[Method]) -> TokenStream { enum __ComemoVariant { #(#variants,)* } - } } @@ -291,6 +290,7 @@ fn create( } else { quote! { MutableConstraint } }; + Ok(quote! { impl #impl_params ::comemo::Track for #ty #where_clause {} diff --git a/src/constraint.rs b/src/constraint.rs index bec98aa..6af0732 100644 --- a/src/constraint.rs +++ b/src/constraint.rs @@ -135,9 +135,9 @@ impl MutableConstraint { where F: FnMut(&T), { - for call in &self.0.read().0 { - if call.call.is_mutable() { - f(&call.call); + for entry in &self.0.read().0 { + if entry.call.is_mutable() { + f(&entry.call); } } } @@ -245,8 +245,8 @@ impl Join for ImmutableConstraint { #[inline] fn join(&self, inner: &Self) { let mut this = self.0.write(); - for call in inner.0.read().0.values() { - this.push_inner(Cow::Borrowed(call)); + for entry in inner.0.read().0.values() { + this.push_inner(Cow::Borrowed(entry)); } } @@ -260,8 +260,8 @@ impl Join for MutableConstraint { #[inline] fn join(&self, inner: &Self) { let mut this = self.0.write(); - for call in inner.0.read().0.iter() { - this.push_inner(Cow::Borrowed(call)); + for entry in inner.0.read().0.iter() { + this.push_inner(Cow::Borrowed(entry)); } } From 185ca2279cdfabfcb8183e2d2acd79af474ca337 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20d=27Herbais=20de=20Thun=20=28Dherse=29?= Date: Fri, 15 Dec 2023 00:53:33 +0100 Subject: [PATCH 28/28] Fix warning in `--release` --- src/constraint.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/constraint.rs b/src/constraint.rs index 6af0732..8e34878 100644 --- a/src/constraint.rs +++ b/src/constraint.rs @@ -164,9 +164,9 @@ impl EntryMap { #[inline] fn push_inner(&mut self, entry: Cow>) { match self.0.entry(entry.call_hash) { - Entry::Occupied(occupied) => { + Entry::Occupied(_occupied) => { #[cfg(debug_assertions)] - check(occupied.get(), &entry); + check(_occupied.get(), &entry); } Entry::Vacant(vacant) => { vacant.insert(entry.into_owned());