From efd3ac3fc66acabf3da0c4b2e76cbbfca906806d Mon Sep 17 00:00:00 2001 From: barak1412 Date: Tue, 17 Sep 2024 08:37:15 +0300 Subject: [PATCH] support negative index --- .../src/chunked_array/struct_/mod.rs | 9 +++++---- .../unit/operations/namespaces/test_struct.py | 18 +++++++++--------- 2 files changed, 14 insertions(+), 13 deletions(-) diff --git a/crates/polars-core/src/chunked_array/struct_/mod.rs b/crates/polars-core/src/chunked_array/struct_/mod.rs index 2bce1cdd8f88..6a5fb1883d5f 100644 --- a/crates/polars-core/src/chunked_array/struct_/mod.rs +++ b/crates/polars-core/src/chunked_array/struct_/mod.rs @@ -14,7 +14,7 @@ use crate::chunked_array::ChunkedArray; use crate::prelude::sort::arg_sort_multiple::{_get_rows_encoded_arr, _get_rows_encoded_ca}; use crate::prelude::*; use crate::series::Series; -use crate::utils::Container; +use crate::utils::{slice_offsets, Container}; pub type StructChunked = ChunkedArray; @@ -373,12 +373,13 @@ impl StructChunked { .ok_or_else(|| polars_err!(StructFieldNotFound: "{}", name)) } pub fn field_by_index(&self, field_idx: i64) -> PolarsResult { - if field_idx < 0 || field_idx >= (self.struct_fields().len() as i64) { + if field_idx >= (self.struct_fields().len() as i64) { polars_bail!( - OutOfBounds: "Given field index {} is Out Of Bound", field_idx + OutOfBounds: "Given field index {} is out of bound", field_idx ) } - Ok(self.field_as_series(field_idx as usize)) + let (field_idx, _) = slice_offsets(field_idx, 0, self.struct_fields().len()); + Ok(self.field_as_series(field_idx)) } pub(crate) fn set_outer_validity(&mut self, validity: Option) { assert_eq!(self.chunks().len(), 1); diff --git a/py-polars/tests/unit/operations/namespaces/test_struct.py b/py-polars/tests/unit/operations/namespaces/test_struct.py index 8551472d6220..626834ed55ea 100644 --- a/py-polars/tests/unit/operations/namespaces/test_struct.py +++ b/py-polars/tests/unit/operations/namespaces/test_struct.py @@ -103,17 +103,17 @@ def test_empty_list_eval_schema_5734() -> None: def test_field_by_index_18732() -> None: - df = pl.DataFrame({"foo": [{"a": 1}, {"a": 2}]}) + df = pl.DataFrame({"foo": [{"a": 1, "b": 2}, {"a": 2, "b": 1}]}) - # lower bound - with pytest.raises(OutOfBoundsError, match=r"Given field index -1 is Out Of Bound"): - df.filter(pl.col.foo.struct[-1] == 1) - - # upper bound - with pytest.raises(OutOfBoundsError, match=r"Given field index 1 is Out Of Bound"): - df.filter(pl.col.foo.struct[1] == 1) + # illegal upper bound + with pytest.raises(OutOfBoundsError, match=r"Given field index 2 is out of bound"): + df.filter(pl.col.foo.struct[2] == 1) # legal - expected_df = pl.DataFrame({"foo": [{"a": 1}]}) + expected_df = pl.DataFrame({"foo": [{"a": 1, "b": 2}]}) result_df = df.filter(pl.col.foo.struct[0] == 1) assert_frame_equal(expected_df, result_df) + + expected_df = pl.DataFrame({"foo": [{"a": 2, "b": 1}]}) + result_df = df.filter(pl.col.foo.struct[-1] == 1) + assert_frame_equal(expected_df, result_df)