Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Fix TypeIs for types with type params in Unions
  • Loading branch information
kreathon committed May 10, 2024
commit 8bebb702e1144f91430f6873a04c0b0666d0c9dd
31 changes: 25 additions & 6 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -5787,6 +5787,7 @@ def find_isinstance_check_helper(self, node: Expression) -> tuple[TypeMap, TypeM
self.lookup_type(expr),
[TypeRange(node.callee.type_is, is_upper_bound=False)],
expr,
ignore_type_params=False,
),
)
elif isinstance(node, ComparisonExpr):
Expand Down Expand Up @@ -7160,11 +7161,17 @@ def conditional_types_with_intersection(
type_ranges: list[TypeRange] | None,
ctx: Context,
default: None = None,
ignore_type_params: bool = True,
) -> tuple[Type | None, Type | None]: ...

@overload
def conditional_types_with_intersection(
self, expr_type: Type, type_ranges: list[TypeRange] | None, ctx: Context, default: Type
self,
expr_type: Type,
type_ranges: list[TypeRange] | None,
ctx: Context,
default: Type,
ignore_type_params: bool = True,
) -> tuple[Type, Type]: ...

def conditional_types_with_intersection(
Expand All @@ -7173,8 +7180,9 @@ def conditional_types_with_intersection(
type_ranges: list[TypeRange] | None,
ctx: Context,
default: Type | None = None,
ignore_type_params: bool = True,
) -> tuple[Type | None, Type | None]:
initial_types = conditional_types(expr_type, type_ranges, default)
initial_types = conditional_types(expr_type, type_ranges, default, ignore_type_params)
# For some reason, doing "yes_map, no_map = conditional_types_to_typemaps(...)"
# doesn't work: mypyc will decide that 'yes_map' is of type None if we try.
yes_type: Type | None = initial_types[0]
Expand Down Expand Up @@ -7422,18 +7430,27 @@ def visit_type_var(self, t: TypeVarType) -> None:

@overload
def conditional_types(
current_type: Type, proposed_type_ranges: list[TypeRange] | None, default: None = None
current_type: Type,
proposed_type_ranges: list[TypeRange] | None,
default: None = None,
ignore_type_params: bool = True,
) -> tuple[Type | None, Type | None]: ...


@overload
def conditional_types(
current_type: Type, proposed_type_ranges: list[TypeRange] | None, default: Type
current_type: Type,
proposed_type_ranges: list[TypeRange] | None,
default: Type,
ignore_type_params: bool = True,
) -> tuple[Type, Type]: ...


def conditional_types(
current_type: Type, proposed_type_ranges: list[TypeRange] | None, default: Type | None = None
current_type: Type,
proposed_type_ranges: list[TypeRange] | None,
default: Type | None = None,
ignore_type_params: bool = True,
) -> tuple[Type | None, Type | None]:
"""Takes in the current type and a proposed type of an expression.

Expand Down Expand Up @@ -7477,7 +7494,9 @@ def conditional_types(
if not type_range.is_upper_bound
]
)
remaining_type = restrict_subtype_away(current_type, proposed_precise_type)
remaining_type = restrict_subtype_away(
current_type, proposed_precise_type, ignore_type_params=ignore_type_params
)
return proposed_type, remaining_type
else:
# An isinstance check, but we don't understand the type
Expand Down
26 changes: 17 additions & 9 deletions mypy/subtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1910,37 +1910,45 @@ def try_restrict_literal_union(t: UnionType, s: Type) -> list[Type] | None:
return new_items


def restrict_subtype_away(t: Type, s: Type) -> Type:
"""Return t minus s for runtime type assertions.
def restrict_subtype_away(t: Type, s: Type, ignore_type_params: bool = True) -> Type:
"""Return t minus s for runtime type assertions and TypeIs[].

If we can't determine a precise result, return a supertype of the
ideal result (just t is a valid result).

This is used for type inference of runtime type checks such as
isinstance(). Currently, this just removes elements of a union type.
isinstance() or TypeIs[]. Currently, this just removes elements
of a union type.
"""
p_t = get_proper_type(t)
if isinstance(p_t, UnionType):
new_items = try_restrict_literal_union(p_t, s)
if new_items is None:
new_items = [
restrict_subtype_away(item, s)
restrict_subtype_away(item, s, ignore_type_params=ignore_type_params)
for item in p_t.relevant_items()
if (isinstance(get_proper_type(item), AnyType) or not covers_at_runtime(item, s))
if (
isinstance(get_proper_type(item), AnyType)
or not covers_type(item, s, ignore_type_params)
)
]
return UnionType.make_union(new_items)
elif covers_at_runtime(t, s):
elif covers_type(t, s, ignore_type_params):
return UninhabitedType()
else:
return t


def covers_at_runtime(item: Type, supertype: Type) -> bool:
"""Will isinstance(item, supertype) always return True at runtime?"""
def covers_type(item: Type, supertype: Type, ignore_type_params: bool = True) -> bool:
"""Checks if item is covered by supertype."""
item = get_proper_type(item)
supertype = get_proper_type(supertype)

# Since runtime type checks will ignore type arguments, erase the types.
if not ignore_type_params:
return is_proper_subtype(item, supertype, ignore_promotions=True)

# The following code is used for isinstance(), where ignore the type
# params is important (since this happens at runtime)
supertype = erase_type(supertype)
if is_proper_subtype(
erase_type(item), supertype, ignore_promotions=True, erase_instances=True
Expand Down
11 changes: 11 additions & 0 deletions test-data/unit/check-typeis.test
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,17 @@ def main(a: object) -> None:
reveal_type(a) # N: Revealed type is "Union[builtins.int, builtins.str]"
[builtins fixtures/tuple.pyi]

[case testTypeIsUnionWithTypeParams]
from typing_extensions import TypeIs
from typing import Iterable, List, Union
def is_iterable_int(val: object) -> TypeIs[Iterable[int]]: pass
def main(a: Union[List[int], List[str]]) -> None:
if is_iterable_int(a):
reveal_type(a) # N: Revealed type is "builtins.list[builtins.int]"
else:
reveal_type(a) # N: Revealed type is "builtins.list[builtins.str]"
[builtins fixtures/tuple.pyi]

[case testTypeIsNonzeroFloat]
from typing_extensions import TypeIs
def is_nonzero(a: object) -> TypeIs[float]: pass
Expand Down