From d4ea91ee4632f7d144c8e72bcd1505ef7fa1da80 Mon Sep 17 00:00:00 2001 From: Loic Diridollou Date: Sat, 12 Oct 2024 22:33:50 -0400 Subject: [PATCH 1/7] GH203 Split groupby with as_index --- pandas-stubs/core/frame.pyi | 21 ++++++++++++++++++--- pandas-stubs/core/groupby/groupby.pyi | 2 +- tests/test_frame.py | 20 +++++++++++++++++++- 3 files changed, 38 insertions(+), 5 deletions(-) diff --git a/pandas-stubs/core/frame.pyi b/pandas-stubs/core/frame.pyi index 3f643f57..b61b03ff 100644 --- a/pandas-stubs/core/frame.pyi +++ b/pandas-stubs/core/frame.pyi @@ -25,7 +25,10 @@ from pandas import ( ) from pandas.core.arraylike import OpsMixin from pandas.core.generic import NDFrame -from pandas.core.groupby.generic import DataFrameGroupBy +from pandas.core.groupby.generic import ( + DataFrameGroupBy, + SeriesGroupBy, +) from pandas.core.groupby.grouper import Grouper from pandas.core.indexers import BaseIndexer from pandas.core.indexes.base import Index @@ -1052,18 +1055,30 @@ class DataFrame(NDFrame, OpsMixin): errors: IgnoreRaise = ..., ) -> None: ... @overload - def groupby( + def groupby( # type: ignore[overload-overlap] # pyright: ignore[reportOverlappingOverload] self, by: Scalar, axis: AxisIndex | NoDefault = ..., level: IndexLabel | None = ..., - as_index: _bool = ..., + as_index: Literal[False] = ..., sort: _bool = ..., group_keys: _bool = ..., observed: _bool | NoDefault = ..., dropna: _bool = ..., ) -> DataFrameGroupBy[Scalar]: ... @overload + def groupby( + self, + by: Scalar, + axis: AxisIndex | NoDefault = ..., + level: IndexLabel | None = ..., + as_index: Literal[True] = True, + sort: _bool = ..., + group_keys: _bool = ..., + observed: _bool | NoDefault = ..., + dropna: _bool = ..., + ) -> SeriesGroupBy: ... + @overload def groupby( self, by: DatetimeIndex, diff --git a/pandas-stubs/core/groupby/groupby.pyi b/pandas-stubs/core/groupby/groupby.pyi index 75be9578..5e942306 100644 --- a/pandas-stubs/core/groupby/groupby.pyi +++ b/pandas-stubs/core/groupby/groupby.pyi @@ -236,7 +236,7 @@ class GroupBy(BaseGroupBy[NDFrameT]): @overload def size(self: GroupBy[Series]) -> Series[int]: ... @overload # return type depends on `as_index` for dataframe groupby - def size(self: GroupBy[DataFrame]) -> DataFrame | Series[int]: ... + def size(self: GroupBy[DataFrame]) -> DataFrame: ... @final def sum( self, diff --git a/tests/test_frame.py b/tests/test_frame.py index 64198952..342ce50c 100644 --- a/tests/test_frame.py +++ b/tests/test_frame.py @@ -1025,6 +1025,24 @@ def test_types_pivot_table() -> None: ) +def test_types_groupby_as_index() -> None: + df = pd.DataFrame({"a": [1, 2, 3]}) + check( + assert_type( + df.groupby("a", as_index=False).size(), + pd.DataFrame, + ), + pd.DataFrame, + ) + check( + assert_type( + df.groupby("a", as_index=True).size(), + "pd.Series[int]", + ), + pd.Series, + ) + + def test_types_groupby() -> None: df = pd.DataFrame(data={"col1": [1, 1, 2], "col2": [3, 4, 5], "col3": [0, 1, 0]}) df.index.name = "ind" @@ -1048,7 +1066,7 @@ def test_types_groupby() -> None: df1: pd.DataFrame = df.groupby(by="col1").agg("sum") df2: pd.DataFrame = df.groupby(level="ind").aggregate("sum") - df3: pd.DataFrame = df.groupby(by="col1", sort=False, as_index=True).transform( + df3: pd.Series = df.groupby(by="col1", sort=False, as_index=True).transform( lambda x: x.max() ) df4: pd.DataFrame = df.groupby(by=["col1", "col2"]).count() From f2e6e3841b4db67e3929ee161e688c4dd008512e Mon Sep 17 00:00:00 2001 From: Loic Diridollou Date: Tue, 15 Oct 2024 21:00:07 -0400 Subject: [PATCH 2/7] Update to the fix --- pandas-stubs/core/frame.pyi | 22 +++++++++++++++++----- pandas-stubs/core/groupby/groupby.pyi | 12 +++++++++--- tests/test_frame.py | 2 +- 3 files changed, 27 insertions(+), 9 deletions(-) diff --git a/pandas-stubs/core/frame.pyi b/pandas-stubs/core/frame.pyi index b61b03ff..91421e19 100644 --- a/pandas-stubs/core/frame.pyi +++ b/pandas-stubs/core/frame.pyi @@ -12,7 +12,10 @@ from re import Pattern from typing import ( Any, ClassVar, + Generic, Literal, + TypeVar, + Union, overload, ) @@ -77,6 +80,7 @@ from pandas._typing import ( Axis, AxisColumn, AxisIndex, + ByT, CalculationMethod, ColspaceArgType, CompressionOptions, @@ -232,6 +236,14 @@ class _LocIndexerFrame(_LocIndexer): value: Scalar | NAType | NaTType | ArrayLike | Series | list | None, ) -> None: ... +TT = TypeVar("TT", bound=Union[Literal[True], Literal[False]]) + +class DataFrameGroupByGen(DataFrameGroupBy[ByT], Generic[ByT, TT]): + pass + +class SeriesGroupByGen(SeriesGroupBy, Generic[TT, ByT]): + pass + class DataFrame(NDFrame, OpsMixin): __hash__: ClassVar[None] # type: ignore[assignment] @@ -1055,29 +1067,29 @@ class DataFrame(NDFrame, OpsMixin): errors: IgnoreRaise = ..., ) -> None: ... @overload - def groupby( # type: ignore[overload-overlap] # pyright: ignore[reportOverlappingOverload] + def groupby( # type: ignore[overload-overlap] # pyright: ignore reportOverlappingOverload self, by: Scalar, axis: AxisIndex | NoDefault = ..., level: IndexLabel | None = ..., - as_index: Literal[False] = ..., + as_index: Literal[True] = True, sort: _bool = ..., group_keys: _bool = ..., observed: _bool | NoDefault = ..., dropna: _bool = ..., - ) -> DataFrameGroupBy[Scalar]: ... + ) -> DataFrameGroupByGen[Scalar, Literal[True]]: ... @overload def groupby( self, by: Scalar, axis: AxisIndex | NoDefault = ..., level: IndexLabel | None = ..., - as_index: Literal[True] = True, + as_index: Literal[False] = ..., sort: _bool = ..., group_keys: _bool = ..., observed: _bool | NoDefault = ..., dropna: _bool = ..., - ) -> SeriesGroupBy: ... + ) -> DataFrameGroupByGen[Scalar, Literal[False]]: ... @overload def groupby( self, diff --git a/pandas-stubs/core/groupby/groupby.pyi b/pandas-stubs/core/groupby/groupby.pyi index 5e942306..beb03d6d 100644 --- a/pandas-stubs/core/groupby/groupby.pyi +++ b/pandas-stubs/core/groupby/groupby.pyi @@ -18,7 +18,10 @@ from typing import ( import numpy as np from pandas.core.base import SelectionMixin -from pandas.core.frame import DataFrame +from pandas.core.frame import ( + DataFrame, + DataFrameGroupByGen, +) from pandas.core.groupby import ( generic, ops, @@ -53,6 +56,7 @@ from pandas._typing import ( AnyArrayLike, Axis, AxisInt, + ByT, CalculationMethod, Dtype, Frequency, @@ -235,8 +239,10 @@ class GroupBy(BaseGroupBy[NDFrameT]): @final @overload def size(self: GroupBy[Series]) -> Series[int]: ... - @overload # return type depends on `as_index` for dataframe groupby - def size(self: GroupBy[DataFrame]) -> DataFrame: ... + @overload + def size(self: DataFrameGroupByGen[ByT, Literal[True]]) -> Series[int]: ... # type: ignore[misc] + @overload + def size(self: DataFrameGroupByGen[ByT, Literal[False]]) -> DataFrame: ... # type: ignore[misc] @final def sum( self, diff --git a/tests/test_frame.py b/tests/test_frame.py index 342ce50c..dc5adad4 100644 --- a/tests/test_frame.py +++ b/tests/test_frame.py @@ -1066,7 +1066,7 @@ def test_types_groupby() -> None: df1: pd.DataFrame = df.groupby(by="col1").agg("sum") df2: pd.DataFrame = df.groupby(level="ind").aggregate("sum") - df3: pd.Series = df.groupby(by="col1", sort=False, as_index=True).transform( + df3: pd.DataFrame = df.groupby(by="col1", sort=False, as_index=True).transform( lambda x: x.max() ) df4: pd.DataFrame = df.groupby(by=["col1", "col2"]).count() From d0d08a9d6e6b73638c1fc829bf863608c4029697 Mon Sep 17 00:00:00 2001 From: Loic Diridollou Date: Tue, 15 Oct 2024 21:12:31 -0400 Subject: [PATCH 3/7] Update to the fix --- pandas-stubs/core/frame.pyi | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/pandas-stubs/core/frame.pyi b/pandas-stubs/core/frame.pyi index 91421e19..3d6e222f 100644 --- a/pandas-stubs/core/frame.pyi +++ b/pandas-stubs/core/frame.pyi @@ -15,7 +15,6 @@ from typing import ( Generic, Literal, TypeVar, - Union, overload, ) @@ -236,13 +235,10 @@ class _LocIndexerFrame(_LocIndexer): value: Scalar | NAType | NaTType | ArrayLike | Series | list | None, ) -> None: ... -TT = TypeVar("TT", bound=Union[Literal[True], Literal[False]]) +_TT = TypeVar("TT", bound=Literal[True, False]) -class DataFrameGroupByGen(DataFrameGroupBy[ByT], Generic[ByT, TT]): - pass - -class SeriesGroupByGen(SeriesGroupBy, Generic[TT, ByT]): - pass +class DataFrameGroupByGen(DataFrameGroupBy[ByT], Generic[ByT, _TT]): ... +class SeriesGroupByGen(SeriesGroupBy, Generic[_TT, ByT]): ... class DataFrame(NDFrame, OpsMixin): __hash__: ClassVar[None] # type: ignore[assignment] From ae0174009b694568c288f30be0458b3e225beb76 Mon Sep 17 00:00:00 2001 From: Loic Diridollou Date: Wed, 16 Oct 2024 20:15:57 -0400 Subject: [PATCH 4/7] Experiment for size --- pandas-stubs/core/frame.pyi | 31 +++++++++------------------ pandas-stubs/core/groupby/generic.pyi | 11 ++++++++-- pandas-stubs/core/groupby/groupby.pyi | 12 +---------- 3 files changed, 20 insertions(+), 34 deletions(-) diff --git a/pandas-stubs/core/frame.pyi b/pandas-stubs/core/frame.pyi index 3d6e222f..e846bbbb 100644 --- a/pandas-stubs/core/frame.pyi +++ b/pandas-stubs/core/frame.pyi @@ -12,9 +12,7 @@ from re import Pattern from typing import ( Any, ClassVar, - Generic, Literal, - TypeVar, overload, ) @@ -27,10 +25,7 @@ from pandas import ( ) from pandas.core.arraylike import OpsMixin from pandas.core.generic import NDFrame -from pandas.core.groupby.generic import ( - DataFrameGroupBy, - SeriesGroupBy, -) +from pandas.core.groupby.generic import DataFrameGroupBy from pandas.core.groupby.grouper import Grouper from pandas.core.indexers import BaseIndexer from pandas.core.indexes.base import Index @@ -79,7 +74,6 @@ from pandas._typing import ( Axis, AxisColumn, AxisIndex, - ByT, CalculationMethod, ColspaceArgType, CompressionOptions, @@ -235,11 +229,6 @@ class _LocIndexerFrame(_LocIndexer): value: Scalar | NAType | NaTType | ArrayLike | Series | list | None, ) -> None: ... -_TT = TypeVar("TT", bound=Literal[True, False]) - -class DataFrameGroupByGen(DataFrameGroupBy[ByT], Generic[ByT, _TT]): ... -class SeriesGroupByGen(SeriesGroupBy, Generic[_TT, ByT]): ... - class DataFrame(NDFrame, OpsMixin): __hash__: ClassVar[None] # type: ignore[assignment] @@ -1073,7 +1062,7 @@ class DataFrame(NDFrame, OpsMixin): group_keys: _bool = ..., observed: _bool | NoDefault = ..., dropna: _bool = ..., - ) -> DataFrameGroupByGen[Scalar, Literal[True]]: ... + ) -> DataFrameGroupBy[Scalar, Literal[True]]: ... @overload def groupby( self, @@ -1085,7 +1074,7 @@ class DataFrame(NDFrame, OpsMixin): group_keys: _bool = ..., observed: _bool | NoDefault = ..., dropna: _bool = ..., - ) -> DataFrameGroupByGen[Scalar, Literal[False]]: ... + ) -> DataFrameGroupBy[Scalar, Literal[False]]: ... @overload def groupby( self, @@ -1097,7 +1086,7 @@ class DataFrame(NDFrame, OpsMixin): group_keys: _bool = ..., observed: _bool | NoDefault = ..., dropna: _bool = ..., - ) -> DataFrameGroupBy[Timestamp]: ... + ) -> DataFrameGroupBy[Timestamp, bool]: ... @overload def groupby( self, @@ -1109,7 +1098,7 @@ class DataFrame(NDFrame, OpsMixin): group_keys: _bool = ..., observed: _bool | NoDefault = ..., dropna: _bool = ..., - ) -> DataFrameGroupBy[Timedelta]: ... + ) -> DataFrameGroupBy[Timedelta, bool]: ... @overload def groupby( self, @@ -1121,7 +1110,7 @@ class DataFrame(NDFrame, OpsMixin): group_keys: _bool = ..., observed: _bool | NoDefault = ..., dropna: _bool = ..., - ) -> DataFrameGroupBy[Period]: ... + ) -> DataFrameGroupBy[Period, bool]: ... @overload def groupby( self, @@ -1133,7 +1122,7 @@ class DataFrame(NDFrame, OpsMixin): group_keys: _bool = ..., observed: _bool | NoDefault = ..., dropna: _bool = ..., - ) -> DataFrameGroupBy[IntervalT]: ... + ) -> DataFrameGroupBy[IntervalT, bool]: ... @overload def groupby( self, @@ -1145,7 +1134,7 @@ class DataFrame(NDFrame, OpsMixin): group_keys: _bool = ..., observed: _bool | NoDefault = ..., dropna: _bool = ..., - ) -> DataFrameGroupBy[tuple]: ... + ) -> DataFrameGroupBy[tuple, bool]: ... @overload def groupby( self, @@ -1157,7 +1146,7 @@ class DataFrame(NDFrame, OpsMixin): group_keys: _bool = ..., observed: _bool | NoDefault = ..., dropna: _bool = ..., - ) -> DataFrameGroupBy[SeriesByT]: ... + ) -> DataFrameGroupBy[SeriesByT, bool]: ... @overload def groupby( self, @@ -1169,7 +1158,7 @@ class DataFrame(NDFrame, OpsMixin): group_keys: _bool = ..., observed: _bool | NoDefault = ..., dropna: _bool = ..., - ) -> DataFrameGroupBy[Any]: ... + ) -> DataFrameGroupBy[Any, bool]: ... def pivot( self, *, diff --git a/pandas-stubs/core/groupby/generic.pyi b/pandas-stubs/core/groupby/generic.pyi index 5ccc4179..56f8fdb8 100644 --- a/pandas-stubs/core/groupby/generic.pyi +++ b/pandas-stubs/core/groupby/generic.pyi @@ -11,6 +11,7 @@ from typing import ( Generic, Literal, NamedTuple, + TypeVar, final, overload, ) @@ -182,7 +183,9 @@ class SeriesGroupBy(GroupBy[Series[S1]], Generic[S1, ByT]): self, ) -> Iterator[tuple[ByT, Series[S1]]]: ... -class DataFrameGroupBy(GroupBy[DataFrame], Generic[ByT]): +_TT = TypeVar("_TT", bound=Literal[True, False]) + +class DataFrameGroupBy(GroupBy[DataFrame], Generic[ByT, _TT]): # error: Overload 3 for "apply" will never be used because its parameters overlap overload 1 @overload # type: ignore[override] def apply( # type: ignore[overload-overlap] @@ -236,7 +239,7 @@ class DataFrameGroupBy(GroupBy[DataFrame], Generic[ByT]): @overload def __getitem__( # pyright: ignore[reportIncompatibleMethodOverride] self, key: Iterable[Hashable] | slice - ) -> DataFrameGroupBy[ByT]: ... + ) -> DataFrameGroupBy[ByT, bool]: ... def nunique(self, dropna: bool = ...) -> DataFrame: ... def idxmax( self, @@ -388,3 +391,7 @@ class DataFrameGroupBy(GroupBy[DataFrame], Generic[ByT]): def __iter__( # pyright: ignore[reportIncompatibleMethodOverride] self, ) -> Iterator[tuple[ByT, DataFrame]]: ... + @overload + def size(self: DataFrameGroupBy[ByT, Literal[True]]) -> Series[int]: ... + @overload + def size(self: DataFrameGroupBy[ByT, Literal[False]]) -> DataFrame: ... diff --git a/pandas-stubs/core/groupby/groupby.pyi b/pandas-stubs/core/groupby/groupby.pyi index beb03d6d..c638dc28 100644 --- a/pandas-stubs/core/groupby/groupby.pyi +++ b/pandas-stubs/core/groupby/groupby.pyi @@ -18,10 +18,7 @@ from typing import ( import numpy as np from pandas.core.base import SelectionMixin -from pandas.core.frame import ( - DataFrame, - DataFrameGroupByGen, -) +from pandas.core.frame import DataFrame from pandas.core.groupby import ( generic, ops, @@ -56,7 +53,6 @@ from pandas._typing import ( AnyArrayLike, Axis, AxisInt, - ByT, CalculationMethod, Dtype, Frequency, @@ -236,13 +232,7 @@ class GroupBy(BaseGroupBy[NDFrameT]): def sem( self: GroupBy[DataFrame], ddof: int = ..., numeric_only: bool = ... ) -> DataFrame: ... - @final - @overload def size(self: GroupBy[Series]) -> Series[int]: ... - @overload - def size(self: DataFrameGroupByGen[ByT, Literal[True]]) -> Series[int]: ... # type: ignore[misc] - @overload - def size(self: DataFrameGroupByGen[ByT, Literal[False]]) -> DataFrame: ... # type: ignore[misc] @final def sum( self, From 9563f04748ce44d4b180c6dcfb04ed0f9b7a3006 Mon Sep 17 00:00:00 2001 From: Loic Diridollou Date: Wed, 16 Oct 2024 20:15:57 -0400 Subject: [PATCH 5/7] Experiment for size --- pandas-stubs/core/frame.pyi | 31 +++++++++------------------ pandas-stubs/core/groupby/generic.pyi | 11 ++++++++-- pandas-stubs/core/groupby/groupby.pyi | 12 +---------- tests/test_frame.py | 16 ++++++++++++++ 4 files changed, 36 insertions(+), 34 deletions(-) diff --git a/pandas-stubs/core/frame.pyi b/pandas-stubs/core/frame.pyi index 3d6e222f..e846bbbb 100644 --- a/pandas-stubs/core/frame.pyi +++ b/pandas-stubs/core/frame.pyi @@ -12,9 +12,7 @@ from re import Pattern from typing import ( Any, ClassVar, - Generic, Literal, - TypeVar, overload, ) @@ -27,10 +25,7 @@ from pandas import ( ) from pandas.core.arraylike import OpsMixin from pandas.core.generic import NDFrame -from pandas.core.groupby.generic import ( - DataFrameGroupBy, - SeriesGroupBy, -) +from pandas.core.groupby.generic import DataFrameGroupBy from pandas.core.groupby.grouper import Grouper from pandas.core.indexers import BaseIndexer from pandas.core.indexes.base import Index @@ -79,7 +74,6 @@ from pandas._typing import ( Axis, AxisColumn, AxisIndex, - ByT, CalculationMethod, ColspaceArgType, CompressionOptions, @@ -235,11 +229,6 @@ class _LocIndexerFrame(_LocIndexer): value: Scalar | NAType | NaTType | ArrayLike | Series | list | None, ) -> None: ... -_TT = TypeVar("TT", bound=Literal[True, False]) - -class DataFrameGroupByGen(DataFrameGroupBy[ByT], Generic[ByT, _TT]): ... -class SeriesGroupByGen(SeriesGroupBy, Generic[_TT, ByT]): ... - class DataFrame(NDFrame, OpsMixin): __hash__: ClassVar[None] # type: ignore[assignment] @@ -1073,7 +1062,7 @@ class DataFrame(NDFrame, OpsMixin): group_keys: _bool = ..., observed: _bool | NoDefault = ..., dropna: _bool = ..., - ) -> DataFrameGroupByGen[Scalar, Literal[True]]: ... + ) -> DataFrameGroupBy[Scalar, Literal[True]]: ... @overload def groupby( self, @@ -1085,7 +1074,7 @@ class DataFrame(NDFrame, OpsMixin): group_keys: _bool = ..., observed: _bool | NoDefault = ..., dropna: _bool = ..., - ) -> DataFrameGroupByGen[Scalar, Literal[False]]: ... + ) -> DataFrameGroupBy[Scalar, Literal[False]]: ... @overload def groupby( self, @@ -1097,7 +1086,7 @@ class DataFrame(NDFrame, OpsMixin): group_keys: _bool = ..., observed: _bool | NoDefault = ..., dropna: _bool = ..., - ) -> DataFrameGroupBy[Timestamp]: ... + ) -> DataFrameGroupBy[Timestamp, bool]: ... @overload def groupby( self, @@ -1109,7 +1098,7 @@ class DataFrame(NDFrame, OpsMixin): group_keys: _bool = ..., observed: _bool | NoDefault = ..., dropna: _bool = ..., - ) -> DataFrameGroupBy[Timedelta]: ... + ) -> DataFrameGroupBy[Timedelta, bool]: ... @overload def groupby( self, @@ -1121,7 +1110,7 @@ class DataFrame(NDFrame, OpsMixin): group_keys: _bool = ..., observed: _bool | NoDefault = ..., dropna: _bool = ..., - ) -> DataFrameGroupBy[Period]: ... + ) -> DataFrameGroupBy[Period, bool]: ... @overload def groupby( self, @@ -1133,7 +1122,7 @@ class DataFrame(NDFrame, OpsMixin): group_keys: _bool = ..., observed: _bool | NoDefault = ..., dropna: _bool = ..., - ) -> DataFrameGroupBy[IntervalT]: ... + ) -> DataFrameGroupBy[IntervalT, bool]: ... @overload def groupby( self, @@ -1145,7 +1134,7 @@ class DataFrame(NDFrame, OpsMixin): group_keys: _bool = ..., observed: _bool | NoDefault = ..., dropna: _bool = ..., - ) -> DataFrameGroupBy[tuple]: ... + ) -> DataFrameGroupBy[tuple, bool]: ... @overload def groupby( self, @@ -1157,7 +1146,7 @@ class DataFrame(NDFrame, OpsMixin): group_keys: _bool = ..., observed: _bool | NoDefault = ..., dropna: _bool = ..., - ) -> DataFrameGroupBy[SeriesByT]: ... + ) -> DataFrameGroupBy[SeriesByT, bool]: ... @overload def groupby( self, @@ -1169,7 +1158,7 @@ class DataFrame(NDFrame, OpsMixin): group_keys: _bool = ..., observed: _bool | NoDefault = ..., dropna: _bool = ..., - ) -> DataFrameGroupBy[Any]: ... + ) -> DataFrameGroupBy[Any, bool]: ... def pivot( self, *, diff --git a/pandas-stubs/core/groupby/generic.pyi b/pandas-stubs/core/groupby/generic.pyi index 5ccc4179..56f8fdb8 100644 --- a/pandas-stubs/core/groupby/generic.pyi +++ b/pandas-stubs/core/groupby/generic.pyi @@ -11,6 +11,7 @@ from typing import ( Generic, Literal, NamedTuple, + TypeVar, final, overload, ) @@ -182,7 +183,9 @@ class SeriesGroupBy(GroupBy[Series[S1]], Generic[S1, ByT]): self, ) -> Iterator[tuple[ByT, Series[S1]]]: ... -class DataFrameGroupBy(GroupBy[DataFrame], Generic[ByT]): +_TT = TypeVar("_TT", bound=Literal[True, False]) + +class DataFrameGroupBy(GroupBy[DataFrame], Generic[ByT, _TT]): # error: Overload 3 for "apply" will never be used because its parameters overlap overload 1 @overload # type: ignore[override] def apply( # type: ignore[overload-overlap] @@ -236,7 +239,7 @@ class DataFrameGroupBy(GroupBy[DataFrame], Generic[ByT]): @overload def __getitem__( # pyright: ignore[reportIncompatibleMethodOverride] self, key: Iterable[Hashable] | slice - ) -> DataFrameGroupBy[ByT]: ... + ) -> DataFrameGroupBy[ByT, bool]: ... def nunique(self, dropna: bool = ...) -> DataFrame: ... def idxmax( self, @@ -388,3 +391,7 @@ class DataFrameGroupBy(GroupBy[DataFrame], Generic[ByT]): def __iter__( # pyright: ignore[reportIncompatibleMethodOverride] self, ) -> Iterator[tuple[ByT, DataFrame]]: ... + @overload + def size(self: DataFrameGroupBy[ByT, Literal[True]]) -> Series[int]: ... + @overload + def size(self: DataFrameGroupBy[ByT, Literal[False]]) -> DataFrame: ... diff --git a/pandas-stubs/core/groupby/groupby.pyi b/pandas-stubs/core/groupby/groupby.pyi index beb03d6d..c638dc28 100644 --- a/pandas-stubs/core/groupby/groupby.pyi +++ b/pandas-stubs/core/groupby/groupby.pyi @@ -18,10 +18,7 @@ from typing import ( import numpy as np from pandas.core.base import SelectionMixin -from pandas.core.frame import ( - DataFrame, - DataFrameGroupByGen, -) +from pandas.core.frame import DataFrame from pandas.core.groupby import ( generic, ops, @@ -56,7 +53,6 @@ from pandas._typing import ( AnyArrayLike, Axis, AxisInt, - ByT, CalculationMethod, Dtype, Frequency, @@ -236,13 +232,7 @@ class GroupBy(BaseGroupBy[NDFrameT]): def sem( self: GroupBy[DataFrame], ddof: int = ..., numeric_only: bool = ... ) -> DataFrame: ... - @final - @overload def size(self: GroupBy[Series]) -> Series[int]: ... - @overload - def size(self: DataFrameGroupByGen[ByT, Literal[True]]) -> Series[int]: ... # type: ignore[misc] - @overload - def size(self: DataFrameGroupByGen[ByT, Literal[False]]) -> DataFrame: ... # type: ignore[misc] @final def sum( self, diff --git a/tests/test_frame.py b/tests/test_frame.py index dc5adad4..c75ecd12 100644 --- a/tests/test_frame.py +++ b/tests/test_frame.py @@ -1043,6 +1043,22 @@ def test_types_groupby_as_index() -> None: ) +def test_types_groupby_size() -> None: + """Test for GH886.""" + data = [ + {"date": "2023-12-01", "val": 12}, + {"date": "2023-12-02", "val": 2}, + {"date": "2023-12-03", "val": 1}, + {"date": "2023-12-03", "val": 10}, + ] + + df = pd.DataFrame(data) + groupby = df.groupby("date") + size = groupby.size() + frame = size.to_frame() + check(assert_type(frame.reset_index(), pd.DataFrame), pd.DataFrame) + + def test_types_groupby() -> None: df = pd.DataFrame(data={"col1": [1, 1, 2], "col2": [3, 4, 5], "col3": [0, 1, 0]}) df.index.name = "ind" From 988372391aa2d55c9dffaf0b35e0b9357af78c11 Mon Sep 17 00:00:00 2001 From: Loic Diridollou Date: Wed, 16 Oct 2024 20:52:19 -0400 Subject: [PATCH 6/7] GH203 Create new overload for DatetimeIndex --- pandas-stubs/core/frame.pyi | 16 ++++++++++++++-- pandas-stubs/core/groupby/generic.pyi | 5 +++++ tests/test_frame.py | 21 +++++++++++++++++++++ 3 files changed, 40 insertions(+), 2 deletions(-) diff --git a/pandas-stubs/core/frame.pyi b/pandas-stubs/core/frame.pyi index e846bbbb..94ff9c32 100644 --- a/pandas-stubs/core/frame.pyi +++ b/pandas-stubs/core/frame.pyi @@ -1076,17 +1076,29 @@ class DataFrame(NDFrame, OpsMixin): dropna: _bool = ..., ) -> DataFrameGroupBy[Scalar, Literal[False]]: ... @overload + def groupby( # type: ignore[overload-overlap] # pyright: ignore reportOverlappingOverload + self, + by: DatetimeIndex, + axis: AxisIndex | NoDefault = ..., + level: IndexLabel | None = ..., + as_index: Literal[True] = True, + sort: _bool = ..., + group_keys: _bool = ..., + observed: _bool | NoDefault = ..., + dropna: _bool = ..., + ) -> DataFrameGroupBy[Timestamp, Literal[True]]: ... + @overload def groupby( self, by: DatetimeIndex, axis: AxisIndex | NoDefault = ..., level: IndexLabel | None = ..., - as_index: _bool = ..., + as_index: Literal[False] = ..., sort: _bool = ..., group_keys: _bool = ..., observed: _bool | NoDefault = ..., dropna: _bool = ..., - ) -> DataFrameGroupBy[Timestamp, bool]: ... + ) -> DataFrameGroupBy[Timestamp, Literal[False]]: ... @overload def groupby( self, diff --git a/pandas-stubs/core/groupby/generic.pyi b/pandas-stubs/core/groupby/generic.pyi index 56f8fdb8..2a77f060 100644 --- a/pandas-stubs/core/groupby/generic.pyi +++ b/pandas-stubs/core/groupby/generic.pyi @@ -30,6 +30,7 @@ from typing_extensions import ( ) from pandas._libs.lib import NoDefault +from pandas._libs.tslibs.timestamps import Timestamp from pandas._typing import ( S1, AggFuncTypeBase, @@ -395,3 +396,7 @@ class DataFrameGroupBy(GroupBy[DataFrame], Generic[ByT, _TT]): def size(self: DataFrameGroupBy[ByT, Literal[True]]) -> Series[int]: ... @overload def size(self: DataFrameGroupBy[ByT, Literal[False]]) -> DataFrame: ... + @overload + def size(self: DataFrameGroupBy[Timestamp, Literal[True]]) -> Series[int]: ... + @overload + def size(self: DataFrameGroupBy[Timestamp, Literal[False]]) -> DataFrame: ... diff --git a/tests/test_frame.py b/tests/test_frame.py index c75ecd12..f7f10d36 100644 --- a/tests/test_frame.py +++ b/tests/test_frame.py @@ -1043,6 +1043,27 @@ def test_types_groupby_as_index() -> None: ) +def test_types_groupby_as_index_timestamp() -> None: + """Test groupby size with DatetimeIndex.""" + idx = pd.DatetimeIndex(["2023-10-01", "2023-10-02", "2023-10-01"], name="date") + sub_idx = pd.DatetimeIndex(["2023-10-01", "2023-10-02"], name="date") + df = pd.DataFrame({"a": [1, 2, 3]}, index=idx) + check( + assert_type( + df.groupby(sub_idx, as_index=False).size(), + pd.DataFrame, + ), + pd.DataFrame, + ) + check( + assert_type( + df.groupby(sub_idx, as_index=True).size(), + "pd.Series[int]", + ), + pd.Series, + ) + + def test_types_groupby_size() -> None: """Test for GH886.""" data = [ From b0cc407eae186861cd9905c60d21733243d35816 Mon Sep 17 00:00:00 2001 From: Loic Diridollou Date: Thu, 31 Oct 2024 17:24:35 -0400 Subject: [PATCH 7/7] GH203 Fix lint --- pandas-stubs/core/frame.pyi | 2 +- tests/test_frame.py | 21 --------------------- 2 files changed, 1 insertion(+), 22 deletions(-) diff --git a/pandas-stubs/core/frame.pyi b/pandas-stubs/core/frame.pyi index 02ad8713..62192478 100644 --- a/pandas-stubs/core/frame.pyi +++ b/pandas-stubs/core/frame.pyi @@ -1090,7 +1090,7 @@ class DataFrame(NDFrame, OpsMixin): dropna: _bool = ..., ) -> DataFrameGroupBy[Timestamp, Literal[True]]: ... @overload - def groupby( + def groupby( # type: ignore[overload-overlap] self, by: DatetimeIndex, axis: AxisIndex | NoDefault = ..., diff --git a/tests/test_frame.py b/tests/test_frame.py index f7f10d36..c75ecd12 100644 --- a/tests/test_frame.py +++ b/tests/test_frame.py @@ -1043,27 +1043,6 @@ def test_types_groupby_as_index() -> None: ) -def test_types_groupby_as_index_timestamp() -> None: - """Test groupby size with DatetimeIndex.""" - idx = pd.DatetimeIndex(["2023-10-01", "2023-10-02", "2023-10-01"], name="date") - sub_idx = pd.DatetimeIndex(["2023-10-01", "2023-10-02"], name="date") - df = pd.DataFrame({"a": [1, 2, 3]}, index=idx) - check( - assert_type( - df.groupby(sub_idx, as_index=False).size(), - pd.DataFrame, - ), - pd.DataFrame, - ) - check( - assert_type( - df.groupby(sub_idx, as_index=True).size(), - "pd.Series[int]", - ), - pd.Series, - ) - - def test_types_groupby_size() -> None: """Test for GH886.""" data = [