• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

OpShin / opshin / 695

pending completion
695

Pull #126

travis-ci-com

web-flow
Merge 0f2304891 into 137444d4e
Pull Request #126: Fix upcasting by using annotated assignment

15 of 15 new or added lines in 3 files covered. (100.0%)

3276 of 3548 relevant lines covered (92.33%)

3.69 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

89.03
/opshin/type_inference.py
1
from copy import copy
4✔
2
import ast
4✔
3

4
from .typed_ast import *
4✔
5
from .util import PythonBuiltInTypes, CompilingNodeTransformer
4✔
6

7
# from frozendict import frozendict
8

9

10
"""
2✔
11
An aggressive type inference based on the work of Aycock [1].
12
It only allows a subset of legal python operations which
13
allow us to infer the type of all involved variables
14
statically.
15
Using this we can resolve overloaded functions when translating Python
16
into UPLC where there is no dynamic type checking.
17
Additionally, this conveniently implements an additional layer of
18
security into the Smart Contract by checking type correctness.
19

20

21
[1]: https://legacy.python.org/workshops/2000-01/proceedings/papers/aycock/aycock.html
22
"""
23

24

25
INITIAL_SCOPE = dict(
4✔
26
    {
27
        # class annotations
28
        "bytes": ByteStringType(),
29
        "int": IntegerType(),
30
        "bool": BoolType(),
31
        "str": StringType(),
32
        "Anything": AnyType(),
33
    }
34
)
35

36
INITIAL_SCOPE.update(
4✔
37
    {
38
        name.name: typ
39
        for name, typ in PythonBuiltInTypes.items()
40
        if isinstance(typ.typ, PolymorphicFunctionType)
41
    }
42
)
43

44

45
class AggressiveTypeInferencer(CompilingNodeTransformer):
4✔
46
    step = "Static Type Inference"
4✔
47

48
    # A stack of dictionaries for storing scoped knowledge of variable types
49
    scopes = [INITIAL_SCOPE]
4✔
50

51
    # Obtain the type of a variable name in the current scope
52
    def variable_type(self, name: str) -> Type:
4✔
53
        name = name
4✔
54
        for scope in reversed(self.scopes):
4✔
55
            if name in scope:
4✔
56
                return scope[name]
4✔
57
        raise TypeInferenceError(f"Variable {name} not initialized at access")
×
58

59
    def enter_scope(self):
4✔
60
        self.scopes.append({})
4✔
61

62
    def exit_scope(self):
4✔
63
        self.scopes.pop()
4✔
64

65
    def set_variable_type(self, name: str, typ: Type, force=False):
4✔
66
        if not force and name in self.scopes[-1] and self.scopes[-1][name] != typ:
4✔
67
            if self.scopes[-1][name] >= typ:
×
68
                # the specified type is broader, we pass on this
69
                return
×
70
            raise TypeInferenceError(
×
71
                f"Type {self.scopes[-1][name]} of variable {name} in local scope does not match inferred type {typ}"
72
            )
73
        self.scopes[-1][name] = typ
4✔
74

75
    def type_from_annotation(self, ann: expr):
4✔
76
        if isinstance(ann, Constant):
4✔
77
            if ann.value is None:
4✔
78
                return UnitType()
4✔
79
        if isinstance(ann, Name):
4✔
80
            if ann.id in ATOMIC_TYPES:
4✔
81
                return ATOMIC_TYPES[ann.id]
4✔
82
            v_t = self.variable_type(ann.id)
4✔
83
            if isinstance(v_t, ClassType):
4✔
84
                return v_t
4✔
85
            raise TypeInferenceError(
×
86
                f"Class name {ann.id} not initialized before annotating variable"
87
            )
88
        if isinstance(ann, Subscript):
4✔
89
            assert isinstance(
4✔
90
                ann.value, Name
91
            ), "Only Union, Dict and List are allowed as Generic types"
92
            if ann.value.id == "Union":
4✔
93
                assert isinstance(
4✔
94
                    ann.slice, Tuple
95
                ), "Union must combine multiple classes"
96
                ann_types = [self.type_from_annotation(e) for e in ann.slice.elts]
4✔
97
                assert all(
4✔
98
                    isinstance(e, RecordType) for e in ann_types
99
                ), "Union must combine multiple PlutusData classes"
100
                assert distinct(
4✔
101
                    [e.record.constructor for e in ann_types]
102
                ), "Union must combine PlutusData classes with unique constructors"
103
                return UnionType(FrozenFrozenList(ann_types))
4✔
104
            if ann.value.id == "List":
4✔
105
                ann_type = self.type_from_annotation(ann.slice)
4✔
106
                assert isinstance(
4✔
107
                    ann_type, ClassType
108
                ), "List must have a single type as parameter"
109
                assert not isinstance(
4✔
110
                    ann_type, TupleType
111
                ), "List can currently not hold tuples"
112
                return ListType(InstanceType(ann_type))
4✔
113
            if ann.value.id == "Dict":
4✔
114
                assert isinstance(ann.slice, Tuple), "Dict must combine two classes"
4✔
115
                assert len(ann.slice.elts) == 2, "Dict must combine two classes"
4✔
116
                ann_types = self.type_from_annotation(
4✔
117
                    ann.slice.elts[0]
118
                ), self.type_from_annotation(ann.slice.elts[1])
119
                assert all(
4✔
120
                    isinstance(e, ClassType) for e in ann_types
121
                ), "Dict must combine two classes"
122
                assert not any(
4✔
123
                    isinstance(e, TupleType) for e in ann_types
124
                ), "Dict can currently not hold tuples"
125
                return DictType(*(InstanceType(a) for a in ann_types))
4✔
126
            if ann.value.id == "Tuple":
×
127
                assert isinstance(
×
128
                    ann.slice, Tuple
129
                ), "Tuple must combine several classes"
130
                ann_types = [self.type_from_annotation(e) for e in ann.slice.elts]
×
131
                assert all(
×
132
                    isinstance(e, ClassType) for e in ann_types
133
                ), "Tuple must combine classes"
134
                return TupleType(FrozenFrozenList([InstanceType(a) for a in ann_types]))
×
135
            raise NotImplementedError(
×
136
                "Only Union, Dict and List are allowed as Generic types"
137
            )
138
        if ann is None:
4✔
139
            return AnyType()
4✔
140
        raise NotImplementedError(f"Annotation type {ann.__class__} is not supported")
×
141

142
    def visit_ClassDef(self, node: ClassDef) -> TypedClassDef:
4✔
143
        class_record = RecordReader.extract(node, self)
4✔
144
        typ = RecordType(class_record)
4✔
145
        self.set_variable_type(node.name, typ)
4✔
146
        typed_node = copy(node)
4✔
147
        typed_node.class_typ = typ
4✔
148
        return typed_node
4✔
149

150
    def visit_Constant(self, node: Constant) -> TypedConstant:
4✔
151
        tc = copy(node)
4✔
152
        assert type(node.value) not in [
4✔
153
            float,
154
            complex,
155
            type(...),
156
        ], "Float, complex numbers and ellipsis currently not supported"
157
        if tc.value is None:
4✔
158
            tc.typ = NoneInstanceType
4✔
159
        else:
160
            tc.typ = InstanceType(ATOMIC_TYPES[type(node.value).__name__])
4✔
161
        return tc
4✔
162

163
    def visit_Tuple(self, node: Tuple) -> TypedTuple:
4✔
164
        tt = copy(node)
4✔
165
        tt.elts = [self.visit(e) for e in node.elts]
4✔
166
        tt.typ = InstanceType(TupleType([e.typ for e in tt.elts]))
4✔
167
        return tt
4✔
168

169
    def visit_List(self, node: List) -> TypedList:
4✔
170
        tt = copy(node)
4✔
171
        tt.elts = [self.visit(e) for e in node.elts]
4✔
172
        l_typ = tt.elts[0].typ
4✔
173
        assert all(
4✔
174
            l_typ >= e.typ for e in tt.elts
175
        ), "All elements of a list must have the same type"
176
        tt.typ = InstanceType(ListType(l_typ))
4✔
177
        return tt
4✔
178

179
    def visit_Dict(self, node: Dict) -> TypedDict:
4✔
180
        tt = copy(node)
4✔
181
        tt.keys = [self.visit(k) for k in node.keys]
4✔
182
        tt.values = [self.visit(v) for v in node.values]
4✔
183
        k_typ = tt.keys[0].typ
4✔
184
        assert all(k_typ >= k.typ for k in tt.keys), "All keys must have the same type"
4✔
185
        v_typ = tt.values[0].typ
4✔
186
        assert all(
4✔
187
            v_typ >= v.typ for v in tt.values
188
        ), "All values must have the same type"
189
        tt.typ = InstanceType(DictType(k_typ, v_typ))
4✔
190
        return tt
4✔
191

192
    def visit_Assign(self, node: Assign) -> TypedAssign:
4✔
193
        typed_ass = copy(node)
4✔
194
        typed_ass.value: TypedExpression = self.visit(node.value)
4✔
195
        # Make sure to first set the type of each target name so we can load it when visiting it
196
        for t in node.targets:
4✔
197
            assert isinstance(
4✔
198
                t, Name
199
            ), "Can only assign to variable names, no type deconstruction"
200
            self.set_variable_type(t.id, typed_ass.value.typ)
4✔
201
        typed_ass.targets = [self.visit(t) for t in node.targets]
4✔
202
        return typed_ass
4✔
203

204
    def visit_AnnAssign(self, node: AnnAssign) -> TypedAnnAssign:
4✔
205
        typed_ass = copy(node)
4✔
206
        typed_ass.value: TypedExpression = self.visit(node.value)
4✔
207
        typed_ass.annotation = self.type_from_annotation(node.annotation)
4✔
208
        assert isinstance(
4✔
209
            node.target, Name
210
        ), "Can only assign to variable names, no type deconstruction"
211
        self.set_variable_type(
4✔
212
            node.target.id, InstanceType(typed_ass.annotation), force=True
213
        )
214
        typed_ass.target = self.visit(node.target)
4✔
215
        assert (
4✔
216
            typed_ass.value.typ >= InstanceType(typed_ass.annotation)
217
            or InstanceType(typed_ass.annotation) >= typed_ass.value.typ
218
        ), "Can only cast between related types"
219
        return typed_ass
4✔
220

221
    def visit_If(self, node: If) -> TypedIf:
4✔
222
        typed_if = copy(node)
4✔
223
        if (
4✔
224
            isinstance(typed_if.test, Call)
225
            and (typed_if.test.func, Name)
226
            and typed_if.test.func.id == "isinstance"
227
        ):
228
            tc = typed_if.test
4✔
229
            # special case for Union
230
            assert isinstance(
4✔
231
                tc.args[0], Name
232
            ), "Target 0 of an isinstance cast must be a variable name"
233
            assert isinstance(
4✔
234
                tc.args[1], Name
235
            ), "Target 1 of an isinstance cast must be a class name"
236
            target_class: RecordType = self.variable_type(tc.args[1].id)
4✔
237
            target_inst = self.visit(tc.args[0])
4✔
238
            target_inst_class = target_inst.typ
4✔
239
            assert isinstance(
4✔
240
                target_inst_class, InstanceType
241
            ), "Can only cast instances, not classes"
242
            assert isinstance(
4✔
243
                target_inst_class.typ, UnionType
244
            ), "Can only cast instances of Union types of PlutusData"
245
            assert isinstance(target_class, RecordType), "Can only cast to PlutusData"
4✔
246
            assert (
4✔
247
                target_class in target_inst_class.typ.typs
248
            ), f"Trying to cast an instance of Union type to non-instance of union type"
249
            typed_if.test = self.visit(
4✔
250
                Compare(
251
                    left=Attribute(tc.args[0], "CONSTR_ID"),
252
                    ops=[Eq()],
253
                    comparators=[Constant(target_class.record.constructor)],
254
                )
255
            )
256
            # for the time of this if branch set the variable type to the specialized type
257
            self.set_variable_type(
4✔
258
                tc.args[0].id, InstanceType(target_class), force=True
259
            )
260
            typed_if.body = [self.visit(s) for s in node.body]
4✔
261
            self.set_variable_type(tc.args[0].id, target_inst_class, force=True)
4✔
262
        else:
263
            typed_if.test = self.visit(node.test)
4✔
264
            assert (
4✔
265
                typed_if.test.typ == BoolInstanceType
266
            ), "Branching condition must have boolean type"
267
            typed_if.body = [self.visit(s) for s in node.body]
4✔
268
        typed_if.orelse = [self.visit(s) for s in node.orelse]
4✔
269
        return typed_if
4✔
270

271
    def visit_While(self, node: While) -> TypedWhile:
4✔
272
        typed_while = copy(node)
4✔
273
        typed_while.test = self.visit(node.test)
4✔
274
        assert (
4✔
275
            typed_while.test.typ == BoolInstanceType
276
        ), "Branching condition must have boolean type"
277
        typed_while.body = [self.visit(s) for s in node.body]
4✔
278
        typed_while.orelse = [self.visit(s) for s in node.orelse]
4✔
279
        return typed_while
4✔
280

281
    def visit_For(self, node: For) -> TypedFor:
4✔
282
        typed_for = copy(node)
4✔
283
        typed_for.iter = self.visit(node.iter)
4✔
284
        if isinstance(node.target, Tuple):
4✔
285
            raise NotImplementedError(
×
286
                "Type deconstruction in for loops is not supported yet"
287
            )
288
        vartyp = None
4✔
289
        itertyp = typed_for.iter.typ
4✔
290
        assert isinstance(
4✔
291
            itertyp, InstanceType
292
        ), "Can only iterate over instances, not classes"
293
        if isinstance(itertyp.typ, TupleType):
4✔
294
            assert itertyp.typ.typs, "Iterating over an empty tuple is not allowed"
×
295
            vartyp = itertyp.typ.typs[0]
×
296
            assert all(
×
297
                itertyp.typ.typs[0] == t for t in typed_for.iter.typ.typs
298
            ), "Iterating through a tuple requires the same type for each element"
299
        elif isinstance(itertyp.typ, ListType):
4✔
300
            vartyp = itertyp.typ.typ
4✔
301
        else:
302
            raise NotImplementedError(
×
303
                "Type inference for loops over non-list objects is not supported"
304
            )
305
        self.set_variable_type(node.target.id, vartyp)
4✔
306
        typed_for.target = self.visit(node.target)
4✔
307
        typed_for.body = [self.visit(s) for s in node.body]
4✔
308
        typed_for.orelse = [self.visit(s) for s in node.orelse]
4✔
309
        return typed_for
4✔
310

311
    def visit_Name(self, node: Name) -> TypedName:
4✔
312
        tn = copy(node)
4✔
313
        # Make sure that the rhs of an assign is evaluated first
314
        tn.typ = self.variable_type(node.id)
4✔
315
        return tn
4✔
316

317
    def visit_Compare(self, node: Compare) -> TypedCompare:
4✔
318
        typed_cmp = copy(node)
4✔
319
        typed_cmp.left = self.visit(node.left)
4✔
320
        typed_cmp.comparators = [self.visit(s) for s in node.comparators]
4✔
321
        typed_cmp.typ = BoolInstanceType
4✔
322
        # the actual required types are being taken care of in the implementation
323
        return typed_cmp
4✔
324

325
    def visit_arg(self, node: arg) -> typedarg:
4✔
326
        ta = copy(node)
4✔
327
        ta.typ = InstanceType(self.type_from_annotation(node.annotation))
4✔
328
        self.set_variable_type(ta.arg, ta.typ)
4✔
329
        return ta
4✔
330

331
    def visit_arguments(self, node: arguments) -> typedarguments:
4✔
332
        if node.kw_defaults or node.kwarg or node.kwonlyargs or node.defaults:
4✔
333
            raise NotImplementedError(
×
334
                "Keyword arguments and defaults not supported yet"
335
            )
336
        ta = copy(node)
4✔
337
        ta.args = [self.visit(a) for a in node.args]
4✔
338
        return ta
4✔
339

340
    def visit_FunctionDef(self, node: FunctionDef) -> TypedFunctionDef:
4✔
341
        tfd = copy(node)
4✔
342
        assert not node.decorator_list, "Functions may not have decorators"
4✔
343
        self.enter_scope()
4✔
344
        tfd.args = self.visit(node.args)
4✔
345
        functyp = FunctionType(
4✔
346
            [t.typ for t in tfd.args.args],
347
            InstanceType(self.type_from_annotation(tfd.returns)),
348
        )
349
        tfd.typ = InstanceType(functyp)
4✔
350
        # We need the function type inside for recursion
351
        self.set_variable_type(node.name, tfd.typ)
4✔
352
        tfd.body = [self.visit(s) for s in node.body]
4✔
353
        # Check that return type and annotated return type match
354
        if not isinstance(node.body[-1], Return):
4✔
355
            assert (
4✔
356
                functyp.rettyp == NoneInstanceType
357
            ), f"Function '{node.name}' has no return statement but is supposed to return not-None value"
358
        else:
359
            assert (
4✔
360
                functyp.rettyp >= tfd.body[-1].typ
361
            ), f"Function '{node.name}' annotated return type does not match actual return type"
362
        self.exit_scope()
4✔
363
        # We need the function type outside for usage
364
        self.set_variable_type(node.name, tfd.typ)
4✔
365
        return tfd
4✔
366

367
    def visit_Module(self, node: Module) -> TypedModule:
4✔
368
        self.enter_scope()
4✔
369
        tm = copy(node)
4✔
370
        tm.body = [self.visit(n) for n in node.body]
4✔
371
        self.exit_scope()
4✔
372
        return tm
4✔
373

374
    def visit_Expr(self, node: Expr) -> TypedExpr:
4✔
375
        tn = copy(node)
4✔
376
        tn.value = self.visit(node.value)
4✔
377
        return tn
4✔
378

379
    def visit_BinOp(self, node: BinOp) -> TypedBinOp:
4✔
380
        tb = copy(node)
4✔
381
        tb.left = self.visit(node.left)
4✔
382
        tb.right = self.visit(node.right)
4✔
383
        # TODO the outcome of the operation may depend on the input types
384
        assert (
4✔
385
            tb.left.typ == tb.right.typ
386
        ), "Inputs to a binary operation need to have the same type"
387
        tb.typ = tb.left.typ
4✔
388
        return tb
4✔
389

390
    def visit_BoolOp(self, node: BoolOp) -> TypedBoolOp:
4✔
391
        tt = copy(node)
4✔
392
        tt.values = [self.visit(e) for e in node.values]
4✔
393
        tt.typ = BoolInstanceType
4✔
394
        assert all(
4✔
395
            BoolInstanceType >= e.typ for e in tt.values
396
        ), "All values compared must be bools"
397
        return tt
4✔
398

399
    def visit_UnaryOp(self, node: UnaryOp) -> TypedUnaryOp:
4✔
400
        tu = copy(node)
4✔
401
        tu.operand = self.visit(node.operand)
4✔
402
        tu.typ = tu.operand.typ
4✔
403
        return tu
4✔
404

405
    def visit_Subscript(self, node: Subscript) -> TypedSubscript:
4✔
406
        ts = copy(node)
4✔
407
        # special case: Subscript of Union / Dict / List and atomic types
408
        if isinstance(ts.value, Name) and ts.value.id in [
4✔
409
            "Union",
410
            "Dict",
411
            "List",
412
        ]:
413
            ts.value = ts.typ = self.type_from_annotation(ts)
4✔
414
            return ts
4✔
415

416
        ts.value = self.visit(node.value)
4✔
417
        assert isinstance(ts.value.typ, InstanceType), "Can only subscript instances"
4✔
418
        if isinstance(ts.value.typ.typ, TupleType):
4✔
419
            assert (
4✔
420
                ts.value.typ.typ.typs
421
            ), "Accessing elements from the empty tuple is not allowed"
422
            if all(ts.value.typ.typ.typs[0] == t for t in ts.value.typ.typ.typs):
4✔
423
                ts.typ = ts.value.typ.typ.typs[0]
4✔
424
            elif isinstance(ts.slice, Constant) and isinstance(ts.slice.value, int):
4✔
425
                ts.typ = ts.value.typ.typ.typs[ts.slice.value]
4✔
426
            else:
427
                raise TypeInferenceError(
×
428
                    f"Could not infer type of subscript of typ {ts.value.typ.typ.__class__}"
429
                )
430
        elif isinstance(ts.value.typ.typ, PairType):
4✔
431
            if isinstance(ts.slice, Constant) and isinstance(ts.slice.value, int):
4✔
432
                ts.typ = (
4✔
433
                    ts.value.typ.typ.l_typ
434
                    if ts.slice.value == 0
435
                    else ts.value.typ.typ.r_typ
436
                )
437
            else:
438
                raise TypeInferenceError(
×
439
                    f"Could not infer type of subscript of typ {ts.value.typ.typ.__class__}"
440
                )
441
        elif isinstance(ts.value.typ.typ, ListType):
4✔
442
            ts.typ = ts.value.typ.typ.typ
4✔
443
            ts.slice = self.visit(node.slice)
4✔
444
            assert ts.slice.typ == IntegerInstanceType, "List indices must be integers"
4✔
445
        elif isinstance(ts.value.typ.typ, ByteStringType):
4✔
446
            if not isinstance(ts.slice, Slice):
4✔
447
                ts.typ = IntegerInstanceType
4✔
448
                ts.slice = self.visit(node.slice)
4✔
449
                assert (
4✔
450
                    ts.slice.typ == IntegerInstanceType
451
                ), "bytes indices must be integers"
452
            elif isinstance(ts.slice, Slice):
4✔
453
                ts.typ = ByteStringInstanceType
4✔
454
                if ts.slice.lower is None:
4✔
455
                    ts.slice.lower = Constant(0)
×
456
                ts.slice.lower = self.visit(node.slice.lower)
4✔
457
                assert (
4✔
458
                    ts.slice.lower.typ == IntegerInstanceType
459
                ), "lower slice indices for bytes must be integers"
460
                if ts.slice.upper is None:
4✔
461
                    ts.slice.upper = Call(
×
462
                        func=Name(id="len", ctx=Load()), args=[ts.value], keywords=[]
463
                    )
464
                ts.slice.upper = self.visit(node.slice.upper)
4✔
465
                assert (
4✔
466
                    ts.slice.upper.typ == IntegerInstanceType
467
                ), "upper slice indices for bytes must be integers"
468
            else:
469
                raise TypeInferenceError(
×
470
                    f"Could not infer type of subscript of typ {ts.value.typ.__class__}"
471
                )
472
        elif isinstance(ts.value.typ.typ, DictType):
×
473
            # TODO could be implemented with potentially just erroring. It might be desired to avoid this though.
474
            raise TypeInferenceError(
×
475
                f"Could not infer type of subscript of dict. Use 'get' with a default value instead."
476
            )
477
        else:
478
            raise TypeInferenceError(
×
479
                f"Could not infer type of subscript of typ {ts.value.typ.__class__}"
480
            )
481
        return ts
4✔
482

483
    def visit_Call(self, node: Call) -> TypedCall:
4✔
484
        assert not node.keywords, "Keyword arguments are not supported yet"
4✔
485
        tc = copy(node)
4✔
486
        tc.args = [self.visit(a) for a in node.args]
4✔
487
        tc.func = self.visit(node.func)
4✔
488
        # might be a cast
489
        if isinstance(tc.func.typ, ClassType):
4✔
490
            tc.func.typ = tc.func.typ.constr_type()
4✔
491
        # type might only turn out after the initialization (note the constr could be polymorphic)
492
        if isinstance(tc.func.typ, InstanceType) and isinstance(
4✔
493
            tc.func.typ.typ, PolymorphicFunctionType
494
        ):
495
            tc.func.typ = PolymorphicFunctionInstanceType(
4✔
496
                tc.func.typ.typ.polymorphic_function.type_from_args(
497
                    [a.typ for a in tc.args]
498
                ),
499
                tc.func.typ.typ.polymorphic_function,
500
            )
501
        if isinstance(tc.func.typ, InstanceType) and isinstance(
4✔
502
            tc.func.typ.typ, FunctionType
503
        ):
504
            functyp = tc.func.typ.typ
4✔
505
            assert len(tc.args) == len(
4✔
506
                functyp.argtyps
507
            ), f"Signature of function does not match number of arguments. Expected {len(functyp.argtyps)} arguments with these types: {functyp.argtyps}"
508
            # all arguments need to be supertypes of the given type
509
            assert all(
4✔
510
                ap >= a.typ for a, ap in zip(tc.args, functyp.argtyps)
511
            ), f"Signature of function does not match arguments. Expected {len(functyp.argtyps)} arguments with these types: {functyp.argtyps}"
512
            tc.typ = functyp.rettyp
4✔
513
            return tc
4✔
514
        raise TypeInferenceError("Could not infer type of call")
×
515

516
    def visit_Pass(self, node: Pass) -> TypedPass:
4✔
517
        tp = copy(node)
×
518
        return tp
×
519

520
    def visit_Return(self, node: Return) -> TypedReturn:
4✔
521
        tp = copy(node)
4✔
522
        tp.value = self.visit(node.value)
4✔
523
        tp.typ = tp.value.typ
4✔
524
        return tp
4✔
525

526
    def visit_Attribute(self, node: Attribute) -> TypedAttribute:
4✔
527
        tp = copy(node)
4✔
528
        tp.value = self.visit(node.value)
4✔
529
        owner = tp.value.typ
4✔
530
        # accesses to field
531
        tp.typ = owner.attribute_type(node.attr)
4✔
532
        return tp
4✔
533

534
    def visit_Assert(self, node: Assert) -> TypedAssert:
4✔
535
        ta = copy(node)
4✔
536
        ta.test = self.visit(node.test)
4✔
537
        assert (
4✔
538
            ta.test.typ == BoolInstanceType
539
        ), "Assertions must result in a boolean type"
540
        if ta.msg is not None:
4✔
541
            ta.msg = self.visit(node.msg)
4✔
542
            assert (
4✔
543
                ta.msg.typ == StringInstanceType
544
            ), "Assertions must has a string message (or None)"
545
        return ta
4✔
546

547
    def visit_RawPlutoExpr(self, node: RawPlutoExpr) -> RawPlutoExpr:
4✔
548
        assert node.typ is not None, "Raw Pluto Expression is missing type annotation"
4✔
549
        return node
4✔
550

551
    def visit_IfExp(self, node: IfExp) -> TypedIfExp:
4✔
552
        node_cp = copy(node)
4✔
553
        node_cp.test = self.visit(node.test)
4✔
554
        assert node_cp.test.typ == BoolInstanceType, "Comparison must have type boolean"
4✔
555
        node_cp.body = self.visit(node.body)
4✔
556
        node_cp.orelse = self.visit(node.orelse)
4✔
557
        if node_cp.body.typ >= node_cp.orelse.typ:
4✔
558
            node_cp.typ = node_cp.body.typ
4✔
559
        elif node_cp.orelse.typ >= node_cp.body.typ:
×
560
            node_cp.typ = node_cp.orelse.typ
×
561
        else:
562
            raise TypeInferenceError(
×
563
                "Branches of if-expression must return compatible types"
564
            )
565
        return node_cp
4✔
566

567
    def visit_comprehension(self, g: comprehension) -> typedcomprehension:
4✔
568
        new_g = copy(g)
4✔
569
        if isinstance(g.target, Tuple):
4✔
570
            raise NotImplementedError(
×
571
                "Type deconstruction in for loops is not supported yet"
572
            )
573
        new_g.iter = self.visit(g.iter)
4✔
574
        itertyp = new_g.iter.typ
4✔
575
        assert isinstance(
4✔
576
            itertyp, InstanceType
577
        ), "Can only iterate over instances, not classes"
578
        if isinstance(itertyp.typ, TupleType):
4✔
579
            assert itertyp.typ.typs, "Iterating over an empty tuple is not allowed"
×
580
            vartyp = itertyp.typ.typs[0]
×
581
            assert all(
×
582
                itertyp.typ.typs[0] == t for t in new_g.iter.typ.typs
583
            ), "Iterating through a tuple requires the same type for each element"
584
        elif isinstance(itertyp.typ, ListType):
4✔
585
            vartyp = itertyp.typ.typ
4✔
586
        else:
587
            raise NotImplementedError(
×
588
                "Type inference for loops over non-list objects is not supported"
589
            )
590
        self.set_variable_type(g.target.id, vartyp)
4✔
591
        new_g.target = self.visit(g.target)
4✔
592
        new_g.ifs = [self.visit(i) for i in g.ifs]
4✔
593
        return new_g
4✔
594

595
    def visit_ListComp(self, node: ListComp) -> TypedListComp:
4✔
596
        typed_listcomp = copy(node)
4✔
597
        # inside the comprehension is a seperate scope
598
        self.enter_scope()
4✔
599
        # first evaluate generators for assigned variables
600
        typed_listcomp.generators = [self.visit(s) for s in node.generators]
4✔
601
        # then evaluate elements
602
        typed_listcomp.elt = self.visit(node.elt)
4✔
603
        self.exit_scope()
4✔
604
        typed_listcomp.typ = InstanceType(ListType(typed_listcomp.elt.typ))
4✔
605
        return typed_listcomp
4✔
606

607
    def generic_visit(self, node: AST) -> TypedAST:
4✔
608
        raise NotImplementedError(
×
609
            f"Cannot infer type of non-implemented node {node.__class__}"
610
        )
611

612

613
class RecordReader(NodeVisitor):
4✔
614
    name: str
4✔
615
    constructor: int
4✔
616
    attributes: typing.List[typing.Tuple[str, Type]]
4✔
617
    _type_inferencer: AggressiveTypeInferencer
4✔
618

619
    def __init__(self, type_inferencer: AggressiveTypeInferencer):
4✔
620
        self.constructor = 0
4✔
621
        self.attributes = []
4✔
622
        self._type_inferencer = type_inferencer
4✔
623

624
    @classmethod
4✔
625
    def extract(cls, c: ClassDef, type_inferencer: AggressiveTypeInferencer) -> Record:
4✔
626
        f = cls(type_inferencer)
4✔
627
        f.visit(c)
4✔
628
        return Record(f.name, f.constructor, FrozenFrozenList(f.attributes))
4✔
629

630
    def visit_AnnAssign(self, node: AnnAssign) -> None:
4✔
631
        assert isinstance(
4✔
632
            node.target, Name
633
        ), "Record elements must have named attributes"
634
        typ = self._type_inferencer.type_from_annotation(node.annotation)
4✔
635
        if node.target.id != "CONSTR_ID":
4✔
636
            assert (
4✔
637
                node.value is None
638
            ), f"PlutusData attribute {node.target.id} may not have a default value"
639
            assert not isinstance(
4✔
640
                typ, TupleType
641
            ), "Records can currently not hold tuples"
642
            self.attributes.append(
4✔
643
                (
644
                    node.target.id,
645
                    InstanceType(typ),
646
                )
647
            )
648
            return
4✔
649
        assert typ == IntegerType, "CONSTR_ID must be assigned an integer"
×
650
        assert isinstance(
×
651
            node.value, Constant
652
        ), "CONSTR_ID must be assigned a constant integer"
653
        assert isinstance(
×
654
            node.value.value, int
655
        ), "CONSTR_ID must be assigned an integer"
656
        self.constructor = node.value.value
×
657

658
    def visit_ClassDef(self, node: ClassDef) -> None:
4✔
659
        self.name = node.name
4✔
660
        for s in node.body:
4✔
661
            self.visit(s)
4✔
662

663
    def visit_Pass(self, node: Pass) -> None:
4✔
664
        pass
4✔
665

666
    def visit_Assign(self, node: Assign) -> None:
4✔
667
        assert len(node.targets) == 1, "Record elements must be assigned one by one"
4✔
668
        target = node.targets[0]
4✔
669
        assert isinstance(target, Name), "Record elements must have named attributes"
4✔
670
        assert (
4✔
671
            target.id == "CONSTR_ID"
672
        ), "Type annotations may only be omitted for CONSTR_ID"
673
        assert isinstance(
4✔
674
            node.value, Constant
675
        ), "CONSTR_ID must be assigned a constant integer"
676
        assert isinstance(
4✔
677
            node.value.value, int
678
        ), "CONSTR_ID must be assigned an integer"
679
        self.constructor = node.value.value
4✔
680

681
    def visit_Expr(self, node: Expr) -> None:
4✔
682
        assert isinstance(
4✔
683
            node.value, Constant
684
        ), "Only comments are allowed inside classes"
685
        return None
4✔
686

687
    def generic_visit(self, node: AST) -> None:
4✔
688
        raise NotImplementedError(f"Can not compile {ast.dump(node)} inside of a class")
×
689

690

691
def typed_ast(ast: AST):
4✔
692
    return AggressiveTypeInferencer().visit(ast)
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc