diff --git a/pandas-stubs/core/frame.pyi b/pandas-stubs/core/frame.pyi index ec12ba00..62192478 100644 --- a/pandas-stubs/core/frame.pyi +++ b/pandas-stubs/core/frame.pyi @@ -1054,29 +1054,53 @@ class DataFrame(NDFrame, OpsMixin): errors: IgnoreRaise = ..., ) -> None: ... @overload - def groupby( + def groupby( # type: ignore[overload-overlap] # pyright: ignore reportOverlappingOverload self, by: Scalar, axis: AxisIndex | NoDefault = ..., level: IndexLabel | None = ..., - as_index: _bool = ..., + as_index: Literal[True] = True, sort: _bool = ..., group_keys: _bool = ..., observed: _bool | NoDefault = ..., dropna: _bool = ..., - ) -> DataFrameGroupBy[Scalar]: ... + ) -> DataFrameGroupBy[Scalar, Literal[True]]: ... @overload def groupby( + self, + by: Scalar, + axis: AxisIndex | NoDefault = ..., + level: IndexLabel | None = ..., + as_index: Literal[False] = ..., + sort: _bool = ..., + group_keys: _bool = ..., + observed: _bool | NoDefault = ..., + dropna: _bool = ..., + ) -> DataFrameGroupBy[Scalar, Literal[False]]: ... + @overload + def groupby( # type: ignore[overload-overlap] # pyright: ignore reportOverlappingOverload self, by: DatetimeIndex, axis: AxisIndex | NoDefault = ..., level: IndexLabel | None = ..., - as_index: _bool = ..., + as_index: Literal[True] = True, + sort: _bool = ..., + group_keys: _bool = ..., + observed: _bool | NoDefault = ..., + dropna: _bool = ..., + ) -> DataFrameGroupBy[Timestamp, Literal[True]]: ... + @overload + def groupby( # type: ignore[overload-overlap] + self, + by: DatetimeIndex, + axis: AxisIndex | NoDefault = ..., + level: IndexLabel | None = ..., + as_index: Literal[False] = ..., sort: _bool = ..., group_keys: _bool = ..., observed: _bool | NoDefault = ..., dropna: _bool = ..., - ) -> DataFrameGroupBy[Timestamp]: ... + ) -> DataFrameGroupBy[Timestamp, Literal[False]]: ... @overload def groupby( self, @@ -1088,7 +1112,7 @@ class DataFrame(NDFrame, OpsMixin): group_keys: _bool = ..., observed: _bool | NoDefault = ..., dropna: _bool = ..., - ) -> DataFrameGroupBy[Timedelta]: ... + ) -> DataFrameGroupBy[Timedelta, bool]: ... @overload def groupby( self, @@ -1100,7 +1124,7 @@ class DataFrame(NDFrame, OpsMixin): group_keys: _bool = ..., observed: _bool | NoDefault = ..., dropna: _bool = ..., - ) -> DataFrameGroupBy[Period]: ... + ) -> DataFrameGroupBy[Period, bool]: ... @overload def groupby( self, @@ -1112,7 +1136,7 @@ class DataFrame(NDFrame, OpsMixin): group_keys: _bool = ..., observed: _bool | NoDefault = ..., dropna: _bool = ..., - ) -> DataFrameGroupBy[IntervalT]: ... + ) -> DataFrameGroupBy[IntervalT, bool]: ... @overload def groupby( self, @@ -1124,7 +1148,7 @@ class DataFrame(NDFrame, OpsMixin): group_keys: _bool = ..., observed: _bool | NoDefault = ..., dropna: _bool = ..., - ) -> DataFrameGroupBy[tuple]: ... + ) -> DataFrameGroupBy[tuple, bool]: ... @overload def groupby( self, @@ -1136,7 +1160,7 @@ class DataFrame(NDFrame, OpsMixin): group_keys: _bool = ..., observed: _bool | NoDefault = ..., dropna: _bool = ..., - ) -> DataFrameGroupBy[SeriesByT]: ... + ) -> DataFrameGroupBy[SeriesByT, bool]: ... @overload def groupby( self, @@ -1148,7 +1172,7 @@ class DataFrame(NDFrame, OpsMixin): group_keys: _bool = ..., observed: _bool | NoDefault = ..., dropna: _bool = ..., - ) -> DataFrameGroupBy[Any]: ... + ) -> DataFrameGroupBy[Any, bool]: ... def pivot( self, *, diff --git a/pandas-stubs/core/groupby/generic.pyi b/pandas-stubs/core/groupby/generic.pyi index 3d43e75e..8f2b5588 100644 --- a/pandas-stubs/core/groupby/generic.pyi +++ b/pandas-stubs/core/groupby/generic.pyi @@ -11,6 +11,7 @@ from typing import ( Generic, Literal, NamedTuple, + TypeVar, final, overload, ) @@ -29,6 +30,7 @@ from typing_extensions import ( ) from pandas._libs.lib import NoDefault +from pandas._libs.tslibs.timestamps import Timestamp from pandas._typing import ( S1, AggFuncTypeBase, @@ -182,7 +184,9 @@ class SeriesGroupBy(GroupBy[Series[S1]], Generic[S1, ByT]): self, ) -> Iterator[tuple[ByT, Series[S1]]]: ... -class DataFrameGroupBy(GroupBy[DataFrame], Generic[ByT]): +_TT = TypeVar("_TT", bound=Literal[True, False]) + +class DataFrameGroupBy(GroupBy[DataFrame], Generic[ByT, _TT]): # error: Overload 3 for "apply" will never be used because its parameters overlap overload 1 @overload # type: ignore[override] def apply( # type: ignore[overload-overlap] @@ -236,7 +240,7 @@ class DataFrameGroupBy(GroupBy[DataFrame], Generic[ByT]): @overload def __getitem__( # pyright: ignore[reportIncompatibleMethodOverride, reportOverlappingOverload] self, key: Iterable[Hashable] | slice - ) -> DataFrameGroupBy[ByT]: ... + ) -> DataFrameGroupBy[ByT, bool]: ... def nunique(self, dropna: bool = ...) -> DataFrame: ... def idxmax( self, @@ -388,3 +392,11 @@ class DataFrameGroupBy(GroupBy[DataFrame], Generic[ByT]): def __iter__( # pyright: ignore[reportIncompatibleMethodOverride] self, ) -> Iterator[tuple[ByT, DataFrame]]: ... + @overload + def size(self: DataFrameGroupBy[ByT, Literal[True]]) -> Series[int]: ... + @overload + def size(self: DataFrameGroupBy[ByT, Literal[False]]) -> DataFrame: ... + @overload + def size(self: DataFrameGroupBy[Timestamp, Literal[True]]) -> Series[int]: ... + @overload + def size(self: DataFrameGroupBy[Timestamp, Literal[False]]) -> DataFrame: ... diff --git a/pandas-stubs/core/groupby/groupby.pyi b/pandas-stubs/core/groupby/groupby.pyi index a009e50c..66e6f825 100644 --- a/pandas-stubs/core/groupby/groupby.pyi +++ b/pandas-stubs/core/groupby/groupby.pyi @@ -232,11 +232,7 @@ class GroupBy(BaseGroupBy[NDFrameT]): def sem( self: GroupBy[DataFrame], ddof: int = ..., numeric_only: bool = ... ) -> DataFrame: ... - @final - @overload def size(self: GroupBy[Series]) -> Series[int]: ... - @overload # return type depends on `as_index` for dataframe groupby - def size(self: GroupBy[DataFrame]) -> DataFrame | Series[int]: ... @final def sum( self, diff --git a/tests/test_frame.py b/tests/test_frame.py index 64198952..c75ecd12 100644 --- a/tests/test_frame.py +++ b/tests/test_frame.py @@ -1025,6 +1025,40 @@ def test_types_pivot_table() -> None: ) +def test_types_groupby_as_index() -> None: + df = pd.DataFrame({"a": [1, 2, 3]}) + check( + assert_type( + df.groupby("a", as_index=False).size(), + pd.DataFrame, + ), + pd.DataFrame, + ) + check( + assert_type( + df.groupby("a", as_index=True).size(), + "pd.Series[int]", + ), + pd.Series, + ) + + +def test_types_groupby_size() -> None: + """Test for GH886.""" + data = [ + {"date": "2023-12-01", "val": 12}, + {"date": "2023-12-02", "val": 2}, + {"date": "2023-12-03", "val": 1}, + {"date": "2023-12-03", "val": 10}, + ] + + df = pd.DataFrame(data) + groupby = df.groupby("date") + size = groupby.size() + frame = size.to_frame() + check(assert_type(frame.reset_index(), pd.DataFrame), pd.DataFrame) + + def test_types_groupby() -> None: df = pd.DataFrame(data={"col1": [1, 1, 2], "col2": [3, 4, 5], "col3": [0, 1, 0]}) df.index.name = "ind"