Skip to content

Commit

Permalink
Merge pull request #7 from ForesightMiningSoftwareCorporation/bit_arrays
Browse files Browse the repository at this point in the history
Replace `Vec<bool>` with `BitVec`
  • Loading branch information
bonsairobo authored Jan 3, 2024
2 parents 1f3c32b + 3bd9144 commit 3384571
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 63 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ readme = "README.md"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
bitvec = "1.0.1"
num-traits = "0.2.17"
once_cell = "1.19.0"
pest = "2.7.5"
Expand Down
172 changes: 111 additions & 61 deletions src/evaluate.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
use crate::{BoolExpression, FloatExt, RealExpression, StringExpression};
use bitvec::vec::BitVec;

#[cfg(feature = "rayon")]
use rayon::prelude::{
IndexedParallelIterator, IntoParallelRefIterator, ParallelExtend, ParallelIterator,
use rayon::{
prelude::{
IndexedParallelIterator, IntoParallelRefIterator, IntoParallelRefMutIterator,
ParallelExtend, ParallelIterator,
},
slice::ParallelSlice,
};

/// To speed up string comparisons, we use string interning.
Expand All @@ -16,7 +21,7 @@ impl<Real: FloatExt> BoolExpression<Real> {
string_bindings: &[S],
mut get_string_literal_id: impl FnMut(&str) -> StringId,
registers: &mut Registers<Real>,
) -> Vec<bool> {
) -> BitVec {
validate_bindings(real_bindings, registers.register_length);
validate_bindings(string_bindings, registers.register_length);
self.evaluate_recursive(
Expand All @@ -33,10 +38,29 @@ impl<Real: FloatExt> BoolExpression<Real> {
string_bindings: &[S],
get_string_literal_id: &mut impl FnMut(&str) -> StringId,
registers: &mut Registers<Real>,
) -> Vec<bool> {
) -> BitVec {
let reg_len = registers.register_length;
match self {
Self::And(lhs, rhs) => evaluate_binary_logic(
|lhs, rhs| lhs && rhs,
|lhs, rhs, out| {
#[cfg(feature = "rayon")]
{
out.resize(reg_len, Default::default());
lhs.as_raw_slice()
.par_iter()
.zip(rhs.as_raw_slice().par_iter())
.zip(out.as_raw_mut_slice().par_iter_mut())
.for_each(|((lhs, rhs), out)| {
*out = lhs & rhs;
})
}
#[cfg(not(feature = "rayon"))]
{
out.resize(reg_len, true);
*out &= lhs;
*out &= rhs;
}
},
lhs.as_ref(),
rhs.as_ref(),
real_bindings,
Expand Down Expand Up @@ -80,7 +104,18 @@ impl<Real: FloatExt> BoolExpression<Real> {
registers,
),
Self::Not(only) => evaluate_unary_logic(
|only| !only,
|only| {
#[cfg(feature = "rayon")]
{
only.as_raw_mut_slice().par_iter_mut().for_each(|i| {
*i = !*i;
});
}
#[cfg(not(feature = "rayon"))]
{
*only = !std::mem::take(only);
}
},
only.as_ref(),
real_bindings,
string_bindings,
Expand All @@ -95,7 +130,25 @@ impl<Real: FloatExt> BoolExpression<Real> {
registers,
),
Self::Or(lhs, rhs) => evaluate_binary_logic(
|lhs, rhs| lhs || rhs,
|lhs, rhs, out| {
#[cfg(feature = "rayon")]
{
out.resize(reg_len, Default::default());
lhs.as_raw_slice()
.par_iter()
.zip(rhs.as_raw_slice().par_iter())
.zip(out.as_raw_mut_slice().par_iter_mut())
.for_each(|((lhs, rhs), out)| {
*out = lhs | rhs;
})
}
#[cfg(not(feature = "rayon"))]
{
out.resize(reg_len, false);
*out |= lhs;
*out |= rhs;
}
},
lhs.as_ref(),
rhs.as_ref(),
real_bindings,
Expand Down Expand Up @@ -297,7 +350,7 @@ fn evaluate_real_comparison<Real: FloatExt, R: AsRef<[Real]>>(
rhs: &RealExpression<Real>,
bindings: &[R],
registers: &mut Registers<Real>,
) -> Vec<bool> {
) -> BitVec {
// Before doing recursive evaluation, we check first if we already have
// input values in our bindings. This avoids unnecessary copies.
let mut lhs_reg = None;
Expand All @@ -319,12 +372,8 @@ fn evaluate_real_comparison<Real: FloatExt, R: AsRef<[Real]>>(

#[cfg(feature = "rayon")]
{
output.par_extend(
lhs_values
.par_iter()
.zip(rhs_values.par_iter())
.map(|(lhs, rhs)| op(*lhs, *rhs)),
);
output.resize(registers.register_length, Default::default());
parallel_comparison(op, lhs_values, rhs_values, &mut output);
}
#[cfg(not(feature = "rayon"))]
{
Expand Down Expand Up @@ -352,7 +401,7 @@ fn evaluate_string_comparison<Real, S: AsRef<[StringId]>>(
bindings: &[S],
mut get_string_literal_id: impl FnMut(&str) -> StringId,
registers: &mut Registers<Real>,
) -> Vec<bool> {
) -> BitVec {
let mut lhs_reg = None;
let lhs_values = match lhs {
StringExpression::Binding(binding) => bindings[*binding].as_ref(),
Expand Down Expand Up @@ -380,12 +429,8 @@ fn evaluate_string_comparison<Real, S: AsRef<[StringId]>>(

#[cfg(feature = "rayon")]
{
output.par_extend(
lhs_values
.par_iter()
.zip(rhs_values.par_iter())
.map(|(lhs, rhs)| op(*lhs, *rhs)),
);
output.resize(registers.register_length, Default::default());
parallel_comparison(op, lhs_values, rhs_values, &mut output);
}
#[cfg(not(feature = "rayon"))]
{
Expand All @@ -406,15 +451,48 @@ fn evaluate_string_comparison<Real, S: AsRef<[StringId]>>(
output
}

#[cfg(feature = "rayon")]
fn parallel_comparison<T: Copy + Send + Sync>(
op: fn(T, T) -> bool,
lhs_values: &[T],
rhs_values: &[T],
output: &mut BitVec,
) {
// Some nasty chunked iteration to make sure chunks of input line up
// with the bit storage integers.
let bits_per_block = usize::BITS as usize;
let bit_blocks = output.as_raw_mut_slice();
let lhs_chunks = lhs_values.par_chunks_exact(bits_per_block);
let rhs_chunks = rhs_values.par_chunks_exact(bits_per_block);
if let Some(rem_block) = bit_blocks.last_mut() {
lhs_chunks
.remainder()
.iter()
.zip(rhs_chunks.remainder())
.enumerate()
.for_each(|(i, (&lhs, &rhs))| {
*rem_block |= usize::from(op(lhs, rhs)) << i;
});
}
lhs_chunks
.zip(rhs_chunks)
.zip(bit_blocks.par_iter_mut())
.for_each(|((lhs_chunk, rhs_chunk), out_block)| {
for (i, (&lhs, &rhs)) in lhs_chunk.iter().zip(rhs_chunk).enumerate() {
*out_block |= usize::from(op(lhs, rhs)) << i;
}
});
}

fn evaluate_binary_logic<Real: FloatExt, R: AsRef<[Real]>, S: AsRef<[StringId]>>(
op: fn(bool, bool) -> bool,
op: impl Fn(&BitVec, &BitVec, &mut BitVec),
lhs: &BoolExpression<Real>,
rhs: &BoolExpression<Real>,
real_bindings: &[R],
string_bindings: &[S],
get_string_literal_id: &mut impl FnMut(&str) -> StringId,
registers: &mut Registers<Real>,
) -> Vec<bool> {
) -> BitVec {
let lhs_values = lhs.evaluate_recursive(
real_bindings,
string_bindings,
Expand All @@ -431,59 +509,31 @@ fn evaluate_binary_logic<Real: FloatExt, R: AsRef<[Real]>, S: AsRef<[StringId]>>
// Allocate this output register as lazily as possible.
let mut output = registers.allocate_bool();

#[cfg(feature = "rayon")]
{
output.par_extend(
lhs_values
.par_iter()
.zip(rhs_values.par_iter())
.map(|(lhs, rhs)| op(*lhs, *rhs)),
);
}
#[cfg(not(feature = "rayon"))]
{
output.extend(
lhs_values
.iter()
.zip(rhs_values.iter())
.map(|(lhs, rhs)| op(*lhs, *rhs)),
);
}
op(&lhs_values, &rhs_values, &mut output);

registers.recycle_bool(lhs_values);
registers.recycle_bool(rhs_values);
output
}

fn evaluate_unary_logic<Real: FloatExt, R: AsRef<[Real]>, S: AsRef<[StringId]>>(
op: fn(bool) -> bool,
op: fn(&mut BitVec),
only: &BoolExpression<Real>,
real_bindings: &[R],
string_bindings: &[S],
get_string_literal_id: &mut impl FnMut(&str) -> StringId,
registers: &mut Registers<Real>,
) -> Vec<bool> {
let only_values = only.evaluate_recursive(
) -> BitVec {
let mut only_values = only.evaluate_recursive(
real_bindings,
string_bindings,
get_string_literal_id,
registers,
);

// Allocate this output register as lazily as possible.
let mut output = registers.allocate_bool();
op(&mut only_values);

#[cfg(feature = "rayon")]
{
output.par_extend(only_values.par_iter().map(|only| op(*only)));
}
#[cfg(not(feature = "rayon"))]
{
output.extend(only_values.iter().map(|only| op(*only)));
}

registers.recycle_bool(only_values);
output
only_values
}

/// Scratch space for calculations. Can be reused across evaluations with the
Expand All @@ -494,7 +544,7 @@ fn evaluate_unary_logic<Real: FloatExt, R: AsRef<[Real]>, S: AsRef<[StringId]>>(
pub struct Registers<Real> {
num_allocations: usize,
real_registers: Vec<Vec<Real>>,
bool_registers: Vec<Vec<bool>>,
bool_registers: Vec<BitVec>,
string_registers: Vec<Vec<StringId>>,
register_length: usize,
}
Expand All @@ -515,7 +565,7 @@ impl<Real> Registers<Real> {
self.real_registers.push(used);
}

fn recycle_bool(&mut self, mut used: Vec<bool>) {
fn recycle_bool(&mut self, mut used: BitVec) {
used.clear();
self.bool_registers.push(used);
}
Expand All @@ -532,10 +582,10 @@ impl<Real> Registers<Real> {
})
}

fn allocate_bool(&mut self) -> Vec<bool> {
fn allocate_bool(&mut self) -> BitVec {
self.bool_registers.pop().unwrap_or_else(|| {
self.num_allocations += 1;
Vec::with_capacity(self.register_length)
BitVec::with_capacity(self.register_length)
})
}

Expand Down
4 changes: 2 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ mod tests {
let bindings = &[bar, baz, foo];
let mut registers = Registers::new(3);
let output = bool.evaluate::<_, [_; 0]>(bindings, &[], |_| unreachable!(), &mut registers);
assert_eq!(&output, &[false, true, false]);
assert_eq!([output[0], output[1], output[2]], [false, true, false]);
assert_eq!(registers.num_allocations(), 3);
}

Expand Down Expand Up @@ -157,7 +157,7 @@ mod tests {
string_literal_id,
&mut registers,
);
assert_eq!(&output, &[false, false, true]);
assert_eq!([output[0], output[1], output[2]], [false, false, true]);
assert_eq!(registers.num_allocations(), 5);
}

Expand Down

0 comments on commit 3384571

Please sign in to comment.