Skip to content

Commit

Permalink
fix: handle all join types
Browse files Browse the repository at this point in the history
  • Loading branch information
Thomzoy committed Jan 22, 2025
1 parent a10315b commit e7e527a
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 10 deletions.
44 changes: 34 additions & 10 deletions sqlframe/base/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,23 @@
"SHUFFLE_REPLICATE_NL",
}

JOIN_TYPE_MAPPING = {
"inner": "inner",
"cross": "cross",
"outer": "full_outer",
"full": "full_outer",
"fullouter": "full_outer",
"left": "left_outer",
"leftouter": "left_outer",
"right": "right_outer",
"rightouter": "right_outer",
"semi": "left_semi",
"leftsemi": "left_semi",
"left_semi": "left_semi",
"anti": "left_anti",
"leftanti": "left_anti",
"left_anti": "left_anti",
}

DF = t.TypeVar("DF", bound="BaseDataFrame")

Expand Down Expand Up @@ -948,24 +965,33 @@ def join(
) -> Self:
from sqlframe.base.functions import coalesce

if on is None:
if (on is None) and ("cross" not in how):
logger.warning("Got no value for on. This appears to change the join to a cross join.")
how = "cross"
if (on is not None) and ("cross" in how):
# Not a lot of doc, but Spark handles cross with predicate as an inner join
# https://learn.microsoft.com/en-us/dotnet/api/microsoft.spark.sql.dataframe.join
logger.warning("Got cross join with an 'on' value. This will result in an inner join.")
how = "inner"

other_df = other_df._convert_leaf_to_cte()
join_expression = self._add_ctes_to_expression(self.expression, other_df.expression.ctes)
# We will determine actual "join on" expression later so we don't provide it at first
join_expression = join_expression.join(
join_expression.ctes[-1].alias, join_type=how.replace("_", " ")
)
join_type = JOIN_TYPE_MAPPING.get(how, how).replace("_", " ")
join_expression = join_expression.join(join_expression.ctes[-1].alias, join_type=join_type)
self_columns = self._get_outer_select_columns(join_expression)
other_columns = self._get_outer_select_columns(other_df.expression)
join_columns = self._ensure_and_normalize_cols(on)
self._handle_self_join(other_df, join_columns)

# Determines the join clause and select columns to be used passed on what type of columns were provided for
# the join. The columns returned changes based on how the on expression is provided.
if how != "cross":
if join_type != "cross":
select_columns = (
self_columns
if join_type in ["left_anti", "left_semi"]
else self_columns + other_columns
)
if isinstance(join_columns[0].expression, exp.Column):
"""
Unique characteristics of join on column names only:
Expand Down Expand Up @@ -996,7 +1022,7 @@ def join(
if not isinstance(column.expression.this, exp.Star)
else column.sql()
)
for column in self_columns + other_columns
for column in select_columns
]
select_column_names = [
column_name
Expand All @@ -1014,13 +1040,11 @@ def join(
* The left join dataframe columns go first and right come after. No sort preference is given to join columns
"""
join_clause = self._normalize_join_clause(join_columns, join_expression)
select_column_names = [
column.alias_or_name for column in self_columns + other_columns
]
select_column_names = [column.alias_or_name for column in select_columns]

# Update the on expression with the actual join clause to replace the dummy one from before
else:
select_column_names = [column.alias_or_name for column in self_columns + other_columns]
select_column_names = [column.alias_or_name for column in select_columns]
join_clause = None
join_expression.args["joins"][-1].set("on", join_clause.expression if join_clause else None)
new_df = self.copy(expression=join_expression)
Expand Down
37 changes: 37 additions & 0 deletions tests/integration/test_int_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,43 @@ def test_join_inner(
compare_frames(df_joined, dfs_joined, sort=True)


@pytest.mark.parametrize(
"how",
[
"inner",
"cross",
"outer",
"full",
"fullouter",
"full_outer",
"left",
"leftouter",
"left_outer",
"right",
"rightouter",
"right_outer",
"semi",
"leftsemi",
"left_semi",
"anti",
"leftanti",
"left_anti",
],
)
def test_join_various_how(
pyspark_employee: PySparkDataFrame,
pyspark_store: PySparkDataFrame,
get_df: t.Callable[[str], BaseDataFrame],
compare_frames: t.Callable,
how: str,
):
employee = get_df("employee")
store = get_df("store")
df_joined = pyspark_employee.join(pyspark_store, on=["store_id"], how=how)
dfs_joined = employee.join(store, on=["store_id"], how=how)
compare_frames(df_joined, dfs_joined, sort=True)


def test_join_inner_no_select(
pyspark_employee: PySparkDataFrame,
pyspark_store: PySparkDataFrame,
Expand Down

0 comments on commit e7e527a

Please sign in to comment.