Skip to content

Commit 712411e

Browse files
committed
Get simple int -> int function to have its constraints solved
1 parent 0b2e337 commit 712411e

File tree

3 files changed

+133
-7
lines changed

3 files changed

+133
-7
lines changed

typelanguage/constraintgen.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,13 @@ def __init__(self, subtype, supertype):
1010
self.subtype = subtype
1111
self.supertype = supertype
1212

13+
def __str__(self):
14+
return '%s <: %s' % (self.subtype, self.supertype)
15+
16+
def substitute(self, substitution):
17+
return Constraint(subtype = self.subtype.substitute(substitution),
18+
supertype = self.supertype.substitute(substitution))
19+
1320
class ConstrainedType(object):
1421
def __init__(self, type=None, constraints=None):
1522
self.type = type
@@ -21,10 +28,15 @@ def __init__(self, env=None, constraints=None, return_type=None):
2128
self.constraints = constraints or []
2229
self.return_type = return_type
2330

31+
def substitute(self, substitution):
32+
return ConstrainedEnv(env = dict([(key, ty.substitute(substitution)) for key, ty in self.env.items()]),
33+
constraints = [constraint.substitute(substitution) for constraint in self.constraints],
34+
return_type = None if self.return_type is None else self.return_type.substitute(substitution))
35+
2436
def pretty(self):
2537
return ("Env:\n\t%(bindings)s\n\nConstraints:\n\t%(constraints)s\n\nResult:\n\t%(result)s" %
2638
dict(bindings = '\n\t'.join(['%s: %s' % (var, ty) for var, ty in self.env.items()]),
27-
constraints = '\n\t'.join(['%s <: %s' % (c.subtype, c.supertype) for c in self.constraints]),
39+
constraints = '\n\t'.join([str(c) for c in self.constraints]),
2840
result = self.return_type))
2941

3042
def constraints(env, pyast):
@@ -113,7 +125,7 @@ def constraints_expr(env, expr):
113125
raise Exception('Variable not found in environment: %s' % expr.id)
114126

115127
elif isinstance(expr, ast.Num):
116-
return ConstrainedType(type=types.AtomicType('int'))
128+
return ConstrainedType(type=types.AtomicType('num'))
117129

118130
elif isinstance(expr, ast.BinOp):
119131
left = constraints_expr(env, expr.left)

typelanguage/constraintsolve.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
import sys
2+
import ast
3+
import logging
4+
import copy
5+
6+
from typelanguage import constraintgen
7+
from typelanguage.types import *
8+
9+
logger = logging.getLogger(__name__)
10+
11+
class Refutation(object):
12+
def __init__(self, reason):
13+
self.reason = reason
14+
15+
def __bool__(self):
16+
return False
17+
18+
def __str__(self):
19+
return 'Refutation(reason="%s")' % self.reason
20+
21+
def reconcile(constraint):
22+
'''
23+
Returns an assignment of type variable names to
24+
types that makes this constraint satisfiable, or a Refutation
25+
'''
26+
27+
if isinstance(constraint.subtype, AtomicType):
28+
if isinstance(constraint.supertype, AtomicType):
29+
if constraint.subtype.name == constraint.supertype.name:
30+
return {}
31+
else:
32+
return Refutation('Cannot reconcile different atomic types: %s' % constraint)
33+
elif isinstance(constraint.supertype, TypeVariable):
34+
return {constraint.supertype.name: contraint.subtype}
35+
else:
36+
return Refutation('Cannot reconcile atomic type with non-atomic type: %s' % constraint)
37+
38+
elif isinstance(constraint.supertype, AtomicType):
39+
if isinstance(constraint.subtype, AtomicType):
40+
if constraint.subtype.name == constraint.supertype.name:
41+
return {}
42+
else:
43+
return Refutation('Cannot reconcile different atomic types: %s' % constraint)
44+
elif isinstance(constraint.subtype, TypeVariable):
45+
return {constraint.subtype.name: constraint.supertype}
46+
else:
47+
return Refutation('Cannot reconcile non-atomic type with atomic type: %s' % constraint)
48+
else:
49+
raise NotImplementedError('Reconciliation of %s' % constraint)
50+
51+
def solve(constraints):
52+
remaining_constraints = copy.copy(constraints)
53+
substitution = {}
54+
55+
while len(remaining_constraints) > 0:
56+
constraint = remaining_constraints.pop()
57+
additional_substitution = reconcile(constraint)
58+
59+
logger.info('reconcile(%s) ==> %s', constraint, additional_substitution)
60+
61+
62+
if isinstance(additional_substitution, Refutation):
63+
return additional_substitution
64+
else:
65+
substitution.update(additional_substitution)
66+
67+
remaining_constraints = [c.substitute(additional_substitution) for c in remaining_constraints]
68+
69+
return substitution
70+
71+
if __name__ == '__main__':
72+
logging.basicConfig()
73+
logging.getLogger('').setLevel(logging.DEBUG)
74+
75+
with open(sys.argv[1]) as fh:
76+
proggy = ast.parse(fh.read())
77+
78+
cs = constraintgen.constraints({}, proggy)
79+
print cs.pretty()
80+
81+
substitution = solve(cs.constraints)
82+
print substitution
83+
84+
print cs.substitute(substitution).pretty()

typelanguage/types.py

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,29 @@
11
from collections import namedtuple, defaultdict
22

33
class Type(object): pass
4-
class AtomicType(Type, namedtuple('AtomicType', ['name'])): pass
5-
class TypeVariable(Type, namedtuple('TypeVariable', ['name'])): pass
4+
5+
class AtomicType(Type):
6+
def __init__(self, name):
7+
self.name = name
8+
9+
def substitute(self, substitution):
10+
return self
11+
12+
def __str__(self):
13+
return self.name
14+
15+
class TypeVariable(Type):
16+
def __init__(self, name):
17+
self.name = name
18+
19+
def substitute(self, substitution):
20+
if self.name in substitution:
21+
return substitution[self.name]
22+
else:
23+
return self
24+
25+
def __str__(self):
26+
return '?%s' % self.name
627

728
class FunctionType(Type):
829
def __init__(self, arg_types, return_type, vararg_type=None, kwonly_arg_types=None, kwarg_type=None):
@@ -12,6 +33,13 @@ def __init__(self, arg_types, return_type, vararg_type=None, kwonly_arg_types=No
1233
self.kwarg_type = kwarg_type
1334
self.kwonly_arg_types = kwonly_arg_types
1435

36+
def substitute(self, substitution):
37+
return FunctionType(arg_types = [ty.substitute(substitution) for ty in self.arg_types],
38+
return_type = self.return_type.substitute(substitution),
39+
vararg_type = None if self.vararg_type is None else self.vararg_type.substitute(substitution),
40+
kwonly_arg_types = None if self.kwonly_arg_types is None else [ty.substitute(substitution) for ty in self.kwonly_arg_types],
41+
kwarg_type = None if self.kwarg_type is None else self.kwarg_type.substitute(substitution))
42+
1543
def __str__(self):
1644
comma_separated_bits = [unicode(v) for v in self.arg_types]
1745

@@ -31,15 +59,17 @@ def __init__(self, fn, args):
3159
self.fn = fn
3260
self.args = args
3361

62+
def substitute(self, substitution):
63+
return TypeApplication(self.fn.substitute(substitution), [ty.substitute(substitution) for ty in self.args])
64+
3465
class UnionType(Type):
3566
def __init__(self, types):
3667
self.types = types
3768

38-
3969
used_vars = defaultdict(lambda: 0)
4070
def fresh(prefix=None):
4171
global used_vars
42-
prefix = prefix or '?X'
72+
prefix = prefix or 'X'
4373
used_vars[prefix] = used_vars[prefix] + 1
44-
return prefix + str(used_vars[prefix])
74+
return TypeVariable(prefix + str(used_vars[prefix]))
4575

0 commit comments

Comments
 (0)