From c081fef23d92b63af112d4bfd2e4b575d7ffda6e Mon Sep 17 00:00:00 2001 From: Konstantin Malanchev Date: Thu, 9 May 2024 11:03:34 -0400 Subject: [PATCH] .to_flat() returns df with ArrowDtype Fixes #66 --- src/nested_pandas/series/accessor.py | 4 +++- tests/nested_pandas/nestedframe/test_nestedframe.py | 3 ++- tests/nested_pandas/series/test_accessor.py | 6 +++++- 3 files changed, 10 insertions(+), 3 deletions(-) diff --git a/src/nested_pandas/series/accessor.py b/src/nested_pandas/series/accessor.py index 021fbd7..8a075a7 100644 --- a/src/nested_pandas/series/accessor.py +++ b/src/nested_pandas/series/accessor.py @@ -91,13 +91,15 @@ def to_flat(self, fields: list[str] | None = None) -> pd.DataFrame: index = None for field in fields: list_array = cast(pa.ListArray, struct_array.field(field)) + flat_array = list_array.flatten() if index is None: index = self.get_flat_index() flat_series[field] = pd.Series( - list_array.flatten(), + flat_array, index=pd.Series(index, name=self._series.index.name), name=field, copy=False, + dtype=pd.ArrowDtype(flat_array.type), ) return pd.DataFrame(flat_series) diff --git a/tests/nested_pandas/nestedframe/test_nestedframe.py b/tests/nested_pandas/nestedframe/test_nestedframe.py index cff96e4..4bfd710 100644 --- a/tests/nested_pandas/nestedframe/test_nestedframe.py +++ b/tests/nested_pandas/nestedframe/test_nestedframe.py @@ -76,7 +76,8 @@ def test_add_nested_with_flat_df(): base = base.add_nested(nested, "nested") assert "nested" in base.columns - assert base.nested.nest.to_flat().equals(nested) + # to_flat() gives pd.ArrowDtype, so we skip dtype check here + assert_frame_equal(base.nested.nest.to_flat(), nested, check_dtype=False) def test_add_nested_with_flat_df_and_mismatched_index(): diff --git a/tests/nested_pandas/series/test_accessor.py b/tests/nested_pandas/series/test_accessor.py index 807657b..9a02aa2 100644 --- a/tests/nested_pandas/series/test_accessor.py +++ b/tests/nested_pandas/series/test_accessor.py @@ -101,12 +101,14 @@ def test_to_flat(): index=[0, 0, 0, 1, 1, 1], name="a", copy=False, + dtype=pd.ArrowDtype(pa.float64()), ), "b": pd.Series( data=[-4.0, -5.0, -6.0, -3.0, -4.0, -5.0], index=[0, 0, 0, 1, 1, 1], name="b", copy=False, + dtype=pd.ArrowDtype(pa.float64()), ), }, index=pd.Index([0, 0, 0, 1, 1, 1], name="idx"), @@ -140,6 +142,7 @@ def test_to_flat_with_fields(): index=[0, 0, 0, 1, 1, 1], name="a", copy=False, + dtype=pd.ArrowDtype(pa.float64()), ), }, ) @@ -527,7 +530,7 @@ def test_to_flat_dropna(): """ flat = pd.DataFrame( - data={"c": [0.0, 2, 4, 1, np.NaN, 3, 1, 4, 1], "d": [5, 4, 7, 5, 3, 1, 9, 3, 4]}, + data={"c": [0, 2, 4, 1, np.NaN, 3, 1, 4, 1], "d": [5, 4, 7, 5, 3, 1, 9, 3, 4]}, index=[0, 0, 0, 1, 1, 1, 2, 2, 2], ) nested = pack_flat(flat, name="nested") @@ -542,4 +545,5 @@ def test_to_flat_dropna(): data={"c": [0.0, 2, 4, 1, 3, 1, 4, 1], "d": [5, 4, 7, 5, 1, 9, 3, 4]}, index=[0, 0, 0, 1, 1, 2, 2, 2], ), + check_dtype=False, # filtered's Series are pd.ArrowDtype )