Skip to content

Commit

Permalink
Add syntactic sugar to filter
Browse files Browse the repository at this point in the history
  • Loading branch information
mcrumiller committed Oct 12, 2023
1 parent 468dd7d commit 1cb983f
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 5 deletions.
60 changes: 55 additions & 5 deletions py-polars/polars/dataframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -3899,7 +3899,9 @@ def insert_at_idx(self, index: int, series: Series) -> Self:

def filter(
self,
predicate: (Expr | str | Series | list[bool] | np.ndarray[Any, Any] | bool),
*predicates: (Expr | str | Series | list[bool] | np.ndarray[Any, Any] | bool)
| None,
**constraints: dict[str, Any],
) -> DataFrame:
"""
Filter the rows in the DataFrame based on a predicate expression.
Expand All @@ -3908,8 +3910,10 @@ def filter(
Parameters
----------
predicate
predicates
Expression that evaluates to a boolean Series.
constraints
Column filters. Use name=value to filter column name by the supplied value.
Examples
--------
Expand Down Expand Up @@ -3959,12 +3963,58 @@ def filter(
│ 3 ┆ 8 ┆ c │
└─────┴─────┴─────┘
Provided multiple filters using alternative syntax:
>>> df.filter(pl.col("foo") == 1, pl.col("ham") == "a")
shape: (2, 3)
┌─────┬─────┬─────┐
│ foo ┆ bar ┆ ham │
│ --- ┆ --- ┆ --- │
│ i64 ┆ i64 ┆ str │
╞═════╪═════╪═════╡
│ 1 ┆ 6 ┆ a │
└─────┴─────┴─────┘
Use column-supplied filters:
>>> df.filter(foo=1, ham="a")
shape: (2, 3)
┌─────┬─────┬─────┐
│ foo ┆ bar ┆ ham │
│ --- ┆ --- ┆ --- │
│ i64 ┆ i64 ┆ str │
╞═════╪═════╪═════╡
│ 1 ┆ 6 ┆ a │
└─────┴─────┴─────┘
"""
if _check_for_numpy(predicate) and isinstance(predicate, np.ndarray):
predicate = pl.Series(predicate)
has_predicates = len(predicates) > 0
has_constraints = len(constraints) > 0
if has_predicates:
predicates = [ # type: ignore[assignment]
pl.Series(predicate)
if _check_for_numpy(predicate) and isinstance(predicate, np.ndarray)
else predicate
for predicate in predicates
]
predicate_iter = iter(predicates)
predicates = next(predicate_iter) # type: ignore[arg-type]
for pred in predicate_iter:
predicates = predicates & pred # type: ignore[assignment, operator]
if has_constraints:
# basic column filters provided
constraint_iter = iter(constraints.items())
key, value = next(constraint_iter)
constraints = col(key) == value # type: ignore[assignment]
for key, value in constraint_iter:
constraints = constraints & (col(key) == value) # type: ignore[assignment]

if has_predicates and has_constraints:
predicates = predicates & constraints # type: ignore[operator]
elif has_constraints:
predicates = constraints # type: ignore[assignment]

return (
self.lazy().filter(predicate).collect(_eager=True) # type: ignore[arg-type]
self.lazy().filter(predicates).collect(_eager=True) # type: ignore[arg-type]
)

@overload
Expand Down
21 changes: 21 additions & 0 deletions py-polars/tests/unit/dataframe/test_df.py
Original file line number Diff line number Diff line change
Expand Up @@ -2815,6 +2815,27 @@ def test_filter_sequence() -> None:
assert df.filter(np.array([True, False, True]))["a"].to_list() == [1, 3]


def test_filter_multiple_predicates() -> None:
df = pl.DataFrame(
{"a": [1, 1, 1, 2, 2], "b": [1, 1, 2, 2, 2], "c": [1, 1, 2, 3, 4]}
)

# using multiple predicates
out = df.filter(pl.col("a") == 1, pl.col("b") <= 2)
expected = pl.DataFrame({"a": [1, 1, 1], "b": [1, 1, 2], "c": [1, 1, 2]})
assert_frame_equal(out, expected)

# using multiple kwargs
out = df.filter(a=1, b=2) # type: ignore[arg-type]
expected = pl.DataFrame({"a": [1], "b": [2], "c": [2]})
assert_frame_equal(out, expected)

# using both
out = df.filter(pl.col("a") == 1, pl.col("b") <= 2, a=1, b=2) # type: ignore[arg-type]
expected = pl.DataFrame({"a": [1], "b": [2], "c": [2]})
assert_frame_equal(out, expected)


def test_indexing_set() -> None:
df = pl.DataFrame({"bool": [True, True], "str": ["N/A", "N/A"], "nr": [1, 2]})

Expand Down

0 comments on commit 1cb983f

Please sign in to comment.