Skip to content

Commit

Permalink
fix: added missing eqNullSafe column method (#183)
Browse files Browse the repository at this point in the history
* fix: added missing eqNullSafe column method

* fix: added tests and reformatted files
  • Loading branch information
eredzik authored Oct 18, 2024
1 parent 338a9db commit 4fbb688
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 0 deletions.
3 changes: 3 additions & 0 deletions sqlframe/base/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,9 @@ def isNotNull(self) -> Column:
new_expression = exp.Not(this=exp.Is(this=self.column_expression, expression=exp.Null()))
return Column(new_expression)

def eqNullSafe(self, other: ColumnOrLiteral) -> Column:
return self.binary_op(exp.NullSafeEQ, other)

def cast(
self,
dataType: t.Union[str, DataType, exp.DataType, exp.DataType.Type],
Expand Down
18 changes: 18 additions & 0 deletions tests/integration/engines/test_engine_column.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pytest

from sqlframe.base import types
from sqlframe.base.types import Row
from sqlframe.bigquery import BigQuerySession
from sqlframe.postgres import PostgresSession

Expand All @@ -14,6 +15,23 @@
pytest_plugins = ["tests.integration.fixtures"]


def test_columnt_eq_null_safe(get_session: t.Callable[[], _BaseSession], get_func):
session = get_session()
df1 = session.createDataFrame([Row(id=1, value="foo"), Row(id=2, value=None)])
lit = get_func("lit", session)
values = df1.select(
df1["value"] == "foo",
df1["value"].eqNullSafe("foo"),
df1["value"].eqNullSafe(None),
df1["id"].eqNullSafe(1),
df1["id"] == lit(None),
df1["id"].eqNullSafe(None),
).collect()

assert values[0] == (True, True, False, True, None, False)
assert values[1] == (None, False, True, False, None, False)


def test_column_get_item_array(get_session: t.Callable[[], _BaseSession], get_func):
session = get_session()
lit = get_func("lit", session)
Expand Down
11 changes: 11 additions & 0 deletions tests/integration/test_int_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,17 @@ def test_where_clause_single(
compare_frames(df_employee, dfs_employee)


def test_where_clause_eq_nullsafe(
pyspark_employee: PySparkDataFrame,
get_df: t.Callable[[str], _BaseDataFrame],
compare_frames: t.Callable,
):
employee = get_df("employee")
df_employee = pyspark_employee.where(F.col("age").eqNullSafe(F.lit(37)))
dfs_employee = employee.where(SF.col("age") == SF.lit(37))
compare_frames(df_employee, dfs_employee)


def test_where_clause_multiple_and(
pyspark_employee: PySparkDataFrame,
get_df: t.Callable[[str], _BaseDataFrame],
Expand Down

0 comments on commit 4fbb688

Please sign in to comment.