From d70041c802bc85bfcdb1a11be0345d9c5e3a0f77 Mon Sep 17 00:00:00 2001 From: Duncan Fairbanks Date: Fri, 22 Dec 2023 12:35:09 -0800 Subject: [PATCH 1/3] use num_traits::Float to abstract over f32/f64 --- Cargo.toml | 1 + src/evaluate.rs | 88 +++++++++++++++++++++++++---------------------- src/expression.rs | 42 +++++++++++----------- src/lib.rs | 10 ++++-- src/parse.rs | 31 ++++++++++------- 5 files changed, 93 insertions(+), 79 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 02fd97d..f64ba3d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,6 +11,7 @@ readme = "README.md" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] +num-traits = "0.2.17" once_cell = "1.19.0" pest = "2.7.5" pest_derive = "2.7.5" diff --git a/src/evaluate.rs b/src/evaluate.rs index 7961619..db17ed7 100644 --- a/src/evaluate.rs +++ b/src/evaluate.rs @@ -1,4 +1,4 @@ -use crate::{BoolExpression, RealExpression, StringExpression}; +use crate::{BoolExpression, FloatExt, RealExpression, StringExpression}; #[cfg(feature = "rayon")] use rayon::prelude::{ @@ -8,14 +8,14 @@ use rayon::prelude::{ /// To speed up string comparisons, we use string interning. pub type StringId = u32; -impl BoolExpression { +impl BoolExpression { /// Calculates the `bool`-valued results of the expression component-wise. - pub fn evaluate, S: AsRef<[StringId]>>( + pub fn evaluate, S: AsRef<[StringId]>>( &self, real_bindings: &[R], string_bindings: &[S], mut get_string_literal_id: impl FnMut(&str) -> StringId, - registers: &mut Registers, + registers: &mut Registers, ) -> Vec { validate_bindings(real_bindings, registers.register_length); validate_bindings(string_bindings, registers.register_length); @@ -27,12 +27,12 @@ impl BoolExpression { ) } - fn evaluate_recursive, S: AsRef<[StringId]>>( + fn evaluate_recursive, S: AsRef<[StringId]>>( &self, real_bindings: &[R], string_bindings: &[S], get_string_literal_id: &mut impl FnMut(&str) -> StringId, - registers: &mut Registers, + registers: &mut Registers, ) -> Vec { match self { Self::And(lhs, rhs) => evaluate_binary_logic( @@ -123,22 +123,26 @@ impl BoolExpression { } } -impl RealExpression { - pub fn evaluate_without_vars(&self, registers: &mut Registers) -> Vec { +impl RealExpression { + pub fn evaluate_without_vars(&self, registers: &mut Registers) -> Vec { self.evaluate::<[_; 0]>(&[], registers) } /// Calculates the real-valued results of the expression component-wise. - pub fn evaluate>(&self, bindings: &[R], registers: &mut Registers) -> Vec { + pub fn evaluate>( + &self, + bindings: &[R], + registers: &mut Registers, + ) -> Vec { validate_bindings(bindings, registers.register_length); self.evaluate_recursive(bindings, registers) } - fn evaluate_recursive>( + fn evaluate_recursive>( &self, bindings: &[R], - registers: &mut Registers, - ) -> Vec { + registers: &mut Registers, + ) -> Vec { match self { Self::Add(lhs, rhs) => evaluate_binary_real_op( |lhs, rhs| lhs + rhs, @@ -200,13 +204,13 @@ fn validate_bindings>(input_bindings: &[B], expected_length: us } } -fn evaluate_binary_real_op>( - op: fn(f64, f64) -> f64, - lhs: &RealExpression, - rhs: &RealExpression, +fn evaluate_binary_real_op>( + op: fn(Real, Real) -> Real, + lhs: &RealExpression, + rhs: &RealExpression, bindings: &[R], - registers: &mut Registers, -) -> Vec { + registers: &mut Registers, +) -> Vec { // Before doing recursive evaluation, we check first if we already have // input values in our bindings. This avoids unnecessary copies. let mut lhs_reg = None; @@ -254,12 +258,12 @@ fn evaluate_binary_real_op>( output } -fn evaluate_unary_real_op>( - op: fn(f64) -> f64, - only: &RealExpression, +fn evaluate_unary_real_op>( + op: fn(Real) -> Real, + only: &RealExpression, bindings: &[R], - registers: &mut Registers, -) -> Vec { + registers: &mut Registers, +) -> Vec { // Before doing recursive evaluation, we check first if we already have // input values in our bindings. This avoids unnecessary copies. let mut only_reg = None; @@ -287,12 +291,12 @@ fn evaluate_unary_real_op>( output } -fn evaluate_real_comparison>( - op: fn(f64, f64) -> bool, - lhs: &RealExpression, - rhs: &RealExpression, +fn evaluate_real_comparison>( + op: fn(Real, Real) -> bool, + lhs: &RealExpression, + rhs: &RealExpression, bindings: &[R], - registers: &mut Registers, + registers: &mut Registers, ) -> Vec { // Before doing recursive evaluation, we check first if we already have // input values in our bindings. This avoids unnecessary copies. @@ -341,13 +345,13 @@ fn evaluate_real_comparison>( output } -fn evaluate_string_comparison>( +fn evaluate_string_comparison>( op: fn(StringId, StringId) -> bool, lhs: &StringExpression, rhs: &StringExpression, bindings: &[S], mut get_string_literal_id: impl FnMut(&str) -> StringId, - registers: &mut Registers, + registers: &mut Registers, ) -> Vec { let mut lhs_reg = None; let lhs_values = match lhs { @@ -402,14 +406,14 @@ fn evaluate_string_comparison>( output } -fn evaluate_binary_logic, S: AsRef<[StringId]>>( +fn evaluate_binary_logic, S: AsRef<[StringId]>>( op: fn(bool, bool) -> bool, - lhs: &BoolExpression, - rhs: &BoolExpression, + lhs: &BoolExpression, + rhs: &BoolExpression, real_bindings: &[R], string_bindings: &[S], get_string_literal_id: &mut impl FnMut(&str) -> StringId, - registers: &mut Registers, + registers: &mut Registers, ) -> Vec { let lhs_values = lhs.evaluate_recursive( real_bindings, @@ -451,13 +455,13 @@ fn evaluate_binary_logic, S: AsRef<[StringId]>>( output } -fn evaluate_unary_logic, S: AsRef<[StringId]>>( +fn evaluate_unary_logic, S: AsRef<[StringId]>>( op: fn(bool) -> bool, - only: &BoolExpression, + only: &BoolExpression, real_bindings: &[R], string_bindings: &[S], get_string_literal_id: &mut impl FnMut(&str) -> StringId, - registers: &mut Registers, + registers: &mut Registers, ) -> Vec { let only_values = only.evaluate_recursive( real_bindings, @@ -487,15 +491,15 @@ fn evaluate_unary_logic, S: AsRef<[StringId]>>( /// /// Attempts to minimize allocations by recycling registers after intermediate /// calculations have finished. -pub struct Registers { +pub struct Registers { num_allocations: usize, - real_registers: Vec>, + real_registers: Vec>, bool_registers: Vec>, string_registers: Vec>, register_length: usize, } -impl Registers { +impl Registers { pub fn new(register_length: usize) -> Self { Self { num_allocations: 0, @@ -506,7 +510,7 @@ impl Registers { } } - fn recycle_real(&mut self, mut used: Vec) { + fn recycle_real(&mut self, mut used: Vec) { used.clear(); self.real_registers.push(used); } @@ -521,7 +525,7 @@ impl Registers { self.string_registers.push(used); } - fn allocate_real(&mut self) -> Vec { + fn allocate_real(&mut self) -> Vec { self.real_registers.pop().unwrap_or_else(|| { self.num_allocations += 1; Vec::with_capacity(self.register_length) diff --git a/src/expression.rs b/src/expression.rs index a992ce9..93522fe 100644 --- a/src/expression.rs +++ b/src/expression.rs @@ -1,28 +1,28 @@ /// Top-level parseable calculation. #[derive(Clone, Debug)] -pub enum Expression { - Boolean(BoolExpression), - Real(RealExpression), +pub enum Expression { + Boolean(BoolExpression), + Real(RealExpression), String(StringExpression), } /// A `bool`-valued expression. #[derive(Clone, Debug)] -pub enum BoolExpression { +pub enum BoolExpression { // Binary logic. - And(Box, Box), - Or(Box, Box), + And(Box>, Box>), + Or(Box>, Box>), // Unary logic. - Not(Box), + Not(Box>), // Real comparisons. - Equal(Box, Box), - Greater(Box, Box), - GreaterEqual(Box, Box), - Less(Box, Box), - LessEqual(Box, Box), - NotEqual(Box, Box), + Equal(Box>, Box>), + Greater(Box>, Box>), + GreaterEqual(Box>, Box>), + Less(Box>, Box>), + LessEqual(Box>, Box>), + NotEqual(Box>, Box>), // String comparisons. StrEqual(StringExpression, StringExpression), @@ -31,19 +31,19 @@ pub enum BoolExpression { /// An `f64`-valued expression. #[derive(Clone, Debug)] -pub enum RealExpression { +pub enum RealExpression { // Binary real ops. - Add(Box, Box), - Div(Box, Box), - Mul(Box, Box), - Pow(Box, Box), - Sub(Box, Box), + Add(Box>, Box>), + Div(Box>, Box>), + Mul(Box>, Box>), + Pow(Box>, Box>), + Sub(Box>, Box>), // Unary real ops. - Neg(Box), + Neg(Box>), // Constant. - Literal(f64), + Literal(Real), // Input variable. Binding(BindingId), diff --git a/src/lib.rs b/src/lib.rs index d001373..14c0fb3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -53,6 +53,10 @@ pub fn empty_binding_map(_var_name: &str) -> BindingId { panic!("Empty binding map") } +pub trait FloatExt: num_traits::Float + std::str::FromStr + Send + Sync {} +impl FloatExt for f32 {} +impl FloatExt for f64 {} + #[cfg(test)] mod tests { use super::*; @@ -84,17 +88,17 @@ mod tests { fn real_op_precedence() { let mut registers = Registers::new(1); - let parsed = Expression::parse("1 * 2 + 3 * 4", empty_binding_map).unwrap(); + let parsed = Expression::::parse("1 * 2 + 3 * 4", empty_binding_map).unwrap(); let real = parsed.unwrap_real(); let output = real.evaluate_without_vars(&mut registers); assert_eq!(&output, &[14.0]); - let parsed = Expression::parse("8 / 4 * 3", empty_binding_map).unwrap(); + let parsed = Expression::::parse("8 / 4 * 3", empty_binding_map).unwrap(); let real = parsed.unwrap_real(); let output = real.evaluate_without_vars(&mut registers); assert_eq!(&output, &[6.0]); - let parsed = Expression::parse("4 ^ 3 ^ 2", empty_binding_map).unwrap(); + let parsed = Expression::::parse("4 ^ 3 ^ 2", empty_binding_map).unwrap(); let real = parsed.unwrap_real(); let output = real.evaluate_without_vars(&mut registers); assert_eq!(&output, &[262144.0]); diff --git a/src/parse.rs b/src/parse.rs index 1357e39..8bdbac3 100644 --- a/src/parse.rs +++ b/src/parse.rs @@ -1,11 +1,13 @@ use crate::expression::{BindingId, BoolExpression, Expression, RealExpression}; use crate::StringExpression; +use num_traits::Float; use once_cell::sync::Lazy; use pest::iterators::Pairs; use pest::pratt_parser::{Assoc, Op, PrattParser}; use pest::Parser; use pest_derive::Parser; use std::collections::HashSet; +use std::str::FromStr; #[derive(Parser)] #[grammar = "grammar.pest"] // relative to project `src` @@ -14,9 +16,9 @@ struct ExpressionParser; // Boxed because error is much larger than Ok variant in most results. pub type ParseError = Box>; -impl Expression { +impl Expression { /// Assume this expression is real-valued. - pub fn unwrap_real(self) -> RealExpression { + pub fn unwrap_real(self) -> RealExpression { match self { Self::Real(r) => r, _ => panic!("Expected Real"), @@ -32,7 +34,7 @@ impl Expression { } /// Assume this expression is boolean-valued. - pub fn unwrap_bool(self) -> BoolExpression { + pub fn unwrap_bool(self) -> BoolExpression { match self { Self::Boolean(b) => b, _ => panic!("Expected Boolean"), @@ -89,7 +91,10 @@ static PRATT_PARSER: Lazy> = Lazy::new(|| { .op(Op::infix(power, Right)) }); -fn parse_recursive(pairs: Pairs, binding_map: &impl Fn(&str) -> BindingId) -> Expression { +fn parse_recursive( + pairs: Pairs, + binding_map: &impl Fn(&str) -> BindingId, +) -> Expression { PRATT_PARSER .map_primary(|pair| match pair.as_rule() { Rule::bool_expr => parse_recursive(pair.into_inner(), binding_map), @@ -97,7 +102,7 @@ fn parse_recursive(pairs: Pairs, binding_map: &impl Fn(&str) -> BindingId) Rule::string_expr => parse_recursive(pair.into_inner(), binding_map), Rule::real_literal => { let literal_str = pair.as_str(); - if let Ok(value) = literal_str.parse::() { + if let Ok(value) = literal_str.parse::() { return Expression::Real(RealExpression::Literal(value)); } panic!("Unexpected literal: {}", literal_str) @@ -210,12 +215,12 @@ mod tests { #[test] fn parse_variable_names() { - let vars = Expression::parse_real_variable_names("v1_dest + x + y + z99").unwrap(); + let vars = Expression::::parse_real_variable_names("v1_dest + x + y + z99").unwrap(); assert!(vars.contains("x"), "{vars:?}"); assert!(vars.contains("y"), "{vars:?}"); assert!(vars.contains("z99"), "{vars:?}"); assert!(vars.contains("v1_dest"), "{vars:?}"); - let vars = Expression::parse_string_variable_names("x == \"W\"").unwrap(); + let vars = Expression::::parse_string_variable_names("x == \"W\"").unwrap(); assert!(vars.contains("x"), "{vars:?}"); } @@ -228,11 +233,11 @@ mod tests { _ => unreachable!(), } } - Expression::parse("x == y", binding_map).unwrap(); - Expression::parse("x != y", binding_map).unwrap(); - Expression::parse("x > y", binding_map).unwrap(); - Expression::parse("x < y", binding_map).unwrap(); - Expression::parse("x <= y", binding_map).unwrap(); - Expression::parse("x >= y", binding_map).unwrap(); + Expression::::parse("x == y", binding_map).unwrap(); + Expression::::parse("x != y", binding_map).unwrap(); + Expression::::parse("x > y", binding_map).unwrap(); + Expression::::parse("x < y", binding_map).unwrap(); + Expression::::parse("x <= y", binding_map).unwrap(); + Expression::::parse("x >= y", binding_map).unwrap(); } } From 9633e303f6b164ad4f3c8c630696f68546d39f58 Mon Sep 17 00:00:00 2001 From: Duncan Fairbanks Date: Fri, 22 Dec 2023 12:38:57 -0800 Subject: [PATCH 2/3] change test to use f32 instead of f64 --- src/lib.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 14c0fb3..87808f5 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -202,9 +202,9 @@ mod tests { let real = parsed.unwrap_real(); const LEN: i32 = 10_000_000; - let x: Vec<_> = (0..LEN).map(|i| i as f64).collect(); - let y: Vec<_> = (0..LEN).map(|i| (LEN - i) as f64).collect(); - let z: Vec<_> = (0..LEN).map(|i| ((LEN / 2) - i) as f64).collect(); + let x: Vec<_> = (0..LEN).map(|i| i as f32).collect(); + let y: Vec<_> = (0..LEN).map(|i| (LEN - i) as f32).collect(); + let z: Vec<_> = (0..LEN).map(|i| ((LEN / 2) - i) as f32).collect(); let bindings = &[x, y, z]; let mut registers = Registers::new(LEN as usize); From 6946111df9ef72beb5840f27296b3a28ab1d49ae Mon Sep 17 00:00:00 2001 From: Duncan Fairbanks Date: Fri, 22 Dec 2023 12:52:04 -0800 Subject: [PATCH 3/3] bump to version 0.3.0 --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index f64ba3d..0735192 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "vector-expr" authors = ["Duncan Fairbanks "] -version = "0.2.0" +version = "0.3.0" edition = "2021" description = "Vectorized expression parser and evaluator" license = "MIT OR Apache-2.0"