Skip to content

Commit ba50426

Browse files
bashtageKevin Sheppard
andauthored
ENH: Improve typing of crosstab (#376)
* ENH: Improve typing of crosstab * TYP: Add second HashableT * ENH: Add overload to increase specificity * ENH: Verify Index and Categorical for crosstab * MAINT: Catch warning * CLN: Remove print statements * TYP: Expand supported types to include ExtensionArray Co-authored-by: Kevin Sheppard <[email protected]>
1 parent 69a6b85 commit ba50426

File tree

3 files changed

+152
-7
lines changed

3 files changed

+152
-7
lines changed

pandas-stubs/core/reshape/pivot.pyi

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,32 @@
11
from typing import (
22
Callable,
3+
Hashable,
34
Literal,
45
Sequence,
6+
TypeVar,
7+
Union,
8+
overload,
59
)
610

11+
import numpy as np
12+
import pandas as pd
713
from pandas.core.frame import DataFrame
814
from pandas.core.groupby.grouper import Grouper
915
from pandas.core.series import Series
16+
from typing_extensions import TypeAlias
1017

1118
from pandas._typing import (
19+
AnyArrayLike,
20+
ArrayLike,
21+
HashableT,
1222
IndexLabel,
1323
Scalar,
1424
)
1525

26+
_ExtendedAnyArrayLike: TypeAlias = Union[AnyArrayLike, ArrayLike]
27+
28+
_HashableT2 = TypeVar("_HashableT2", bound=Hashable)
29+
1630
def pivot_table(
1731
data: DataFrame,
1832
values: str | None = ...,
@@ -32,13 +46,28 @@ def pivot(
3246
columns: IndexLabel = ...,
3347
values: IndexLabel = ...,
3448
) -> DataFrame: ...
49+
@overload
50+
def crosstab(
51+
index: list | _ExtendedAnyArrayLike | list[Sequence | _ExtendedAnyArrayLike],
52+
columns: list | _ExtendedAnyArrayLike | list[Sequence | _ExtendedAnyArrayLike],
53+
values: list | _ExtendedAnyArrayLike,
54+
rownames: list[HashableT] | None = ...,
55+
colnames: list[_HashableT2] | None = ...,
56+
*,
57+
aggfunc: str | np.ufunc | Callable[[Series], float],
58+
margins: bool = ...,
59+
margins_name: str = ...,
60+
dropna: bool = ...,
61+
normalize: bool | Literal[0, 1, "all", "index", "columns"] = ...,
62+
) -> DataFrame: ...
63+
@overload
3564
def crosstab(
36-
index: Sequence | Series,
37-
columns: Sequence | Series,
38-
values: Sequence | None = ...,
39-
rownames: Sequence | None = ...,
40-
colnames: Sequence | None = ...,
41-
aggfunc: Callable | None = ...,
65+
index: list | _ExtendedAnyArrayLike | list[Sequence | _ExtendedAnyArrayLike],
66+
columns: list | _ExtendedAnyArrayLike | list[Sequence | _ExtendedAnyArrayLike],
67+
values: None = ...,
68+
rownames: list[HashableT] | None = ...,
69+
colnames: list[_HashableT2] | None = ...,
70+
aggfunc: None = ...,
4271
margins: bool = ...,
4372
margins_name: str = ...,
4473
dropna: bool = ...,

tests/test_interval.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@ def test_max_intervals() -> None:
3333
i2 = pd.Interval(
3434
pd.Timestamp("2000-01-01T12:00:00"), pd.Timestamp("2000-01-02"), closed="both"
3535
)
36-
print(max(i1.left, i2.left))
3736

3837

3938
def test_interval_length() -> None:

tests/test_pandas.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1103,3 +1103,120 @@ def test_merge_asof() -> None:
11031103
),
11041104
pd.DataFrame,
11051105
)
1106+
1107+
1108+
def test_crosstab_args() -> None:
1109+
a = [1, 2, 3, 4, 1, 2, 3, 4, 1, 2]
1110+
b: list = [4, 5, 6, 3, 4, 3, 5, 6, 5, 5]
1111+
c = [1, 3, 2, 3, 1, 2, 3, 1, 3, 2]
1112+
check(assert_type(pd.crosstab(a, b), pd.DataFrame), pd.DataFrame)
1113+
check(assert_type(pd.crosstab(a, [b, c]), pd.DataFrame), pd.DataFrame)
1114+
check(
1115+
assert_type(pd.crosstab(np.array(a), np.array(b)), pd.DataFrame), pd.DataFrame
1116+
)
1117+
check(
1118+
assert_type(pd.crosstab(np.array(a), [np.array(b), np.array(c)]), pd.DataFrame),
1119+
pd.DataFrame,
1120+
)
1121+
check(
1122+
assert_type(pd.crosstab(pd.Series(a), pd.Series(b)), pd.DataFrame), pd.DataFrame
1123+
)
1124+
check(
1125+
assert_type(pd.crosstab(pd.Index(a), pd.Index(b)), pd.DataFrame), pd.DataFrame
1126+
)
1127+
check(
1128+
assert_type(pd.crosstab(pd.Categorical(a), pd.Categorical(b)), pd.DataFrame),
1129+
pd.DataFrame,
1130+
)
1131+
check(
1132+
assert_type(
1133+
pd.crosstab(pd.Series(a), [pd.Series(b), pd.Series(c)]), pd.DataFrame
1134+
),
1135+
pd.DataFrame,
1136+
)
1137+
values = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
1138+
check(
1139+
assert_type(pd.crosstab(a, b, values=values, aggfunc=np.sum), pd.DataFrame),
1140+
pd.DataFrame,
1141+
)
1142+
check(
1143+
assert_type(
1144+
pd.crosstab(a, b, values=pd.Index(values), aggfunc=np.sum), pd.DataFrame
1145+
),
1146+
pd.DataFrame,
1147+
)
1148+
with pytest.warns(FutureWarning):
1149+
check(
1150+
assert_type(
1151+
pd.crosstab(a, b, values=pd.Categorical(values), aggfunc=np.sum),
1152+
pd.DataFrame,
1153+
),
1154+
pd.DataFrame,
1155+
)
1156+
check(
1157+
assert_type(
1158+
pd.crosstab(a, b, values=np.array(values), aggfunc=np.sum), pd.DataFrame
1159+
),
1160+
pd.DataFrame,
1161+
)
1162+
1163+
check(
1164+
assert_type(
1165+
pd.crosstab(a, b, values=pd.Series(values), aggfunc=np.sum), pd.DataFrame
1166+
),
1167+
pd.DataFrame,
1168+
)
1169+
check(
1170+
assert_type(pd.crosstab(a, b, values=values, aggfunc=np.mean), pd.DataFrame),
1171+
pd.DataFrame,
1172+
)
1173+
check(
1174+
assert_type(pd.crosstab(a, b, values=values, aggfunc="mean"), pd.DataFrame),
1175+
pd.DataFrame,
1176+
)
1177+
1178+
def m(x: pd.Series) -> float:
1179+
return x.sum() / len(x)
1180+
1181+
check(
1182+
assert_type(pd.crosstab(a, b, values=values, aggfunc=m), pd.DataFrame),
1183+
pd.DataFrame,
1184+
)
1185+
1186+
def m2(x: pd.Series) -> int:
1187+
return int(x.sum())
1188+
1189+
check(
1190+
assert_type(pd.crosstab(a, b, values=values, aggfunc=m2), pd.DataFrame),
1191+
pd.DataFrame,
1192+
)
1193+
check(
1194+
assert_type(
1195+
pd.crosstab(a, b, margins=True, margins_name="something"), pd.DataFrame
1196+
),
1197+
pd.DataFrame,
1198+
)
1199+
check(
1200+
assert_type(pd.crosstab(a, b, margins=True, dropna=True), pd.DataFrame),
1201+
pd.DataFrame,
1202+
)
1203+
check(
1204+
assert_type(pd.crosstab(a, b, colnames=["a"], rownames=["b"]), pd.DataFrame),
1205+
pd.DataFrame,
1206+
)
1207+
rownames: list[tuple] = [("b", 1)]
1208+
colnames: list[tuple] = [("a",)]
1209+
check(
1210+
assert_type(
1211+
pd.crosstab(a, b, colnames=colnames, rownames=rownames),
1212+
pd.DataFrame,
1213+
),
1214+
pd.DataFrame,
1215+
)
1216+
check(assert_type(pd.crosstab(a, b, normalize=0), pd.DataFrame), pd.DataFrame)
1217+
check(assert_type(pd.crosstab(a, b, normalize=1), pd.DataFrame), pd.DataFrame)
1218+
check(assert_type(pd.crosstab(a, b, normalize="all"), pd.DataFrame), pd.DataFrame)
1219+
check(assert_type(pd.crosstab(a, b, normalize="index"), pd.DataFrame), pd.DataFrame)
1220+
check(
1221+
assert_type(pd.crosstab(a, b, normalize="columns"), pd.DataFrame), pd.DataFrame
1222+
)

0 commit comments

Comments
 (0)