Skip to content

Commit

Permalink
fix(python): Fix get_index/iteration for Array types (#12047)
Browse files Browse the repository at this point in the history
  • Loading branch information
stinodego authored Oct 28, 2023
1 parent 3b6aa8f commit d2ad727
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 19 deletions.
2 changes: 1 addition & 1 deletion py-polars/polars/series/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -953,7 +953,7 @@ def __contains__(self, item: Any) -> bool:
return self.implode().list.contains(item).item()

def __iter__(self) -> Generator[Any, None, None]:
if self.dtype == List:
if self.dtype in (List, Array):
# TODO: either make a change and return py-native list data here, or find
# a faster way to return nested/List series; sequential 'get_index' calls
# make this path a lot slower (~10x) than it needs to be.
Expand Down
8 changes: 0 additions & 8 deletions py-polars/polars/testing/asserts/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,14 +151,6 @@ def _assert_series_values_equal(
if right.dtype == Categorical:
right = right.cast(Utf8)

# Handle arrays
# TODO: Remove this check when equality for Arrays is implemented
# https://github.com/pola-rs/polars/issues/12012
if left.dtype == Array:
left = left.cast(List(left.dtype.inner)) # type: ignore[union-attr]
if right.dtype == Array:
right = right.cast(List(right.dtype.inner)) # type: ignore[union-attr]

# Determine unequal elements
try:
unequal = left.ne_missing(right)
Expand Down
23 changes: 13 additions & 10 deletions py-polars/src/series/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -165,17 +165,20 @@ impl PySeries {
Err(e) => return Err(PyPolarsErr::from(e).into()),
};

if let AnyValue::List(s) = av {
let pyseries = PySeries::new(s);
let out = POLARS
.getattr(py, "wrap_s")
.unwrap()
.call1(py, (pyseries,))
.unwrap();
return Ok(out.into_py(py));
}
let out = match av {
AnyValue::List(s) | AnyValue::Array(s, _) => {
let pyseries = PySeries::new(s);
let out = POLARS
.getattr(py, "wrap_s")
.unwrap()
.call1(py, (pyseries,))
.unwrap();
out.into_py(py)
},
_ => Wrap(av).into_py(py),
};

Ok(Wrap(av).into_py(py))
Ok(out)
}

/// Get index but allow negative indices
Expand Down

0 comments on commit d2ad727

Please sign in to comment.