From d2ad7272552f35be0439efeedf73668ff39af1cc Mon Sep 17 00:00:00 2001 From: Stijn de Gooijer Date: Sat, 28 Oct 2023 15:30:01 +0200 Subject: [PATCH] fix(python): Fix `get_index`/iteration for `Array` types (#12047) --- py-polars/polars/series/series.py | 2 +- py-polars/polars/testing/asserts/series.py | 8 -------- py-polars/src/series/mod.rs | 23 ++++++++++++---------- 3 files changed, 14 insertions(+), 19 deletions(-) diff --git a/py-polars/polars/series/series.py b/py-polars/polars/series/series.py index 32db495ab0d1..2f8ba274d11d 100644 --- a/py-polars/polars/series/series.py +++ b/py-polars/polars/series/series.py @@ -953,7 +953,7 @@ def __contains__(self, item: Any) -> bool: return self.implode().list.contains(item).item() def __iter__(self) -> Generator[Any, None, None]: - if self.dtype == List: + if self.dtype in (List, Array): # TODO: either make a change and return py-native list data here, or find # a faster way to return nested/List series; sequential 'get_index' calls # make this path a lot slower (~10x) than it needs to be. diff --git a/py-polars/polars/testing/asserts/series.py b/py-polars/polars/testing/asserts/series.py index 25252a9ac3ea..11bbfe8512f3 100644 --- a/py-polars/polars/testing/asserts/series.py +++ b/py-polars/polars/testing/asserts/series.py @@ -151,14 +151,6 @@ def _assert_series_values_equal( if right.dtype == Categorical: right = right.cast(Utf8) - # Handle arrays - # TODO: Remove this check when equality for Arrays is implemented - # https://github.com/pola-rs/polars/issues/12012 - if left.dtype == Array: - left = left.cast(List(left.dtype.inner)) # type: ignore[union-attr] - if right.dtype == Array: - right = right.cast(List(right.dtype.inner)) # type: ignore[union-attr] - # Determine unequal elements try: unequal = left.ne_missing(right) diff --git a/py-polars/src/series/mod.rs b/py-polars/src/series/mod.rs index 3c6c7d497f96..6ef7724b7ead 100644 --- a/py-polars/src/series/mod.rs +++ b/py-polars/src/series/mod.rs @@ -165,17 +165,20 @@ impl PySeries { Err(e) => return Err(PyPolarsErr::from(e).into()), }; - if let AnyValue::List(s) = av { - let pyseries = PySeries::new(s); - let out = POLARS - .getattr(py, "wrap_s") - .unwrap() - .call1(py, (pyseries,)) - .unwrap(); - return Ok(out.into_py(py)); - } + let out = match av { + AnyValue::List(s) | AnyValue::Array(s, _) => { + let pyseries = PySeries::new(s); + let out = POLARS + .getattr(py, "wrap_s") + .unwrap() + .call1(py, (pyseries,)) + .unwrap(); + out.into_py(py) + }, + _ => Wrap(av).into_py(py), + }; - Ok(Wrap(av).into_py(py)) + Ok(out) } /// Get index but allow negative indices