diff --git a/Cargo.toml b/Cargo.toml index 02fd97d..796d0bc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,18 +1,13 @@ -[package] -name = "vector-expr" -authors = ["Duncan Fairbanks "] -version = "0.2.0" -edition = "2021" -description = "Vectorized expression parser and evaluator" -license = "MIT OR Apache-2.0" -repository = "https://github.com/ForesightMiningSoftwareCorporation/vector_expr" -readme = "README.md" +[workspace] +members = ["crates/vector-expr", "crates/vector-expr-f64"] +resolver = "2" -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html +[patch.crates-io] +vector-expr = { path = "crates/vector-expr" } +vector-expr-f64 = { path = "crates/vector-expr-f64" } -[dependencies] +[workspace.dependencies] once_cell = "1.19.0" pest = "2.7.5" pest_derive = "2.7.5" - -rayon = { version = "1", optional = true } +rayon = "1" diff --git a/README.md b/README.md index c437dee..79efc03 100644 --- a/README.md +++ b/README.md @@ -12,27 +12,6 @@ parallelism via the `rayon` feature). ## Example -```rust -use vector_expr::*; - -fn binding_map(var_name: &str) -> BindingId { - match var_name { - "bar" => 0, - "baz" => 1, - "foo" => 2, - _ => unreachable!(), - } -} -let parsed = Expression::parse("2 * (foo + bar) * baz", &binding_map).unwrap(); -let real = parsed.unwrap_real(); - -let bar = [1.0, 2.0, 3.0]; -let baz = [4.0, 5.0, 6.0]; -let foo = [7.0, 8.0, 9.0]; -let bindings: &[&[f64]] = &[&bar, &baz, &foo]; -let mut registers = Registers::new(3); -let output = real.evaluate(bindings, &mut registers); -assert_eq!(&output, &[64.0, 100.0, 144.0]); -``` +See unit tests in `src/lib.rs`. License: MIT OR Apache-2.0 diff --git a/crates/vector-expr-f64/Cargo.toml b/crates/vector-expr-f64/Cargo.toml new file mode 100644 index 0000000..3151976 --- /dev/null +++ b/crates/vector-expr-f64/Cargo.toml @@ -0,0 +1,25 @@ +[package] +name = "vector-expr-f64" +authors = ["Duncan Fairbanks "] +version = "0.2.0" +edition = "2021" +description = "Vectorized expression parser and evaluator" +license = "MIT OR Apache-2.0" +repository = "https://github.com/ForesightMiningSoftwareCorporation/vector_expr" +readme = "../../README.md" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[lib] +name = "vector_expr_f64" +path = "../../src/lib.rs" + +[features] +default = ["f64"] +f64 = [] + +[dependencies] +once_cell = { workspace = true } +pest = { workspace = true } +pest_derive = { workspace = true } +rayon = { workspace = true, optional = true } diff --git a/crates/vector-expr/Cargo.toml b/crates/vector-expr/Cargo.toml new file mode 100644 index 0000000..684b963 --- /dev/null +++ b/crates/vector-expr/Cargo.toml @@ -0,0 +1,26 @@ +[package] +name = "vector-expr" +authors = ["Duncan Fairbanks "] +version = "0.2.0" +edition = "2021" +description = "Vectorized expression parser and evaluator" +license = "MIT OR Apache-2.0" +repository = "https://github.com/ForesightMiningSoftwareCorporation/vector_expr" +readme = "../../README.md" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[lib] +name = "vector_expr" +path = "../../src/lib.rs" + +[features] +default = ["f32"] +f32 = [] + +[dependencies] +once_cell = { workspace = true } +pest = { workspace = true } +pest_derive = { workspace = true } +rayon = { workspace = true, optional = true } + diff --git a/src/evaluate.rs b/src/evaluate.rs index 7961619..656377c 100644 --- a/src/evaluate.rs +++ b/src/evaluate.rs @@ -1,3 +1,4 @@ +use crate::real::Real; use crate::{BoolExpression, RealExpression, StringExpression}; #[cfg(feature = "rayon")] @@ -10,7 +11,7 @@ pub type StringId = u32; impl BoolExpression { /// Calculates the `bool`-valued results of the expression component-wise. - pub fn evaluate, S: AsRef<[StringId]>>( + pub fn evaluate, S: AsRef<[StringId]>>( &self, real_bindings: &[R], string_bindings: &[S], @@ -27,7 +28,7 @@ impl BoolExpression { ) } - fn evaluate_recursive, S: AsRef<[StringId]>>( + fn evaluate_recursive, S: AsRef<[StringId]>>( &self, real_bindings: &[R], string_bindings: &[S], @@ -124,21 +125,25 @@ impl BoolExpression { } impl RealExpression { - pub fn evaluate_without_vars(&self, registers: &mut Registers) -> Vec { + pub fn evaluate_without_vars(&self, registers: &mut Registers) -> Vec { self.evaluate::<[_; 0]>(&[], registers) } /// Calculates the real-valued results of the expression component-wise. - pub fn evaluate>(&self, bindings: &[R], registers: &mut Registers) -> Vec { + pub fn evaluate>( + &self, + bindings: &[R], + registers: &mut Registers, + ) -> Vec { validate_bindings(bindings, registers.register_length); self.evaluate_recursive(bindings, registers) } - fn evaluate_recursive>( + fn evaluate_recursive>( &self, bindings: &[R], registers: &mut Registers, - ) -> Vec { + ) -> Vec { match self { Self::Add(lhs, rhs) => evaluate_binary_real_op( |lhs, rhs| lhs + rhs, @@ -200,13 +205,13 @@ fn validate_bindings>(input_bindings: &[B], expected_length: us } } -fn evaluate_binary_real_op>( - op: fn(f64, f64) -> f64, +fn evaluate_binary_real_op>( + op: fn(Real, Real) -> Real, lhs: &RealExpression, rhs: &RealExpression, bindings: &[R], registers: &mut Registers, -) -> Vec { +) -> Vec { // Before doing recursive evaluation, we check first if we already have // input values in our bindings. This avoids unnecessary copies. let mut lhs_reg = None; @@ -254,12 +259,12 @@ fn evaluate_binary_real_op>( output } -fn evaluate_unary_real_op>( - op: fn(f64) -> f64, +fn evaluate_unary_real_op>( + op: fn(Real) -> Real, only: &RealExpression, bindings: &[R], registers: &mut Registers, -) -> Vec { +) -> Vec { // Before doing recursive evaluation, we check first if we already have // input values in our bindings. This avoids unnecessary copies. let mut only_reg = None; @@ -287,8 +292,8 @@ fn evaluate_unary_real_op>( output } -fn evaluate_real_comparison>( - op: fn(f64, f64) -> bool, +fn evaluate_real_comparison>( + op: fn(Real, Real) -> bool, lhs: &RealExpression, rhs: &RealExpression, bindings: &[R], @@ -402,7 +407,7 @@ fn evaluate_string_comparison>( output } -fn evaluate_binary_logic, S: AsRef<[StringId]>>( +fn evaluate_binary_logic, S: AsRef<[StringId]>>( op: fn(bool, bool) -> bool, lhs: &BoolExpression, rhs: &BoolExpression, @@ -451,7 +456,7 @@ fn evaluate_binary_logic, S: AsRef<[StringId]>>( output } -fn evaluate_unary_logic, S: AsRef<[StringId]>>( +fn evaluate_unary_logic, S: AsRef<[StringId]>>( op: fn(bool) -> bool, only: &BoolExpression, real_bindings: &[R], @@ -489,7 +494,7 @@ fn evaluate_unary_logic, S: AsRef<[StringId]>>( /// calculations have finished. pub struct Registers { num_allocations: usize, - real_registers: Vec>, + real_registers: Vec>, bool_registers: Vec>, string_registers: Vec>, register_length: usize, @@ -506,7 +511,7 @@ impl Registers { } } - fn recycle_real(&mut self, mut used: Vec) { + fn recycle_real(&mut self, mut used: Vec) { used.clear(); self.real_registers.push(used); } @@ -521,7 +526,7 @@ impl Registers { self.string_registers.push(used); } - fn allocate_real(&mut self) -> Vec { + fn allocate_real(&mut self) -> Vec { self.real_registers.pop().unwrap_or_else(|| { self.num_allocations += 1; Vec::with_capacity(self.register_length) diff --git a/src/expression.rs b/src/expression.rs index a992ce9..be71145 100644 --- a/src/expression.rs +++ b/src/expression.rs @@ -1,3 +1,5 @@ +use crate::real::Real; + /// Top-level parseable calculation. #[derive(Clone, Debug)] pub enum Expression { @@ -29,7 +31,7 @@ pub enum BoolExpression { StrNotEqual(StringExpression, StringExpression), } -/// An `f64`-valued expression. +/// A `Real`-valued expression. #[derive(Clone, Debug)] pub enum RealExpression { // Binary real ops. @@ -43,7 +45,7 @@ pub enum RealExpression { Neg(Box), // Constant. - Literal(f64), + Literal(Real), // Input variable. Binding(BindingId), @@ -55,5 +57,6 @@ pub enum StringExpression { Binding(BindingId), } -/// Index into the `&[&[f64]]` bindings passed to expression evaluation. +/// Index into the `&[&[Real]]` or `&[&[StringId]]` bindings passed to +/// expression evaluation. pub type BindingId = usize; diff --git a/src/lib.rs b/src/lib.rs index d001373..14bbca0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,37 +1,14 @@ -//! Vectorized math expression parser/evaluator. -//! -//! # Why? -//! -//! Performance. Evaluation of math expressions involving many variables can -//! incur significant overhead from traversing the expression tree or performing -//! variable lookups. We amortize that cost by performing intermediate -//! operations on _vectors_ of input data at a time (with optional data -//! parallelism via the `rayon` feature). -//! -//! # Example -//! -//! ```rust -//! use vector_expr::*; -//! -//! fn binding_map(var_name: &str) -> BindingId { -//! match var_name { -//! "bar" => 0, -//! "baz" => 1, -//! "foo" => 2, -//! _ => unreachable!(), -//! } -//! } -//! let parsed = Expression::parse("2 * (foo + bar) * baz", binding_map).unwrap(); -//! let real = parsed.unwrap_real(); -//! -//! let bar = [1.0, 2.0, 3.0]; -//! let baz = [4.0, 5.0, 6.0]; -//! let foo = [7.0, 8.0, 9.0]; -//! let bindings: &[&[f64]] = &[&bar, &baz, &foo]; -//! let mut registers = Registers::new(3); -//! let output = real.evaluate(bindings, &mut registers); -//! assert_eq!(&output, &[64.0, 100.0, 144.0]); -//! ``` +#![doc = include_str!("../README.md")] + +mod real { + /// The scalar type used throughout this crate. + #[cfg(feature = "f64")] + pub type Real = f64; + + /// The scalar type used throughout this crate. + #[cfg(feature = "f32")] + pub type Real = f32; +} mod evaluate; mod expression; @@ -55,6 +32,7 @@ pub fn empty_binding_map(_var_name: &str) -> BindingId { #[cfg(test)] mod tests { + use super::real::Real; use super::*; #[test] @@ -198,9 +176,9 @@ mod tests { let real = parsed.unwrap_real(); const LEN: i32 = 10_000_000; - let x: Vec<_> = (0..LEN).map(|i| i as f64).collect(); - let y: Vec<_> = (0..LEN).map(|i| (LEN - i) as f64).collect(); - let z: Vec<_> = (0..LEN).map(|i| ((LEN / 2) - i) as f64).collect(); + let x: Vec<_> = (0..LEN).map(|i| i as Real).collect(); + let y: Vec<_> = (0..LEN).map(|i| (LEN - i) as Real).collect(); + let z: Vec<_> = (0..LEN).map(|i| ((LEN / 2) - i) as Real).collect(); let bindings = &[x, y, z]; let mut registers = Registers::new(LEN as usize); diff --git a/src/parse.rs b/src/parse.rs index 1357e39..debd73a 100644 --- a/src/parse.rs +++ b/src/parse.rs @@ -1,4 +1,5 @@ use crate::expression::{BindingId, BoolExpression, Expression, RealExpression}; +use crate::real::Real; use crate::StringExpression; use once_cell::sync::Lazy; use pest::iterators::Pairs; @@ -8,7 +9,7 @@ use pest_derive::Parser; use std::collections::HashSet; #[derive(Parser)] -#[grammar = "grammar.pest"] // relative to project `src` +#[grammar = "../../src/grammar.pest"] // relative to workspace `src` struct ExpressionParser; // Boxed because error is much larger than Ok variant in most results. @@ -97,7 +98,7 @@ fn parse_recursive(pairs: Pairs, binding_map: &impl Fn(&str) -> BindingId) Rule::string_expr => parse_recursive(pair.into_inner(), binding_map), Rule::real_literal => { let literal_str = pair.as_str(); - if let Ok(value) = literal_str.parse::() { + if let Ok(value) = literal_str.parse::() { return Expression::Real(RealExpression::Literal(value)); } panic!("Unexpected literal: {}", literal_str)