Skip to content

Commit

Permalink
Stored u64s directly in order_heap_data
Browse files Browse the repository at this point in the history
  • Loading branch information
dewert99 committed Mar 19, 2024
1 parent 9972854 commit 288548b
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 123 deletions.
69 changes: 32 additions & 37 deletions src/batsat/src/core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ use {
self, lbool, CRef, ClauseAllocator, ClauseRef, DeletePred, LSet, Lit, OccLists,
OccListsData, VMap, Var,
},
crate::heap::{Comparator, Heap, HeapData, MemoComparator},
crate::heap::{CachedKeyComparator, Heap, HeapData},
crate::interface::SolverInterface,
crate::theory::Theory,
std::{cmp, fmt, mem},
Expand Down Expand Up @@ -67,7 +67,7 @@ struct VarState {
/// A heuristic measurement of the activity of a variable.
activity: VMap<f32>,
/// A priority queue of variables ordered with respect to the variable activity.
order_heap_data: HeapData<Var, f32>,
order_heap_data: HeapData<Var, VarOrderKey>,
/// Current assignment for each variable.
ass: VMap<lbool>,
/// Stores reason and level for each variable.
Expand Down Expand Up @@ -1124,7 +1124,7 @@ impl SolverV {
self.vars.value_lit(x)
}

fn order_heap(&mut self) -> Heap<Var, f32, VarOrder> {
fn order_heap(&mut self) -> Heap<Var, VarOrder> {
self.vars.order_heap()
}

Expand Down Expand Up @@ -1178,7 +1178,7 @@ impl SolverV {
&mut self.opts.random_seed,
self.vars.order_heap_data.len() as i32,
) as usize;
next = self.vars.order_heap_data[idx_tmp];
next = self.vars.order_heap_data[idx_tmp].var();
if self.value(next) == lbool::UNDEF && self.decision[next] {
self.rnd_decisions += 1;
}
Expand Down Expand Up @@ -2165,8 +2165,8 @@ impl VarState {
for (_, x) in self.activity.iter_mut() {
*x *= scale;
}
for (_, x) in self.order_heap_data.heap_mut().iter_mut() {
*x *= scale
for x in self.order_heap_data.heap_mut().iter_mut() {
x.scale_activity(scale)
}
self.var_inc *= scale;
}
Expand All @@ -2189,7 +2189,7 @@ impl VarState {
self.trail.push(p);
}

fn order_heap(&mut self) -> Heap<Var, f32, VarOrder> {
fn order_heap(&mut self) -> Heap<Var, VarOrder> {
self.order_heap_data.promote(VarOrder {
activity: &self.activity,
})
Expand Down Expand Up @@ -2410,47 +2410,42 @@ impl PartialEq for Watcher {
}
impl Eq for Watcher {}

impl<'a> VarOrder<'a> {
fn check_activity(&self, var: Var) -> f32 {
if var == Var::UNDEF {
0.0
} else {
self.activity[var]
}
}
}

const COMP_MASK: u64 = (u32::MAX as u64) << u32::BITS;
impl<'a> Comparator<(Var, f32)> for VarOrder<'a> {
type Comp = u64;
#[derive(Copy, Clone, Ord, PartialOrd, Eq, PartialEq)]
struct VarOrderKey(u64);

impl VarOrderKey {
#[inline]
fn max_value(&self) -> (Var, f32) {
(Var::UNDEF, 0.0)
fn new(var: Var, activity: f32) -> Self {
VarOrderKey((!(activity.to_bits() as u64) << u32::BITS) | (var.idx() as u64))
}

#[inline]
fn to_cmp_form(&self, v: &(Var, f32)) -> u64 {
debug_assert_eq!(self.check_activity(v.0), v.1);
let x = ((v.1.to_bits() as u64) << u32::BITS) | (v.0.idx() as u64);
x ^ COMP_MASK
fn var(self) -> Var {
Var::unsafe_from_idx(self.0 as u32)
}

#[inline]
fn from_cmp_form(&self, c: Self::Comp) -> (Var, f32) {
let c = c ^ COMP_MASK;
let v = Var::unsafe_from_idx((c & (u32::MAX as u64)) as u32);
let a = f32::from_bits((c >> u32::BITS) as u32);
(v, a)
fn activity(self) -> f32 {
f32::from_bits(!((self.0 >> u32::BITS) as u32))
}
}

impl<'a> MemoComparator<Var, f32> for VarOrder<'a> {
fn value(&self, k: Var) -> f32 {
self.activity[k]
fn scale_activity(&mut self, scale: f32) {
*self = VarOrderKey::new(self.var(), self.activity() * scale)
}
}
impl<'a> CachedKeyComparator<Var> for VarOrder<'a> {
type Key = VarOrderKey;

fn cache_key(&self, t: Var) -> Self::Key {
VarOrderKey::new(t, self.activity[t])
}

fn max_key(&self) -> Self::Key {
VarOrderKey::new(Var::UNDEF, 0.0)
}

fn un_cache_key(&self, k: Self::Key) -> Var {
k.var()
}
}
impl<'a> DeletePred<Watcher> for WatcherDeleted<'a> {
#[inline]
fn deleted(&self, w: &Watcher) -> bool {
Expand Down
124 changes: 38 additions & 86 deletions src/batsat/src/heap.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
use crate::intmap::{AsIndex, IntMap};
use std::fmt::Debug;
use std::{cmp, mem, ops};
use std::{mem, ops};

/// Quaternary Heap
#[derive(Debug, Clone)]
pub struct HeapData<K: AsIndex, V> {
heap: Box<[(K, V)]>,
heap: Box<[V]>,
next_slot: usize,
indices: IntMap<K, i32>,
}
Expand Down Expand Up @@ -34,12 +34,12 @@ impl<K: AsIndex, V> HeapData<K, V> {
self.indices.has(k) && self.indices[k] >= 0
}

pub fn promote<Comp: Comparator<(K, V)>>(&mut self, comp: Comp) -> Heap<K, V, Comp> {
pub fn promote<Comp: CachedKeyComparator<K, Key = V>>(&mut self, comp: Comp) -> Heap<K, Comp> {
Heap { data: self, comp }
}

/// Raw mutable access to all the elements of the heap
pub(crate) fn heap_mut(&mut self) -> &mut [(K, V)] {
pub(crate) fn heap_mut(&mut self) -> &mut [V] {
if self.next_slot == 0 {
&mut []
} else {
Expand All @@ -49,87 +49,41 @@ impl<K: AsIndex, V> HeapData<K, V> {
}

impl<K: AsIndex, V> ops::Index<usize> for HeapData<K, V> {
type Output = K;
type Output = V;
fn index(&self, index: usize) -> &Self::Output {
&self.heap[index].0
&self.heap[index]
}
}

pub trait Comparator<T: ?Sized> {
type Comp: Ord;
pub trait CachedKeyComparator<T> {
type Key: Ord + Copy;

fn max_value(&self) -> T;
fn to_cmp_form(&self, v: &T) -> Self::Comp;
fn from_cmp_form(&self, c: Self::Comp) -> T;
fn cache_key(&self, t: T) -> Self::Key;

fn cmp(&self, lhs: &T, rhs: &T) -> cmp::Ordering {
self.to_cmp_form(&lhs).cmp(&self.to_cmp_form(&rhs))
}
fn max(&self, lhs: T, rhs: T) -> T
where
T: Sized,
{
if self.ge(&rhs, &lhs) {
rhs
} else {
lhs
}
}
fn min(&self, lhs: T, rhs: T) -> T
where
T: Sized,
{
if self.le(&lhs, &rhs) {
lhs
} else {
rhs
}
}
fn le(&self, lhs: &T, rhs: &T) -> bool {
match self.cmp(lhs, rhs) {
cmp::Ordering::Less | cmp::Ordering::Equal => true,
_ => false,
}
}
fn lt(&self, lhs: &T, rhs: &T) -> bool {
match self.cmp(lhs, rhs) {
cmp::Ordering::Less => true,
_ => false,
}
}
#[inline]
fn gt(&self, lhs: &T, rhs: &T) -> bool {
self.lt(rhs, lhs)
}
#[inline]
fn ge(&self, lhs: &T, rhs: &T) -> bool {
self.le(rhs, lhs)
}
}
fn max_key(&self) -> Self::Key;

pub trait MemoComparator<K, V>: Comparator<(K, V)> {
fn value(&self, k: K) -> V;
fn un_cache_key(&self, k: Self::Key) -> T;
}

#[derive(Debug)]
pub struct Heap<'a, K: AsIndex + 'a, V: 'a, Comp> {
data: &'a mut HeapData<K, V>,
pub struct Heap<'a, K: AsIndex + 'a, Comp: CachedKeyComparator<K>> {
data: &'a mut HeapData<K, Comp::Key>,
comp: Comp,
}

impl<'a, K: AsIndex + 'a, V: 'a, Comp> ops::Deref for Heap<'a, K, V, Comp> {
type Target = HeapData<K, V>;
impl<'a, K: AsIndex + 'a, Comp: CachedKeyComparator<K>> ops::Deref for Heap<'a, K, Comp> {
type Target = HeapData<K, Comp::Key>;
fn deref(&self) -> &Self::Target {
&self.data
}
}
impl<'a, K: AsIndex + 'a, V: 'a, Comp> ops::DerefMut for Heap<'a, K, V, Comp> {
impl<'a, K: AsIndex + 'a, Comp: CachedKeyComparator<K>> ops::DerefMut for Heap<'a, K, Comp> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.data
}
}

impl<'a, K: AsIndex + 'a, V: Copy + 'a, Comp: MemoComparator<K, V>> Heap<'a, K, V, Comp> {
impl<'a, K: AsIndex + 'a, Comp: CachedKeyComparator<K>> Heap<'a, K, Comp> {
// ensure size is always a multiple of 4
#[cold]
#[inline(never)]
Expand All @@ -138,56 +92,55 @@ impl<'a, K: AsIndex + 'a, V: Copy + 'a, Comp: MemoComparator<K, V>> Heap<'a, K,
if self.next_slot == 0 {
self.next_slot = ROOT as usize;
// Enough space for the root and 4 children
self.heap = vec![self.comp.max_value(); 8].into_boxed_slice();
self.heap = vec![self.comp.max_key(); 8].into_boxed_slice();
} else {
let new_size = self.next_slot << 2;
let mut heap = mem::replace(&mut self.heap, Box::new([])).into_vec();
heap.resize(new_size, self.comp.max_value());
heap.resize(new_size, self.comp.max_key());
self.heap = heap.into_boxed_slice();
}
}

#[inline]
fn heap_push(&mut self, k: K, v: V) -> u32 {
fn heap_push(&mut self, k: Comp::Key) -> u32 {
if self.next_slot >= self.heap.len() {
self.heap_reserve();
assert!(self.next_slot < self.heap.len());
}
let slot = self.next_slot;
self.heap[slot] = (k, v);
self.heap[slot] = k;
self.next_slot += 1;
slot as u32
}

fn percolate_up(&mut self, mut i: u32) {
let x = self.heap[i as usize];
let xc = self.comp.to_cmp_form(&x);
let mut p = parent_index(i);

while i != ROOT && xc < self.comp.to_cmp_form(&self.heap[p as usize]) {
while i != ROOT && x < self.heap[p as usize] {
self.heap[i as usize] = self.heap[p as usize];
let tmp = self.heap[p as usize];
self.indices[tmp.0] = i as i32;
self.data.indices[self.comp.un_cache_key(tmp)] = i as i32;
i = p;
p = parent_index(p);
}
self.heap[i as usize] = x;
self.indices[x.0] = i as i32;
self.data.indices[self.comp.un_cache_key(x)] = i as i32;
}

fn percolate_down(&mut self, mut i: u32) {
let x = self.comp.to_cmp_form(&self.heap[i as usize]);
let x = self.heap[i as usize];
let len = (self.next_slot + 3) & (usize::MAX - 3); // round up to nearest multiple of 4
// since the heap is padded with maximum values we can pretend that these are part of the
// heap but never swap with them
let heap = &mut self.data.heap[..len];
loop {
let min = |x: (u32, Comp::Comp), y: (u32, Comp::Comp)| if x.1 < y.1 { x } else { y };
let min = |x: (u32, Comp::Key), y: (u32, Comp::Key)| if x.1 < y.1 { x } else { y };
let left_index = left_index(i);
let Some(arr) = heap.get(left_index as usize..left_index as usize + 4) else {
break;
};
let bundle = |x| (left_index + x, self.comp.to_cmp_form(&arr[x as usize]));
let bundle = |x| (left_index + x, arr[x as usize]);
let b0 = bundle(0);
let b1 = bundle(1);
let b2 = bundle(2);
Expand All @@ -198,27 +151,25 @@ impl<'a, K: AsIndex + 'a, V: Copy + 'a, Comp: MemoComparator<K, V>> Heap<'a, K,
if min > x {
break;
}
let min = self.comp.from_cmp_form(min);
heap[i as usize] = min;
self.data.indices[min.0] = i as i32;
self.data.indices[self.comp.un_cache_key(min)] = i as i32;
i = child;
}
let x = self.comp.from_cmp_form(x);
heap[i as usize] = x;
self.data.indices[x.0] = i as i32;
self.data.indices[self.comp.un_cache_key(x)] = i as i32;
}

pub fn decrease(&mut self, k: K) {
debug_assert!(self.in_heap(k));
let k_index = self.indices[k];
self.heap[k_index as usize].1 = self.comp.value(k);
self.heap[k_index as usize] = self.comp.cache_key(k);
self.percolate_up(k_index as u32);
}

pub fn insert(&mut self, k: K) {
self.indices.reserve(k, -1);
debug_assert!(!self.in_heap(k));
let k_index = self.heap_push(k, self.comp.value(k));
let k_index = self.heap_push(self.comp.cache_key(k));
self.indices[k] = k_index as i32;
self.percolate_up(k_index);
}
Expand All @@ -228,17 +179,18 @@ impl<'a, K: AsIndex + 'a, V: Copy + 'a, Comp: MemoComparator<K, V>> Heap<'a, K,
let x = self.heap[ROOT as usize];
let last = self.next_slot - 1;
self.next_slot = last;
self.indices[x.0] = -1;
let x_var = self.comp.un_cache_key(x);
self.indices[x_var] = -1;
if self.is_empty() {
self.heap[last] = self.comp.max_value();
return x.0;
self.heap[last] = self.comp.max_key();
return x_var;
}
let lastval = self.heap[last];
self.heap[last] = self.comp.max_value();
self.heap[last] = self.comp.max_key();
self.heap[ROOT as usize] = lastval;
self.indices[lastval.0] = ROOT as i32;
self.data.indices[self.comp.un_cache_key(lastval)] = ROOT as i32;
self.percolate_down(ROOT);
x.0
self.comp.un_cache_key(x)
}
}

Expand Down

0 comments on commit 288548b

Please sign in to comment.