Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow passing a callable with type vars in self types #18401

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 6 additions & 7 deletions mypy/solve.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,13 +73,6 @@ def solve_constraints(
# Constraints inferred from unions require special handling in polymorphic inference.
constraints = skip_reverse_union_constraints(constraints)

# Collect a list of constraints for each type variable.
cmap: dict[TypeVarId, list[Constraint]] = {tv: [] for tv in vars + extra_vars}
for con in constraints:
if con.type_var in vars + extra_vars:
cmap[con.type_var].append(con)

if allow_polymorphic:
if constraints:
solutions, free_vars = solve_with_dependent(
vars + extra_vars, constraints, vars, originals
Expand All @@ -88,6 +81,12 @@ def solve_constraints(
solutions = {}
free_vars = []
else:
# Collect a list of constraints for each type variable.
cmap: dict[TypeVarId, list[Constraint]] = {tv: [] for tv in vars + extra_vars}
for con in constraints:
if con.type_var in vars + extra_vars:
cmap[con.type_var].append(con)

solutions = {}
free_vars = []
for tv, cs in cmap.items():
Expand Down
27 changes: 23 additions & 4 deletions mypy/typeops.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,9 @@ class B(A): pass
if func.variables and supported_self_type(
self_param_type, allow_callable=allow_callable, allow_instances=not ignore_instances
):
from mypy.constraints import SUPERTYPE_OF, infer_constraints
from mypy.infer import infer_type_arguments
from mypy.solve import solve_constraints

if original_type is None:
# TODO: type check method override (see #7861).
Expand All @@ -364,9 +366,9 @@ class B(A): pass
self_vars = [tv for tv in func.variables if tv.id in self_ids]

# Solve for these type arguments using the actual class or instance type.
typeargs = infer_type_arguments(
self_vars, self_param_type, original_type, is_supertype=True
)
constraints = infer_constraints(self_param_type, original_type, SUPERTYPE_OF)
typeargs, free_vars = solve_constraints(self_vars, constraints, allow_polymorphic=True)

if (
is_classmethod
and any(isinstance(get_proper_type(t), UninhabitedType) for t in typeargs)
Expand All @@ -376,12 +378,29 @@ class B(A): pass
typeargs = infer_type_arguments(
self_vars, self_param_type, TypeType(original_type), is_supertype=True
)
free_vars = []

# Update the method signature with the solutions found.
# Technically, some constraints might be unsolvable, make them Never.
to_apply = [t if t is not None else UninhabitedType() for t in typeargs]

# Try to push in any type vars where the self type was the only location. e.g.:
# [T] () -> (T) -> T should return () -> [T] (T) -> T
outer_tvs = set()
for arg in func.arg_types[1:]:
outer_tvs |= set(get_all_type_vars(arg)) & set(free_vars)

inner_tvs = [v for v in free_vars if v not in outer_tvs]
result_type = get_proper_type(func.ret_type)
if isinstance(result_type, CallableType):
func = func.copy_modified(
ret_type=result_type.copy_modified(
variables=list(result_type.variables) + inner_tvs
)
)

func = expand_type(func, {tv.id: arg for tv, arg in zip(self_vars, to_apply)})
variables = [v for v in func.variables if v not in self_vars]
variables = [v for v in func.variables if v not in self_vars or v in outer_tvs]
else:
variables = func.variables

Expand Down
15 changes: 15 additions & 0 deletions test-data/unit/check-selftype.test
Original file line number Diff line number Diff line change
Expand Up @@ -2214,3 +2214,18 @@ class Test2:

reveal_type(Test2().method) # N: Revealed type is "def (foo: builtins.int, *, bar: builtins.str) -> builtins.bytes"
[builtins fixtures/tuple.pyi]

[case testCallableWithTypeVarInSelfType]
from typing import Generic, TypeVar, Callable

T = TypeVar("T")
V = TypeVar("V")

class X(Generic[T]):
def f(self: X[Callable[[V], None]]) -> Callable[[V], V]:
def inner_f(v: V) -> V:
return v

return inner_f

reveal_type(X[Callable[[T], None]]().f()) # N: Revealed type is "def [V] (V`4) -> V`4"
Loading