Skip to content

Commit

Permalink
update for polars row to names
Browse files Browse the repository at this point in the history
  • Loading branch information
samuel.oranyeli committed Jul 7, 2024
1 parent 83bb7c0 commit 84f9d25
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 90 deletions.
2 changes: 1 addition & 1 deletion janitor/functions/row_to_names.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def _row_to_names_dispatch( # noqa: F811
len_df = len(df_)
arrays = [arr._values for _, arr in df_.items()]
if remove_rows_above and remove_rows:
indexer = np.arange(row_numbers.stop + 1, len_df)
indexer = np.arange(row_numbers.stop, len_df)
elif remove_rows_above:
indexer = np.arange(row_numbers.start, len_df)
elif remove_rows:
Expand Down
145 changes: 76 additions & 69 deletions janitor/polars/row_to_names.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from __future__ import annotations

from functools import singledispatch

from janitor.utils import check, import_message

from .polars_flavor import register_dataframe_method
Expand All @@ -28,8 +30,6 @@ def row_to_names(
"""
Elevates a row, or rows, to be the column names of a DataFrame.
For a LazyFrame, the user should materialize into a DataFrame before using `row_to_names`..
Examples:
Replace column names with the first row.
Expand Down Expand Up @@ -103,8 +103,7 @@ def row_to_names(
Args:
row_numbers: Position of the row(s) containing the variable names.
Note that indexing starts from 0. It can also be a list/slice.
Defaults to 0 (first row).
It can be an integer, list or a slice.
remove_rows: Whether the row(s) should be removed from the DataFrame.
remove_rows_above: Whether the row(s) above the selected row should
be removed from the DataFrame.
Expand All @@ -115,85 +114,93 @@ def row_to_names(
A polars DataFrame.
""" # noqa: E501
return _row_to_names(
row_numbers,
df=df,
row_numbers=row_numbers,
remove_rows=remove_rows,
remove_rows_above=remove_rows_above,
separator=separator,
)


@singledispatch
def _row_to_names(
df: pl.DataFrame,
row_numbers: int | list | slice,
remove_rows: bool,
remove_rows_above: bool,
separator: str,
row_numbers, df, remove_rows, remove_rows_above, separator
) -> pl.DataFrame:
"""
Function to convert rows in the DataFrame to column names.
Base function for row_to_names.
"""
check("separator", separator, [str])
if isinstance(row_numbers, int):
row_numbers = slice(row_numbers, row_numbers + 1)
elif isinstance(row_numbers, slice):
if row_numbers.step is not None:
raise ValueError(
"The step argument for slice is not supported in row_to_names."
)
elif isinstance(row_numbers, list):
for entry in row_numbers:
check("entry in the row_numbers argument", entry, [int])
else:
raise TypeError(
"row_numbers should be either an integer, "
"a slice or a list; "
f"instead got type {type(row_numbers).__name__}"
raise TypeError(
"row_numbers should be either an integer, "
"a slice or a list; "
f"instead got type {type(row_numbers).__name__}"
)


@_row_to_names.register(int) # noqa: F811
def _row_to_names_dispatch( # noqa: F811
row_numbers, df, remove_rows, remove_rows_above, separator
):
expression = pl.col("*").cast(pl.String).gather(row_numbers)
expression = pl.struct(expression)
headers = df.select(expression).to_series(0).to_list()[0]
df = df.rename(mapping=headers)
if remove_rows_above and remove_rows:
return df.slice(row_numbers + 1)
elif remove_rows_above:
return df.slice(row_numbers)
elif remove_rows:
expression = pl.int_range(pl.len()).ne(row_numbers)
return df.filter(expression)
return df


@_row_to_names.register(slice) # noqa: F811
def _row_to_names_dispatch( # noqa: F811
row_numbers, df, remove_rows, remove_rows_above, separator
):
if row_numbers.step is not None:
raise ValueError(
"The step argument for slice is not supported in row_to_names."
)
is_a_slice = isinstance(row_numbers, slice)
if is_a_slice:
expression = pl.all().str.concat(delimiter=separator)
expression = pl.struct(expression)
offset = row_numbers.start
length = row_numbers.stop - row_numbers.start
mapping = df.slice(
offset=offset,
length=length,
headers = df.slice(row_numbers.start, row_numbers.stop - row_numbers.start)
headers = headers.cast(pl.String)
expression = pl.all().str.concat(delimiter=separator)
expression = pl.struct(expression)
headers = headers.select(expression).to_series(0).to_list()[0]
df = df.rename(mapping=headers)
if remove_rows_above and remove_rows:
return df.slice(row_numbers.stop)
elif remove_rows_above:
return df.slice(row_numbers.start)
elif remove_rows:
expression = pl.int_range(pl.len()).is_between(
row_numbers.start, row_numbers.stop, closed="left"
)
mapping = mapping.select(expression)
else:
expression = pl.all().gather(row_numbers)
expression = expression.str.concat(delimiter=separator)
expression = pl.struct(expression)
mapping = df.select(expression)

mapping = mapping.to_series(0)[0]
df = df.rename(mapping=mapping)
if remove_rows_above:
if not is_a_slice:
raise ValueError(
"The remove_rows_above argument is applicable "
"only if the row_numbers argument is an integer "
"or a slice."
)
if remove_rows:
return df.slice(offset=row_numbers.stop)
return df.slice(offset=row_numbers.start)
return df.filter(~expression)
return df

if remove_rows:
if is_a_slice:
df = [
df.slice(offset=0, length=row_numbers.start),
df.slice(offset=row_numbers.stop),
]
return pl.concat(df, rechunk=True)
name = "".join(df.columns)
name = f"{name}_"
df = (
df.with_row_index(name=name)
.filter(pl.col(name=name).is_in(row_numbers).not_())
.select(pl.exclude(name))

@_row_to_names.register(list) # noqa: F811
def _row_to_names_dispatch( # noqa: F811
row_numbers, df, remove_rows, remove_rows_above, separator
):
if remove_rows_above:
raise ValueError(
"The remove_rows_above argument is applicable "
"only if the row_numbers argument is an integer "
"or a slice."
)
return df

for entry in row_numbers:
check("entry in the row_numbers argument", entry, [int])

expression = pl.col("*").gather(row_numbers)
headers = df.select(expression).cast(pl.String)
expression = pl.all().str.concat(delimiter=separator)
expression = pl.struct(expression)
headers = headers.select(expression).to_series(0).to_list()[0]
df = df.rename(mapping=headers)
if remove_rows:
expression = pl.int_range(pl.len()).is_in(row_numbers)
return df.filter(~expression)
return df
16 changes: 16 additions & 0 deletions tests/functions/test_row_to_names.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,22 @@ def test_row_to_names_delete_above_slice(dataframe):
assert df.iloc[0, 4] == "Basel"


@pytest.mark.functions
def test_row_to_names_delete_above_delete_rows(dataframe):
"""
Test output for remove_rows=True
and remove_rows_above=True
"""
df = dataframe.row_to_names(
slice(2, 4), remove_rows=True, remove_rows_above=True
)
assert df.iloc[0, 0] == 2
assert df.iloc[0, 1] == 2.456234
assert df.iloc[0, 2] == 2
assert df.iloc[0, 3] == "leopard"
assert df.iloc[0, 4] == "Shanghai"


@pytest.mark.functions
def test_row_to_names_delete_above_is_a_list(dataframe):
"Raise if row_numbers is a list"
Expand Down
20 changes: 0 additions & 20 deletions tests/polars/functions/test_row_to_names_polars.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,6 @@ def df():
)


def test_separator_type(df):
"""
Raise if separator is not a string
"""
with pytest.raises(TypeError, match="separator should be.+"):
df.row_to_names([1, 2], separator=1)


def test_row_numbers_type(df):
"""
Raise if row_numbers is not an int/slice/list
Expand Down Expand Up @@ -88,8 +80,6 @@ def test_row_to_names_list(df):

def test_row_to_names_delete_this_row(df):
df = df.row_to_names(2, remove_rows=True)
if isinstance(df, pl.LazyFrame):
df = df.collect()
assert df.to_series(0)[0] == 1.234_523_45
assert df.to_series(1)[0] == 1
assert df.to_series(2)[0] == "rabbit"
Expand All @@ -98,8 +88,6 @@ def test_row_to_names_delete_this_row(df):

def test_row_to_names_list_delete_this_row(df):
df = df.row_to_names([2], remove_rows=True)
if isinstance(df, pl.LazyFrame):
df = df.collect()
assert df.to_series(0)[0] == 1.234_523_45
assert df.to_series(1)[0] == 1
assert df.to_series(2)[0] == "rabbit"
Expand All @@ -108,8 +96,6 @@ def test_row_to_names_list_delete_this_row(df):

def test_row_to_names_delete_above(df):
df = df.row_to_names(2, remove_rows_above=True)
if isinstance(df, pl.LazyFrame):
df = df.collect()
assert df.to_series(0)[0] == 3.234_612_5
assert df.to_series(1)[0] == 3
assert df.to_series(2)[0] == "lion"
Expand All @@ -119,8 +105,6 @@ def test_row_to_names_delete_above(df):
def test_row_to_names_delete_above_list(df):
"Test output if row_numbers is a list"
df = df.row_to_names(slice(2, 4), remove_rows_above=True)
if isinstance(df, pl.LazyFrame):
df = df.collect()
assert df.to_series(0)[0] == 3.234_612_5
assert df.to_series(1)[0] == 3
assert df.to_series(2)[0] == "lion"
Expand All @@ -133,8 +117,6 @@ def test_row_to_names_delete_above_delete_rows(df):
and remove_rows_above=True
"""
df = df.row_to_names(slice(2, 4), remove_rows=True, remove_rows_above=True)
if isinstance(df, pl.LazyFrame):
df = df.collect()
assert df.to_series(0)[0] == 2.456234
assert df.to_series(1)[0] == 2
assert df.to_series(2)[0] == "leopard"
Expand All @@ -147,8 +129,6 @@ def test_row_to_names_delete_above_delete_rows_scalar(df):
and remove_rows_above=True
"""
df = df.row_to_names(2, remove_rows=True, remove_rows_above=True)
if isinstance(df, pl.LazyFrame):
df = df.collect()
assert df.to_series(0)[0] == 1.23452345
assert df.to_series(1)[0] == 1
assert df.to_series(2)[0] == "rabbit"
Expand Down

0 comments on commit 84f9d25

Please sign in to comment.