diff --git a/crates/polars-sql/src/sql_expr.rs b/crates/polars-sql/src/sql_expr.rs index 8fffef9c1e8a..a7f45a94f3c3 100644 --- a/crates/polars-sql/src/sql_expr.rs +++ b/crates/polars-sql/src/sql_expr.rs @@ -527,55 +527,87 @@ impl SQLExprVisitor<'_> { op: &BinaryOperator, right: &SQLExpr, ) -> PolarsResult { - let left = self.visit_expr(left)?; - let mut right = self.visit_expr(right)?; - right = self.convert_temporal_strings(&left, &right); + let lhs = self.visit_expr(left)?; + let mut rhs = self.visit_expr(right)?; + rhs = self.convert_temporal_strings(&lhs, &rhs); Ok(match op { - SQLBinaryOperator::And => left.and(right), - SQLBinaryOperator::Divide => left / right, - SQLBinaryOperator::DuckIntegerDivide => left.floor_div(right).cast(DataType::Int64), - SQLBinaryOperator::Eq => left.eq(right), - SQLBinaryOperator::Gt => left.gt(right), - SQLBinaryOperator::GtEq => left.gt_eq(right), - SQLBinaryOperator::Lt => left.lt(right), - SQLBinaryOperator::LtEq => left.lt_eq(right), - SQLBinaryOperator::Minus => left - right, - SQLBinaryOperator::Modulo => left % right, - SQLBinaryOperator::Multiply => left * right, - SQLBinaryOperator::NotEq => left.eq(right).not(), - SQLBinaryOperator::Or => left.or(right), - SQLBinaryOperator::Plus => left + right, - SQLBinaryOperator::Spaceship => left.eq_missing(right), + SQLBinaryOperator::And => lhs.and(rhs), + SQLBinaryOperator::Divide => lhs / rhs, + SQLBinaryOperator::DuckIntegerDivide => lhs.floor_div(rhs).cast(DataType::Int64), + SQLBinaryOperator::Eq => lhs.eq(rhs), + SQLBinaryOperator::Gt => lhs.gt(rhs), + SQLBinaryOperator::GtEq => lhs.gt_eq(rhs), + SQLBinaryOperator::Lt => lhs.lt(rhs), + SQLBinaryOperator::LtEq => lhs.lt_eq(rhs), + SQLBinaryOperator::Minus => lhs - rhs, + SQLBinaryOperator::Modulo => lhs % rhs, + SQLBinaryOperator::Multiply => lhs * rhs, + SQLBinaryOperator::NotEq => lhs.eq(rhs).not(), + SQLBinaryOperator::Or => lhs.or(rhs), + SQLBinaryOperator::Plus => lhs + rhs, + SQLBinaryOperator::Spaceship => lhs.eq_missing(rhs), SQLBinaryOperator::StringConcat => { - left.cast(DataType::String) + right.cast(DataType::String) + lhs.cast(DataType::String) + rhs.cast(DataType::String) }, - SQLBinaryOperator::Xor => left.xor(right), + SQLBinaryOperator::Xor => lhs.xor(rhs), + SQLBinaryOperator::PGStartsWith => lhs.str().starts_with(rhs), // ---- // Regular expression operators // ---- - SQLBinaryOperator::PGRegexMatch => match right { - Expr::Literal(LiteralValue::String(_)) => left.str().contains(right, true), - _ => polars_bail!(SQLSyntax: "invalid pattern for '~' operator: {:?}", right), + // "a ~ b" + SQLBinaryOperator::PGRegexMatch => match rhs { + Expr::Literal(LiteralValue::String(_)) => lhs.str().contains(rhs, true), + _ => polars_bail!(SQLSyntax: "invalid pattern for '~' operator: {:?}", rhs), }, - SQLBinaryOperator::PGRegexNotMatch => match right { - Expr::Literal(LiteralValue::String(_)) => left.str().contains(right, true).not(), - _ => polars_bail!(SQLSyntax: "invalid pattern for '!~' operator: {:?}", right), + // "a !~ b" + SQLBinaryOperator::PGRegexNotMatch => match rhs { + Expr::Literal(LiteralValue::String(_)) => lhs.str().contains(rhs, true).not(), + _ => polars_bail!(SQLSyntax: "invalid pattern for '!~' operator: {:?}", rhs), }, - SQLBinaryOperator::PGRegexIMatch => match right { + // "a ~* b" + SQLBinaryOperator::PGRegexIMatch => match rhs { Expr::Literal(LiteralValue::String(pat)) => { - left.str().contains(lit(format!("(?i){}", pat)), true) + lhs.str().contains(lit(format!("(?i){}", pat)), true) }, - _ => polars_bail!(SQLSyntax: "invalid pattern for '~*' operator: {:?}", right), + _ => polars_bail!(SQLSyntax: "invalid pattern for '~*' operator: {:?}", rhs), }, - SQLBinaryOperator::PGRegexNotIMatch => match right { + // "a !~* b" + SQLBinaryOperator::PGRegexNotIMatch => match rhs { Expr::Literal(LiteralValue::String(pat)) => { - left.str().contains(lit(format!("(?i){}", pat)), true).not() + lhs.str().contains(lit(format!("(?i){}", pat)), true).not() }, _ => { - polars_bail!(SQLSyntax: "invalid pattern for '!~*' operator: {:?}", right) + polars_bail!(SQLSyntax: "invalid pattern for '!~*' operator: {:?}", rhs) }, }, + // ---- + // LIKE/ILIKE operators + // ---- + SQLBinaryOperator::PGLikeMatch + | SQLBinaryOperator::PGNotLikeMatch + | SQLBinaryOperator::PGILikeMatch + | SQLBinaryOperator::PGNotILikeMatch => { + let expr = if matches!( + op, + SQLBinaryOperator::PGLikeMatch | SQLBinaryOperator::PGNotLikeMatch + ) { + SQLExpr::Like { + negated: matches!(op, SQLBinaryOperator::PGNotLikeMatch), + expr: Box::new(left.clone()), + pattern: Box::new(right.clone()), + escape_char: None, + } + } else { + SQLExpr::ILike { + negated: matches!(op, SQLBinaryOperator::PGNotILikeMatch), + expr: Box::new(left.clone()), + pattern: Box::new(right.clone()), + escape_char: None, + } + }; + self.visit_expr(&expr)? + }, other => { polars_bail!(SQLInterface: "operator {:?} is not currently supported", other) }, diff --git a/py-polars/tests/unit/sql/test_operators.py b/py-polars/tests/unit/sql/test_operators.py index b0354292d85f..668ead0bc087 100644 --- a/py-polars/tests/unit/sql/test_operators.py +++ b/py-polars/tests/unit/sql/test_operators.py @@ -110,6 +110,25 @@ def test_is_between(foods_ipc_path: Path) -> None: assert not any((22 <= cal <= 30) for cal in out["calories"]) +def test_starts_with() -> None: + lf = pl.LazyFrame( + { + "x": ["aaa", "bbb", "a"], + "y": ["abc", "b", "aa"], + }, + ) + assert lf.sql("SELECT x ^@ 'a' AS x_starts_with_a FROM self").collect().rows() == [ + (True,), + (False,), + (True,), + ] + assert lf.sql("SELECT x ^@ y AS x_starts_with_y FROM self").collect().rows() == [ + (False,), + (True,), + (False,), + ] + + @pytest.mark.parametrize("match_float", [False, True]) def test_unary_ops_8890(match_float: bool) -> None: with pl.SQLContext( diff --git a/py-polars/tests/unit/sql/test_strings.py b/py-polars/tests/unit/sql/test_strings.py index 4d6e6c598986..1669a4fc7266 100644 --- a/py-polars/tests/unit/sql/test_strings.py +++ b/py-polars/tests/unit/sql/test_strings.py @@ -216,15 +216,15 @@ def test_string_lengths() -> None: ("_0%_", "LIKE", [2, 4]), ("%0", "LIKE", [2]), ("0%", "LIKE", [2]), - ("__0%", "LIKE", [2, 3]), - ("%*%", "ILIKE", [3]), - ("____", "LIKE", [4]), - ("a%C", "LIKE", []), - ("a%C", "ILIKE", [0, 1, 3]), - ("%C?", "ILIKE", [4]), - ("a0c?", "LIKE", [4]), - ("000", "LIKE", [2]), - ("00", "LIKE", []), + ("__0%", "~~", [2, 3]), + ("%*%", "~~*", [3]), + ("____", "~~", [4]), + ("a%C", "~~", []), + ("a%C", "~~*", [0, 1, 3]), + ("%C?", "~~*", [4]), + ("a0c?", "~~", [4]), + ("000", "~~", [2]), + ("00", "~~", []), ], ) def test_string_like(pattern: str, like: str, expected: list[int]) -> None: @@ -235,9 +235,9 @@ def test_string_like(pattern: str, like: str, expected: list[int]) -> None: } ) with pl.SQLContext(df=df) as ctx: - for not_ in ("", "NOT "): + for not_ in ("", ("NOT " if like.endswith("LIKE") else "!")): out = ctx.execute( - f"""SELECT idx FROM df WHERE txt {not_}{like} '{pattern}'""" + f"SELECT idx FROM df WHERE txt {not_}{like} '{pattern}'" ).collect() res = out["idx"].to_list()