diff --git a/sqlframe/base/column.py b/sqlframe/base/column.py index 73ece8f..ea3e30c 100644 --- a/sqlframe/base/column.py +++ b/sqlframe/base/column.py @@ -9,6 +9,7 @@ import sqlglot from sqlglot import Dialect from sqlglot import expressions as exp +from sqlglot.expressions import paren from sqlglot.helper import flatten, is_iterable from sqlglot.optimizer.normalize_identifiers import normalize_identifiers @@ -63,10 +64,10 @@ def __le__(self, other: ColumnOrLiteral) -> Column: return self.binary_op(exp.LTE, other) def __and__(self, other: ColumnOrLiteral) -> Column: - return self.binary_op(exp.And, other) + return self.binary_op(exp.And, other, paren=True) def __or__(self, other: ColumnOrLiteral) -> Column: - return self.binary_op(exp.Or, other) + return self.binary_op(exp.Or, other, paren=True) def __mod__(self, other: ColumnOrLiteral) -> Column: return self.binary_op(exp.Mod, other, paren=True) diff --git a/tests/unit/standalone/test_column.py b/tests/unit/standalone/test_column.py index 28e4653..bcd388b 100644 --- a/tests/unit/standalone/test_column.py +++ b/tests/unit/standalone/test_column.py @@ -31,13 +31,13 @@ def test_ge(): def test_and(): assert ( (F.col("cola") == F.col("colb")) & (F.col("colc") == F.col("cold")) - ).sql() == "cola = colb AND colc = cold" + ).sql() == "(cola = colb AND colc = cold)" def test_or(): assert ( (F.col("cola") == F.col("colb")) | (F.col("colc") == F.col("cold")) - ).sql() == "cola = colb OR colc = cold" + ).sql() == "(cola = colb OR colc = cold)" def test_mod(): @@ -89,7 +89,7 @@ def test_invert(): def test_invert_conjuction(): - assert (~(F.col("cola") | F.col("colb"))).sql() == "NOT (cola OR colb)" + assert (~(F.col("cola") | F.col("colb"))).sql() == "NOT ((cola OR colb))" def test_paren(): diff --git a/tests/unit/standalone/test_functions.py b/tests/unit/standalone/test_functions.py index f6a9125..319e81f 100644 --- a/tests/unit/standalone/test_functions.py +++ b/tests/unit/standalone/test_functions.py @@ -2611,7 +2611,7 @@ def test_sort_array(expression, expected): SF.length(y) - SF.length(x) ), ), - "ARRAY_SORT(cola, (x, y) -> CASE WHEN x IS NULL OR y IS NULL THEN 0 ELSE (LENGTH(y) - LENGTH(x)) END)", + "ARRAY_SORT(cola, (x, y) -> CASE WHEN (x IS NULL OR y IS NULL) THEN 0 ELSE (LENGTH(y) - LENGTH(x)) END)", ), ], )