Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
d157fa1
wip
MarcoGorelli Aug 26, 2025
b305640
maybe fix the overloads
MarcoGorelli Aug 26, 2025
1e15cb9
wip tests
MarcoGorelli Aug 26, 2025
517c7d4
fixup test
MarcoGorelli Aug 26, 2025
c54fe63
appease mypy
MarcoGorelli Aug 26, 2025
527f00f
wip
MarcoGorelli Aug 26, 2025
e5c8df7
fixup
MarcoGorelli Aug 26, 2025
2ae0638
fixup
MarcoGorelli Aug 26, 2025
bc73767
remove ndarray overloads as they dont really make sense for a nullabl…
MarcoGorelli Aug 26, 2025
2527f5d
fixup
MarcoGorelli Aug 26, 2025
6760c3d
fixup
MarcoGorelli Aug 26, 2025
eef5694
Merge remote-tracking branch 'upstream/main' into natype-arithemtic
MarcoGorelli Aug 26, 2025
a686ae9
fixup
MarcoGorelli Aug 26, 2025
08e604c
mypy fixup
MarcoGorelli Aug 26, 2025
a094e3e
remove `__bool__`, and comparisons (eq/ne/gt/lt/...) with `Series`/`I…
MarcoGorelli Aug 27, 2025
cb1da48
remove redundant annotation
MarcoGorelli Aug 27, 2025
340582c
comment cases which require other fixes
MarcoGorelli Aug 27, 2025
8c1eced
note pyright bug
MarcoGorelli Aug 27, 2025
83683e5
Merge remote-tracking branch 'upstream/main' into natype-arithemtic
MarcoGorelli Sep 5, 2025
ce16ae2
update pyright
MarcoGorelli Sep 5, 2025
c61f750
keep divmod(na, 1), ignore pyright, link to issue
MarcoGorelli Sep 12, 2025
7213016
uncomment some parts
MarcoGorelli Sep 12, 2025
34da99d
Merge remote-tracking branch 'upstream/main' into natype-arithemtic
MarcoGorelli Sep 12, 2025
c49dcec
remove more outdated commented-out tests
MarcoGorelli Sep 12, 2025
e2537a5
linting
MarcoGorelli Sep 12, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fixup
  • Loading branch information
MarcoGorelli committed Aug 26, 2025
commit e5c8df7a73f386fe5ff282d9d1592810ff0cfc8f
37 changes: 22 additions & 15 deletions pandas-stubs/_libs/missing.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ from pandas import (
Index,
Series,
)
from pandas.core.arrays.boolean import BooleanArray
from typing_extensions import Self

class NAType:
Expand All @@ -25,7 +26,7 @@ class NAType:
@overload
def __add__(self, other: Index, /) -> Index: ... # type: ignore[overload-overlap]
@overload
def __add__(self, other: npt.NDArray, /) -> npt.NDArray: ... # type: ignore[overload-overlap]
def __add__(self, other: npt.NDArray[Any], /) -> npt.NDArray: ... # type: ignore[overload-overlap]
@overload
def __add__(self, other: object, /) -> NAType: ...
@overload
Expand All @@ -35,8 +36,6 @@ class NAType:
@overload
def __radd__(self, other: Index, /) -> Index: ... # type: ignore[overload-overlap]
@overload
def __radd__(self, other: npt.NDArray, /) -> npt.NDArray: ... # type: ignore[overload-overlap]
@overload
def __radd__(self, other: object, /) -> NAType: ...
@overload
def __sub__( # type: ignore[overload-overlap] # pyright: ignore[reportOverlappingOverload]
Expand Down Expand Up @@ -118,7 +117,7 @@ class NAType:
self, other: Series, /
) -> Series: ...
@overload
def __rfloordiv__(self, other: Index, /) -> npt.NDArray: ... # type: ignore[overload-overlap]
def __rfloordiv__(self, other: Index, /) -> Index: ... # type: ignore[overload-overlap]
@overload
def __rfloordiv__(self, other: npt.NDArray, /) -> npt.NDArray: ... # type: ignore[overload-overlap]
@overload
Expand Down Expand Up @@ -166,9 +165,9 @@ class NAType:
@overload # type: ignore[override]
def __eq__( # type: ignore[overload-overlap] # pyright: ignore[reportOverlappingOverload]
self, other: Series, /
) -> Series: ...
) -> BooleanArray: ...
@overload
def __eq__(self, other: Index, /) -> npt.NDArray: ... # type: ignore[overload-overlap]
def __eq__(self, other: Index, /) -> BooleanArray: ... # type: ignore[overload-overlap]
@overload
def __eq__(self, other: npt.NDArray, /) -> npt.NDArray: ... # type: ignore[overload-overlap]
@overload
Expand All @@ -180,7 +179,7 @@ class NAType:
self, other: Series, /
) -> Series: ...
@overload
def __ne__(self, other: Index, /) -> npt.NDArray: ... # type: ignore[overload-overlap]
def __ne__(self, other: Index, /) -> BooleanArray: ... # type: ignore[overload-overlap]
@overload
def __ne__(self, other: npt.NDArray, /) -> npt.NDArray: ... # type: ignore[overload-overlap]
@overload
Expand All @@ -192,7 +191,7 @@ class NAType:
self, other: Series, /
) -> Series: ...
@overload
def __le__(self, other: Index, /) -> npt.NDArray: ... # type: ignore[overload-overlap]
def __le__(self, other: Index, /) -> BooleanArray: ... # type: ignore[overload-overlap]
@overload
def __le__(self, other: npt.NDArray, /) -> npt.NDArray: ... # type: ignore[overload-overlap]
@overload
Expand All @@ -202,7 +201,7 @@ class NAType:
self, other: Series, /
) -> Series: ...
@overload
def __lt__(self, other: Index, /) -> npt.NDArray: ... # type: ignore[overload-overlap]
def __lt__(self, other: Index, /) -> BooleanArray: ... # type: ignore[overload-overlap]
@overload
def __lt__(self, other: npt.NDArray, /) -> npt.NDArray: ... # type: ignore[overload-overlap]
@overload
Expand All @@ -212,7 +211,7 @@ class NAType:
self, other: Series, /
) -> Series: ...
@overload
def __gt__(self, other: Index, /) -> npt.NDArray: ... # type: ignore[overload-overlap]
def __gt__(self, other: Index, /) -> BooleanArray: ... # type: ignore[overload-overlap]
@overload
def __gt__(self, other: npt.NDArray, /) -> npt.NDArray: ... # type: ignore[overload-overlap]
@overload
Expand All @@ -222,7 +221,7 @@ class NAType:
self, other: Series, /
) -> Series: ...
@overload
def __ge__(self, other: Index, /) -> npt.NDArray: ... # type: ignore[overload-overlap]
def __ge__(self, other: Index, /) -> BooleanArray: ... # type: ignore[overload-overlap]
@overload
def __ge__(self, other: npt.NDArray, /) -> npt.NDArray: ... # type: ignore[overload-overlap]
@overload
Expand Down Expand Up @@ -260,7 +259,9 @@ class NAType:
@overload
def __and__(self, other: npt.NDArray, /) -> npt.NDArray: ... # type: ignore[overload-overlap]
@overload
def __and__(self, other: object, /) -> Literal[False] | NAType: ...
def __and__(self, other: Literal[False], /) -> Literal[False]: ...
@overload
def __and__(self, other: object, /) -> NAType: ...
@overload
def __rand__( # type: ignore[overload-overlap] # pyright: ignore[reportOverlappingOverload]
self, other: Series, /
Expand All @@ -270,7 +271,9 @@ class NAType:
@overload
def __rand__(self, other: npt.NDArray, /) -> npt.NDArray: ... # type: ignore[overload-overlap]
@overload
def __rand__(self, other: object, /) -> Literal[False] | NAType: ...
def __rand__(self, other: Literal[False], /) -> Literal[False]: ...
@overload
def __rand__(self, other: object, /) -> NAType: ...
@overload
def __or__( # type: ignore[overload-overlap] # pyright: ignore[reportOverlappingOverload]
self, other: Series, /
Expand All @@ -280,7 +283,9 @@ class NAType:
@overload
def __or__(self, other: npt.NDArray, /) -> npt.NDArray: ... # type: ignore[overload-overlap]
@overload
def __or__(self, other: object, /) -> Literal[True] | NAType: ...
def __or__(self, other: Literal[True], /) -> Literal[True]: ...
@overload
def __or__(self, other: object, /) -> NAType: ...
@overload
def __ror__( # type: ignore[overload-overlap] # pyright: ignore[reportOverlappingOverload]
self, other: Series, /
Expand All @@ -290,7 +295,9 @@ class NAType:
@overload
def __ror__(self, other: npt.NDArray, /) -> npt.NDArray: ... # type: ignore[overload-overlap]
@overload
def __ror__(self, other: object, /) -> Literal[True] | NAType: ...
def __ror__(self, other: Literal[True], /) -> Literal[True]: ...
@overload
def __ror__(self, other: object, /) -> NAType: ...
@overload
def __xor__( # type: ignore[overload-overlap] # pyright: ignore[reportOverlappingOverload]
self, other: Series, /
Expand Down
102 changes: 42 additions & 60 deletions tests/test_natype.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,24 @@
from typing import Any
from typing import (
Any,
Literal,
)

import numpy as np
import numpy.typing as npt
import pandas as pd
from pandas.api.typing import NAType
from pandas.core.arrays.boolean import BooleanArray
from typing_extensions import assert_type

from tests import check


def test_arithmetic() -> None:
na = pd.NA
s_int = pd.Series([1, 2, 3])
idx_int = pd.Index([1, 2, 3])

s_int: pd.Series[int] = pd.Series([1, 2, 3], dtype="Int64")
idx_int: pd.Index[int] = pd.Index([1, 2, 3], dtype="Int64")

arr_int: npt.NDArray[Any] = np.array([1, 2, 3])
ma_int: npt.NDArray[Any] = np.array([[1, 2, 3], [4, 5, 6]])

Expand All @@ -25,8 +31,7 @@ def test_arithmetic() -> None:
# __radd__
check(assert_type(s_int + na, pd.Series), pd.Series)
# https://github.com/pandas-dev/pandas-stubs/issues/1347
check(assert_type(idx_int + na, pd.Index), pd.Index) # type: ignore[assert-type] # pyright: ignore[reportAssertTypeFailure]
check(assert_type(arr_int + na, npt.NDArray), npt.NDArray) # type: ignore[assert-type] # pyright: ignore[reportAssertTypeFailure]
check(assert_type(idx_int + na, pd.Index), pd.Index) # type: ignore[assert-type]# pyright: ignore[reportAssertTypeFailure]
check(assert_type(1 + na, NAType), NAType)

# __sub__
Expand All @@ -37,7 +42,8 @@ def test_arithmetic() -> None:

# __rsub__
check(assert_type(s_int - na, pd.Series), pd.Series)
check(assert_type(idx_int - na, pd.Index), pd.Index) # type: ignore[assert-type] # pyright: ignore[reportAssertTypeFailure]
# https://github.com/pandas-dev/pandas-stubs/issues/1347
check(assert_type(idx_int - na, pd.Index), pd.Index) # type: ignore[assert-type]# pyright: ignore[reportAssertTypeFailure]
check(assert_type(arr_int - na, npt.NDArray), npt.NDArray) # type: ignore[assert-type] # pyright: ignore[reportAssertTypeFailure]
check(assert_type(1 - na, NAType), NAType)

Expand All @@ -49,12 +55,13 @@ def test_arithmetic() -> None:

# __rmul__
check(assert_type(s_int * na, pd.Series), pd.Series)
check(assert_type(idx_int * na, pd.Index), pd.Index) # type: ignore[assert-type] # pyright: ignore[reportAssertTypeFailure]
# https://github.com/pandas-dev/pandas-stubs/issues/1347
check(assert_type(idx_int * na, pd.Index), pd.Index) # type: ignore[assert-type]# pyright: ignore[reportAssertTypeFailure]
check(assert_type(arr_int * na, npt.NDArray), npt.NDArray) # type: ignore[assert-type] # pyright: ignore[reportAssertTypeFailure]
check(assert_type(1 * na, NAType), NAType)

# __matmul__
check(assert_type(na @ ma_int, npt.NDArray), npt.NDArray)
check(assert_type(na @ ma_int, npt.NDArray), npt.NDArray) # type: ignore[assert-type] # mypy bug? pyright fine
check(assert_type(na @ 1, NAType), NAType)

# __rmatmul__
Expand All @@ -79,8 +86,9 @@ def test_arithmetic() -> None:
check(assert_type(na // 1, NAType), NAType)

# __rfloordiv__
# TODO: put these back but use nullable series to test it
check(assert_type(s_int // na, pd.Series), pd.Series)
check(assert_type(idx_int // na, npt.NDArray), npt.NDArray)
check(assert_type(idx_int // na, pd.Index), pd.Index)
check(assert_type(arr_int // na, npt.NDArray), npt.NDArray) # type: ignore[assert-type] # pyright: ignore[reportAssertTypeFailure]
check(assert_type(1 // na, NAType), NAType)

Expand All @@ -92,121 +100,95 @@ def test_arithmetic() -> None:

# __rmod__
check(assert_type(s_int % na, pd.Series), pd.Series)
check(assert_type(idx_int % na, pd.Index), pd.Index) # type: ignore[assert-type] # pyright: ignore[reportAssertTypeFailure]
# https://github.com/pandas-dev/pandas-stubs/issues/1347
check(assert_type(idx_int % na, pd.Index), pd.Index) # type: ignore[assert-type] # pyright: ignore[reportAssertTypeFailure]
check(assert_type(arr_int % na, npt.NDArray), npt.NDArray) # type: ignore[assert-type] # pyright: ignore[reportAssertTypeFailure]
check(assert_type(1 % na, NAType), NAType)

# __eq__
check(assert_type(na == s_int, pd.Series), pd.Series)
check(assert_type(na == idx_int, npt.NDArray), npt.NDArray)
check(assert_type(na == idx_int, BooleanArray), BooleanArray)
check(assert_type(na == arr_int, npt.NDArray), npt.NDArray) # type: ignore[assert-type] # mypy bug? pyright fine
check(assert_type(na == 1, NAType), NAType)

# __req__
# check(assert_type(= s_int=na, pd.Series), pd.Series)
# check(assert_type(= idx_int=na, npt.NDArray), npt.NDArray)
# check(assert_type(= arr_int=na, npt.NDArray), npt.NDArray)
# check(assert_type(= 1=na, NAType), NAType)

# __ne__
check(assert_type(na != s_int, pd.Series), pd.Series)
check(assert_type(na != idx_int, npt.NDArray), npt.NDArray)
check(assert_type(na != idx_int, BooleanArray), BooleanArray)
check(assert_type(na != arr_int, npt.NDArray), npt.NDArray) # type: ignore[assert-type] # mypy bug? pyright fine
check(assert_type(na != 1, NAType), NAType)

# __rne__
# check(assert_type(= s_int!na, pd.Series), pd.Series)
# check(assert_type(= idx_int!na, npt.NDArray), npt.NDArray)
# check(assert_type(= arr_int!na, npt.NDArray), npt.NDArray)
# check(assert_type(= 1!na, NAType), NAType)

# __le__
check(assert_type(na <= s_int, pd.Series), pd.Series)
check(assert_type(na <= idx_int, npt.NDArray), npt.NDArray)
check(assert_type(na <= idx_int, BooleanArray), BooleanArray)
check(assert_type(na <= arr_int, npt.NDArray), npt.NDArray) # type: ignore[assert-type] # mypy bug? pyright fine
check(assert_type(na <= 1, NAType), NAType)

# __rle__
# check(assert_type(= s_int<na, pd.Series), pd.Series)
# check(assert_type(= idx_int<na, npt.NDArray), npt.NDArray)
# check(assert_type(= arr_int<na, npt.NDArray), npt.NDArray)
# check(assert_type(= 1<na, NAType), NAType)

# __lt__
check(assert_type(na < s_int, pd.Series), pd.Series)
check(assert_type(na < idx_int, npt.NDArray), npt.NDArray)
check(assert_type(na < idx_int, BooleanArray), BooleanArray)
check(assert_type(na < arr_int, npt.NDArray), npt.NDArray) # type: ignore[assert-type] # mypy bug? pyright fine
check(assert_type(na < 1, NAType), NAType)

# __rlt__
check(assert_type(s_int < na, pd.Series), pd.Series)
check(assert_type(idx_int < na, npt.NDArray), npt.NDArray)
check(assert_type(arr_int < na, npt.NDArray), npt.NDArray) # type: ignore[assert-type] # pyright: ignore[reportAssertTypeFailure]
check(assert_type(1 < na, NAType), NAType)

# __ge__
check(assert_type(na >= s_int, pd.Series), pd.Series)
check(assert_type(na >= idx_int, npt.NDArray), npt.NDArray)
check(assert_type(na >= idx_int, BooleanArray), BooleanArray)
check(assert_type(na >= arr_int, npt.NDArray), npt.NDArray) # type: ignore[assert-type] # mypy bug? pyright fine
check(assert_type(na >= 1, NAType), NAType)

# __rge__
# check(assert_type(= s_int>na, pd.Series), pd.Series)
# check(assert_type(= idx_int>na, npt.NDArray), npt.NDArray)
# check(assert_type(= arr_int>na, npt.NDArray), npt.NDArray)
# check(assert_type(= 1>na, NAType), NAType)

# __gt__
check(assert_type(na > s_int, pd.Series), pd.Series)
check(assert_type(na > idx_int, npt.NDArray), npt.NDArray)
check(assert_type(na > idx_int, BooleanArray), BooleanArray)
check(assert_type(na > arr_int, npt.NDArray), npt.NDArray) # type: ignore[assert-type] # mypy bug? pyright fine
check(assert_type(na > 1, NAType), NAType)

# __rgt__
check(assert_type(s_int > na, pd.Series), pd.Series)
check(assert_type(idx_int > na, npt.NDArray), npt.NDArray)
check(assert_type(arr_int > na, npt.NDArray), npt.NDArray) # type: ignore[assert-type] # pyright: ignore[reportAssertTypeFailure]
check(assert_type(1 > na, NAType), NAType)

# __pow__
check(assert_type(na**s_int, pd.Series), pd.Series)
check(assert_type(na**idx_int, pd.Index), pd.Index)
check(assert_type(na**arr_int, npt.NDArray), npt.NDArray) # type: ignore[assert-type] # mypy bug? pyright fine
check(assert_type(na**2, NAType), NAType)

# __rpow__
check(assert_type(s_int * na, pd.Series), pd.Series)
check(assert_type(idx_int * na, pd.Index), pd.Index) # type: ignore[assert-type] # pyright: ignore[reportAssertTypeFailure]
check(assert_type(arr_int * na, npt.NDArray), npt.NDArray) # type: ignore[assert-type] # pyright: ignore[reportAssertTypeFailure]
check(assert_type(2 * na, NAType), NAType)
check(assert_type(s_int**na, pd.Series), pd.Series)
# https://github.com/pandas-dev/pandas-stubs/issues/1347
check(assert_type(idx_int**na, pd.Index), pd.Index) # type: ignore[assert-type] # pyright: ignore[reportAssertTypeFailure]
check(assert_type(arr_int**na, npt.NDArray), npt.NDArray) # type: ignore[assert-type] # pyright: ignore[reportAssertTypeFailure]
check(assert_type(2**na, NAType), NAType)

# __and__
check(assert_type(na & s_int, pd.Series), pd.Series)
check(assert_type(na & idx_int, pd.Index), pd.Index)
check(assert_type(na & arr_int, npt.NDArray), npt.NDArray) # type: ignore[assert-type] # mypy bug? pyright fine
# check(assert_type(na & True, NAType), NAType)
check(assert_type(na & False, Literal[False]), bool)
check(assert_type(na & True, NAType), NAType)
check(assert_type(na & na, NAType), NAType)

# __rand__
check(assert_type(s_int & na, pd.Series), pd.Series)
check(assert_type(idx_int & na, pd.Index), pd.Index)
check(assert_type(arr_int & na, npt.NDArray), npt.NDArray) # type: ignore[assert-type] # pyright: ignore[reportAssertTypeFailure]
# check(ssert_typa & Trueana, NAType), NAType)
check(assert_type(False & na, Literal[False]), bool)
check(assert_type(True & na, NAType), NAType)

# __or__
check(assert_type(na | s_int, pd.Series), pd.Series)
check(assert_type(na | idx_int, pd.Index), pd.Index)
check(assert_type(na | arr_int, npt.NDArray), npt.NDArray) # type: ignore[assert-type] # mypy bug? pyright fine
# check(assert_type(na | True, NAType), NAType)
check(assert_type(na | False, NAType), NAType)
check(assert_type(na | True, Literal[True]), bool)

# __ror__
check(assert_type(s_int | na, pd.Series), pd.Series)
check(assert_type(idx_int | na, pd.Index), pd.Index)
check(assert_type(arr_int | na, npt.NDArray), npt.NDArray) # type: ignore[assert-type] # pyright: ignore[reportAssertTypeFailure]
# check(ssert_typa | Trueana, NAType), NAType)
check(assert_type(False | na, NAType), NAType)
check(assert_type(True | na, Literal[True]), bool)

# __xor__
check(assert_type(na ^ s_int, pd.Series), pd.Series)
check(assert_type(na ^ idx_int, pd.Index), pd.Index)
check(assert_type(na ^ arr_int, npt.NDArray), npt.NDArray) # type: ignore[assert-type] # mypy bug? pyright fine

# rxor
check(assert_type(s_int ^ na, pd.Series), pd.Series)
check(assert_type(idx_int ^ na, pd.Index), pd.Index)
check(assert_type(arr_int ^ na, npt.NDArray), npt.NDArray) # type: ignore[assert-type] # pyright: ignore[reportAssertTypeFailure]