From 43bf944a0cb4768ab18c3a05c1d5bb793dde1b83 Mon Sep 17 00:00:00 2001 From: Alexander Beedie Date: Fri, 26 Jul 2024 15:15:03 +0400 Subject: [PATCH] fix(python): Raise suitable error when invalid column passed to `get_column_index` (#17868) --- py-polars/polars/dataframe/frame.py | 2 ++ py-polars/src/dataframe/general.rs | 7 +++++-- py-polars/tests/unit/dataframe/test_df.py | 11 ++++++++++- 3 files changed, 17 insertions(+), 3 deletions(-) diff --git a/py-polars/polars/dataframe/frame.py b/py-polars/polars/dataframe/frame.py index cf23828c420b..20d003b9f79e 100644 --- a/py-polars/polars/dataframe/frame.py +++ b/py-polars/polars/dataframe/frame.py @@ -4754,6 +4754,8 @@ def get_column_index(self, name: str) -> int: ... ) >>> df.get_column_index("ham") 2 + >>> df.get_column_index("sandwich") # doctest: +SKIP + ColumnNotFoundError: sandwich """ return self._df.get_column_index(name) diff --git a/py-polars/src/dataframe/general.rs b/py-polars/src/dataframe/general.rs index 9998e6283ae8..99507bf7c880 100644 --- a/py-polars/src/dataframe/general.rs +++ b/py-polars/src/dataframe/general.rs @@ -226,8 +226,11 @@ impl PyDataFrame { } } - pub fn get_column_index(&self, name: &str) -> Option { - self.df.get_column_index(name) + pub fn get_column_index(&self, name: &str) -> PyResult { + Ok(self + .df + .try_get_column_index(name) + .map_err(PyPolarsErr::from)?) } pub fn get_column(&self, name: &str) -> PyResult { diff --git a/py-polars/tests/unit/dataframe/test_df.py b/py-polars/tests/unit/dataframe/test_df.py index 5d3d33905ad1..2ed5a8a39b32 100644 --- a/py-polars/tests/unit/dataframe/test_df.py +++ b/py-polars/tests/unit/dataframe/test_df.py @@ -2832,7 +2832,6 @@ def test_from_records_u64_12329() -> None: def test_negative_slice_12642() -> None: df = pl.DataFrame({"x": range(5)}) - assert_frame_equal(df.slice(-2, 1), df.tail(2).head(1)) @@ -2841,3 +2840,13 @@ def test_iter_columns() -> None: iter_columns = df.iter_columns() assert_series_equal(next(iter_columns), pl.Series("a", [1, 1, 2])) assert_series_equal(next(iter_columns), pl.Series("b", [4, 5, 6])) + + +def test_get_column_index() -> None: + df = pl.DataFrame({"actual": [1001], "expected": [1000]}) + + assert df.get_column_index("actual") == 0 + assert df.get_column_index("expected") == 1 + + with pytest.raises(ColumnNotFoundError, match="missing"): + df.get_column_index("missing")