Skip to content

Commit

Permalink
feat(python): Make the available concat alignment strategies more g…
Browse files Browse the repository at this point in the history
…eneric
  • Loading branch information
alexander-beedie committed Jan 9, 2025
1 parent 323b88c commit 167e18d
Show file tree
Hide file tree
Showing 3 changed files with 118 additions and 43 deletions.
4 changes: 4 additions & 0 deletions py-polars/polars/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,10 @@ def __arrow_c_stream__(self, requested_schema: object | None = None) -> object:
"diagonal_relaxed",
"horizontal",
"align",
"align_full",
"align_inner",
"align_left",
"align_right",
]
CorrelationMethod: TypeAlias = Literal["pearson", "spearman"]
DbReadEngine: TypeAlias = Literal["adbc", "connectorx"]
Expand Down
102 changes: 70 additions & 32 deletions py-polars/polars/functions/eager.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ def concat(
----------
items
DataFrames, LazyFrames, or Series to concatenate.
how : {'vertical', 'vertical_relaxed', 'diagonal', 'diagonal_relaxed', 'horizontal', 'align'}
Series only support the `vertical` strategy.
how : {'vertical', 'vertical_relaxed', 'diagonal', 'diagonal_relaxed', 'horizontal', 'align', 'align_full', 'align_inner', 'align_left', 'align_right'}
Note that `Series` only support the `vertical` strategy.
* vertical: Applies multiple `vstack` operations.
* vertical_relaxed: Same as `vertical`, but additionally coerces columns to
Expand All @@ -49,10 +49,14 @@ def concat(
their common supertype *if* they are mismatched (eg: Int32 → Int64).
* horizontal: Stacks Series from DataFrames horizontally and fills with `null`
if the lengths don't match.
* align: Combines frames horizontally, auto-determining the common key columns
and aligning rows using the same logic as `align_frames`; this behaviour is
patterned after a full outer join, but does not handle column-name collision.
(If you need more control, you should use a suitable join method instead).
* align, align_full, align_left, align_right: Combines frames horizontally,
auto-determining the common key columns and aligning rows using the same
logic as `align_frames` (note that "align" is an alias for "align_full").
The "align" strategy determines the type of join used to align the frames,
equivalent to the "how" parameter on `align_frames`. Note that the common
join columns are automatically coalesced, but other column collisions
will raise an error (if you need more control over this you should use
a suitable `join` method directly).
rechunk
Make sure that the result data is in contiguous memory.
parallel
Expand Down Expand Up @@ -100,6 +104,9 @@ def concat(
│ 2 ┆ 4 ┆ 6 ┆ 8 ┆ 10 │
└─────┴─────┴─────┴─────┴─────┘
The "diagonal" strategy allows for some frames to have missing columns,
the values for which are filled with `null`:
>>> df_d1 = pl.DataFrame({"a": [1], "b": [3]})
>>> df_d2 = pl.DataFrame({"a": [2], "c": [4]})
>>> pl.concat([df_d1, df_d2], how="diagonal")
Expand All @@ -113,10 +120,12 @@ def concat(
│ 2 ┆ null ┆ 4 │
└─────┴──────┴──────┘
The "align" strategies require at least one common column to align on:
>>> df_a1 = pl.DataFrame({"id": [1, 2], "x": [3, 4]})
>>> df_a2 = pl.DataFrame({"id": [2, 3], "y": [5, 6]})
>>> df_a3 = pl.DataFrame({"id": [1, 3], "z": [7, 8]})
>>> pl.concat([df_a1, df_a2, df_a3], how="align")
>>> pl.concat([df_a1, df_a2, df_a3], how="align") # equivalent to "align_full"
shape: (3, 4)
┌─────┬──────┬──────┬──────┐
│ id ┆ x ┆ y ┆ z │
Expand All @@ -127,6 +136,34 @@ def concat(
│ 2 ┆ 4 ┆ 5 ┆ null │
│ 3 ┆ null ┆ 6 ┆ 8 │
└─────┴──────┴──────┴──────┘
>>> pl.concat([df_a1, df_a2, df_a3], how="align_left")
shape: (2, 4)
┌─────┬─────┬──────┬──────┐
│ id ┆ x ┆ y ┆ z │
│ --- ┆ --- ┆ --- ┆ --- │
│ i64 ┆ i64 ┆ i64 ┆ i64 │
╞═════╪═════╪══════╪══════╡
│ 1 ┆ 3 ┆ null ┆ 7 │
│ 2 ┆ 4 ┆ 5 ┆ null │
└─────┴─────┴──────┴──────┘
>>> pl.concat([df_a1, df_a2, df_a3], how="align_right")
shape: (2, 4)
┌─────┬──────┬──────┬─────┐
│ id ┆ x ┆ y ┆ z │
│ --- ┆ --- ┆ --- ┆ --- │
│ i64 ┆ i64 ┆ i64 ┆ i64 │
╞═════╪══════╪══════╪═════╡
│ 1 ┆ null ┆ null ┆ 7 │
│ 3 ┆ null ┆ 6 ┆ 8 │
└─────┴──────┴──────┴─────┘
>>> pl.concat([df_a1, df_a2, df_a3], how="align_inner")
shape: (0, 4)
┌─────┬─────┬─────┬─────┐
│ id ┆ x ┆ y ┆ z │
│ --- ┆ --- ┆ --- ┆ --- │
│ i64 ┆ i64 ┆ i64 ┆ i64 │
╞═════╪═════╪═════╪═════╡
└─────┴─────┴─────┴─────┘
""" # noqa: W505
# unpack/standardise (handles generator input)
elems = list(items)
Expand All @@ -139,47 +176,48 @@ def concat(
):
return elems[0]

if how == "align":
if how.startswith("align"):
if not isinstance(elems[0], (pl.DataFrame, pl.LazyFrame)):
msg = f"'align' strategy is not supported for {type(elems[0]).__name__!r}"
msg = f"{how!r} strategy is not supported for {type(elems[0]).__name__!r}"
raise TypeError(msg)

# establish common columns, maintaining the order in which they appear
all_columns = list(chain.from_iterable(e.collect_schema() for e in elems))
key = {v: k for k, v in enumerate(ordered_unique(all_columns))}
output_column_order = list(key)
common_cols = sorted(
reduce(
lambda x, y: set(x) & set(y), # type: ignore[arg-type, return-value]
chain(e.collect_schema() for e in elems),
),
key=lambda k: key.get(k, 0),
)
# we require at least one key column for 'align'
# we require at least one key column for 'align' strategies
if not common_cols:
msg = "'align' strategy requires at least one common column"
msg = f"{how!r} strategy requires at least one common column"
raise InvalidOperationError(msg)

# align the frame data using a full outer join with no suffix-resolution
# (so we raise an error in case of column collision, like "horizontal")
lf: LazyFrame = reduce(
lambda x, y: (
x.join(
y,
how="full",
on=common_cols,
suffix="_PL_CONCAT_RIGHT",
maintain_order="right_left",
)
# Coalesce full outer join columns
.with_columns(
F.coalesce([name, f"{name}_PL_CONCAT_RIGHT"])
for name in common_cols
)
.drop([f"{name}_PL_CONCAT_RIGHT" for name in common_cols])
),
[df.lazy() for df in elems],
).sort(by=common_cols)

# align frame data using a join, with no suffix-resolution (will raise
# a DuplicateError in case of column collision, same as "horizontal")
join_method: JoinStrategy = (
"full" if how == "align" else how.removeprefix("align_") # type: ignore[assignment]
)
lf: LazyFrame = (
reduce(
lambda x, y: (
x.join(
y,
on=common_cols,
how=join_method,
maintain_order="right_left",
coalesce=True,
)
),
[df.lazy() for df in elems],
)
.sort(by=common_cols)
.select(*output_column_order)
)
eager = isinstance(elems[0], pl.DataFrame)
return lf.collect() if eager else lf # type: ignore[return-value]

Expand Down
55 changes: 44 additions & 11 deletions py-polars/tests/unit/functions/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,31 +18,64 @@ def test_concat_align() -> None:
b = pl.DataFrame({"a": ["a", "b", "c"], "c": [5.5, 6.0, 7.5]})
c = pl.DataFrame({"a": ["a", "b", "c", "d", "e"], "d": ["w", "x", "y", "z", None]})

result = pl.concat([a, b, c], how="align")
for align_full in ("align", "align_full"):
result = pl.concat([a, b, c], how=align_full)
expected = pl.DataFrame(
{
"a": ["a", "b", "c", "d", "e", "e"],
"b": [1, 2, None, 4, 5, 6],
"c": [5.5, 6.0, 7.5, None, None, None],
"d": ["w", "x", "y", "z", None, None],
}
)
assert_frame_equal(result, expected)

result = pl.concat([a, b, c], how="align_left")
expected = pl.DataFrame(
{
"a": ["a", "b", "c", "d", "e", "e"],
"b": [1, 2, None, 4, 5, 6],
"c": [5.5, 6.0, 7.5, None, None, None],
"d": ["w", "x", "y", "z", None, None],
"a": ["a", "b", "d", "e", "e"],
"b": [1, 2, 4, 5, 6],
"c": [5.5, 6.0, None, None, None],
"d": ["w", "x", "z", None, None],
}
)
assert_frame_equal(result, expected)

result = pl.concat([a, b, c], how="align_right")
expected = pl.DataFrame(
{
"a": ["a", "b", "c", "d", "e"],
"b": [1, 2, None, None, None],
"c": [5.5, 6.0, 7.5, None, None],
"d": ["w", "x", "y", "z", None],
}
)
assert_frame_equal(result, expected)

def test_concat_align_no_common_cols() -> None:
result = pl.concat([a, b, c], how="align_inner")
expected = pl.DataFrame(
{
"a": ["a", "b"],
"b": [1, 2],
"c": [5.5, 6.0],
"d": ["w", "x"],
}
)
assert_frame_equal(result, expected)


@pytest.mark.parametrize(
"strategy", ["align", "align_full", "align_left", "align_right"]
)
def test_concat_align_no_common_cols(strategy: ConcatMethod) -> None:
df1 = pl.DataFrame({"a": [1, 2], "b": [1, 2]})
df2 = pl.DataFrame({"c": [3, 4], "d": [3, 4]})

with pytest.raises(
InvalidOperationError,
match="'align' strategy requires at least one common column",
match=f"{strategy!r} strategy requires at least one common column",
):
pl.concat((df1, df2), how="align")


data2 = pl.DataFrame({"field3": [3, 4], "field4": ["C", "D"]})
pl.concat((df1, df2), how=strategy)


@pytest.mark.parametrize(
Expand Down

0 comments on commit 167e18d

Please sign in to comment.