From bcda660e17d8e1ae63408bdc60e97c8abbbaddb4 Mon Sep 17 00:00:00 2001 From: Laurenz Date: Wed, 27 Sep 2023 14:17:26 +0200 Subject: [PATCH] Make thread safe --- Cargo.toml | 3 + macros/src/track.rs | 2 +- src/cache.rs | 331 ++++++++------------------------------------ src/constraint.rs | 161 +++++++++++++++++++++ src/input.rs | 4 +- src/lib.rs | 29 +++- src/track.rs | 10 +- tests/tests.rs | 6 +- 8 files changed, 260 insertions(+), 286 deletions(-) create mode 100644 src/constraint.rs diff --git a/Cargo.toml b/Cargo.toml index 2bad481..d52ed6a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,3 +13,6 @@ keywords = ["incremental", "memoization", "tracked", "constraints"] [dependencies] comemo-macros = { version = "0.3.1", path = "macros" } siphasher = "1" +dashmap = "5" +once_cell = "1" +parking_lot = "0.12" diff --git a/macros/src/track.rs b/macros/src/track.rs index 60e80ec..75d5082 100644 --- a/macros/src/track.rs +++ b/macros/src/track.rs @@ -36,7 +36,7 @@ pub fn expand(item: &syn::Item) -> Result { } let name = &item.ident; - let ty = parse_quote! { dyn #name + '__comemo_dynamic }; + let ty = parse_quote! { dyn #name + Send + Sync + '__comemo_dynamic }; (ty, &item.generics, Some(name.clone())) } _ => bail!(item, "`track` can only be applied to impl blocks and traits"), diff --git a/src/cache.rs b/src/cache.rs index df7e702..ba472c2 100644 --- a/src/cache.rs +++ b/src/cache.rs @@ -1,26 +1,11 @@ use std::any::{Any, TypeId}; -use std::cell::{Cell, RefCell}; -use std::collections::HashMap; -use std::hash::Hash; +use std::sync::atomic::{AtomicUsize, Ordering}; use siphasher::sip128::{Hasher128, SipHasher13}; +use crate::constraint::Join; use crate::input::Input; -thread_local! { - /// The global, dynamic cache shared by all memoized functions. - static CACHE: RefCell = RefCell::new(Cache::default()); - - /// The global ID counter for tracked values. Each tracked value gets a - /// unqiue ID based on which its validations are cached in the accelerator. - /// IDs may only be reused upon eviction of the accelerator. - static ID: Cell = const { Cell::new(0) }; - - /// The global, dynamic accelerator shared by all cached values. - static ACCELERATOR: RefCell> - = RefCell::new(HashMap::default()); -} - /// Execute a function or use a cached result for it. pub fn memoized<'c, In, Out, F>( id: TypeId, @@ -30,64 +15,56 @@ pub fn memoized<'c, In, Out, F>( ) -> Out where In: Input + 'c, - Out: Clone + 'static, + Out: Clone + Send + Sync + 'static, F: FnOnce(In::Tracked<'c>) -> Out, { - CACHE.with(|cache| { - // Compute the hash of the input's key part. - let key = { - let mut state = SipHasher13::new(); - input.key(&mut state); - let hash = state.finish128().as_u128(); - (id, hash) - }; - - // Check if there is a cached output. - let mut borrow = cache.borrow_mut(); - if let Some(constrained) = borrow.lookup::(key, &input) { - // Replay the mutations. - input.replay(&constrained.constraint); - - // Add the cached constraints to the outer ones. - input.retrack(constraint).1.join(&constrained.constraint); + // Compute the hash of the input's key part. + let key = { + let mut state = SipHasher13::new(); + input.key(&mut state); + let hash = state.finish128().as_u128(); + (id, hash) + }; + + // Check if there is a cached output. + if let Some(constrained) = crate::CACHE.get(&key).and_then(|entries| { + entries + .try_map(|value| { + value.iter().rev().find_map(|entry| entry.lookup::(&input)) + }) + .ok() + }) { + // Replay the mutations. + input.replay(&constrained.constraint); - let value = constrained.output.clone(); - borrow.last_was_hit = true; - return value; - } + // Add the cached constraints to the outer ones. + input.retrack(constraint).1.join(&constrained.constraint); - // Release the borrow so that nested memoized calls can access the - // cache without panicking. - drop(borrow); + let value = constrained.output.clone(); + crate::LAST_WAS_HIT.with(|hit| hit.set(true)); + return value; + } - // Execute the function with the new constraints hooked in. - let (input, outer) = input.retrack(constraint); - let output = func(input); + // Execute the function with the new constraints hooked in. + let (input, outer) = input.retrack(constraint); + let output = func(input); - // Add the new constraints to the outer ones. - outer.join(constraint); + // Add the new constraints to the outer ones. + outer.join(constraint); - // Insert the result into the cache. - borrow = cache.borrow_mut(); - borrow.insert::(key, constraint.take(), output.clone()); - borrow.last_was_hit = false; + // Insert the result into the cache. + crate::LAST_WAS_HIT.with(|cell| cell.set(false)); + crate::CACHE + .entry(key) + .or_default() + .push(CacheEntry::new::(constraint.take(), output.clone())); - output - }) + output } /// Whether the last call was a hit. pub fn last_was_hit() -> bool { - CACHE.with(|cache| cache.borrow().last_was_hit) -} - -/// Get the next ID. -pub fn id() -> usize { - ID.with(|cell| { - let current = cell.get(); - cell.set(current.wrapping_add(1)); - current - }) + crate::LAST_WAS_HIT.with(|cell| cell.get()) } /// Evict the cache. @@ -100,71 +77,25 @@ pub fn id() -> usize { /// Comemo's cache is thread-local, meaning that this only evicts this thread's /// cache. pub fn evict(max_age: usize) { - CACHE.with(|cache| { - let mut cache = cache.borrow_mut(); - cache.map.retain(|_, entries| { - entries.retain_mut(|entry| { - entry.age += 1; - entry.age <= max_age - }); - !entries.is_empty() + crate::CACHE.retain(|_, entries| { + entries.retain_mut(|entry| { + let age = entry.age.fetch_add(1, Ordering::Relaxed); + age < max_age }); + !entries.is_empty() }); - ACCELERATOR.with(|accelerator| accelerator.borrow_mut().clear()); -} - -/// The global cache. -#[derive(Default)] -struct Cache { - /// Maps from function IDs + hashes to memoized results. - map: HashMap<(TypeId, u128), Vec>, - /// Whether the last call was a hit. - last_was_hit: bool, -} - -impl Cache { - /// Look for a matching entry in the cache. - fn lookup( - &mut self, - key: (TypeId, u128), - input: &In, - ) -> Option<&Constrained> - where - In: Input, - Out: Clone + 'static, - { - self.map - .get_mut(&key)? - .iter_mut() - .rev() - .find_map(|entry| entry.lookup::(input)) - } - - /// Insert an entry into the cache. - fn insert( - &mut self, - key: (TypeId, u128), - constraint: In::Constraint, - output: Out, - ) where - In: Input, - Out: 'static, - { - self.map - .entry(key) - .or_default() - .push(CacheEntry::new::(constraint, output)); - } + crate::ACCELERATOR.clear(); + crate::ID.store(0, Ordering::SeqCst); } /// A memoized result. -struct CacheEntry { +pub struct CacheEntry { /// The memoized function's constrained output. /// /// This is of type `Constrained`. - constrained: Box, + constrained: Box, /// How many evictions have passed since the entry has been last used. - age: usize, + age: AtomicUsize, } /// A value with a constraint. @@ -180,16 +111,16 @@ impl CacheEntry { fn new(constraint: In::Constraint, output: Out) -> Self where In: Input, - Out: 'static, + Out: Send + Sync + 'static, { Self { constrained: Box::new(Constrained { constraint, output }), - age: 0, + age: AtomicUsize::new(0), } } /// Return the entry's output if it is valid for the given input. - fn lookup(&mut self, input: &In) -> Option<&Constrained> + fn lookup(&self, input: &In) -> Option<&Constrained> where In: Input, Out: Clone + 'static, @@ -198,162 +129,8 @@ impl CacheEntry { self.constrained.downcast_ref().expect("wrong entry type"); input.validate(&constrained.constraint).then(|| { - self.age = 0; + self.age.store(0, Ordering::Relaxed); constrained }) } } - -/// Defines a constraint for a tracked type. -#[derive(Clone)] -pub struct Constraint(RefCell>>); - -/// A call entry. -#[derive(Clone)] -struct Call { - args: T, - ret: u128, - both: u128, - mutable: bool, -} - -impl Constraint { - /// Create empty constraints. - pub fn new() -> Self { - Self::default() - } - - /// Enter a constraint for a call to an immutable function. - #[inline] - pub fn push(&self, args: T, ret: u128, mutable: bool) { - let both = hash(&(&args, ret)); - self.push_inner(Call { args, ret, both, mutable }); - } - - /// Enter a constraint for a call to an immutable function. - #[inline] - fn push_inner(&self, call: Call) { - let mut calls = self.0.borrow_mut(); - - if !call.mutable { - for prev in calls.iter().rev() { - if prev.mutable { - break; - } - - #[cfg(debug_assertions)] - if prev.args == call.args { - check(prev.ret, call.ret); - } - - if prev.both == call.both { - return; - } - } - } - - calls.push(call); - } - - /// Whether the method satisfies as all input-output pairs. - #[inline] - pub fn validate(&self, mut f: F) -> bool - where - F: FnMut(&T) -> u128, - { - self.0.borrow().iter().all(|entry| f(&entry.args) == entry.ret) - } - - /// Whether the method satisfies as all input-output pairs. - #[inline] - pub fn validate_with_id(&self, mut f: F, id: usize) -> bool - where - F: FnMut(&T) -> u128, - { - let calls = self.0.borrow(); - ACCELERATOR.with(|accelerator| { - let mut map = accelerator.borrow_mut(); - calls.iter().all(|entry| { - *map.entry((id, entry.both)).or_insert_with(|| f(&entry.args)) - == entry.ret - }) - }) - } - - /// Replay all input-output pairs. - #[inline] - pub fn replay(&self, mut f: F) - where - F: FnMut(&T), - { - for entry in self.0.borrow().iter() { - if entry.mutable { - f(&entry.args); - } - } - } -} - -impl Default for Constraint { - fn default() -> Self { - Self(RefCell::new(vec![])) - } -} - -/// Extend an outer constraint by an inner one. -pub trait Join { - /// Join this constraint with the `inner` one. - fn join(&self, inner: &T); - - /// Take out the constraint. - fn take(&self) -> Self; -} - -impl Join for Option<&T> { - #[inline] - fn join(&self, inner: &T) { - if let Some(outer) = self { - outer.join(inner); - } - } - - #[inline] - fn take(&self) -> Self { - unimplemented!("cannot call `Join::take` on optional constraint") - } -} - -impl Join for Constraint { - #[inline] - fn join(&self, inner: &Self) { - for call in inner.0.borrow().iter() { - self.push_inner(call.clone()); - } - } - - #[inline] - fn take(&self) -> Self { - Self(RefCell::new(std::mem::take(&mut *self.0.borrow_mut()))) - } -} - -/// Produce a 128-bit hash of a value. -#[inline] -pub fn hash(value: &T) -> u128 { - let mut state = SipHasher13::new(); - value.hash(&mut state); - state.finish128().as_u128() -} - -/// Check for a constraint violation. -#[inline] -#[track_caller] -#[allow(dead_code)] -fn check(left_hash: u128, right_hash: u128) { - if left_hash != right_hash { - panic!( - "comemo: found conflicting constraints. \ - is this tracked function pure?" - ) - } -} diff --git a/src/constraint.rs b/src/constraint.rs new file mode 100644 index 0000000..6854e53 --- /dev/null +++ b/src/constraint.rs @@ -0,0 +1,161 @@ +use std::hash::Hash; + +use parking_lot::{RwLock, RwLockUpgradableReadGuard}; +use siphasher::sip128::{Hasher128, SipHasher13}; + +/// Defines a constraint for a tracked type. +pub struct Constraint(RwLock>>); + +/// A call entry. +#[derive(Clone)] +struct Call { + args: T, + ret: u128, + both: u128, + mutable: bool, +} + +impl Constraint { + /// Create empty constraints. + pub fn new() -> Self { + Self::default() + } + + /// Enter a constraint for a call to an immutable function. + #[inline] + pub fn push(&self, args: T, ret: u128, mutable: bool) { + let both = hash(&(&args, ret)); + self.push_inner(Call { args, ret, both, mutable }); + } + + /// Enter a constraint for a call to an immutable function. + #[inline] + fn push_inner(&self, call: Call) { + let calls = self.0.upgradable_read(); + + if !call.mutable { + for prev in calls.iter().rev() { + if prev.mutable { + break; + } + + #[cfg(debug_assertions)] + if prev.args == call.args { + check(prev.ret, call.ret); + } + + if prev.both == call.both { + return; + } + } + } + + RwLockUpgradableReadGuard::upgrade(calls).push(call); + } + + /// Whether the method satisfies as all input-output pairs. + #[inline] + pub fn validate(&self, mut f: F) -> bool + where + F: FnMut(&T) -> u128, + { + self.0.read().iter().all(|entry| f(&entry.args) == entry.ret) + } + + /// Whether the method satisfies as all input-output pairs. + #[inline] + pub fn validate_with_id(&self, mut f: F, id: usize) -> bool + where + F: FnMut(&T) -> u128, + { + self.0.read().iter().all(|entry| { + *crate::ACCELERATOR + .entry((id, entry.both)) + .or_insert_with(|| f(&entry.args)) + == entry.ret + }) + } + + /// Replay all input-output pairs. + #[inline] + pub fn replay(&self, mut f: F) + where + F: FnMut(&T), + { + for entry in self.0.read().iter() { + if entry.mutable { + f(&entry.args); + } + } + } +} + +impl Default for Constraint { + fn default() -> Self { + Self(RwLock::new(vec![])) + } +} + +impl Clone for Constraint { + fn clone(&self) -> Self { + Self(RwLock::new(self.0.read().clone())) + } +} + +/// Extend an outer constraint by an inner one. +pub trait Join { + /// Join this constraint with the `inner` one. + fn join(&self, inner: &T); + + /// Take out the constraint. + fn take(&self) -> Self; +} + +impl Join for Option<&T> { + #[inline] + fn join(&self, inner: &T) { + if let Some(outer) = self { + outer.join(inner); + } + } + + #[inline] + fn take(&self) -> Self { + unimplemented!("cannot call `Join::take` on optional constraint") + } +} + +impl Join for Constraint { + #[inline] + fn join(&self, inner: &Self) { + for call in inner.0.read().iter() { + self.push_inner(call.clone()); + } + } + + #[inline] + fn take(&self) -> Self { + Self(RwLock::new(std::mem::take(&mut *self.0.write()))) + } +} + +/// Produce a 128-bit hash of a value. +#[inline] +pub fn hash(value: &T) -> u128 { + let mut state = SipHasher13::new(); + value.hash(&mut state); + state.finish128().as_u128() +} + +/// Check for a constraint violation. +#[inline] +#[track_caller] +#[allow(dead_code)] +fn check(left_hash: u128, right_hash: u128) { + if left_hash != right_hash { + panic!( + "comemo: found conflicting constraints. \ + is this tracked function pure?" + ) + } +} diff --git a/src/input.rs b/src/input.rs index d4185ea..7d8f91a 100644 --- a/src/input.rs +++ b/src/input.rs @@ -1,6 +1,6 @@ use std::hash::{Hash, Hasher}; -use crate::cache::Join; +use crate::constraint::Join; use crate::track::{Track, Tracked, TrackedMut, Validate}; /// Ensure a type is suitable as input. @@ -13,7 +13,7 @@ pub fn assert_hashable_or_trackable(_: &In) {} /// types containing tuples up to length twelve. pub trait Input { /// The constraints for this input. - type Constraint: Default + Clone + Join + 'static; + type Constraint: Default + Clone + Join + Send + Sync + 'static; /// The input with new constraints hooked in. type Tracked<'r> diff --git a/src/lib.rs b/src/lib.rs index 9c5b77b..bd76535 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -82,7 +82,24 @@ For the full example see [`examples/calc.rs`][calc]. [calc]: https://github.com/typst/comemo/blob/main/examples/calc.rs */ +/// The global, dynamic cache shared by all memoized functions. +static CACHE: Lazy>> = Lazy::new(DashMap::new); + +/// The global, dynamic accelerator shared by all cached values. +static ACCELERATOR: Lazy> = Lazy::new(DashMap::new); + +/// The global ID counter for tracked values. Each tracked value gets a +/// unqiue ID based on which its validations are cached in the accelerator. +/// IDs may only be reused upon eviction of the accelerator. +static ID: AtomicUsize = AtomicUsize::new(0); + +thread_local! { + /// Whether the last call on this thread was a hit. + static LAST_WAS_HIT: Cell = Cell::new(false); +} + mod cache; +mod constraint; mod input; mod prehashed; mod track; @@ -92,10 +109,20 @@ pub use crate::prehashed::Prehashed; pub use crate::track::{Track, Tracked, TrackedMut, Validate}; pub use comemo_macros::{memoize, track}; +use std::any::TypeId; +use std::cell::Cell; +use std::sync::atomic::AtomicUsize; + +use dashmap::DashMap; +use once_cell::sync::Lazy; + +use self::cache::CacheEntry; + /// These are implementation details. Do not rely on them! #[doc(hidden)] pub mod internal { - pub use crate::cache::{hash, last_was_hit, memoized, Constraint}; + pub use crate::cache::{last_was_hit, memoized}; + pub use crate::constraint::{hash, Constraint}; pub use crate::input::{assert_hashable_or_trackable, Args}; pub use crate::track::{to_parts_mut_mut, to_parts_mut_ref, to_parts_ref, Surfaces}; } diff --git a/src/track.rs b/src/track.rs index e707057..a54a34c 100644 --- a/src/track.rs +++ b/src/track.rs @@ -1,7 +1,8 @@ use std::fmt::{self, Debug, Formatter}; use std::ops::{Deref, DerefMut}; +use std::sync::atomic::Ordering; -use crate::cache::{id, Join}; +use crate::constraint::Join; /// A trackable type. /// @@ -50,7 +51,7 @@ pub trait Track: Validate + Surfaces { /// This trait is implemented by the `#[track]` macro alongside [`Track`]. pub trait Validate { /// The constraints for this type. - type Constraint: Default + Clone + Join + 'static; + type Constraint: Default + Clone + Join + Send + Sync + 'static; /// Whether this value fulfills the given constraints. /// @@ -316,3 +317,8 @@ where { (tracked.value, tracked.constraint) } + +/// Get the next ID. +fn id() -> usize { + crate::ID.fetch_add(1, Ordering::Relaxed) +} diff --git a/tests/tests.rs b/tests/tests.rs index 9674272..b0ff84d 100644 --- a/tests/tests.rs +++ b/tests/tests.rs @@ -143,15 +143,15 @@ fn test_evict() { #[test] fn test_tracked_trait() { #[memoize] - fn traity(loader: Tracked, path: &Path) -> Vec { + fn traity(loader: Tracked, path: &Path) -> Vec { loader.load(path).unwrap() } - fn wrapper(loader: &dyn Loader, path: &Path) -> Vec { + fn wrapper(loader: &(dyn Loader + Send + Sync), path: &Path) -> Vec { traity(loader.track(), path) } - let loader: &dyn Loader = &StaticLoader; + let loader: &(dyn Loader + Send + Sync) = &StaticLoader; test!(miss: traity(loader.track(), Path::new("hi.rs")), [1, 2, 3]); test!(hit: traity(loader.track(), Path::new("hi.rs")), [1, 2, 3]); test!(miss: traity(loader.track(), Path::new("bye.rs")), [1, 2, 3]);