From 84f9d25d31c3cb65e0a24145f7c3d99c5187189d Mon Sep 17 00:00:00 2001 From: "samuel.oranyeli" Date: Sun, 7 Jul 2024 16:04:05 +1000 Subject: [PATCH] update for polars row to names --- janitor/functions/row_to_names.py | 2 +- janitor/polars/row_to_names.py | 145 +++++++++--------- tests/functions/test_row_to_names.py | 16 ++ .../functions/test_row_to_names_polars.py | 20 --- 4 files changed, 93 insertions(+), 90 deletions(-) diff --git a/janitor/functions/row_to_names.py b/janitor/functions/row_to_names.py index 1ddd95383..ee97d8531 100644 --- a/janitor/functions/row_to_names.py +++ b/janitor/functions/row_to_names.py @@ -165,7 +165,7 @@ def _row_to_names_dispatch( # noqa: F811 len_df = len(df_) arrays = [arr._values for _, arr in df_.items()] if remove_rows_above and remove_rows: - indexer = np.arange(row_numbers.stop + 1, len_df) + indexer = np.arange(row_numbers.stop, len_df) elif remove_rows_above: indexer = np.arange(row_numbers.start, len_df) elif remove_rows: diff --git a/janitor/polars/row_to_names.py b/janitor/polars/row_to_names.py index b507018a2..a0b16b62d 100644 --- a/janitor/polars/row_to_names.py +++ b/janitor/polars/row_to_names.py @@ -2,6 +2,8 @@ from __future__ import annotations +from functools import singledispatch + from janitor.utils import check, import_message from .polars_flavor import register_dataframe_method @@ -28,8 +30,6 @@ def row_to_names( """ Elevates a row, or rows, to be the column names of a DataFrame. - For a LazyFrame, the user should materialize into a DataFrame before using `row_to_names`.. - Examples: Replace column names with the first row. @@ -103,8 +103,7 @@ def row_to_names( Args: row_numbers: Position of the row(s) containing the variable names. - Note that indexing starts from 0. It can also be a list/slice. - Defaults to 0 (first row). + It can be an integer, list or a slice. remove_rows: Whether the row(s) should be removed from the DataFrame. remove_rows_above: Whether the row(s) above the selected row should be removed from the DataFrame. @@ -115,85 +114,93 @@ def row_to_names( A polars DataFrame. """ # noqa: E501 return _row_to_names( + row_numbers, df=df, - row_numbers=row_numbers, remove_rows=remove_rows, remove_rows_above=remove_rows_above, separator=separator, ) +@singledispatch def _row_to_names( - df: pl.DataFrame, - row_numbers: int | list | slice, - remove_rows: bool, - remove_rows_above: bool, - separator: str, + row_numbers, df, remove_rows, remove_rows_above, separator ) -> pl.DataFrame: """ - Function to convert rows in the DataFrame to column names. + Base function for row_to_names. """ - check("separator", separator, [str]) - if isinstance(row_numbers, int): - row_numbers = slice(row_numbers, row_numbers + 1) - elif isinstance(row_numbers, slice): - if row_numbers.step is not None: - raise ValueError( - "The step argument for slice is not supported in row_to_names." - ) - elif isinstance(row_numbers, list): - for entry in row_numbers: - check("entry in the row_numbers argument", entry, [int]) - else: - raise TypeError( - "row_numbers should be either an integer, " - "a slice or a list; " - f"instead got type {type(row_numbers).__name__}" + raise TypeError( + "row_numbers should be either an integer, " + "a slice or a list; " + f"instead got type {type(row_numbers).__name__}" + ) + + +@_row_to_names.register(int) # noqa: F811 +def _row_to_names_dispatch( # noqa: F811 + row_numbers, df, remove_rows, remove_rows_above, separator +): + expression = pl.col("*").cast(pl.String).gather(row_numbers) + expression = pl.struct(expression) + headers = df.select(expression).to_series(0).to_list()[0] + df = df.rename(mapping=headers) + if remove_rows_above and remove_rows: + return df.slice(row_numbers + 1) + elif remove_rows_above: + return df.slice(row_numbers) + elif remove_rows: + expression = pl.int_range(pl.len()).ne(row_numbers) + return df.filter(expression) + return df + + +@_row_to_names.register(slice) # noqa: F811 +def _row_to_names_dispatch( # noqa: F811 + row_numbers, df, remove_rows, remove_rows_above, separator +): + if row_numbers.step is not None: + raise ValueError( + "The step argument for slice is not supported in row_to_names." ) - is_a_slice = isinstance(row_numbers, slice) - if is_a_slice: - expression = pl.all().str.concat(delimiter=separator) - expression = pl.struct(expression) - offset = row_numbers.start - length = row_numbers.stop - row_numbers.start - mapping = df.slice( - offset=offset, - length=length, + headers = df.slice(row_numbers.start, row_numbers.stop - row_numbers.start) + headers = headers.cast(pl.String) + expression = pl.all().str.concat(delimiter=separator) + expression = pl.struct(expression) + headers = headers.select(expression).to_series(0).to_list()[0] + df = df.rename(mapping=headers) + if remove_rows_above and remove_rows: + return df.slice(row_numbers.stop) + elif remove_rows_above: + return df.slice(row_numbers.start) + elif remove_rows: + expression = pl.int_range(pl.len()).is_between( + row_numbers.start, row_numbers.stop, closed="left" ) - mapping = mapping.select(expression) - else: - expression = pl.all().gather(row_numbers) - expression = expression.str.concat(delimiter=separator) - expression = pl.struct(expression) - mapping = df.select(expression) - - mapping = mapping.to_series(0)[0] - df = df.rename(mapping=mapping) - if remove_rows_above: - if not is_a_slice: - raise ValueError( - "The remove_rows_above argument is applicable " - "only if the row_numbers argument is an integer " - "or a slice." - ) - if remove_rows: - return df.slice(offset=row_numbers.stop) - return df.slice(offset=row_numbers.start) + return df.filter(~expression) + return df - if remove_rows: - if is_a_slice: - df = [ - df.slice(offset=0, length=row_numbers.start), - df.slice(offset=row_numbers.stop), - ] - return pl.concat(df, rechunk=True) - name = "".join(df.columns) - name = f"{name}_" - df = ( - df.with_row_index(name=name) - .filter(pl.col(name=name).is_in(row_numbers).not_()) - .select(pl.exclude(name)) + +@_row_to_names.register(list) # noqa: F811 +def _row_to_names_dispatch( # noqa: F811 + row_numbers, df, remove_rows, remove_rows_above, separator +): + if remove_rows_above: + raise ValueError( + "The remove_rows_above argument is applicable " + "only if the row_numbers argument is an integer " + "or a slice." ) - return df + for entry in row_numbers: + check("entry in the row_numbers argument", entry, [int]) + + expression = pl.col("*").gather(row_numbers) + headers = df.select(expression).cast(pl.String) + expression = pl.all().str.concat(delimiter=separator) + expression = pl.struct(expression) + headers = headers.select(expression).to_series(0).to_list()[0] + df = df.rename(mapping=headers) + if remove_rows: + expression = pl.int_range(pl.len()).is_in(row_numbers) + return df.filter(~expression) return df diff --git a/tests/functions/test_row_to_names.py b/tests/functions/test_row_to_names.py index 5295b44d9..758afe44d 100644 --- a/tests/functions/test_row_to_names.py +++ b/tests/functions/test_row_to_names.py @@ -74,6 +74,22 @@ def test_row_to_names_delete_above_slice(dataframe): assert df.iloc[0, 4] == "Basel" +@pytest.mark.functions +def test_row_to_names_delete_above_delete_rows(dataframe): + """ + Test output for remove_rows=True + and remove_rows_above=True + """ + df = dataframe.row_to_names( + slice(2, 4), remove_rows=True, remove_rows_above=True + ) + assert df.iloc[0, 0] == 2 + assert df.iloc[0, 1] == 2.456234 + assert df.iloc[0, 2] == 2 + assert df.iloc[0, 3] == "leopard" + assert df.iloc[0, 4] == "Shanghai" + + @pytest.mark.functions def test_row_to_names_delete_above_is_a_list(dataframe): "Raise if row_numbers is a list" diff --git a/tests/polars/functions/test_row_to_names_polars.py b/tests/polars/functions/test_row_to_names_polars.py index 372e2e09a..1c81660e0 100644 --- a/tests/polars/functions/test_row_to_names_polars.py +++ b/tests/polars/functions/test_row_to_names_polars.py @@ -17,14 +17,6 @@ def df(): ) -def test_separator_type(df): - """ - Raise if separator is not a string - """ - with pytest.raises(TypeError, match="separator should be.+"): - df.row_to_names([1, 2], separator=1) - - def test_row_numbers_type(df): """ Raise if row_numbers is not an int/slice/list @@ -88,8 +80,6 @@ def test_row_to_names_list(df): def test_row_to_names_delete_this_row(df): df = df.row_to_names(2, remove_rows=True) - if isinstance(df, pl.LazyFrame): - df = df.collect() assert df.to_series(0)[0] == 1.234_523_45 assert df.to_series(1)[0] == 1 assert df.to_series(2)[0] == "rabbit" @@ -98,8 +88,6 @@ def test_row_to_names_delete_this_row(df): def test_row_to_names_list_delete_this_row(df): df = df.row_to_names([2], remove_rows=True) - if isinstance(df, pl.LazyFrame): - df = df.collect() assert df.to_series(0)[0] == 1.234_523_45 assert df.to_series(1)[0] == 1 assert df.to_series(2)[0] == "rabbit" @@ -108,8 +96,6 @@ def test_row_to_names_list_delete_this_row(df): def test_row_to_names_delete_above(df): df = df.row_to_names(2, remove_rows_above=True) - if isinstance(df, pl.LazyFrame): - df = df.collect() assert df.to_series(0)[0] == 3.234_612_5 assert df.to_series(1)[0] == 3 assert df.to_series(2)[0] == "lion" @@ -119,8 +105,6 @@ def test_row_to_names_delete_above(df): def test_row_to_names_delete_above_list(df): "Test output if row_numbers is a list" df = df.row_to_names(slice(2, 4), remove_rows_above=True) - if isinstance(df, pl.LazyFrame): - df = df.collect() assert df.to_series(0)[0] == 3.234_612_5 assert df.to_series(1)[0] == 3 assert df.to_series(2)[0] == "lion" @@ -133,8 +117,6 @@ def test_row_to_names_delete_above_delete_rows(df): and remove_rows_above=True """ df = df.row_to_names(slice(2, 4), remove_rows=True, remove_rows_above=True) - if isinstance(df, pl.LazyFrame): - df = df.collect() assert df.to_series(0)[0] == 2.456234 assert df.to_series(1)[0] == 2 assert df.to_series(2)[0] == "leopard" @@ -147,8 +129,6 @@ def test_row_to_names_delete_above_delete_rows_scalar(df): and remove_rows_above=True """ df = df.row_to_names(2, remove_rows=True, remove_rows_above=True) - if isinstance(df, pl.LazyFrame): - df = df.collect() assert df.to_series(0)[0] == 1.23452345 assert df.to_series(1)[0] == 1 assert df.to_series(2)[0] == "rabbit"