Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parallel comemo & optimizations #5

Merged
merged 28 commits into from
Dec 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,5 @@ jobs:
steps:
- uses: actions/checkout@v3
- uses: dtolnay/rust-toolchain@stable
- run: cargo build
- run: cargo test
- run: cargo build --all-features
- run: cargo test --all-features
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
.vscode
.DS_Store
/target
macros/target
Cargo.lock
14 changes: 14 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,20 @@ license = "MIT OR Apache-2.0"
categories = ["caching"]
keywords = ["incremental", "memoization", "tracked", "constraints"]

[features]
default = []
testing = []

[dependencies]
comemo-macros = { version = "0.3.1", path = "macros" }
once_cell = "1.18"
parking_lot = "0.12"
siphasher = "1"

[dev-dependencies]
serial_test = "2.0.0"

[[test]]
name = "tests"
path = "tests/tests.rs"
required-features = ["testing"]
30 changes: 21 additions & 9 deletions macros/src/memoize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ pub fn expand(item: &syn::Item) -> Result<proc_macro2::TokenStream> {
};

// Preprocess and validate the function.
let function = prepare(&item)?;
let function = prepare(item)?;

// Rewrite the function's body to memoize it.
process(&function)
Expand All @@ -23,7 +23,7 @@ struct Function {
/// An argument to a memoized function.
enum Argument {
Receiver(syn::Token![self]),
Ident(Option<syn::Token![mut]>, syn::Ident),
Ident(Box<syn::Type>, Option<syn::Token![mut]>, syn::Ident),
}

/// Preprocess and validate a function.
Expand Down Expand Up @@ -71,7 +71,7 @@ fn prepare_arg(input: &syn::FnArg) -> Result<Argument> {
bail!(typed.ty, "memoized functions cannot have mutable parameters")
}

Argument::Ident(mutability.clone(), ident.clone())
Argument::Ident(typed.ty.clone(), *mutability, ident.clone())
}
})
}
Expand All @@ -82,7 +82,7 @@ fn process(function: &Function) -> Result<TokenStream> {
let bounds = function.args.iter().map(|arg| {
let val = match arg {
Argument::Receiver(token) => quote! { #token },
Argument::Ident(_, ident) => quote! { #ident },
Argument::Ident(_, _, ident) => quote! { #ident },
};
quote_spanned! { function.item.span() =>
::comemo::internal::assert_hashable_or_trackable(&#val);
Expand All @@ -94,14 +94,20 @@ fn process(function: &Function) -> Result<TokenStream> {
Argument::Receiver(token) => quote! {
::comemo::internal::hash(&#token)
},
Argument::Ident(_, ident) => quote! { #ident },
Argument::Ident(_, _, ident) => quote! { #ident },
});
let arg_tuple = quote! { (#(#args,)*) };

let arg_tys = function.args.iter().map(|arg| match arg {
Argument::Receiver(_) => quote! { () },
Argument::Ident(ty, _, _) => quote! { #ty },
});
let arg_ty_tuple = quote! { (#(#arg_tys,)*) };

// Construct a tuple for all parameters.
let params = function.args.iter().map(|arg| match arg {
Argument::Receiver(_) => quote! { _ },
Argument::Ident(mutability, ident) => quote! { #mutability #ident },
Argument::Ident(_, mutability, ident) => quote! { #mutability #ident },
});
let param_tuple = quote! { (#(#params,)*) };

Expand All @@ -118,14 +124,20 @@ fn process(function: &Function) -> Result<TokenStream> {
ident.mutability = None;
}

let unique = quote! { __ComemoUnique };
wrapped.block = parse_quote! { {
struct #unique;
static __CACHE: ::comemo::internal::Cache<
<::comemo::internal::Args<#arg_ty_tuple> as ::comemo::internal::Input>::Constraint,
#output,
> = ::comemo::internal::Cache::new(|| {
::comemo::internal::register_evictor(|max_age| __CACHE.evict(max_age));
::core::default::Default::default()
});

#(#bounds;)*
::comemo::internal::memoized(
::core::any::TypeId::of::<#unique>(),
::comemo::internal::Args(#arg_tuple),
&::core::default::Default::default(),
&__CACHE,
#closure,
)
} };
Expand Down
83 changes: 59 additions & 24 deletions macros/src/track.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,34 +20,38 @@ pub fn expand(item: &syn::Item) -> Result<TokenStream> {
}

for item in &item.items {
methods.push(prepare_impl_method(&item)?);
methods.push(prepare_impl_method(item)?);
}

let ty = item.self_ty.as_ref().clone();
(ty, &item.generics, None)
}
syn::Item::Trait(item) => {
for param in item.generics.params.iter() {
bail!(param, "tracked traits cannot be generic")
if let Some(first) = item.generics.params.first() {
bail!(first, "tracked traits cannot be generic")
}

for item in &item.items {
methods.push(prepare_trait_method(&item)?);
methods.push(prepare_trait_method(item)?);
}

let name = &item.ident;
let ty = parse_quote! { dyn #name + '__comemo_dynamic };
(ty, &item.generics, Some(name.clone()))
(ty, &item.generics, Some(item.ident.clone()))
}
_ => bail!(item, "`track` can only be applied to impl blocks and traits"),
};

// Produce the necessary items for the type to become trackable.
let variants = create_variants(&methods);
let scope = create(&ty, generics, trait_, &methods)?;

Ok(quote! {
#item
const _: () = { #scope };
const _: () = {
#variants
#scope
};
})
}

Expand Down Expand Up @@ -175,6 +179,43 @@ fn prepare_method(vis: syn::Visibility, sig: &syn::Signature) -> Result<Method>
})
}

/// Produces the variants for the constraint.
fn create_variants(methods: &[Method]) -> TokenStream {
let variants = methods.iter().map(create_variant);
let is_mutable_variants = methods.iter().map(|m| {
let name = &m.sig.ident;
let mutable = m.mutable;
quote! { __ComemoVariant::#name(..) => #mutable }
});

let is_mutable = (!methods.is_empty())
.then(|| {
quote! {
match &self.0 {
#(#is_mutable_variants),*
}
}
})
.unwrap_or_else(|| quote! { false });

quote! {
#[derive(Clone, PartialEq, Hash)]
pub struct __ComemoCall(__ComemoVariant);

impl ::comemo::internal::Call for __ComemoCall {
fn is_mutable(&self) -> bool {
#is_mutable
}
}

#[derive(Clone, PartialEq, Hash)]
#[allow(non_camel_case_types)]
enum __ComemoVariant {
#(#variants,)*
}
}
}

/// Produce the necessary items for a type to become trackable.
fn create(
ty: &syn::Type,
Expand Down Expand Up @@ -229,26 +270,32 @@ fn create(
};

// Prepare replying.
let immutable = methods.iter().all(|m| !m.mutable);
let replays = methods.iter().map(create_replay);
let replay = methods.iter().any(|m| m.mutable).then(|| {
let replay = (!immutable).then(|| {
quote! {
constraint.replay(|call| match &call.0 { #(#replays,)* });
}
});

// Prepare variants and wrapper methods.
let variants = methods.iter().map(create_variant);
let wrapper_methods = methods
.iter()
.filter(|m| !m.mutable)
.map(|m| create_wrapper(m, false));
let wrapper_methods_mut = methods.iter().map(|m| create_wrapper(m, true));

let constraint = if immutable {
quote! { ImmutableConstraint }
} else {
quote! { MutableConstraint }
};

Ok(quote! {
impl #impl_params ::comemo::Track for #ty #where_clause {}
impl #impl_params ::comemo::Track for #ty #where_clause {}

impl #impl_params ::comemo::Validate for #ty #where_clause {
type Constraint = ::comemo::internal::Constraint<__ComemoCall>;
impl #impl_params ::comemo::Validate for #ty #where_clause {
type Constraint = ::comemo::internal::#constraint<__ComemoCall>;

#[inline]
fn validate(&self, constraint: &Self::Constraint) -> bool {
Expand All @@ -267,15 +314,6 @@ fn create(
}
}

#[derive(Clone, PartialEq, Hash)]
pub struct __ComemoCall(__ComemoVariant);

#[derive(Clone, PartialEq, Hash)]
#[allow(non_camel_case_types)]
enum __ComemoVariant {
#(#variants,)*
}

#[doc(hidden)]
impl #impl_params ::comemo::internal::Surfaces for #ty #where_clause {
type Surface<#t> = __ComemoSurface #type_params_t where Self: #t;
Expand Down Expand Up @@ -323,7 +361,6 @@ fn create(
impl #impl_params_t #prefix __ComemoSurfaceMut #type_params_t {
#(#wrapper_methods_mut)*
}

})
}

Expand Down Expand Up @@ -370,10 +407,9 @@ fn create_wrapper(method: &Method, tracked_mut: bool) -> TokenStream {
let vis = &method.vis;
let sig = &method.sig;
let args = &method.args;
let mutable = method.mutable;
let to_parts = if !tracked_mut {
quote! { to_parts_ref(self.0) }
} else if !mutable {
} else if !method.mutable {
quote! { to_parts_mut_ref(&self.0) }
} else {
quote! { to_parts_mut_mut(&mut self.0) }
Expand All @@ -389,7 +425,6 @@ fn create_wrapper(method: &Method, tracked_mut: bool) -> TokenStream {
constraint.push(
__ComemoCall(__comemo_variant),
::comemo::internal::hash(&output),
#mutable,
);
}
output
Expand Down
63 changes: 63 additions & 0 deletions src/accelerate.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
use std::collections::HashMap;
use std::sync::atomic::{AtomicUsize, Ordering};

use parking_lot::{MappedRwLockReadGuard, Mutex, RwLock, RwLockReadGuard};

/// The global list of currently alive accelerators.
static ACCELERATORS: RwLock<(usize, Vec<Accelerator>)> = RwLock::new((0, Vec::new()));

/// The current ID of the accelerator.
static ID: AtomicUsize = AtomicUsize::new(0);

/// The type of each individual accelerator.
///
/// Maps from call hashes to return hashes.
type Accelerator = Mutex<HashMap<u128, u128>>;

/// Generate a new accelerator.
pub fn id() -> usize {
// Get the next ID.
ID.fetch_add(1, Ordering::SeqCst)
}

/// Evict the accelerators.
pub fn evict() {
let mut accelerators = ACCELERATORS.write();
let (offset, vec) = &mut *accelerators;

// Update the offset.
*offset = ID.load(Ordering::SeqCst);

// Clear all accelerators while keeping the memory allocated.
vec.iter_mut().for_each(|accelerator| accelerator.lock().clear())
}

/// Get an accelerator by ID.
pub fn get(id: usize) -> Option<MappedRwLockReadGuard<'static, Accelerator>> {
// We always lock the accelerators, as we need to make sure that the
// accelerator is not removed while we are reading it.
let mut accelerators = ACCELERATORS.read();

let mut i = id.checked_sub(accelerators.0)?;
if i >= accelerators.1.len() {
drop(accelerators);
resize(i + 1);
accelerators = ACCELERATORS.read();

// Because we release the lock before resizing the accelerator, we need
// to check again whether the ID is still valid because another thread
// might evicted the cache.
i = id.checked_sub(accelerators.0)?;
}

Some(RwLockReadGuard::map(accelerators, move |(_, vec)| &vec[i]))
}

/// Adjusts the amount of accelerators.
#[cold]
fn resize(len: usize) {
let mut pair = ACCELERATORS.write();
if len > pair.1.len() {
pair.1.resize_with(len, || Mutex::new(HashMap::new()));
}
}
Loading