Skip to content

Commit

Permalink
fix: handle all 'how' values in df.join
Browse files Browse the repository at this point in the history
  • Loading branch information
Thomzoy committed Jan 21, 2025
1 parent cf6d67f commit ec60f7c
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 19 deletions.
96 changes: 77 additions & 19 deletions sqlframe/base/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,23 @@
"SHUFFLE_REPLICATE_NL",
}

JOIN_TYPE_MAPPING = {
"inner": "inner",
"cross": "cross",
"outer": "full_outer",
"full": "full_outer",
"fullouter": "full_outer",
"left": "left_outer",
"leftouter": "left_outer",
"right": "right_outer",
"rightouter": "right_outer",
"semi": "left_semi",
"leftsemi": "left_semi",
"left_semi": "left_semi",
"anti": "left_anti",
"leftanti": "left_anti",
"left_anti": "left_anti",
}

DF = t.TypeVar("DF", bound="BaseDataFrame")

Expand Down Expand Up @@ -335,7 +352,8 @@ def _replace_cte_names_with_hashes(self, expression: exp.Select):
for cte in expression.ctes:
old_name_id = cte.args["alias"].this
new_hashed_id = exp.to_identifier(
self._create_hash_from_expression(cte.this), quoted=old_name_id.args["quoted"]
self._create_hash_from_expression(cte.this),
quoted=old_name_id.args["quoted"],
)
replacement_mapping[old_name_id] = new_hashed_id
expression = expression.transform(replace_id_value, replacement_mapping).assert_is(
Expand Down Expand Up @@ -390,7 +408,10 @@ def _convert_leaf_to_cte(
sequence_id = sequence_id or df.sequence_id
expression = df.expression.copy()
cte_expression, cte_name = df._create_cte_from_expression(
expression=expression, branch_id=self.branch_id, sequence_id=sequence_id, name=name
expression=expression,
branch_id=self.branch_id,
sequence_id=sequence_id,
name=name,
)
new_expression = df._add_ctes_to_expression(
exp.Select(), expression.ctes + [cte_expression]
Expand Down Expand Up @@ -494,12 +515,16 @@ def _add_ctes_to_expression(self, expression: exp.Select, ctes: t.List[exp.CTE])
)
new_cte_alias = self._create_hash_from_expression(cte.this)
replaced_cte_names[cte.args["alias"].this] = maybe_parse(
new_cte_alias, dialect=self.session.input_dialect, into=exp.Identifier
new_cte_alias,
dialect=self.session.input_dialect,
into=exp.Identifier,
)
cte.set(
"alias",
maybe_parse(
new_cte_alias, dialect=self.session.input_dialect, into=exp.TableAlias
new_cte_alias,
dialect=self.session.input_dialect,
into=exp.TableAlias,
),
)
existing_ctes.append(cte)
Expand Down Expand Up @@ -632,7 +657,10 @@ def _get_expressions(
exp.Literal.string(cache_storage_level),
]
expression = exp.Cache(
this=cache_table, expression=select_expression, lazy=True, options=options
this=cache_table,
expression=select_expression,
lazy=True,
options=options,
)

# We will drop the "view" if it exists before running the cache table
Expand Down Expand Up @@ -783,7 +811,8 @@ def select(self, *cols, **kwargs) -> Self:
for col in columns
]
return self.copy(
expression=self.expression.select(*[x.expression for x in columns], **kwargs), **kwargs
expression=self.expression.select(*[x.expression for x in columns], **kwargs),
**kwargs,
)

@operation(Operation.NO_OP)
Expand Down Expand Up @@ -875,19 +904,33 @@ def join(
) -> Self:
from sqlframe.base.functions import coalesce

if on is None:
if (on is None) and ("cross" not in how):
logger.warning("Got no value for on. This appears to change the join to a cross join.")
how = "cross"
if (on is not None) and ("cross" in how):
# Not a lot of doc, but Spark handles cross with predicate as an inner join
# https://learn.microsoft.com/en-us/dotnet/api/microsoft.spark.sql.dataframe.join
logger.warning("Got cross join with an 'on' value. This will result in an inner join.")
how = "inner"

other_df = other_df._convert_leaf_to_cte()
join_expression = self._add_ctes_to_expression(self.expression, other_df.expression.ctes)
# Normalizing join type:
join_type = JOIN_TYPE_MAPPING.get(how, how).replace("_", " ")
# We will determine actual "join on" expression later so we don't provide it at first
join_expression = join_expression.join(
join_expression.ctes[-1].alias, join_type=how.replace("_", " ")
join_expression.ctes[-1].alias,
join_type=join_type,
)
self_columns = self._get_outer_select_columns(join_expression)
other_columns = self._get_outer_select_columns(other_df.expression)
join_columns = self._ensure_and_normalize_cols(on)

select_columns = (
self_columns
if join_type in ["left anti", "left semi"]
else self_columns + other_columns
)
# If the two dataframes being joined come from the same branch, we then check if they have any columns that
# were created using the "branch_id" (df["column_name"]). If so, we know that we need to differentiate
# the two columns since they would end up with the same table name. We do this by checking for the unique
Expand Down Expand Up @@ -965,7 +1008,7 @@ def join(
if not isinstance(column.expression.this, exp.Star)
else column.sql()
)
for column in self_columns + other_columns
for column in select_columns
]
select_column_names = [
column_name
Expand All @@ -986,13 +1029,11 @@ def join(
if len(join_columns) > 1:
join_columns = [functools.reduce(lambda x, y: x & y, join_columns)]
join_clause = join_columns[0]
select_column_names = [
column.alias_or_name for column in self_columns + other_columns
]
select_column_names = [column.alias_or_name for column in select_columns]

# Update the on expression with the actual join clause to replace the dummy one from before
else:
select_column_names = [column.alias_or_name for column in self_columns + other_columns]
select_column_names = [column.alias_or_name for column in select_columns]
join_clause = None
join_expression.args["joins"][-1].set("on", join_clause.expression if join_clause else None)
new_df = self.copy(expression=join_expression)
Expand Down Expand Up @@ -1147,7 +1188,10 @@ def dropna(

def _get_explain_plan_rows(self) -> t.List[Row]:
sql_queries = self.sql(
pretty=False, optimize=False, as_list=True, dialect=self.session.execution_dialect
pretty=False,
optimize=False,
as_list=True,
dialect=self.session.execution_dialect,
)
if len(sql_queries) > 1:
raise ValueError("Cannot explain a DataFrame with multiple queries")
Expand All @@ -1158,7 +1202,9 @@ def _get_explain_plan_rows(self) -> t.List[Row]:
return results

def explain(
self, extended: t.Optional[t.Union[bool, str]] = None, mode: t.Optional[str] = None
self,
extended: t.Optional[t.Union[bool, str]] = None,
mode: t.Optional[str] = None,
) -> None:
"""Prints the (logical and physical) plans to the console for debugging purposes.
Expand Down Expand Up @@ -1630,7 +1676,10 @@ def first(self) -> t.Optional[Row]:
return self.head()

def show(
self, n: int = 20, truncate: t.Optional[t.Union[bool, int]] = None, vertical: bool = False
self,
n: int = 20,
truncate: t.Optional[t.Union[bool, int]] = None,
vertical: bool = False,
):
if vertical:
raise NotImplementedError("Vertical show is not yet supported")
Expand All @@ -1648,7 +1697,10 @@ def show(

def printSchema(self, level: t.Optional[int] = None) -> None:
def print_schema(
column_name: str, column_type: exp.DataType, nullable: bool, current_level: int
column_name: str,
column_type: exp.DataType,
nullable: bool,
current_level: int,
):
if level and current_level >= level:
return
Expand All @@ -1659,7 +1711,12 @@ def print_schema(
)
if column_type.this in (exp.DataType.Type.STRUCT, exp.DataType.Type.OBJECT):
for column_def in column_type.expressions:
print_schema(column_def.name, column_def.args["kind"], True, current_level + 1)
print_schema(
column_def.name,
column_def.args["kind"],
True,
current_level + 1,
)
if column_type.this == exp.DataType.Type.ARRAY:
for data_type in column_type.expressions:
print_schema("element", data_type, True, current_level + 1)
Expand Down Expand Up @@ -1691,7 +1748,8 @@ def createOrReplaceTempView(self, name: str) -> None:
df = self.copy()._convert_leaf_to_cte()
self.session.temp_views[name] = df
self.session.catalog.add_table(
name, [x.alias_or_name for x in self._get_outer_select_columns(df.expression)]
name,
[x.alias_or_name for x in self._get_outer_select_columns(df.expression)],
)

def count(self) -> int:
Expand Down
37 changes: 37 additions & 0 deletions tests/integration/test_int_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,43 @@ def test_join_inner(
compare_frames(df_joined, dfs_joined, sort=True)


@pytest.mark.parametrize(
"how",
[
"inner",
"cross",
"outer",
"full",
"fullouter",
"full_outer",
"left",
"leftouter",
"left_outer",
"right",
"rightouter",
"right_outer",
"semi",
"leftsemi",
"left_semi",
"anti",
"leftanti",
"left_anti",
],
)
def test_join_various_how(
pyspark_employee: PySparkDataFrame,
pyspark_store: PySparkDataFrame,
get_df: t.Callable[[str], BaseDataFrame],
compare_frames: t.Callable,
how: str,
):
employee = get_df("employee")
store = get_df("store")
df_joined = pyspark_employee.join(pyspark_store, on=["store_id"], how=how)
dfs_joined = employee.join(store, on=["store_id"], how=how)
compare_frames(df_joined, dfs_joined, sort=True)


def test_join_inner_no_select(
pyspark_employee: PySparkDataFrame,
pyspark_store: PySparkDataFrame,
Expand Down

0 comments on commit ec60f7c

Please sign in to comment.