diff --git a/pandas-stubs/core/reshape/pivot.pyi b/pandas-stubs/core/reshape/pivot.pyi index ab327bc8a..4d71977c6 100644 --- a/pandas-stubs/core/reshape/pivot.pyi +++ b/pandas-stubs/core/reshape/pivot.pyi @@ -1,18 +1,32 @@ from typing import ( Callable, + Hashable, Literal, Sequence, + TypeVar, + Union, + overload, ) +import numpy as np +import pandas as pd from pandas.core.frame import DataFrame from pandas.core.groupby.grouper import Grouper from pandas.core.series import Series +from typing_extensions import TypeAlias from pandas._typing import ( + AnyArrayLike, + ArrayLike, + HashableT, IndexLabel, Scalar, ) +_ExtendedAnyArrayLike: TypeAlias = Union[AnyArrayLike, ArrayLike] + +_HashableT2 = TypeVar("_HashableT2", bound=Hashable) + def pivot_table( data: DataFrame, values: str | None = ..., @@ -32,13 +46,28 @@ def pivot( columns: IndexLabel = ..., values: IndexLabel = ..., ) -> DataFrame: ... +@overload +def crosstab( + index: list | _ExtendedAnyArrayLike | list[Sequence | _ExtendedAnyArrayLike], + columns: list | _ExtendedAnyArrayLike | list[Sequence | _ExtendedAnyArrayLike], + values: list | _ExtendedAnyArrayLike, + rownames: list[HashableT] | None = ..., + colnames: list[_HashableT2] | None = ..., + *, + aggfunc: str | np.ufunc | Callable[[Series], float], + margins: bool = ..., + margins_name: str = ..., + dropna: bool = ..., + normalize: bool | Literal[0, 1, "all", "index", "columns"] = ..., +) -> DataFrame: ... +@overload def crosstab( - index: Sequence | Series, - columns: Sequence | Series, - values: Sequence | None = ..., - rownames: Sequence | None = ..., - colnames: Sequence | None = ..., - aggfunc: Callable | None = ..., + index: list | _ExtendedAnyArrayLike | list[Sequence | _ExtendedAnyArrayLike], + columns: list | _ExtendedAnyArrayLike | list[Sequence | _ExtendedAnyArrayLike], + values: None = ..., + rownames: list[HashableT] | None = ..., + colnames: list[_HashableT2] | None = ..., + aggfunc: None = ..., margins: bool = ..., margins_name: str = ..., dropna: bool = ..., diff --git a/tests/test_interval.py b/tests/test_interval.py index 5c23f3e00..38a0dbf12 100644 --- a/tests/test_interval.py +++ b/tests/test_interval.py @@ -33,7 +33,6 @@ def test_max_intervals() -> None: i2 = pd.Interval( pd.Timestamp("2000-01-01T12:00:00"), pd.Timestamp("2000-01-02"), closed="both" ) - print(max(i1.left, i2.left)) def test_interval_length() -> None: diff --git a/tests/test_pandas.py b/tests/test_pandas.py index 742da67cb..07d2ac665 100644 --- a/tests/test_pandas.py +++ b/tests/test_pandas.py @@ -1103,3 +1103,120 @@ def test_merge_asof() -> None: ), pd.DataFrame, ) + + +def test_crosstab_args() -> None: + a = [1, 2, 3, 4, 1, 2, 3, 4, 1, 2] + b: list = [4, 5, 6, 3, 4, 3, 5, 6, 5, 5] + c = [1, 3, 2, 3, 1, 2, 3, 1, 3, 2] + check(assert_type(pd.crosstab(a, b), pd.DataFrame), pd.DataFrame) + check(assert_type(pd.crosstab(a, [b, c]), pd.DataFrame), pd.DataFrame) + check( + assert_type(pd.crosstab(np.array(a), np.array(b)), pd.DataFrame), pd.DataFrame + ) + check( + assert_type(pd.crosstab(np.array(a), [np.array(b), np.array(c)]), pd.DataFrame), + pd.DataFrame, + ) + check( + assert_type(pd.crosstab(pd.Series(a), pd.Series(b)), pd.DataFrame), pd.DataFrame + ) + check( + assert_type(pd.crosstab(pd.Index(a), pd.Index(b)), pd.DataFrame), pd.DataFrame + ) + check( + assert_type(pd.crosstab(pd.Categorical(a), pd.Categorical(b)), pd.DataFrame), + pd.DataFrame, + ) + check( + assert_type( + pd.crosstab(pd.Series(a), [pd.Series(b), pd.Series(c)]), pd.DataFrame + ), + pd.DataFrame, + ) + values = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + check( + assert_type(pd.crosstab(a, b, values=values, aggfunc=np.sum), pd.DataFrame), + pd.DataFrame, + ) + check( + assert_type( + pd.crosstab(a, b, values=pd.Index(values), aggfunc=np.sum), pd.DataFrame + ), + pd.DataFrame, + ) + with pytest.warns(FutureWarning): + check( + assert_type( + pd.crosstab(a, b, values=pd.Categorical(values), aggfunc=np.sum), + pd.DataFrame, + ), + pd.DataFrame, + ) + check( + assert_type( + pd.crosstab(a, b, values=np.array(values), aggfunc=np.sum), pd.DataFrame + ), + pd.DataFrame, + ) + + check( + assert_type( + pd.crosstab(a, b, values=pd.Series(values), aggfunc=np.sum), pd.DataFrame + ), + pd.DataFrame, + ) + check( + assert_type(pd.crosstab(a, b, values=values, aggfunc=np.mean), pd.DataFrame), + pd.DataFrame, + ) + check( + assert_type(pd.crosstab(a, b, values=values, aggfunc="mean"), pd.DataFrame), + pd.DataFrame, + ) + + def m(x: pd.Series) -> float: + return x.sum() / len(x) + + check( + assert_type(pd.crosstab(a, b, values=values, aggfunc=m), pd.DataFrame), + pd.DataFrame, + ) + + def m2(x: pd.Series) -> int: + return int(x.sum()) + + check( + assert_type(pd.crosstab(a, b, values=values, aggfunc=m2), pd.DataFrame), + pd.DataFrame, + ) + check( + assert_type( + pd.crosstab(a, b, margins=True, margins_name="something"), pd.DataFrame + ), + pd.DataFrame, + ) + check( + assert_type(pd.crosstab(a, b, margins=True, dropna=True), pd.DataFrame), + pd.DataFrame, + ) + check( + assert_type(pd.crosstab(a, b, colnames=["a"], rownames=["b"]), pd.DataFrame), + pd.DataFrame, + ) + rownames: list[tuple] = [("b", 1)] + colnames: list[tuple] = [("a",)] + check( + assert_type( + pd.crosstab(a, b, colnames=colnames, rownames=rownames), + pd.DataFrame, + ), + pd.DataFrame, + ) + check(assert_type(pd.crosstab(a, b, normalize=0), pd.DataFrame), pd.DataFrame) + check(assert_type(pd.crosstab(a, b, normalize=1), pd.DataFrame), pd.DataFrame) + check(assert_type(pd.crosstab(a, b, normalize="all"), pd.DataFrame), pd.DataFrame) + check(assert_type(pd.crosstab(a, b, normalize="index"), pd.DataFrame), pd.DataFrame) + check( + assert_type(pd.crosstab(a, b, normalize="columns"), pd.DataFrame), pd.DataFrame + )