Skip to content

Commit

Permalink
Merge pull request #6 from ForesightMiningSoftwareCorporation/num_traits
Browse files Browse the repository at this point in the history
Use `num_traits::Float` to abstract over f32/f64
  • Loading branch information
bonsairobo authored Jan 3, 2024
2 parents 18a56ab + 6946111 commit 1f3c32b
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 83 deletions.
3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[package]
name = "vector-expr"
authors = ["Duncan Fairbanks <duncan.fairbanks@foresightmining.com>"]
version = "0.2.0"
version = "0.3.0"
edition = "2021"
description = "Vectorized expression parser and evaluator"
license = "MIT OR Apache-2.0"
Expand All @@ -11,6 +11,7 @@ readme = "README.md"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
num-traits = "0.2.17"
once_cell = "1.19.0"
pest = "2.7.5"
pest_derive = "2.7.5"
Expand Down
88 changes: 46 additions & 42 deletions src/evaluate.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::{BoolExpression, RealExpression, StringExpression};
use crate::{BoolExpression, FloatExt, RealExpression, StringExpression};

#[cfg(feature = "rayon")]
use rayon::prelude::{
Expand All @@ -8,14 +8,14 @@ use rayon::prelude::{
/// To speed up string comparisons, we use string interning.
pub type StringId = u32;

impl BoolExpression {
impl<Real: FloatExt> BoolExpression<Real> {
/// Calculates the `bool`-valued results of the expression component-wise.
pub fn evaluate<R: AsRef<[f64]>, S: AsRef<[StringId]>>(
pub fn evaluate<R: AsRef<[Real]>, S: AsRef<[StringId]>>(
&self,
real_bindings: &[R],
string_bindings: &[S],
mut get_string_literal_id: impl FnMut(&str) -> StringId,
registers: &mut Registers,
registers: &mut Registers<Real>,
) -> Vec<bool> {
validate_bindings(real_bindings, registers.register_length);
validate_bindings(string_bindings, registers.register_length);
Expand All @@ -27,12 +27,12 @@ impl BoolExpression {
)
}

fn evaluate_recursive<R: AsRef<[f64]>, S: AsRef<[StringId]>>(
fn evaluate_recursive<R: AsRef<[Real]>, S: AsRef<[StringId]>>(
&self,
real_bindings: &[R],
string_bindings: &[S],
get_string_literal_id: &mut impl FnMut(&str) -> StringId,
registers: &mut Registers,
registers: &mut Registers<Real>,
) -> Vec<bool> {
match self {
Self::And(lhs, rhs) => evaluate_binary_logic(
Expand Down Expand Up @@ -123,22 +123,26 @@ impl BoolExpression {
}
}

impl RealExpression {
pub fn evaluate_without_vars(&self, registers: &mut Registers) -> Vec<f64> {
impl<Real: FloatExt> RealExpression<Real> {
pub fn evaluate_without_vars(&self, registers: &mut Registers<Real>) -> Vec<Real> {
self.evaluate::<[_; 0]>(&[], registers)
}

/// Calculates the real-valued results of the expression component-wise.
pub fn evaluate<R: AsRef<[f64]>>(&self, bindings: &[R], registers: &mut Registers) -> Vec<f64> {
pub fn evaluate<R: AsRef<[Real]>>(
&self,
bindings: &[R],
registers: &mut Registers<Real>,
) -> Vec<Real> {
validate_bindings(bindings, registers.register_length);
self.evaluate_recursive(bindings, registers)
}

fn evaluate_recursive<R: AsRef<[f64]>>(
fn evaluate_recursive<R: AsRef<[Real]>>(
&self,
bindings: &[R],
registers: &mut Registers,
) -> Vec<f64> {
registers: &mut Registers<Real>,
) -> Vec<Real> {
match self {
Self::Add(lhs, rhs) => evaluate_binary_real_op(
|lhs, rhs| lhs + rhs,
Expand Down Expand Up @@ -200,13 +204,13 @@ fn validate_bindings<T, B: AsRef<[T]>>(input_bindings: &[B], expected_length: us
}
}

fn evaluate_binary_real_op<R: AsRef<[f64]>>(
op: fn(f64, f64) -> f64,
lhs: &RealExpression,
rhs: &RealExpression,
fn evaluate_binary_real_op<Real: FloatExt, R: AsRef<[Real]>>(
op: fn(Real, Real) -> Real,
lhs: &RealExpression<Real>,
rhs: &RealExpression<Real>,
bindings: &[R],
registers: &mut Registers,
) -> Vec<f64> {
registers: &mut Registers<Real>,
) -> Vec<Real> {
// Before doing recursive evaluation, we check first if we already have
// input values in our bindings. This avoids unnecessary copies.
let mut lhs_reg = None;
Expand Down Expand Up @@ -254,12 +258,12 @@ fn evaluate_binary_real_op<R: AsRef<[f64]>>(
output
}

fn evaluate_unary_real_op<R: AsRef<[f64]>>(
op: fn(f64) -> f64,
only: &RealExpression,
fn evaluate_unary_real_op<Real: FloatExt, R: AsRef<[Real]>>(
op: fn(Real) -> Real,
only: &RealExpression<Real>,
bindings: &[R],
registers: &mut Registers,
) -> Vec<f64> {
registers: &mut Registers<Real>,
) -> Vec<Real> {
// Before doing recursive evaluation, we check first if we already have
// input values in our bindings. This avoids unnecessary copies.
let mut only_reg = None;
Expand Down Expand Up @@ -287,12 +291,12 @@ fn evaluate_unary_real_op<R: AsRef<[f64]>>(
output
}

fn evaluate_real_comparison<R: AsRef<[f64]>>(
op: fn(f64, f64) -> bool,
lhs: &RealExpression,
rhs: &RealExpression,
fn evaluate_real_comparison<Real: FloatExt, R: AsRef<[Real]>>(
op: fn(Real, Real) -> bool,
lhs: &RealExpression<Real>,
rhs: &RealExpression<Real>,
bindings: &[R],
registers: &mut Registers,
registers: &mut Registers<Real>,
) -> Vec<bool> {
// Before doing recursive evaluation, we check first if we already have
// input values in our bindings. This avoids unnecessary copies.
Expand Down Expand Up @@ -341,13 +345,13 @@ fn evaluate_real_comparison<R: AsRef<[f64]>>(
output
}

fn evaluate_string_comparison<S: AsRef<[StringId]>>(
fn evaluate_string_comparison<Real, S: AsRef<[StringId]>>(
op: fn(StringId, StringId) -> bool,
lhs: &StringExpression,
rhs: &StringExpression,
bindings: &[S],
mut get_string_literal_id: impl FnMut(&str) -> StringId,
registers: &mut Registers,
registers: &mut Registers<Real>,
) -> Vec<bool> {
let mut lhs_reg = None;
let lhs_values = match lhs {
Expand Down Expand Up @@ -402,14 +406,14 @@ fn evaluate_string_comparison<S: AsRef<[StringId]>>(
output
}

fn evaluate_binary_logic<R: AsRef<[f64]>, S: AsRef<[StringId]>>(
fn evaluate_binary_logic<Real: FloatExt, R: AsRef<[Real]>, S: AsRef<[StringId]>>(
op: fn(bool, bool) -> bool,
lhs: &BoolExpression,
rhs: &BoolExpression,
lhs: &BoolExpression<Real>,
rhs: &BoolExpression<Real>,
real_bindings: &[R],
string_bindings: &[S],
get_string_literal_id: &mut impl FnMut(&str) -> StringId,
registers: &mut Registers,
registers: &mut Registers<Real>,
) -> Vec<bool> {
let lhs_values = lhs.evaluate_recursive(
real_bindings,
Expand Down Expand Up @@ -451,13 +455,13 @@ fn evaluate_binary_logic<R: AsRef<[f64]>, S: AsRef<[StringId]>>(
output
}

fn evaluate_unary_logic<R: AsRef<[f64]>, S: AsRef<[StringId]>>(
fn evaluate_unary_logic<Real: FloatExt, R: AsRef<[Real]>, S: AsRef<[StringId]>>(
op: fn(bool) -> bool,
only: &BoolExpression,
only: &BoolExpression<Real>,
real_bindings: &[R],
string_bindings: &[S],
get_string_literal_id: &mut impl FnMut(&str) -> StringId,
registers: &mut Registers,
registers: &mut Registers<Real>,
) -> Vec<bool> {
let only_values = only.evaluate_recursive(
real_bindings,
Expand Down Expand Up @@ -487,15 +491,15 @@ fn evaluate_unary_logic<R: AsRef<[f64]>, S: AsRef<[StringId]>>(
///
/// Attempts to minimize allocations by recycling registers after intermediate
/// calculations have finished.
pub struct Registers {
pub struct Registers<Real> {
num_allocations: usize,
real_registers: Vec<Vec<f64>>,
real_registers: Vec<Vec<Real>>,
bool_registers: Vec<Vec<bool>>,
string_registers: Vec<Vec<StringId>>,
register_length: usize,
}

impl Registers {
impl<Real> Registers<Real> {
pub fn new(register_length: usize) -> Self {
Self {
num_allocations: 0,
Expand All @@ -506,7 +510,7 @@ impl Registers {
}
}

fn recycle_real(&mut self, mut used: Vec<f64>) {
fn recycle_real(&mut self, mut used: Vec<Real>) {
used.clear();
self.real_registers.push(used);
}
Expand All @@ -521,7 +525,7 @@ impl Registers {
self.string_registers.push(used);
}

fn allocate_real(&mut self) -> Vec<f64> {
fn allocate_real(&mut self) -> Vec<Real> {
self.real_registers.pop().unwrap_or_else(|| {
self.num_allocations += 1;
Vec::with_capacity(self.register_length)
Expand Down
42 changes: 21 additions & 21 deletions src/expression.rs
Original file line number Diff line number Diff line change
@@ -1,28 +1,28 @@
/// Top-level parseable calculation.
#[derive(Clone, Debug)]
pub enum Expression {
Boolean(BoolExpression),
Real(RealExpression),
pub enum Expression<Real> {
Boolean(BoolExpression<Real>),
Real(RealExpression<Real>),
String(StringExpression),
}

/// A `bool`-valued expression.
#[derive(Clone, Debug)]
pub enum BoolExpression {
pub enum BoolExpression<Real> {
// Binary logic.
And(Box<BoolExpression>, Box<BoolExpression>),
Or(Box<BoolExpression>, Box<BoolExpression>),
And(Box<BoolExpression<Real>>, Box<BoolExpression<Real>>),
Or(Box<BoolExpression<Real>>, Box<BoolExpression<Real>>),

// Unary logic.
Not(Box<BoolExpression>),
Not(Box<BoolExpression<Real>>),

// Real comparisons.
Equal(Box<RealExpression>, Box<RealExpression>),
Greater(Box<RealExpression>, Box<RealExpression>),
GreaterEqual(Box<RealExpression>, Box<RealExpression>),
Less(Box<RealExpression>, Box<RealExpression>),
LessEqual(Box<RealExpression>, Box<RealExpression>),
NotEqual(Box<RealExpression>, Box<RealExpression>),
Equal(Box<RealExpression<Real>>, Box<RealExpression<Real>>),
Greater(Box<RealExpression<Real>>, Box<RealExpression<Real>>),
GreaterEqual(Box<RealExpression<Real>>, Box<RealExpression<Real>>),
Less(Box<RealExpression<Real>>, Box<RealExpression<Real>>),
LessEqual(Box<RealExpression<Real>>, Box<RealExpression<Real>>),
NotEqual(Box<RealExpression<Real>>, Box<RealExpression<Real>>),

// String comparisons.
StrEqual(StringExpression, StringExpression),
Expand All @@ -31,19 +31,19 @@ pub enum BoolExpression {

/// An `f64`-valued expression.
#[derive(Clone, Debug)]
pub enum RealExpression {
pub enum RealExpression<Real> {
// Binary real ops.
Add(Box<RealExpression>, Box<RealExpression>),
Div(Box<RealExpression>, Box<RealExpression>),
Mul(Box<RealExpression>, Box<RealExpression>),
Pow(Box<RealExpression>, Box<RealExpression>),
Sub(Box<RealExpression>, Box<RealExpression>),
Add(Box<RealExpression<Real>>, Box<RealExpression<Real>>),
Div(Box<RealExpression<Real>>, Box<RealExpression<Real>>),
Mul(Box<RealExpression<Real>>, Box<RealExpression<Real>>),
Pow(Box<RealExpression<Real>>, Box<RealExpression<Real>>),
Sub(Box<RealExpression<Real>>, Box<RealExpression<Real>>),

// Unary real ops.
Neg(Box<RealExpression>),
Neg(Box<RealExpression<Real>>),

// Constant.
Literal(f64),
Literal(Real),

// Input variable.
Binding(BindingId),
Expand Down
16 changes: 10 additions & 6 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ pub fn empty_binding_map(_var_name: &str) -> BindingId {
panic!("Empty binding map")
}

pub trait FloatExt: num_traits::Float + std::str::FromStr + Send + Sync {}
impl FloatExt for f32 {}
impl FloatExt for f64 {}

#[cfg(test)]
mod tests {
use super::*;
Expand Down Expand Up @@ -84,17 +88,17 @@ mod tests {
fn real_op_precedence() {
let mut registers = Registers::new(1);

let parsed = Expression::parse("1 * 2 + 3 * 4", empty_binding_map).unwrap();
let parsed = Expression::<f32>::parse("1 * 2 + 3 * 4", empty_binding_map).unwrap();
let real = parsed.unwrap_real();
let output = real.evaluate_without_vars(&mut registers);
assert_eq!(&output, &[14.0]);

let parsed = Expression::parse("8 / 4 * 3", empty_binding_map).unwrap();
let parsed = Expression::<f32>::parse("8 / 4 * 3", empty_binding_map).unwrap();
let real = parsed.unwrap_real();
let output = real.evaluate_without_vars(&mut registers);
assert_eq!(&output, &[6.0]);

let parsed = Expression::parse("4 ^ 3 ^ 2", empty_binding_map).unwrap();
let parsed = Expression::<f32>::parse("4 ^ 3 ^ 2", empty_binding_map).unwrap();
let real = parsed.unwrap_real();
let output = real.evaluate_without_vars(&mut registers);
assert_eq!(&output, &[262144.0]);
Expand Down Expand Up @@ -198,9 +202,9 @@ mod tests {
let real = parsed.unwrap_real();

const LEN: i32 = 10_000_000;
let x: Vec<_> = (0..LEN).map(|i| i as f64).collect();
let y: Vec<_> = (0..LEN).map(|i| (LEN - i) as f64).collect();
let z: Vec<_> = (0..LEN).map(|i| ((LEN / 2) - i) as f64).collect();
let x: Vec<_> = (0..LEN).map(|i| i as f32).collect();
let y: Vec<_> = (0..LEN).map(|i| (LEN - i) as f32).collect();
let z: Vec<_> = (0..LEN).map(|i| ((LEN / 2) - i) as f32).collect();
let bindings = &[x, y, z];

let mut registers = Registers::new(LEN as usize);
Expand Down
Loading

0 comments on commit 1f3c32b

Please sign in to comment.