• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

OpShin / opshin / 865

pending completion
865

push

travis-ci-com

nielstron
Bump opshin version

1 of 1 new or added line in 1 file covered. (100.0%)

3728 of 4026 relevant lines covered (92.6%)

3.7 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

89.83
/opshin/type_inference.py
1
from copy import copy
4✔
2
import ast
4✔
3

4
from .typed_ast import *
4✔
5
from .util import PythonBuiltInTypes, CompilingNodeTransformer
4✔
6

7
# from frozendict import frozendict
8

9

10
"""
2✔
11
An aggressive type inference based on the work of Aycock [1].
12
It only allows a subset of legal python operations which
13
allow us to infer the type of all involved variables
14
statically.
15
Using this we can resolve overloaded functions when translating Python
16
into UPLC where there is no dynamic type checking.
17
Additionally, this conveniently implements an additional layer of
18
security into the Smart Contract by checking type correctness.
19

20

21
[1]: https://legacy.python.org/workshops/2000-01/proceedings/papers/aycock/aycock.html
22
"""
23

24

25
INITIAL_SCOPE = dict(
4✔
26
    {
27
        # class annotations
28
        "bytes": ByteStringType(),
29
        "int": IntegerType(),
30
        "bool": BoolType(),
31
        "str": StringType(),
32
        "Anything": AnyType(),
33
    }
34
)
35

36
INITIAL_SCOPE.update(
4✔
37
    {
38
        name.name: typ
39
        for name, typ in PythonBuiltInTypes.items()
40
        if isinstance(typ.typ, PolymorphicFunctionType)
41
    }
42
)
43

44

45
class AggressiveTypeInferencer(CompilingNodeTransformer):
4✔
46
    step = "Static Type Inference"
4✔
47

48
    # A stack of dictionaries for storing scoped knowledge of variable types
49
    scopes = [INITIAL_SCOPE]
4✔
50

51
    # Obtain the type of a variable name in the current scope
52
    def variable_type(self, name: str) -> Type:
4✔
53
        name = name
4✔
54
        for scope in reversed(self.scopes):
4✔
55
            if name in scope:
4✔
56
                return scope[name]
4✔
57
        raise TypeInferenceError(f"Variable {name} not initialized at access")
×
58

59
    def enter_scope(self):
4✔
60
        self.scopes.append({})
4✔
61

62
    def exit_scope(self):
4✔
63
        self.scopes.pop()
4✔
64

65
    def set_variable_type(self, name: str, typ: Type, force=False):
4✔
66
        if not force and name in self.scopes[-1] and self.scopes[-1][name] != typ:
4✔
67
            if self.scopes[-1][name] >= typ:
×
68
                # the specified type is broader, we pass on this
69
                return
×
70
            raise TypeInferenceError(
×
71
                f"Type {self.scopes[-1][name]} of variable {name} in local scope does not match inferred type {typ}"
72
            )
73
        self.scopes[-1][name] = typ
4✔
74

75
    def type_from_annotation(self, ann: expr):
4✔
76
        if isinstance(ann, Constant):
4✔
77
            if ann.value is None:
4✔
78
                return UnitType()
4✔
79
        if isinstance(ann, Name):
4✔
80
            if ann.id in ATOMIC_TYPES:
4✔
81
                return ATOMIC_TYPES[ann.id]
4✔
82
            v_t = self.variable_type(ann.id)
4✔
83
            if isinstance(v_t, ClassType):
4✔
84
                return v_t
4✔
85
            raise TypeInferenceError(
×
86
                f"Class name {ann.id} not initialized before annotating variable"
87
            )
88
        if isinstance(ann, Subscript):
4✔
89
            assert isinstance(
4✔
90
                ann.value, Name
91
            ), "Only Union, Dict and List are allowed as Generic types"
92
            if ann.value.id == "Union":
4✔
93
                assert isinstance(
4✔
94
                    ann.slice, Tuple
95
                ), "Union must combine multiple classes"
96
                ann_types = [self.type_from_annotation(e) for e in ann.slice.elts]
4✔
97
                assert all(
4✔
98
                    isinstance(e, RecordType) for e in ann_types
99
                ), "Union must combine multiple PlutusData classes"
100
                assert distinct(
4✔
101
                    [e.record.constructor for e in ann_types]
102
                ), "Union must combine PlutusData classes with unique constructors"
103
                return UnionType(FrozenFrozenList(ann_types))
4✔
104
            if ann.value.id == "List":
4✔
105
                ann_type = self.type_from_annotation(ann.slice)
4✔
106
                assert isinstance(
4✔
107
                    ann_type, ClassType
108
                ), "List must have a single type as parameter"
109
                assert not isinstance(
4✔
110
                    ann_type, TupleType
111
                ), "List can currently not hold tuples"
112
                return ListType(InstanceType(ann_type))
4✔
113
            if ann.value.id == "Dict":
4✔
114
                assert isinstance(ann.slice, Tuple), "Dict must combine two classes"
4✔
115
                assert len(ann.slice.elts) == 2, "Dict must combine two classes"
4✔
116
                ann_types = self.type_from_annotation(
4✔
117
                    ann.slice.elts[0]
118
                ), self.type_from_annotation(ann.slice.elts[1])
119
                assert all(
4✔
120
                    isinstance(e, ClassType) for e in ann_types
121
                ), "Dict must combine two classes"
122
                assert not any(
4✔
123
                    isinstance(e, TupleType) for e in ann_types
124
                ), "Dict can currently not hold tuples"
125
                return DictType(*(InstanceType(a) for a in ann_types))
4✔
126
            if ann.value.id == "Tuple":
×
127
                assert isinstance(
×
128
                    ann.slice, Tuple
129
                ), "Tuple must combine several classes"
130
                ann_types = [self.type_from_annotation(e) for e in ann.slice.elts]
×
131
                assert all(
×
132
                    isinstance(e, ClassType) for e in ann_types
133
                ), "Tuple must combine classes"
134
                return TupleType(FrozenFrozenList([InstanceType(a) for a in ann_types]))
×
135
            raise NotImplementedError(
×
136
                "Only Union, Dict and List are allowed as Generic types"
137
            )
138
        if ann is None:
4✔
139
            return AnyType()
4✔
140
        raise NotImplementedError(f"Annotation type {ann.__class__} is not supported")
×
141

142
    def visit_ClassDef(self, node: ClassDef) -> TypedClassDef:
4✔
143
        class_record = RecordReader.extract(node, self)
4✔
144
        typ = RecordType(class_record)
4✔
145
        self.set_variable_type(node.name, typ)
4✔
146
        typed_node = copy(node)
4✔
147
        typed_node.class_typ = typ
4✔
148
        return typed_node
4✔
149

150
    def visit_Constant(self, node: Constant) -> TypedConstant:
4✔
151
        tc = copy(node)
4✔
152
        assert type(node.value) not in [
4✔
153
            float,
154
            complex,
155
            type(...),
156
        ], "Float, complex numbers and ellipsis currently not supported"
157
        tc.typ = InstanceType(ATOMIC_TYPES[type(node.value).__name__])
4✔
158
        return tc
4✔
159

160
    def visit_Tuple(self, node: Tuple) -> TypedTuple:
4✔
161
        tt = copy(node)
4✔
162
        tt.elts = [self.visit(e) for e in node.elts]
4✔
163
        tt.typ = InstanceType(TupleType([e.typ for e in tt.elts]))
4✔
164
        return tt
4✔
165

166
    def visit_List(self, node: List) -> TypedList:
4✔
167
        tt = copy(node)
4✔
168
        tt.elts = [self.visit(e) for e in node.elts]
4✔
169
        l_typ = tt.elts[0].typ
4✔
170
        assert all(
4✔
171
            l_typ >= e.typ for e in tt.elts
172
        ), "All elements of a list must have the same type"
173
        tt.typ = InstanceType(ListType(l_typ))
4✔
174
        return tt
4✔
175

176
    def visit_Dict(self, node: Dict) -> TypedDict:
4✔
177
        tt = copy(node)
4✔
178
        tt.keys = [self.visit(k) for k in node.keys]
4✔
179
        tt.values = [self.visit(v) for v in node.values]
4✔
180
        k_typ = tt.keys[0].typ
4✔
181
        assert all(k_typ >= k.typ for k in tt.keys), "All keys must have the same type"
4✔
182
        v_typ = tt.values[0].typ
4✔
183
        assert all(
4✔
184
            v_typ >= v.typ for v in tt.values
185
        ), "All values must have the same type"
186
        tt.typ = InstanceType(DictType(k_typ, v_typ))
4✔
187
        return tt
4✔
188

189
    def visit_Assign(self, node: Assign) -> TypedAssign:
4✔
190
        typed_ass = copy(node)
4✔
191
        typed_ass.value: TypedExpression = self.visit(node.value)
4✔
192
        # Make sure to first set the type of each target name so we can load it when visiting it
193
        for t in node.targets:
4✔
194
            assert isinstance(
4✔
195
                t, Name
196
            ), "Can only assign to variable names, no type deconstruction"
197
            self.set_variable_type(t.id, typed_ass.value.typ)
4✔
198
        typed_ass.targets = [self.visit(t) for t in node.targets]
4✔
199
        return typed_ass
4✔
200

201
    def visit_AnnAssign(self, node: AnnAssign) -> TypedAnnAssign:
4✔
202
        typed_ass = copy(node)
4✔
203
        typed_ass.value: TypedExpression = self.visit(node.value)
4✔
204
        typed_ass.annotation = self.type_from_annotation(node.annotation)
4✔
205
        assert isinstance(
4✔
206
            node.target, Name
207
        ), "Can only assign to variable names, no type deconstruction"
208
        self.set_variable_type(
4✔
209
            node.target.id, InstanceType(typed_ass.annotation), force=True
210
        )
211
        typed_ass.target = self.visit(node.target)
4✔
212
        assert (
4✔
213
            typed_ass.value.typ >= InstanceType(typed_ass.annotation)
214
            or InstanceType(typed_ass.annotation) >= typed_ass.value.typ
215
        ), "Can only cast between related types"
216
        return typed_ass
4✔
217

218
    def visit_If(self, node: If) -> TypedIf:
4✔
219
        typed_if = copy(node)
4✔
220
        if (
4✔
221
            isinstance(typed_if.test, Call)
222
            and (typed_if.test.func, Name)
223
            and typed_if.test.func.id == "isinstance"
224
        ):
225
            tc = typed_if.test
4✔
226
            # special case for Union
227
            assert isinstance(
4✔
228
                tc.args[0], Name
229
            ), "Target 0 of an isinstance cast must be a variable name"
230
            assert isinstance(
4✔
231
                tc.args[1], Name
232
            ), "Target 1 of an isinstance cast must be a class name"
233
            target_class: RecordType = self.variable_type(tc.args[1].id)
4✔
234
            target_inst = self.visit(tc.args[0])
4✔
235
            target_inst_class = target_inst.typ
4✔
236
            assert isinstance(
4✔
237
                target_inst_class, InstanceType
238
            ), "Can only cast instances, not classes"
239
            assert isinstance(
4✔
240
                target_inst_class.typ, UnionType
241
            ), "Can only cast instances of Union types of PlutusData"
242
            assert isinstance(target_class, RecordType), "Can only cast to PlutusData"
4✔
243
            assert (
4✔
244
                target_class in target_inst_class.typ.typs
245
            ), f"Trying to cast an instance of Union type to non-instance of union type"
246
            typed_if.test = self.visit(
4✔
247
                Compare(
248
                    left=Attribute(tc.args[0], "CONSTR_ID"),
249
                    ops=[Eq()],
250
                    comparators=[Constant(target_class.record.constructor)],
251
                )
252
            )
253
            # for the time of this if branch set the variable type to the specialized type
254
            self.set_variable_type(
4✔
255
                tc.args[0].id, InstanceType(target_class), force=True
256
            )
257
            typed_if.body = [self.visit(s) for s in node.body]
4✔
258
            self.set_variable_type(tc.args[0].id, target_inst_class, force=True)
4✔
259
        else:
260
            typed_if.test = self.visit(node.test)
4✔
261
            assert (
4✔
262
                typed_if.test.typ == BoolInstanceType
263
            ), "Branching condition must have boolean type"
264
            typed_if.body = [self.visit(s) for s in node.body]
4✔
265
        typed_if.orelse = [self.visit(s) for s in node.orelse]
4✔
266
        return typed_if
4✔
267

268
    def visit_While(self, node: While) -> TypedWhile:
4✔
269
        typed_while = copy(node)
4✔
270
        typed_while.test = self.visit(node.test)
4✔
271
        assert (
4✔
272
            typed_while.test.typ == BoolInstanceType
273
        ), "Branching condition must have boolean type"
274
        typed_while.body = [self.visit(s) for s in node.body]
4✔
275
        typed_while.orelse = [self.visit(s) for s in node.orelse]
4✔
276
        return typed_while
4✔
277

278
    def visit_For(self, node: For) -> TypedFor:
4✔
279
        typed_for = copy(node)
4✔
280
        typed_for.iter = self.visit(node.iter)
4✔
281
        if isinstance(node.target, Tuple):
4✔
282
            raise NotImplementedError(
×
283
                "Type deconstruction in for loops is not supported yet"
284
            )
285
        vartyp = None
4✔
286
        itertyp = typed_for.iter.typ
4✔
287
        assert isinstance(
4✔
288
            itertyp, InstanceType
289
        ), "Can only iterate over instances, not classes"
290
        if isinstance(itertyp.typ, TupleType):
4✔
291
            assert itertyp.typ.typs, "Iterating over an empty tuple is not allowed"
×
292
            vartyp = itertyp.typ.typs[0]
×
293
            assert all(
×
294
                itertyp.typ.typs[0] == t for t in typed_for.iter.typ.typs
295
            ), "Iterating through a tuple requires the same type for each element"
296
        elif isinstance(itertyp.typ, ListType):
4✔
297
            vartyp = itertyp.typ.typ
4✔
298
        else:
299
            raise NotImplementedError(
×
300
                "Type inference for loops over non-list objects is not supported"
301
            )
302
        self.set_variable_type(node.target.id, vartyp)
4✔
303
        typed_for.target = self.visit(node.target)
4✔
304
        typed_for.body = [self.visit(s) for s in node.body]
4✔
305
        typed_for.orelse = [self.visit(s) for s in node.orelse]
4✔
306
        return typed_for
4✔
307

308
    def visit_Name(self, node: Name) -> TypedName:
4✔
309
        tn = copy(node)
4✔
310
        # Make sure that the rhs of an assign is evaluated first
311
        tn.typ = self.variable_type(node.id)
4✔
312
        return tn
4✔
313

314
    def visit_Compare(self, node: Compare) -> TypedCompare:
4✔
315
        typed_cmp = copy(node)
4✔
316
        typed_cmp.left = self.visit(node.left)
4✔
317
        typed_cmp.comparators = [self.visit(s) for s in node.comparators]
4✔
318
        typed_cmp.typ = BoolInstanceType
4✔
319
        # the actual required types are being taken care of in the implementation
320
        return typed_cmp
4✔
321

322
    def visit_arg(self, node: arg) -> typedarg:
4✔
323
        ta = copy(node)
4✔
324
        ta.typ = InstanceType(self.type_from_annotation(node.annotation))
4✔
325
        self.set_variable_type(ta.arg, ta.typ)
4✔
326
        return ta
4✔
327

328
    def visit_arguments(self, node: arguments) -> typedarguments:
4✔
329
        if node.kw_defaults or node.kwarg or node.kwonlyargs or node.defaults:
4✔
330
            raise NotImplementedError(
×
331
                "Keyword arguments and defaults not supported yet"
332
            )
333
        ta = copy(node)
4✔
334
        ta.args = [self.visit(a) for a in node.args]
4✔
335
        return ta
4✔
336

337
    def visit_FunctionDef(self, node: FunctionDef) -> TypedFunctionDef:
4✔
338
        tfd = copy(node)
4✔
339
        assert not node.decorator_list, "Functions may not have decorators"
4✔
340
        self.enter_scope()
4✔
341
        tfd.args = self.visit(node.args)
4✔
342
        functyp = FunctionType(
4✔
343
            [t.typ for t in tfd.args.args],
344
            InstanceType(self.type_from_annotation(tfd.returns)),
345
        )
346
        tfd.typ = InstanceType(functyp)
4✔
347
        # We need the function type inside for recursion
348
        self.set_variable_type(node.name, tfd.typ)
4✔
349
        tfd.body = [self.visit(s) for s in node.body]
4✔
350
        # Check that return type and annotated return type match
351
        if not isinstance(node.body[-1], Return):
4✔
352
            assert (
4✔
353
                functyp.rettyp >= NoneInstanceType
354
            ), f"Function '{node.name}' has no return statement but is supposed to return not-None value"
355
        else:
356
            assert (
4✔
357
                functyp.rettyp >= tfd.body[-1].typ
358
            ), f"Function '{node.name}' annotated return type does not match actual return type"
359
        self.exit_scope()
4✔
360
        # We need the function type outside for usage
361
        self.set_variable_type(node.name, tfd.typ)
4✔
362
        return tfd
4✔
363

364
    def visit_Module(self, node: Module) -> TypedModule:
4✔
365
        self.enter_scope()
4✔
366
        tm = copy(node)
4✔
367
        tm.body = [self.visit(n) for n in node.body]
4✔
368
        self.exit_scope()
4✔
369
        return tm
4✔
370

371
    def visit_Expr(self, node: Expr) -> TypedExpr:
4✔
372
        tn = copy(node)
4✔
373
        tn.value = self.visit(node.value)
4✔
374
        return tn
4✔
375

376
    def visit_BinOp(self, node: BinOp) -> TypedBinOp:
4✔
377
        tb = copy(node)
4✔
378
        tb.left = self.visit(node.left)
4✔
379
        tb.right = self.visit(node.right)
4✔
380
        # TODO the outcome of the operation may depend on the input types
381
        assert (
4✔
382
            tb.left.typ == tb.right.typ
383
        ), "Inputs to a binary operation need to have the same type"
384
        tb.typ = tb.left.typ
4✔
385
        return tb
4✔
386

387
    def visit_BoolOp(self, node: BoolOp) -> TypedBoolOp:
4✔
388
        tt = copy(node)
4✔
389
        tt.values = [self.visit(e) for e in node.values]
4✔
390
        tt.typ = BoolInstanceType
4✔
391
        assert all(
4✔
392
            BoolInstanceType >= e.typ for e in tt.values
393
        ), "All values compared must be bools"
394
        return tt
4✔
395

396
    def visit_UnaryOp(self, node: UnaryOp) -> TypedUnaryOp:
4✔
397
        tu = copy(node)
4✔
398
        tu.operand = self.visit(node.operand)
4✔
399
        tu.typ = tu.operand.typ
4✔
400
        return tu
4✔
401

402
    def visit_Subscript(self, node: Subscript) -> TypedSubscript:
4✔
403
        ts = copy(node)
4✔
404
        # special case: Subscript of Union / Dict / List and atomic types
405
        if isinstance(ts.value, Name) and ts.value.id in [
4✔
406
            "Union",
407
            "Dict",
408
            "List",
409
        ]:
410
            ts.value = ts.typ = self.type_from_annotation(ts)
4✔
411
            return ts
4✔
412

413
        ts.value = self.visit(node.value)
4✔
414
        assert isinstance(ts.value.typ, InstanceType), "Can only subscript instances"
4✔
415
        if isinstance(ts.value.typ.typ, TupleType):
4✔
416
            assert (
4✔
417
                ts.value.typ.typ.typs
418
            ), "Accessing elements from the empty tuple is not allowed"
419
            if all(ts.value.typ.typ.typs[0] == t for t in ts.value.typ.typ.typs):
4✔
420
                ts.typ = ts.value.typ.typ.typs[0]
4✔
421
            elif isinstance(ts.slice, Constant) and isinstance(ts.slice.value, int):
4✔
422
                ts.typ = ts.value.typ.typ.typs[ts.slice.value]
4✔
423
            else:
424
                raise TypeInferenceError(
×
425
                    f"Could not infer type of subscript of typ {ts.value.typ.typ.__class__}"
426
                )
427
        elif isinstance(ts.value.typ.typ, PairType):
4✔
428
            if isinstance(ts.slice, Constant) and isinstance(ts.slice.value, int):
4✔
429
                ts.typ = (
4✔
430
                    ts.value.typ.typ.l_typ
431
                    if ts.slice.value == 0
432
                    else ts.value.typ.typ.r_typ
433
                )
434
            else:
435
                raise TypeInferenceError(
×
436
                    f"Could not infer type of subscript of typ {ts.value.typ.typ.__class__}"
437
                )
438
        elif isinstance(ts.value.typ.typ, ListType):
4✔
439
            ts.typ = ts.value.typ.typ.typ
4✔
440
            ts.slice = self.visit(node.slice)
4✔
441
            assert ts.slice.typ == IntegerInstanceType, "List indices must be integers"
4✔
442
        elif isinstance(ts.value.typ.typ, ByteStringType):
4✔
443
            if not isinstance(ts.slice, Slice):
4✔
444
                ts.typ = IntegerInstanceType
4✔
445
                ts.slice = self.visit(node.slice)
4✔
446
                assert (
4✔
447
                    ts.slice.typ == IntegerInstanceType
448
                ), "bytes indices must be integers"
449
            elif isinstance(ts.slice, Slice):
4✔
450
                ts.typ = ByteStringInstanceType
4✔
451
                if ts.slice.lower is None:
4✔
452
                    ts.slice.lower = Constant(0)
×
453
                ts.slice.lower = self.visit(node.slice.lower)
4✔
454
                assert (
4✔
455
                    ts.slice.lower.typ == IntegerInstanceType
456
                ), "lower slice indices for bytes must be integers"
457
                if ts.slice.upper is None:
4✔
458
                    ts.slice.upper = Call(
×
459
                        func=Name(id="len", ctx=Load()), args=[ts.value], keywords=[]
460
                    )
461
                ts.slice.upper = self.visit(node.slice.upper)
4✔
462
                assert (
4✔
463
                    ts.slice.upper.typ == IntegerInstanceType
464
                ), "upper slice indices for bytes must be integers"
465
            else:
466
                raise TypeInferenceError(
×
467
                    f"Could not infer type of subscript of typ {ts.value.typ.__class__}"
468
                )
469
        elif isinstance(ts.value.typ.typ, DictType):
4✔
470
            # TODO could be implemented with potentially just erroring. It might be desired to avoid this though.
471
            if not isinstance(ts.slice, Slice):
4✔
472
                ts.slice = self.visit(node.slice)
4✔
473
                assert (
4✔
474
                    ts.slice.typ == ts.value.typ.typ.key_typ
475
                ), f"Dict subscript must have dict key type {ts.value.typ.typ.key_typ} but has type {ts.slice.typ}"
476
                ts.typ = ts.value.typ.typ.value_typ
4✔
477
            else:
478
                raise TypeInferenceError(
×
479
                    f"Could not infer type of subscript of dict with a slice."
480
                )
481
        else:
482
            raise TypeInferenceError(
×
483
                f"Could not infer type of subscript of typ {ts.value.typ.__class__}"
484
            )
485
        return ts
4✔
486

487
    def visit_Call(self, node: Call) -> TypedCall:
4✔
488
        assert not node.keywords, "Keyword arguments are not supported yet"
4✔
489
        tc = copy(node)
4✔
490
        tc.args = [self.visit(a) for a in node.args]
4✔
491
        tc.func = self.visit(node.func)
4✔
492
        # might be a cast
493
        if isinstance(tc.func.typ, ClassType):
4✔
494
            tc.func.typ = tc.func.typ.constr_type()
4✔
495
        # type might only turn out after the initialization (note the constr could be polymorphic)
496
        if isinstance(tc.func.typ, InstanceType) and isinstance(
4✔
497
            tc.func.typ.typ, PolymorphicFunctionType
498
        ):
499
            tc.func.typ = PolymorphicFunctionInstanceType(
4✔
500
                tc.func.typ.typ.polymorphic_function.type_from_args(
501
                    [a.typ for a in tc.args]
502
                ),
503
                tc.func.typ.typ.polymorphic_function,
504
            )
505
        if isinstance(tc.func.typ, InstanceType) and isinstance(
4✔
506
            tc.func.typ.typ, FunctionType
507
        ):
508
            functyp = tc.func.typ.typ
4✔
509
            assert len(tc.args) == len(
4✔
510
                functyp.argtyps
511
            ), f"Signature of function does not match number of arguments. Expected {len(functyp.argtyps)} arguments with these types: {functyp.argtyps}"
512
            # all arguments need to be supertypes of the given type
513
            assert all(
4✔
514
                ap >= a.typ for a, ap in zip(tc.args, functyp.argtyps)
515
            ), f"Signature of function does not match arguments. Expected {len(functyp.argtyps)} arguments with these types: {functyp.argtyps}"
516
            tc.typ = functyp.rettyp
4✔
517
            return tc
4✔
518
        raise TypeInferenceError("Could not infer type of call")
×
519

520
    def visit_Pass(self, node: Pass) -> TypedPass:
4✔
521
        tp = copy(node)
4✔
522
        return tp
4✔
523

524
    def visit_Return(self, node: Return) -> TypedReturn:
4✔
525
        tp = copy(node)
4✔
526
        tp.value = self.visit(node.value)
4✔
527
        tp.typ = tp.value.typ
4✔
528
        return tp
4✔
529

530
    def visit_Attribute(self, node: Attribute) -> TypedAttribute:
4✔
531
        tp = copy(node)
4✔
532
        tp.value = self.visit(node.value)
4✔
533
        owner = tp.value.typ
4✔
534
        # accesses to field
535
        tp.typ = owner.attribute_type(node.attr)
4✔
536
        return tp
4✔
537

538
    def visit_Assert(self, node: Assert) -> TypedAssert:
4✔
539
        ta = copy(node)
4✔
540
        ta.test = self.visit(node.test)
4✔
541
        assert (
4✔
542
            ta.test.typ == BoolInstanceType
543
        ), "Assertions must result in a boolean type"
544
        if ta.msg is not None:
4✔
545
            ta.msg = self.visit(node.msg)
4✔
546
            assert (
4✔
547
                ta.msg.typ == StringInstanceType
548
            ), "Assertions must has a string message (or None)"
549
        return ta
4✔
550

551
    def visit_RawPlutoExpr(self, node: RawPlutoExpr) -> RawPlutoExpr:
4✔
552
        assert node.typ is not None, "Raw Pluto Expression is missing type annotation"
4✔
553
        return node
4✔
554

555
    def visit_IfExp(self, node: IfExp) -> TypedIfExp:
4✔
556
        node_cp = copy(node)
4✔
557
        node_cp.test = self.visit(node.test)
4✔
558
        assert node_cp.test.typ == BoolInstanceType, "Comparison must have type boolean"
4✔
559
        node_cp.body = self.visit(node.body)
4✔
560
        node_cp.orelse = self.visit(node.orelse)
4✔
561
        if node_cp.body.typ >= node_cp.orelse.typ:
4✔
562
            node_cp.typ = node_cp.body.typ
4✔
563
        elif node_cp.orelse.typ >= node_cp.body.typ:
×
564
            node_cp.typ = node_cp.orelse.typ
×
565
        else:
566
            raise TypeInferenceError(
×
567
                "Branches of if-expression must return compatible types"
568
            )
569
        return node_cp
4✔
570

571
    def visit_comprehension(self, g: comprehension) -> typedcomprehension:
4✔
572
        new_g = copy(g)
4✔
573
        if isinstance(g.target, Tuple):
4✔
574
            raise NotImplementedError(
×
575
                "Type deconstruction in for loops is not supported yet"
576
            )
577
        new_g.iter = self.visit(g.iter)
4✔
578
        itertyp = new_g.iter.typ
4✔
579
        assert isinstance(
4✔
580
            itertyp, InstanceType
581
        ), "Can only iterate over instances, not classes"
582
        if isinstance(itertyp.typ, TupleType):
4✔
583
            assert itertyp.typ.typs, "Iterating over an empty tuple is not allowed"
×
584
            vartyp = itertyp.typ.typs[0]
×
585
            assert all(
×
586
                itertyp.typ.typs[0] == t for t in new_g.iter.typ.typs
587
            ), "Iterating through a tuple requires the same type for each element"
588
        elif isinstance(itertyp.typ, ListType):
4✔
589
            vartyp = itertyp.typ.typ
4✔
590
        else:
591
            raise NotImplementedError(
×
592
                "Type inference for loops over non-list objects is not supported"
593
            )
594
        self.set_variable_type(g.target.id, vartyp)
4✔
595
        new_g.target = self.visit(g.target)
4✔
596
        new_g.ifs = [self.visit(i) for i in g.ifs]
4✔
597
        return new_g
4✔
598

599
    def visit_ListComp(self, node: ListComp) -> TypedListComp:
4✔
600
        typed_listcomp = copy(node)
4✔
601
        # inside the comprehension is a seperate scope
602
        self.enter_scope()
4✔
603
        # first evaluate generators for assigned variables
604
        typed_listcomp.generators = [self.visit(s) for s in node.generators]
4✔
605
        # then evaluate elements
606
        typed_listcomp.elt = self.visit(node.elt)
4✔
607
        self.exit_scope()
4✔
608
        typed_listcomp.typ = InstanceType(ListType(typed_listcomp.elt.typ))
4✔
609
        return typed_listcomp
4✔
610

611
    def generic_visit(self, node: AST) -> TypedAST:
4✔
612
        raise NotImplementedError(
×
613
            f"Cannot infer type of non-implemented node {node.__class__}"
614
        )
615

616

617
class RecordReader(NodeVisitor):
4✔
618
    name: str
4✔
619
    constructor: int
4✔
620
    attributes: typing.List[typing.Tuple[str, Type]]
4✔
621
    _type_inferencer: AggressiveTypeInferencer
4✔
622

623
    def __init__(self, type_inferencer: AggressiveTypeInferencer):
4✔
624
        self.constructor = 0
4✔
625
        self.attributes = []
4✔
626
        self._type_inferencer = type_inferencer
4✔
627

628
    @classmethod
4✔
629
    def extract(cls, c: ClassDef, type_inferencer: AggressiveTypeInferencer) -> Record:
4✔
630
        f = cls(type_inferencer)
4✔
631
        f.visit(c)
4✔
632
        return Record(f.name, f.constructor, FrozenFrozenList(f.attributes))
4✔
633

634
    def visit_AnnAssign(self, node: AnnAssign) -> None:
4✔
635
        assert isinstance(
4✔
636
            node.target, Name
637
        ), "Record elements must have named attributes"
638
        typ = self._type_inferencer.type_from_annotation(node.annotation)
4✔
639
        if node.target.id != "CONSTR_ID":
4✔
640
            assert (
4✔
641
                node.value is None
642
            ), f"PlutusData attribute {node.target.id} may not have a default value"
643
            assert not isinstance(
4✔
644
                typ, TupleType
645
            ), "Records can currently not hold tuples"
646
            self.attributes.append(
4✔
647
                (
648
                    node.target.id,
649
                    InstanceType(typ),
650
                )
651
            )
652
            return
4✔
653
        assert typ == IntegerType, "CONSTR_ID must be assigned an integer"
×
654
        assert isinstance(
×
655
            node.value, Constant
656
        ), "CONSTR_ID must be assigned a constant integer"
657
        assert isinstance(
×
658
            node.value.value, int
659
        ), "CONSTR_ID must be assigned an integer"
660
        self.constructor = node.value.value
×
661

662
    def visit_ClassDef(self, node: ClassDef) -> None:
4✔
663
        self.name = node.name
4✔
664
        for s in node.body:
4✔
665
            self.visit(s)
4✔
666

667
    def visit_Pass(self, node: Pass) -> None:
4✔
668
        pass
4✔
669

670
    def visit_Assign(self, node: Assign) -> None:
4✔
671
        assert len(node.targets) == 1, "Record elements must be assigned one by one"
4✔
672
        target = node.targets[0]
4✔
673
        assert isinstance(target, Name), "Record elements must have named attributes"
4✔
674
        assert (
4✔
675
            target.id == "CONSTR_ID"
676
        ), "Type annotations may only be omitted for CONSTR_ID"
677
        assert isinstance(
4✔
678
            node.value, Constant
679
        ), "CONSTR_ID must be assigned a constant integer"
680
        assert isinstance(
4✔
681
            node.value.value, int
682
        ), "CONSTR_ID must be assigned an integer"
683
        self.constructor = node.value.value
4✔
684

685
    def visit_Expr(self, node: Expr) -> None:
4✔
686
        assert isinstance(
4✔
687
            node.value, Constant
688
        ), "Only comments are allowed inside classes"
689
        return None
4✔
690

691
    def generic_visit(self, node: AST) -> None:
4✔
692
        raise NotImplementedError(f"Can not compile {ast.dump(node)} inside of a class")
×
693

694

695
def typed_ast(ast: AST):
4✔
696
    return AggressiveTypeInferencer().visit(ast)
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc