diff --git a/Cargo.toml b/Cargo.toml index 0735192..d2718d4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,6 +11,7 @@ readme = "README.md" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] +bitvec = "1.0.1" num-traits = "0.2.17" once_cell = "1.19.0" pest = "2.7.5" diff --git a/src/evaluate.rs b/src/evaluate.rs index db17ed7..2e8eb4c 100644 --- a/src/evaluate.rs +++ b/src/evaluate.rs @@ -1,8 +1,13 @@ use crate::{BoolExpression, FloatExt, RealExpression, StringExpression}; +use bitvec::vec::BitVec; #[cfg(feature = "rayon")] -use rayon::prelude::{ - IndexedParallelIterator, IntoParallelRefIterator, ParallelExtend, ParallelIterator, +use rayon::{ + prelude::{ + IndexedParallelIterator, IntoParallelRefIterator, IntoParallelRefMutIterator, + ParallelExtend, ParallelIterator, + }, + slice::ParallelSlice, }; /// To speed up string comparisons, we use string interning. @@ -16,7 +21,7 @@ impl BoolExpression { string_bindings: &[S], mut get_string_literal_id: impl FnMut(&str) -> StringId, registers: &mut Registers, - ) -> Vec { + ) -> BitVec { validate_bindings(real_bindings, registers.register_length); validate_bindings(string_bindings, registers.register_length); self.evaluate_recursive( @@ -33,10 +38,29 @@ impl BoolExpression { string_bindings: &[S], get_string_literal_id: &mut impl FnMut(&str) -> StringId, registers: &mut Registers, - ) -> Vec { + ) -> BitVec { + let reg_len = registers.register_length; match self { Self::And(lhs, rhs) => evaluate_binary_logic( - |lhs, rhs| lhs && rhs, + |lhs, rhs, out| { + #[cfg(feature = "rayon")] + { + out.resize(reg_len, Default::default()); + lhs.as_raw_slice() + .par_iter() + .zip(rhs.as_raw_slice().par_iter()) + .zip(out.as_raw_mut_slice().par_iter_mut()) + .for_each(|((lhs, rhs), out)| { + *out = lhs & rhs; + }) + } + #[cfg(not(feature = "rayon"))] + { + out.resize(reg_len, true); + *out &= lhs; + *out &= rhs; + } + }, lhs.as_ref(), rhs.as_ref(), real_bindings, @@ -80,7 +104,18 @@ impl BoolExpression { registers, ), Self::Not(only) => evaluate_unary_logic( - |only| !only, + |only| { + #[cfg(feature = "rayon")] + { + only.as_raw_mut_slice().par_iter_mut().for_each(|i| { + *i = !*i; + }); + } + #[cfg(not(feature = "rayon"))] + { + *only = !std::mem::take(only); + } + }, only.as_ref(), real_bindings, string_bindings, @@ -95,7 +130,25 @@ impl BoolExpression { registers, ), Self::Or(lhs, rhs) => evaluate_binary_logic( - |lhs, rhs| lhs || rhs, + |lhs, rhs, out| { + #[cfg(feature = "rayon")] + { + out.resize(reg_len, Default::default()); + lhs.as_raw_slice() + .par_iter() + .zip(rhs.as_raw_slice().par_iter()) + .zip(out.as_raw_mut_slice().par_iter_mut()) + .for_each(|((lhs, rhs), out)| { + *out = lhs | rhs; + }) + } + #[cfg(not(feature = "rayon"))] + { + out.resize(reg_len, false); + *out |= lhs; + *out |= rhs; + } + }, lhs.as_ref(), rhs.as_ref(), real_bindings, @@ -297,7 +350,7 @@ fn evaluate_real_comparison>( rhs: &RealExpression, bindings: &[R], registers: &mut Registers, -) -> Vec { +) -> BitVec { // Before doing recursive evaluation, we check first if we already have // input values in our bindings. This avoids unnecessary copies. let mut lhs_reg = None; @@ -319,12 +372,8 @@ fn evaluate_real_comparison>( #[cfg(feature = "rayon")] { - output.par_extend( - lhs_values - .par_iter() - .zip(rhs_values.par_iter()) - .map(|(lhs, rhs)| op(*lhs, *rhs)), - ); + output.resize(registers.register_length, Default::default()); + parallel_comparison(op, lhs_values, rhs_values, &mut output); } #[cfg(not(feature = "rayon"))] { @@ -352,7 +401,7 @@ fn evaluate_string_comparison>( bindings: &[S], mut get_string_literal_id: impl FnMut(&str) -> StringId, registers: &mut Registers, -) -> Vec { +) -> BitVec { let mut lhs_reg = None; let lhs_values = match lhs { StringExpression::Binding(binding) => bindings[*binding].as_ref(), @@ -380,12 +429,8 @@ fn evaluate_string_comparison>( #[cfg(feature = "rayon")] { - output.par_extend( - lhs_values - .par_iter() - .zip(rhs_values.par_iter()) - .map(|(lhs, rhs)| op(*lhs, *rhs)), - ); + output.resize(registers.register_length, Default::default()); + parallel_comparison(op, lhs_values, rhs_values, &mut output); } #[cfg(not(feature = "rayon"))] { @@ -406,15 +451,48 @@ fn evaluate_string_comparison>( output } +#[cfg(feature = "rayon")] +fn parallel_comparison( + op: fn(T, T) -> bool, + lhs_values: &[T], + rhs_values: &[T], + output: &mut BitVec, +) { + // Some nasty chunked iteration to make sure chunks of input line up + // with the bit storage integers. + let bits_per_block = usize::BITS as usize; + let bit_blocks = output.as_raw_mut_slice(); + let lhs_chunks = lhs_values.par_chunks_exact(bits_per_block); + let rhs_chunks = rhs_values.par_chunks_exact(bits_per_block); + if let Some(rem_block) = bit_blocks.last_mut() { + lhs_chunks + .remainder() + .iter() + .zip(rhs_chunks.remainder()) + .enumerate() + .for_each(|(i, (&lhs, &rhs))| { + *rem_block |= usize::from(op(lhs, rhs)) << i; + }); + } + lhs_chunks + .zip(rhs_chunks) + .zip(bit_blocks.par_iter_mut()) + .for_each(|((lhs_chunk, rhs_chunk), out_block)| { + for (i, (&lhs, &rhs)) in lhs_chunk.iter().zip(rhs_chunk).enumerate() { + *out_block |= usize::from(op(lhs, rhs)) << i; + } + }); +} + fn evaluate_binary_logic, S: AsRef<[StringId]>>( - op: fn(bool, bool) -> bool, + op: impl Fn(&BitVec, &BitVec, &mut BitVec), lhs: &BoolExpression, rhs: &BoolExpression, real_bindings: &[R], string_bindings: &[S], get_string_literal_id: &mut impl FnMut(&str) -> StringId, registers: &mut Registers, -) -> Vec { +) -> BitVec { let lhs_values = lhs.evaluate_recursive( real_bindings, string_bindings, @@ -431,24 +509,7 @@ fn evaluate_binary_logic, S: AsRef<[StringId]>> // Allocate this output register as lazily as possible. let mut output = registers.allocate_bool(); - #[cfg(feature = "rayon")] - { - output.par_extend( - lhs_values - .par_iter() - .zip(rhs_values.par_iter()) - .map(|(lhs, rhs)| op(*lhs, *rhs)), - ); - } - #[cfg(not(feature = "rayon"))] - { - output.extend( - lhs_values - .iter() - .zip(rhs_values.iter()) - .map(|(lhs, rhs)| op(*lhs, *rhs)), - ); - } + op(&lhs_values, &rhs_values, &mut output); registers.recycle_bool(lhs_values); registers.recycle_bool(rhs_values); @@ -456,34 +517,23 @@ fn evaluate_binary_logic, S: AsRef<[StringId]>> } fn evaluate_unary_logic, S: AsRef<[StringId]>>( - op: fn(bool) -> bool, + op: fn(&mut BitVec), only: &BoolExpression, real_bindings: &[R], string_bindings: &[S], get_string_literal_id: &mut impl FnMut(&str) -> StringId, registers: &mut Registers, -) -> Vec { - let only_values = only.evaluate_recursive( +) -> BitVec { + let mut only_values = only.evaluate_recursive( real_bindings, string_bindings, get_string_literal_id, registers, ); - // Allocate this output register as lazily as possible. - let mut output = registers.allocate_bool(); + op(&mut only_values); - #[cfg(feature = "rayon")] - { - output.par_extend(only_values.par_iter().map(|only| op(*only))); - } - #[cfg(not(feature = "rayon"))] - { - output.extend(only_values.iter().map(|only| op(*only))); - } - - registers.recycle_bool(only_values); - output + only_values } /// Scratch space for calculations. Can be reused across evaluations with the @@ -494,7 +544,7 @@ fn evaluate_unary_logic, S: AsRef<[StringId]>>( pub struct Registers { num_allocations: usize, real_registers: Vec>, - bool_registers: Vec>, + bool_registers: Vec, string_registers: Vec>, register_length: usize, } @@ -515,7 +565,7 @@ impl Registers { self.real_registers.push(used); } - fn recycle_bool(&mut self, mut used: Vec) { + fn recycle_bool(&mut self, mut used: BitVec) { used.clear(); self.bool_registers.push(used); } @@ -532,10 +582,10 @@ impl Registers { }) } - fn allocate_bool(&mut self) -> Vec { + fn allocate_bool(&mut self) -> BitVec { self.bool_registers.pop().unwrap_or_else(|| { self.num_allocations += 1; - Vec::with_capacity(self.register_length) + BitVec::with_capacity(self.register_length) }) } diff --git a/src/lib.rs b/src/lib.rs index 87808f5..9add32b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -123,7 +123,7 @@ mod tests { let bindings = &[bar, baz, foo]; let mut registers = Registers::new(3); let output = bool.evaluate::<_, [_; 0]>(bindings, &[], |_| unreachable!(), &mut registers); - assert_eq!(&output, &[false, true, false]); + assert_eq!([output[0], output[1], output[2]], [false, true, false]); assert_eq!(registers.num_allocations(), 3); } @@ -157,7 +157,7 @@ mod tests { string_literal_id, &mut registers, ); - assert_eq!(&output, &[false, false, true]); + assert_eq!([output[0], output[1], output[2]], [false, false, true]); assert_eq!(registers.num_allocations(), 5); }