diff --git a/python/datafusion/expr.py b/python/datafusion/expr.py index bd6a86fb..f6a51ce2 100644 --- a/python/datafusion/expr.py +++ b/python/datafusion/expr.py @@ -393,6 +393,22 @@ def cast( return Expr(self.expr.cast(to)) + def between(self, low: Any, high: Any, negated: bool = False) -> Expr: + """Returns ``True`` if this expression is between a given range. + + Args: + low: lower bound of the range (inclusive). + high: higher bound of the range (inclusive). + negated: negates whether the expression is between a given range + """ + if not isinstance(low, Expr): + low = Expr.literal(low) + + if not isinstance(high, Expr): + high = Expr.literal(high) + + return Expr(self.expr.between(low.expr, high.expr, negated=negated)) + def rex_type(self) -> RexType: """Return the Rex Type of this expression. diff --git a/python/datafusion/tests/test_functions.py b/python/datafusion/tests/test_functions.py index 8e3c5139..9353f872 100644 --- a/python/datafusion/tests/test_functions.py +++ b/python/datafusion/tests/test_functions.py @@ -1024,3 +1024,34 @@ def test_cast(df, python_datatype, name: str, expected): result = df.collect() result = result[0] assert result.column(0) == result.column(1) + + +@pytest.mark.parametrize( + "negated, low, high, expected", + [ + pytest.param(False, 3, 5, {"filtered": [4, 5]}), + pytest.param(False, 4, 5, {"filtered": [4, 5]}), + pytest.param(True, 3, 5, {"filtered": [6]}), + pytest.param(True, 4, 6, []), + ], +) +def test_between(df, negated, low, high, expected): + df = df.filter(column("b").between(low, high, negated=negated)).select( + column("b").alias("filtered") + ) + + actual = df.collect() + + if expected: + actual = actual[0].to_pydict() + assert actual == expected + else: + assert len(actual) == 0 # the rows are empty + + +def test_between_default(df): + df = df.filter(column("b").between(3, 5)).select(column("b").alias("filtered")) + expected = {"filtered": [4, 5]} + + actual = df.collect()[0].to_pydict() + assert actual == expected diff --git a/src/expr.rs b/src/expr.rs index ab16f287..0e1a193f 100644 --- a/src/expr.rs +++ b/src/expr.rs @@ -293,6 +293,17 @@ impl PyExpr { expr.into() } + #[pyo3(signature = (low, high, negated=false))] + pub fn between(&self, low: PyExpr, high: PyExpr, negated: bool) -> PyExpr { + let expr = Expr::Between(Between::new( + Box::new(self.expr.clone()), + negated, + Box::new(low.into()), + Box::new(high.into()), + )); + expr.into() + } + /// A Rex (Row Expression) specifies a single row of data. That specification /// could include user defined functions or types. RexType identifies the row /// as one of the possible valid `RexTypes`.