• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

OpShin / opshin / 18409503285

10 Oct 2025 02:25PM UTC coverage: 92.68% (-0.2%) from 92.835%
18409503285

push

github

nielstron
Version bump

1265 of 1480 branches covered (85.47%)

Branch coverage included in aggregate %.

4800 of 5064 relevant lines covered (94.79%)

2.84 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

96.74
/opshin/type_inference.py
1
"""
2
An aggressive type inference based on the work of Aycock [1].
3
It only allows a subset of legal python operations which
4
allow us to infer the type of all involved variables
5
statically.
6
Using this we can resolve overloaded functions when translating Python
7
into UPLC where there is no dynamic type checking.
8
Additionally, this conveniently implements an additional layer of
9
security into the Smart Contract by checking type correctness.
10

11

12
[1]: https://legacy.python.org/workshops/2000-01/proceedings/papers/aycock/aycock.html
13
"""
14

15
import re
3✔
16
import ast
3✔
17
from ast import *
3✔
18
import typing
3✔
19
from collections import defaultdict
3✔
20
from copy import copy
3✔
21
from hashlib import sha256
3✔
22

23
from frozenlist2 import frozenlist
3✔
24
from ordered_set import OrderedSet
3✔
25
from pycardano import PlutusData
3✔
26
from typing import Union
3✔
27
import pluthon as plt
3✔
28
from .typed_ast import *
3✔
29
from .util import (
3✔
30
    CompilingNodeTransformer,
31
    distinct,
32
    TypedNodeVisitor,
33
    OPSHIN_LOGGER,
34
    custom_fix_missing_locations,
35
    read_vars,
36
    externally_bound_vars,
37
)
38
from .fun_impls import PythonBuiltInTypes
3✔
39
from .rewrite.rewrite_cast_condition import SPECIAL_BOOL
3✔
40
from .type_impls import (
3✔
41
    Type,
42
    ByteStringType,
43
    IntegerType,
44
    StringType,
45
    AnyType,
46
    BoolType,
47
    InstanceType,
48
    RecordType,
49
    PolymorphicFunctionType,
50
    Record,
51
    BoolInstanceType,
52
    IntegerInstanceType,
53
    UnitInstanceType,
54
    ByteStringInstanceType,
55
    StringInstanceType,
56
    ListType,
57
    DictType,
58
    UnionType,
59
    PairType,
60
    TypeInferenceError,
61
    UnitType,
62
    ATOMIC_TYPES,
63
    ClassType,
64
    TupleType,
65
    PolymorphicFunctionInstanceType,
66
    FunctionType,
67
)
68

69
# from frozendict import frozendict
70

71

72
INITIAL_SCOPE = {
3✔
73
    # class annotations
74
    "bytes": ByteStringType(),
75
    "bytearray": ByteStringType(),
76
    "int": IntegerType(),
77
    "bool": BoolType(),
78
    "str": StringType(),
79
    "Anything": AnyType(),
80
}
81

82
INITIAL_SCOPE.update(
3✔
83
    {
84
        name.name: typ
85
        for name, typ in PythonBuiltInTypes.items()
86
        if isinstance(typ.typ, PolymorphicFunctionType)
87
    }
88
)
89

90
DUNDER_MAP = {
3✔
91
    # ast.Compare:
92
    ast.Eq: "__eq__",
93
    ast.NotEq: "__ne__",
94
    ast.Lt: "__lt__",
95
    ast.LtE: "__le__",
96
    ast.Gt: "__gt__",
97
    ast.GtE: "__ge__",
98
    # ast.Is # no dunder
99
    # ast.IsNot # no dunder
100
    ast.In: "__contains__",
101
    ast.NotIn: "__contains__",
102
    # ast.Binop:
103
    ast.Add: "__add__",
104
    ast.Sub: "__sub__",
105
    ast.Mult: "__mul__",
106
    ast.Div: "__truediv__",
107
    ast.FloorDiv: "__floordiv__",
108
    ast.Mod: "__mod__",
109
    ast.Pow: "__pow__",
110
    ast.MatMult: "__matmul__",
111
    # ast.UnaryOp:
112
    # ast.UAdd
113
    ast.USub: "__neg__",
114
    ast.Not: "__bool__",
115
    ast.Invert: "__invert__",
116
    # ast.BoolOp
117
    ast.And: "__and__",
118
    ast.Or: "__or__",
119
}
120

121
DUNDER_REVERSE_MAP = {
3✔
122
    ast.Add: "__radd__",
123
    ast.Sub: "__rsub__",
124
    ast.Mult: "__rmul__",
125
    ast.Div: "__rtruediv__",
126
    ast.FloorDiv: "__rfloordiv__",
127
    ast.Mod: "__rmod__",
128
    ast.Pow: "__rpow__",
129
    ast.LShift: "__rlshift__",
130
    ast.RShift: "__rrshift__",
131
    ast.And: "__rand__",
132
    ast.Or: "__ror__",
133
}
134

135
ALL_DUNDERS = set(DUNDER_MAP.values()).union(set(DUNDER_REVERSE_MAP.values()))
3✔
136

137

138
def record_from_plutusdata(c: PlutusData):
3✔
139
    return Record(
3✔
140
        name=c.__class__.__name__,
141
        orig_name=c.__class__.__name__,
142
        constructor=c.CONSTR_ID,
143
        fields=frozenlist([(k, constant_type(v)) for k, v in c.__dict__.items()]),
144
    )
145

146

147
def constant_type(c):
3✔
148
    if isinstance(c, bool):
3✔
149
        return BoolInstanceType
3✔
150
    if isinstance(c, int):
3✔
151
        return IntegerInstanceType
3✔
152
    if isinstance(c, type(None)):
3✔
153
        return UnitInstanceType
3✔
154
    if isinstance(c, bytes):
3✔
155
        return ByteStringInstanceType
3✔
156
    if isinstance(c, str):
3✔
157
        return StringInstanceType
3✔
158
    if isinstance(c, list):
3✔
159
        assert len(c) > 0, "Lists must be non-empty"
3✔
160
        types = [constant_type(x) for x in c]
3✔
161
        first_typ = find_max_type([InstanceType(t) for t in types])
3✔
162
        if first_typ is None:
3!
163
            raise ValueError(
×
164
                f"All elements in a list must have a compatible type, found typs {tuple(t.python_type() for t in types)}"
165
            )
166
        return InstanceType(ListType(first_typ))
3✔
167
    if isinstance(c, dict):
3✔
168
        assert len(c) > 0, "Dicts must be non-empty"
3✔
169

170
        key_types = [constant_type(k) for k in c.keys()]
3✔
171
        value_types = [constant_type(v) for v in c.values()]
3✔
172
        first_key_typ = find_max_type([InstanceType(t) for t in key_types])
3✔
173
        first_value_typ = find_max_type([InstanceType(t) for t in value_types])
3✔
174
        if first_key_typ is None:
3!
175
            raise ValueError(
×
176
                f"All keys in a dict must have a compatible type, found typs {tuple(t.python_type() for t in key_types)}"
177
            )
178
        if first_value_typ is None:
3!
179
            raise ValueError(
×
180
                f"All values in a dict must have a compatible type, found typs {tuple(t.python_type() for t in value_types)}"
181
            )
182
        return InstanceType(DictType(first_key_typ, first_value_typ))
3✔
183
    if isinstance(c, PlutusData):
3✔
184
        return InstanceType(RecordType(record=record_from_plutusdata(c)))
3✔
185
    raise NotImplementedError(f"Type {type(c)} not supported")
186

187

188
TypeMap = typing.Dict[str, Type]
3✔
189
TypeMapPair = typing.Tuple[TypeMap, TypeMap]
3✔
190

191

192
def union_types(*ts: Type):
3✔
193
    ts = OrderedSet(ts)
3✔
194
    # If all types are the same, just return the type
195
    if len(ts) == 1:
3✔
196
        return ts[0]
3✔
197
    # If there is a type that is compatible with all other types, choose the maximum
198
    for t in ts:
3✔
199
        if all(t >= tp for tp in ts):
3✔
200
            return t
3✔
201
    assert ts, "Union must combine multiple classes"
3✔
202
    # flatten encountered union types
203
    all_ts = []
3✔
204
    to_process = list(reversed(ts))
3✔
205
    while to_process:
3✔
206
        t = to_process.pop()
3✔
207
        if isinstance(t, UnionType):
3✔
208
            to_process = to_process.extend(reversed(t.typs))
3✔
209
        else:
210
            assert isinstance(
3✔
211
                t, (RecordType, IntegerType, ByteStringType, ListType, DictType)
212
            ), f"Union must combine multiple PlutusData, int, bytes, List[Anything] or Dict[Anything,Anything] but found {t.python_type()}"
213
            if isinstance(t, ListType):
3✔
214
                assert isinstance(t.typ, InstanceType) and isinstance(
3✔
215
                    t.typ.typ, AnyType
216
                ), "Union must contain only lists of Any, i.e. List[Anything]"
217
            if isinstance(t, DictType):
3✔
218
                assert (
3✔
219
                    isinstance(t.key_typ, InstanceType)
220
                    and isinstance(t.key_typ.typ, AnyType)
221
                    and isinstance(t.value_typ, InstanceType)
222
                    and isinstance(t.value_typ.typ, AnyType)
223
                ), "Union must contain only dicts of Any, i.e. Dict[Anything, Anything]"
224
            all_ts.append(t)
3✔
225
    union_set = OrderedSet(all_ts)
3✔
226
    assert distinct(
3✔
227
        [
228
            e.record.constructor
229
            for e in union_set
230
            if not isinstance(e, (ByteStringType, IntegerType, ListType, DictType))
231
        ]
232
    ), (
233
        "Union must combine PlutusData classes with unique CONSTR_ID, but found duplicates: "
234
        + str(
235
            {
236
                e.record.orig_name: e.record.constructor
237
                for e in union_set
238
                if isinstance(e, RecordType)
239
            }
240
        )
241
    )
242
    return UnionType(frozenlist(union_set))
3✔
243

244

245
def intersection_types(*ts: Type):
3✔
246
    ts = OrderedSet(ts)
3✔
247
    if len(ts) == 1:
3✔
248
        return ts[0]
3✔
249
    ts = [t if isinstance(t, UnionType) else UnionType(frozenlist([t])) for t in ts]
3✔
250
    assert ts, "Must have at least one type to intersect"
3✔
251
    intersection_set = OrderedSet(ts[0].typs)
3✔
252
    for t in ts[1:]:
3✔
253
        intersection_set.intersection_update(t.typs)
3✔
254
    return UnionType(frozenlist(intersection_set))
3✔
255

256

257
def find_max_type(elts: typing.List[InstanceType]):
3✔
258
    if not elts:
3✔
259
        return InstanceType(AnyType())
3✔
260
    set_elts = OrderedSet(elts)
3✔
261
    max_typ = None
3✔
262
    for m in elts:
3✔
263
        l_typ = m.typ
3✔
264
        if all(l_typ >= e.typ for e in set_elts):
3✔
265
            max_typ = l_typ
3✔
266
            break
3✔
267
    if max_typ is None:
3✔
268
        # try to derive a union type
269
        try:
3✔
270
            max_typ = InstanceType(union_types(*(e.typ.typ for e in set_elts)))
3✔
271
        except AssertionError:
×
272
            # if this fails, we have a list with incompatible types
273
            raise ValueError(
×
274
                f"All elements must have a compatible type, found typs {tuple(e.typ.python_type() for e in elts)}"
275
            )
276
    return max_typ
3✔
277

278

279
class TypeCheckVisitor(TypedNodeVisitor):
3✔
280
    """
281
    Generates the types to which objects are cast due to a boolean expression
282
    It returns a tuple of dictionaries which are a name -> type mapping
283
    for variable names that are assured to have a specific type if this expression
284
    is True/False respectively
285
    """
286

287
    def __init__(self, allow_isinstance_anything=False):
3✔
288
        self.allow_isinstance_anything = allow_isinstance_anything
3✔
289

290
    def generic_visit(self, node: AST) -> TypeMapPair:
3✔
291
        return getattr(node, "typechecks", ({}, {}))
3✔
292

293
    def visit_Call(self, node: Call) -> TypeMapPair:
3✔
294
        if isinstance(node.func, Name) and node.func.orig_id == SPECIAL_BOOL:
3✔
295
            return self.visit(node.args[0])
3✔
296
        if not (isinstance(node.func, Name) and node.func.orig_id == "isinstance"):
3✔
297
            return ({}, {})
3✔
298
        # special case for Union
299
        if not isinstance(node.args[0], Name):
3✔
300
            OPSHIN_LOGGER.warning(
3✔
301
                "Target 0 of an isinstance cast must be a variable name for type casting to work. You can still proceed, but the inferred type of the isinstance cast will not be accurate."
302
            )
303
            return ({}, {})
3✔
304
        assert isinstance(node.args[1], Name) or isinstance(
3✔
305
            node.args[1].typ, (ListType, DictType)
306
        ), "Target 1 of an isinstance cast must be a class name"
307
        target_class: RecordType = node.args[1].typ
3✔
308
        inst = node.args[0]
3✔
309
        inst_class = inst.typ
3✔
310
        assert isinstance(
3✔
311
            inst_class, InstanceType
312
        ), "Can only cast instances, not classes"
313
        # assert isinstance(target_class, RecordType), "Can only cast to PlutusData"
314
        if isinstance(inst_class.typ, UnionType):
3✔
315
            assert (
3✔
316
                target_class in inst_class.typ.typs
317
            ), f"Trying to cast an instance of Union type to non-instance of union type"
318
            union_without_target_class = union_types(
3✔
319
                *(x for x in inst_class.typ.typs if x != target_class)
320
            )
321
        elif isinstance(inst_class.typ, AnyType) and self.allow_isinstance_anything:
3!
322
            union_without_target_class = AnyType()
×
323
        else:
324
            assert (
3✔
325
                inst_class.typ == target_class
326
            ), "Can only cast instances of Union types of PlutusData or cast the same class. If you know what you are doing, enable the flag '--allow-isinstance-anything'"
327
            union_without_target_class = target_class
3✔
328
        varname = node.args[0].id
3✔
329
        return ({varname: target_class}, {varname: union_without_target_class})
3✔
330

331
    def visit_BoolOp(self, node: BoolOp) -> TypeMapPair:
3✔
332
        res = {}
3✔
333
        inv_res = {}
3✔
334
        checks = [self.visit(v) for v in node.values]
3✔
335
        checked_types = defaultdict(list)
3✔
336
        inv_checked_types = defaultdict(list)
3✔
337
        for c, inv_c in checks:
3✔
338
            for v, t in c.items():
3✔
339
                checked_types[v].append(t)
3✔
340
            for v, t in inv_c.items():
3✔
341
                inv_checked_types[v].append(t)
3✔
342
        if isinstance(node.op, And):
3✔
343
            # a conjunction is just the intersection
344
            for v, ts in checked_types.items():
3✔
345
                res[v] = intersection_types(*ts)
3✔
346
            # if the conjunction fails, its any of the respective reverses, but only if the type is checked in every conjunction
347
            for v, ts in inv_checked_types.items():
3✔
348
                if len(ts) < len(checks):
3!
349
                    continue
3✔
350
                inv_res[v] = union_types(*ts)
×
351
        elif isinstance(node.op, Or):
3✔
352
            # a disjunction is just the union, but some type must be checked in every disjunction
353
            for v, ts in checked_types.items():
3✔
354
                if len(ts) < len(checks):
3✔
355
                    continue
3✔
356
                res[v] = union_types(*ts)
3✔
357
            # if the disjunction fails, then it must be in the intersection of the inverses
358
            for v, ts in inv_checked_types.items():
3✔
359
                inv_res[v] = intersection_types(*ts)
3✔
360
        else:
361
            raise NotImplementedError(f"Unsupported boolean operator {node.op}")
362
        return (res, inv_res)
3✔
363

364
    def visit_UnaryOp(self, node: UnaryOp) -> TypeMapPair:
3✔
365
        (res, inv_res) = self.visit(node.operand)
3✔
366
        if isinstance(node.op, Not):
3!
367
            return (inv_res, res)
3✔
368
        return (res, inv_res)
×
369

370

371
def merge_scope(s1: typing.Dict[str, Type], s2: typing.Dict[str, Type]):
3✔
372
    keys = OrderedSet(s1.keys()).union(s2.keys())
3✔
373
    merged = {}
3✔
374
    for k in keys:
3✔
375
        if k not in s1.keys():
3✔
376
            merged[k] = s2[k]
3✔
377
        elif k not in s2.keys():
3✔
378
            merged[k] = s1[k]
3✔
379
        else:
380
            try:
3✔
381
                assert isinstance(s1[k], InstanceType) and isinstance(
3✔
382
                    s2[k], InstanceType
383
                ), f"""Can only merge instance types, found class type '{s1[k].python_type() if not isinstance(s1[k], InstanceType) else s2[k].python_type() if not isinstance(s2[k], InstanceType) else s1[k].python_type() + "' and '" + s1[k].python_type()}' for '{k}'"""
384
                merged[k] = InstanceType(union_types(s1[k].typ, s2[k].typ))
3✔
385
            except AssertionError as e:
3✔
386
                raise AssertionError(
3✔
387
                    f"Can not merge scopes after branching, conflicting types for '{k}': '{e}'"
388
                )
389
    return merged
3✔
390

391

392
class AggressiveTypeInferencer(CompilingNodeTransformer):
3✔
393
    step = "Static Type Inference"
3✔
394

395
    def __init__(self, allow_isinstance_anything=False):
3✔
396
        self.allow_isinstance_anything = allow_isinstance_anything
3✔
397
        self.FUNCTION_ARGUMENT_REGISTRY = {}
3✔
398
        self.wrapped = []
3✔
399

400
        # A stack of dictionaries for storing scoped knowledge of variable types
401
        self.scopes = [INITIAL_SCOPE]
3✔
402

403
    # Obtain the type of a variable name in the current scope
404
    def variable_type(self, name: str) -> Type:
3✔
405
        name = name
3✔
406
        for scope in reversed(self.scopes):
3✔
407
            if name in scope:
3✔
408
                return scope[name]
3✔
409
        # try to find an outer scope where the variable name maps to the original name
410
        outer_scope_type = None
3✔
411
        for scope in reversed(self.scopes):
3✔
412
            for key, type in scope.items():
3✔
413
                if map_to_orig_name(key) == map_to_orig_name(name):
3✔
414
                    outer_scope_type = type
3✔
415
        if outer_scope_type is None:
3✔
416
            # If the variable is not found in any scope, raise an error
417
            raise TypeInferenceError(
3✔
418
                f"Variable '{map_to_orig_name(name)}' not initialized at access. You need to define it before using it the first time."
419
            )
420
        else:
421
            raise TypeInferenceError(
3✔
422
                f"Variable '{map_to_orig_name(name)}' not initialized at access.\n"
423
                f"Note that you may be trying to access variable '{map_to_orig_name(name)}' of type '{outer_scope_type.python_type()}' in an outer scope and later redefine it. This is not allowed.\n"
424
                "This can happen for example if you redefine a (renamed) imported function but try to use it before the redefinition."
425
            )
426

427
    def is_defined_in_current_scope(self, name: str) -> bool:
3✔
428
        try:
3✔
429
            self.variable_type(name)
3✔
430
            return True
3✔
431
        except TypeInferenceError:
×
432
            return False
×
433

434
    def enter_scope(self):
3✔
435
        self.scopes.append({})
3✔
436

437
    def exit_scope(self):
3✔
438
        self.scopes.pop()
3✔
439

440
    def set_variable_type(self, name: str, typ: Type, force=False):
3✔
441
        if not force and name in self.scopes[-1] and self.scopes[-1][name] != typ:
3✔
442
            if self.scopes[-1][name] >= typ:
3✔
443
                # the specified type is broader, we pass on this
444
                return
3✔
445
            raise TypeInferenceError(
3✔
446
                f"Type '{self.scopes[-1][name].python_type()}' of variable '{map_to_orig_name(name)}' in local scope does not match inferred type '{typ.python_type()}'"
447
            )
448
        self.scopes[-1][name] = typ
3✔
449

450
    def implement_typechecks(self, typchecks: TypeMap):
3✔
451
        prevtyps = {}
3✔
452
        for n, t in typchecks.items():
3✔
453
            prevtyps[n] = self.variable_type(n).typ
3✔
454
            self.set_variable_type(n, InstanceType(t), force=True)
3✔
455
        return prevtyps
3✔
456

457
    def dunder_override(self, node: Union[BinOp, Compare, UnaryOp]):
3✔
458
        # Check for potential dunder_method override
459
        operand = None
3✔
460
        operation = None
3✔
461
        args = []
3✔
462
        if isinstance(node, UnaryOp):
3✔
463
            operand = self.visit(node.operand)
3✔
464
            operation = node.op
3✔
465
        elif isinstance(node, BinOp):
3✔
466
            operand = self.visit(node.left)
3✔
467
            operation = node.op
3✔
468
            args = [self.visit(node.right)]
3✔
469
        elif isinstance(node, Compare):
3!
470
            operation = node.ops[0]
3✔
471
            if any([isinstance(operation, x) for x in [ast.In, ast.NotIn]]):
3✔
472
                operand = self.visit(node.comparators[0])
3✔
473
                args = [self.visit(node.left)]
3✔
474
            else:
475
                operand = self.visit(node.left)
3✔
476
                args = [self.visit(c) for c in node.comparators]
3✔
477
            assert len(node.ops) == 1, "Only support one op at a time"
3✔
478
        operand_type = operand.typ
3✔
479
        if (
3✔
480
            operation.__class__ in DUNDER_MAP
481
            and isinstance(operand_type, InstanceType)
482
            and isinstance(operand_type.typ, RecordType)
483
        ):
484
            dunder = DUNDER_MAP[operation.__class__]
3✔
485
            operand_class_name = operand_type.typ.record.name
3✔
486
            method_name = f"{operand_class_name}_+_{dunder}"
3✔
487
            if any([method_name in scope for scope in self.scopes]):
3✔
488
                call = ast.Call(
3✔
489
                    func=ast.Attribute(
490
                        value=operand,
491
                        attr=dunder,
492
                        ctx=ast.Load(),
493
                    ),
494
                    args=args,
495
                    keywords=[],
496
                )
497
                call.func.orig_id = f"{operand_class_name}.{dunder}"
3✔
498
                call.func.id = method_name
3✔
499
                call = self.visit_Call(call)
3✔
500
                if (dunder == "__contains__" and isinstance(operation, ast.NotIn)) or (
3✔
501
                    dunder == "__bool__" and isinstance(operation, ast.Not)
502
                ):
503
                    # we need to negate the result
504
                    not_call = TypedUnaryOp(
3✔
505
                        op=ast.Not(), operand=call, typ=BoolInstanceType
506
                    )
507
                    return not_call
3✔
508
                return call
3✔
509
        # if this is not supported, try the reverse dunder
510
        # note we assume 1, i.e. allow only a single right operand
511
        right_op_typ = args[0].typ if len(args) == 1 else None
3✔
512
        if (
3✔
513
            operation.__class__ in DUNDER_REVERSE_MAP
514
            and isinstance(right_op_typ, InstanceType)
515
            and isinstance(right_op_typ.typ, RecordType)
516
        ):
517
            dunder = DUNDER_REVERSE_MAP[operation.__class__]
3✔
518
            right_class_name = right_op_typ.typ.record.name
3✔
519
            method_name = f"{right_class_name}_+_{dunder}"
3✔
520
            if any([method_name in scope for scope in self.scopes]):
3!
521
                call = ast.Call(
3✔
522
                    func=ast.Attribute(
523
                        value=args[0],
524
                        attr=dunder,
525
                        ctx=ast.Load(),
526
                    ),
527
                    args=[operand],
528
                    keywords=[],
529
                )
530
                call.func.orig_id = f"{right_class_name}.{dunder}"
3✔
531
                call.func.id = method_name
3✔
532
                return self.visit_Call(call)
3✔
533
        return None
3✔
534

535
    def type_from_annotation(self, ann: expr):
3✔
536
        if isinstance(ann, Constant):
3✔
537
            if ann.value is None:
3✔
538
                return UnitType()
3✔
539
            else:
540
                for scope in reversed(self.scopes):
3!
541
                    for key, value in scope.items():
3✔
542
                        if (
3✔
543
                            isinstance(value, RecordType)
544
                            and value.record.orig_name == ann.value
545
                        ):
546
                            return value
3✔
547

548
        if isinstance(ann, Name):
3✔
549
            if ann.id in ATOMIC_TYPES:
3✔
550
                return ATOMIC_TYPES[ann.id]
3✔
551
            if ann.id == "Self":
3!
552
                v_t = self.variable_type(ann.idSelf_new)
×
553
            elif ann.id in ["Union", "List", "Dict"]:
3✔
554
                raise TypeInferenceError(
3✔
555
                    f"Annotation {ann.id} is not allowed as a variable type, use List[Anything], Dict[Anything, Anything] or Union[...] instead"
556
                )
557
            else:
558
                v_t = self.variable_type(ann.id)
3✔
559
            if isinstance(v_t, ClassType):
3!
560
                return v_t
3✔
561
            raise TypeInferenceError(
×
562
                f"Class name {ann.orig_id} not initialized before annotating variable"
563
            )
564
        if isinstance(ann, Subscript):
3✔
565
            assert isinstance(
3✔
566
                ann.value, Name
567
            ), "Only Union, Dict and List are allowed as Generic types"
568
            if ann.value.orig_id == "Union":
3✔
569
                if isinstance(ann.slice, Name):
3✔
570
                    elts = [ann.slice]
3✔
571
                elif isinstance(ann.slice, ast.Tuple):
3!
572
                    elts = ann.slice.elts
3✔
573
                else:
574
                    raise TypeInferenceError(
×
575
                        "Union must combine several classes, use Union[Class1, Class2, ...]"
576
                    )
577
                # only allow List[Anything] and Dict[Anything, Anything] in unions
578
                for elt in elts:
3✔
579
                    if isinstance(elt, Subscript) and elt.value.id == "List":
3✔
580
                        assert (
3✔
581
                            isinstance(elt.slice, Name)
582
                            and elt.slice.orig_id == "Anything"
583
                        ), f"Only List[Anything] is supported in Unions. Received List[{elt.slice.orig_id}]."
584
                    if isinstance(elt, Subscript) and elt.value.id == "Dict":
3✔
585
                        assert all(
3✔
586
                            isinstance(e, Name) and e.orig_id == "Anything"
587
                            for e in elt.slice.elts
588
                        ), f"Only Dict[Anything, Anything] is supported in Unions. Received Dict[{elt.slice.elts[0].orig_id}, {elt.slice.elts[1].orig_id}]."
589
                ann_types = frozenlist([self.type_from_annotation(e) for e in elts])
3✔
590
                # flatten encountered union types
591
                ann_types = frozenlist(
3✔
592
                    sum(
593
                        (
594
                            tuple(t.typs) if isinstance(t, UnionType) else (t,)
595
                            for t in ann_types
596
                        ),
597
                        start=(),
598
                    )
599
                )
600
                # check for unique constr_ids
601
                constr_ids = [
3✔
602
                    record.record.constructor
603
                    for record in ann_types
604
                    if isinstance(record, RecordType)
605
                ]
606
                assert len(constr_ids) == len(set(constr_ids)), (
3✔
607
                    "Union must combine PlutusData classes with unique CONSTR_ID, but found duplicates: "
608
                    + str(
609
                        {
610
                            e.record.orig_name: e.record.constructor
611
                            for e in ann_types
612
                            if isinstance(e, RecordType)
613
                        }
614
                    )
615
                )
616
                return union_types(*ann_types)
3✔
617
            if ann.value.orig_id == "List":
3✔
618
                ann_type = self.type_from_annotation(ann.slice)
3✔
619
                assert isinstance(
3✔
620
                    ann_type, ClassType
621
                ), "List must have a single type as parameter"
622
                assert not isinstance(
3✔
623
                    ann_type, TupleType
624
                ), "List can currently not hold tuples"
625
                return ListType(InstanceType(ann_type))
3✔
626
            if ann.value.orig_id == "Dict":
3!
627
                assert isinstance(ann.slice, Tuple), "Dict must combine two classes"
3✔
628
                assert len(ann.slice.elts) == 2, "Dict must combine two classes"
3✔
629
                ann_types = self.type_from_annotation(
3✔
630
                    ann.slice.elts[0]
631
                ), self.type_from_annotation(ann.slice.elts[1])
632
                assert all(
3✔
633
                    isinstance(e, ClassType) for e in ann_types
634
                ), "Dict must combine two classes"
635
                assert not any(
3✔
636
                    isinstance(e, TupleType) for e in ann_types
637
                ), "Dict can currently not hold tuples"
638
                return DictType(*(InstanceType(a) for a in ann_types))
3✔
639
            if ann.value.orig_id == "Tuple":
×
640
                assert isinstance(
×
641
                    ann.slice, Tuple
642
                ), "Tuple must combine several classes"
643
                ann_types = [self.type_from_annotation(e) for e in ann.slice.elts]
×
644
                assert all(
×
645
                    isinstance(e, ClassType) for e in ann_types
646
                ), "Tuple must combine classes"
647
                return TupleType(frozenlist([InstanceType(a) for a in ann_types]))
×
648
            raise NotImplementedError(
649
                "Only Union, Dict and List are allowed as Generic types"
650
            )
651
        if ann is None:
3✔
652
            return AnyType()
3✔
653
        raise NotImplementedError(f"Annotation type {ann.__class__} is not supported")
654

655
    def visit_sequence(self, node_seq: typing.List[stmt]) -> plt.AST:
3✔
656
        additional_functions = []
3✔
657
        for n in node_seq:
3✔
658
            if not isinstance(n, ast.ClassDef):
3✔
659
                continue
3✔
660
            non_method_attributes = []
3✔
661
            for attribute in n.body:
3✔
662
                if not isinstance(attribute, ast.FunctionDef):
3✔
663
                    non_method_attributes.append(attribute)
3✔
664
                    continue
3✔
665
                func = copy(attribute)
3✔
666
                if func.name[0:2] == "__" and func.name[-2:] == "__":
3✔
667
                    assert (
3✔
668
                        func.name in ALL_DUNDERS
669
                    ), f"The following Dunder methods are supported {sorted(ALL_DUNDERS)}. Received {func.name} which is not supported"
670
                func.name = f"{n.name}_+_{attribute.name}"
3✔
671

672
                def does_literally_reference_self(arg):
3✔
673
                    if arg is None:
3✔
674
                        return False
3✔
675
                    if isinstance(arg, Name) and arg.id == n.name:
3✔
676
                        return True
3✔
677
                    if (
3✔
678
                        isinstance(arg, ast.Subscript)
679
                        and isinstance(arg.value, Name)
680
                        and arg.value.id in ("Union", "List", "Dict")
681
                    ):
682
                        # Only possible for List, Dict and Union
683
                        if any(
3✔
684
                            does_literally_reference_self(e) for e in arg.slice.elts
685
                        ):
686
                            return True
3✔
687
                    return False
3✔
688

689
                for arg in func.args.args:
3✔
690
                    assert not does_literally_reference_self(
3✔
691
                        arg.annotation
692
                    ), f"Argument '{arg.arg}' of method '{attribute.name}' in class '{n.name}' literally references the class itself. This is not allowed. If you want to reference the class itself, use 'Self' as type annotation."
693
                assert not does_literally_reference_self(
3✔
694
                    func.returns
695
                ), f"Return type of method '{attribute.name}' in class '{n.name}' literally references the class itself. This is not allowed. If you want to reference the class itself, use 'Self' as type annotation."
696
                ann = ast.Name(id=n.name, ctx=ast.Load())
3✔
697
                if len(func.args.args) == 0:
3✔
698
                    raise TypeError(
3✔
699
                        f"Method '{attribute.orig_name}' in class '{n.orig_name}' must have at least one argument (self)"
700
                    )
701
                custom_fix_missing_locations(ann, attribute.args.args[0])
3✔
702
                if func.args.args[0].orig_arg != "self":
3✔
703
                    OPSHIN_LOGGER.warning(
3✔
704
                        f"The first argument of method '{attribute.name}' in class '{n.orig_name}' should be named 'self', but found '{func.args.args[0].orig_arg}'. This is not enforced, but recommended."
705
                    )
706
                if func.args.args[0].annotation is not None and not (
3✔
707
                    isinstance(func.args.args[0].annotation, Name)
708
                    and func.args.args[0].annotation.id == "Self"
709
                ):
710
                    raise TypeError(
3✔
711
                        f"The first argument of method '{attribute.name}' in class '{n.name}' must either not be annotated or be annotated with 'Self' to indicate that it is the instance of the class."
712
                    )
713
                ann.orig_id = attribute.args.args[0].orig_arg
3✔
714
                func.args.args[0].annotation = ann
3✔
715
                additional_functions.append(func)
3✔
716
            n.body = non_method_attributes
3✔
717
        if additional_functions:
3✔
718
            last = node_seq.pop()
3✔
719
            node_seq.extend(additional_functions)
3✔
720
            node_seq.append(last)
3✔
721

722
        stmts = []
3✔
723
        prevtyps = {}
3✔
724
        for n in node_seq:
3✔
725
            stmt = self.visit(n)
3✔
726
            stmts.append(stmt)
3✔
727
            # if an assert is amng the statements apply the isinstance cast
728
            if isinstance(stmt, Assert):
3✔
729
                typchecks, _ = TypeCheckVisitor(self.allow_isinstance_anything).visit(
3✔
730
                    stmt.test
731
                )
732
                # for the time after this assert, the variable has the specialized type
733
                prevtyps.update(self.implement_typechecks(typchecks))
3✔
734
        self.implement_typechecks(prevtyps)
3✔
735
        return stmts
3✔
736

737
    def visit_ClassDef(self, node: ClassDef) -> TypedClassDef:
3✔
738
        class_record = RecordReader(self).extract(node)
3✔
739
        typ = RecordType(class_record)
3✔
740
        self.set_variable_type(node.name, typ)
3✔
741
        self.FUNCTION_ARGUMENT_REGISTRY[node.name] = [
3✔
742
            typedarg(arg=field, typ=field_typ, orig_arg=field)
743
            for field, field_typ in class_record.fields
744
        ]
745
        typed_node = copy(node)
3✔
746
        typed_node.class_typ = typ
3✔
747
        return typed_node
3✔
748

749
    def visit_Constant(self, node: Constant) -> TypedConstant:
3✔
750
        tc = copy(node)
3✔
751
        assert type(node.value) not in [
3✔
752
            float,
753
            complex,
754
            type(...),
755
        ], "Float, complex numbers and ellipsis currently not supported"
756
        tc.typ = constant_type(node.value)
3✔
757
        return tc
3✔
758

759
    def visit_NoneType(self, node: None) -> TypedConstant:
3✔
760
        tc = Constant(value=None)
3✔
761
        tc.typ = constant_type(tc.value)
3✔
762
        return tc
3✔
763

764
    def visit_Tuple(self, node: Tuple) -> TypedTuple:
3✔
765
        tt = copy(node)
3✔
766
        tt.elts = [self.visit(e) for e in node.elts]
3✔
767
        tt.typ = InstanceType(TupleType(frozenlist([e.typ for e in tt.elts])))
3✔
768
        return tt
3✔
769

770
    def visit_List(self, node: List) -> TypedList:
3✔
771
        tt = copy(node)
3✔
772
        tt.elts = [self.visit(e) for e in node.elts]
3✔
773
        assert all(
3✔
774
            isinstance(e.typ, InstanceType) for e in tt.elts
775
        ), f"All list elements must be instances of a class, found class types {', '.join(e.typ.python_type() for e in tt.elts if not isinstance(e.typ, InstanceType))}"
776
        # try to derive a max type
777
        max_typ = find_max_type(tt.elts)
3✔
778
        tt.typ = InstanceType(ListType(max_typ))
3✔
779
        return tt
3✔
780

781
    def visit_Dict(self, node: Dict) -> TypedDict:
3✔
782
        tt = copy(node)
3✔
783
        tt.keys = [self.visit(k) for k in node.keys]
3✔
784
        tt.values = [self.visit(v) for v in node.values]
3✔
785
        assert all(
3✔
786
            isinstance(e.typ, InstanceType) for e in tt.keys
787
        ), f"All keys of a dict must be instances of a class, found class types {', '.join(e.typ.python_type() for e in tt.keys if not isinstance(e.typ, InstanceType))}"
788
        # try to derive a max type
789
        k_typ = find_max_type(tt.keys)
3✔
790
        v_typ = find_max_type(tt.values)
3✔
791
        tt.typ = InstanceType(DictType(k_typ, v_typ))
3✔
792
        return tt
3✔
793

794
    def visit_Assign(self, node: Assign) -> TypedAssign:
3✔
795
        typed_ass = copy(node)
3✔
796
        typed_ass.value: TypedExpression = self.visit(node.value)
3✔
797
        # Make sure to first set the type of each target name so we can load it when visiting it
798
        for t in node.targets:
3✔
799
            assert isinstance(
3✔
800
                t, Name
801
            ), "Can only assign to variable names (e.g., x = 5). OpShin does not allow assigning to tuple deconstructors (e.g., a, b = (1, 2)) or to dicts, lists, or members (e.g., x[0] = 1; x.foo = 1)"
802
            # Check compatibility to previous types -> variable can be bound in a function before and needs to maintain type
803
            self.set_variable_type(t.id, typed_ass.value.typ)
3✔
804
        typed_ass.targets = [self.visit(t) for t in node.targets]
3✔
805
        # for deconstructed tuples, check that the size matches
806
        if hasattr(typed_ass.value, "is_tuple_with_deconstruction"):
3✔
807
            assert isinstance(typed_ass.value.typ, InstanceType) and (
3✔
808
                isinstance(typed_ass.value.typ.typ, TupleType)
809
                or isinstance(typed_ass.value.typ.typ, PairType)
810
            ), f"Tuple deconstruction expected a tuple type, found '{typed_ass.value.typ.python_type()}'"
811
            if isinstance(typed_ass.value.typ.typ, PairType):
3✔
812
                assert (
3✔
813
                    typed_ass.value.is_tuple_with_deconstruction == 2
814
                ), f"Too many values to unpack or not enough values to unpack. Tuple deconstruction required assigning to 2 elements found '{typed_ass.value.is_tuple_with_deconstruction}'"
815
            else:
816
                assert typed_ass.value.is_tuple_with_deconstruction == len(
3✔
817
                    typed_ass.value.typ.typ.typs
818
                ), f"Too many values to unpack or not enough values to unpack. Tuple deconstruction required tuple with {typed_ass.value.is_tuple_with_deconstruction} elements found '{typed_ass.value.typ.python_type()}'"
819
        return typed_ass
3✔
820

821
    def visit_AnnAssign(self, node: AnnAssign) -> TypedAnnAssign:
3✔
822
        typed_ass = copy(node)
3✔
823
        typed_ass.annotation = self.type_from_annotation(node.annotation)
3✔
824
        if isinstance(typed_ass.annotation, ListType) and (
3✔
825
            (isinstance(node.value, Constant) and node.value.value == [])
826
            or (isinstance(node.value, List) and node.value.elts == [])
827
        ):
828
            # Empty lists are only allowed in annotated assignments
829
            typed_ass.value: TypedExpression = copy(node.value)
3✔
830
            typed_ass.value.typ = InstanceType(typed_ass.annotation)
3✔
831
        elif isinstance(typed_ass.annotation, DictType) and (
3✔
832
            (isinstance(node.value, Constant) and node.value.value == {})
833
            or (
834
                isinstance(node.value, Dict)
835
                and node.value.keys == []
836
                and node.value.values == []
837
            )
838
        ):
839
            # Empty lists are only allowed in annotated assignments
840
            typed_ass.value: TypedExpression = copy(node.value)
3✔
841
            typed_ass.value.typ = InstanceType(typed_ass.annotation)
3✔
842
        else:
843
            typed_ass.value: TypedExpression = self.visit(node.value)
3✔
844
        assert isinstance(
3✔
845
            node.target, Name
846
        ), "Can only assign to variable names, no type deconstruction"
847
        # Check compatibility to previous types -> variable can be bound in a function before and needs to maintain type
848
        self.set_variable_type(node.target.id, InstanceType(typed_ass.annotation))
3✔
849
        typed_ass.target = self.visit(node.target)
3✔
850
        assert (
3✔
851
            typed_ass.value.typ >= InstanceType(typed_ass.annotation)
852
            or InstanceType(typed_ass.annotation) >= typed_ass.value.typ
853
        ), "Can only cast between related types"
854
        return typed_ass
3✔
855

856
    def visit_If(self, node: If) -> TypedIf:
3✔
857
        typed_if = copy(node)
3✔
858
        typed_if.test = self.visit(node.test)
3✔
859
        assert (
3✔
860
            typed_if.test.typ == BoolInstanceType
861
        ), "Branching condition must have boolean type"
862
        typchecks, inv_typchecks = TypeCheckVisitor(
3✔
863
            self.allow_isinstance_anything
864
        ).visit(typed_if.test)
865
        # for the time of the branch, these types are cast
866
        initial_scope = copy(self.scopes[-1])
3✔
867
        wrapped = self.implement_typechecks(typchecks)
3✔
868
        self.wrapped.extend(wrapped.keys())
3✔
869
        typed_if.body = self.visit_sequence(node.body)
3✔
870
        self.wrapped = [x for x in self.wrapped if x not in wrapped.keys()]
3✔
871

872
        # save resulting types
873
        final_scope_body = copy(self.scopes[-1])
3✔
874
        # reverse typechecks and remove typing of one branch
875
        self.scopes[-1] = initial_scope
3✔
876
        # for the time of the else branch, the inverse types hold
877
        wrapped = self.implement_typechecks(inv_typchecks)
3✔
878
        self.wrapped.extend(wrapped.keys())
3✔
879
        typed_if.orelse = self.visit_sequence(node.orelse)
3✔
880
        self.wrapped = [x for x in self.wrapped if x not in wrapped.keys()]
3✔
881
        final_scope_else = self.scopes[-1]
3✔
882
        # unify the resulting branch scopes
883
        self.scopes[-1] = merge_scope(final_scope_body, final_scope_else)
3✔
884
        return typed_if
3✔
885

886
    def visit_While(self, node: While) -> TypedWhile:
3✔
887
        typed_while = copy(node)
3✔
888
        typed_while.test = self.visit(node.test)
3✔
889
        assert (
3✔
890
            typed_while.test.typ == BoolInstanceType
891
        ), "Branching condition must have boolean type"
892
        typchecks, inv_typchecks = TypeCheckVisitor(
3✔
893
            self.allow_isinstance_anything
894
        ).visit(typed_while.test)
895
        # for the time of the branch, these types are cast
896
        initial_scope = copy(self.scopes[-1])
3✔
897
        wrapped = self.implement_typechecks(typchecks)
3✔
898
        self.wrapped.extend(wrapped.keys())
3✔
899
        typed_while.body = self.visit_sequence(node.body)
3✔
900
        self.wrapped = [x for x in self.wrapped if x not in wrapped.keys()]
3✔
901
        final_scope_body = copy(self.scopes[-1])
3✔
902
        # revert changes
903
        self.scopes[-1] = initial_scope
3✔
904
        # for the time of the else branch, the inverse types hold
905
        wrapped = self.implement_typechecks(inv_typchecks)
3✔
906
        self.wrapped.extend(wrapped.keys())
3✔
907
        typed_while.orelse = self.visit_sequence(node.orelse)
3✔
908
        self.wrapped = [x for x in self.wrapped if x not in wrapped.keys()]
3✔
909
        final_scope_else = self.scopes[-1]
3✔
910
        self.scopes[-1] = merge_scope(final_scope_body, final_scope_else)
3✔
911
        return typed_while
3✔
912

913
    def visit_For(self, node: For) -> TypedFor:
3✔
914
        typed_for = copy(node)
3✔
915
        typed_for.iter = self.visit(node.iter)
3✔
916
        if isinstance(node.target, Tuple):
3✔
917
            raise NotImplementedError(
918
                "Tuple deconstruction in for loops is not supported yet"
919
            )
920
        vartyp = None
3✔
921
        itertyp = typed_for.iter.typ
3✔
922
        assert isinstance(
3✔
923
            itertyp, InstanceType
924
        ), "Can only iterate over instances, not classes"
925
        if isinstance(itertyp.typ, TupleType):
3✔
926
            assert itertyp.typ.typs, "Iterating over an empty tuple is not allowed"
3✔
927
            vartyp = itertyp.typ.typs[0]
3✔
928
            assert all(
3✔
929
                itertyp.typ.typs[0] == t for t in itertyp.typ.typs
930
            ), f"Iterating through a tuple requires the same type for each element, found tuple of type {itertyp.typ.python_type()}"
931
        elif isinstance(itertyp.typ, ListType):
3✔
932
            vartyp = itertyp.typ.typ
3✔
933
        else:
934
            raise NotImplementedError(
935
                "Type inference for loops over non-list and non-tuple objects is not supported"
936
            )
937
        self.set_variable_type(node.target.id, vartyp)
3✔
938
        typed_for.target = self.visit(node.target)
3✔
939
        typed_for.body = self.visit_sequence(node.body)
3✔
940
        typed_for.orelse = self.visit_sequence(node.orelse)
3✔
941
        return typed_for
3✔
942

943
    def visit_Name(self, node: Name) -> TypedName:
3✔
944
        tn = copy(node)
3✔
945
        # typing List and Dict are not present in scope we don't want to call variable_type
946
        if node.orig_id == "List":
3✔
947
            tn.typ = ListType(InstanceType(AnyType()))
3✔
948
        elif node.orig_id == "Dict":
3✔
949
            tn.typ = DictType(InstanceType(AnyType()), InstanceType(AnyType()))
3✔
950
        else:
951
            # Make sure that the rhs of an assign is evaluated first
952
            tn.typ = self.variable_type(node.id)
3✔
953
        if node.id in self.wrapped:
3✔
954
            tn.is_wrapped = True
3✔
955
        return tn
3✔
956

957
    def visit_keyword(self, node: keyword) -> Typedkeyword:
3✔
958
        tk = copy(node)
×
959
        tk.value = self.visit(node.value)
×
960
        return tk
×
961

962
    def visit_Compare(self, node: Compare) -> Union[TypedCompare, TypedCall]:
3✔
963
        dunder_node = self.dunder_override(node)
3✔
964
        if dunder_node is not None:
3✔
965
            return dunder_node
3✔
966
        typed_cmp = copy(node)
3✔
967
        typed_cmp.left = self.visit(node.left)
3✔
968
        typed_cmp.comparators = [self.visit(s) for s in node.comparators]
3✔
969
        typed_cmp.typ = BoolInstanceType
3✔
970

971
        return typed_cmp
3✔
972

973
    def visit_arg(self, node: arg) -> typedarg:
3✔
974
        ta = copy(node)
3✔
975
        ta.typ = InstanceType(self.type_from_annotation(node.annotation))
3✔
976
        self.set_variable_type(ta.arg, ta.typ)
3✔
977
        return ta
3✔
978

979
    def visit_arguments(self, node: arguments) -> typedarguments:
3✔
980
        if node.kw_defaults or node.kwarg or node.kwonlyargs or node.defaults:
3✔
981
            raise NotImplementedError(
982
                "Keyword arguments and defaults not supported yet"
983
            )
984
        ta = copy(node)
3✔
985
        ta.args = [self.visit(a) for a in node.args]
3✔
986
        return ta
3✔
987

988
    def visit_FunctionDef(self, node: FunctionDef) -> TypedFunctionDef:
3✔
989
        tfd = copy(node)
3✔
990
        wraps_builtin = (
3✔
991
            all(
992
                isinstance(o, Name) and o.orig_id == "wraps_builtin"
993
                for o in node.decorator_list
994
            )
995
            and node.decorator_list
996
        )
997
        assert (
998
            not node.decorator_list or wraps_builtin
999
        ), f"Functions may not have decorators other than literal @wraps_builtin, found other decorators at {node.orig_name}."
1000
        for i, arg in enumerate(node.args.args):
3✔
1001
            if hasattr(arg.annotation, "idSelf"):
3✔
1002
                tfd.args.args[i].annotation.id = tfd.args.args[0].annotation.id
3✔
1003
        if hasattr(node.returns, "idSelf"):
3✔
1004
            tfd.returns.id = tfd.args.args[0].annotation.id
3✔
1005

1006
        self.enter_scope()
3✔
1007
        tfd.args = self.visit(node.args)
3✔
1008

1009
        functyp = FunctionType(
3✔
1010
            frozenlist([t.typ for t in tfd.args.args]),
1011
            InstanceType(self.type_from_annotation(tfd.returns)),
1012
            bound_vars={
1013
                v: self.variable_type(v)
1014
                for v in externally_bound_vars(node)
1015
                if not v in ["List", "Dict"]
1016
            },
1017
            bind_self=node.name if node.name in read_vars(node) else None,
1018
        )
1019
        tfd.typ = InstanceType(functyp)
3✔
1020
        if wraps_builtin:
3✔
1021
            # the body of wrapping builtin functions is fully ignored
1022
            pass
3✔
1023
        else:
1024
            # We need the function type inside for recursion
1025
            self.set_variable_type(node.name, tfd.typ)
3✔
1026
            tfd.body = self.visit_sequence(node.body)
3✔
1027
            # Its possible that bound_variables might have changed after visiting body
1028
            bv = {
3✔
1029
                v: self.variable_type(v)
1030
                for v in externally_bound_vars(node)
1031
                if not v in ["List", "Dict"]
1032
            }
1033
            if bv != tfd.typ.typ.bound_vars:
3✔
1034
                # node was modified in place, so we can simply rerun visit_FunctionDef
1035
                self.exit_scope()
3✔
1036
                return self.visit_FunctionDef(node)
3✔
1037
            # Check that return type and annotated return type match
1038
            rets_extractor = ReturnExtractor(functyp.rettyp)
3✔
1039
            rets_extractor.check_fulfills(tfd)
3✔
1040

1041
        self.exit_scope()
3✔
1042
        # We need the function type outside for usage
1043
        self.set_variable_type(node.name, tfd.typ)
3✔
1044
        self.FUNCTION_ARGUMENT_REGISTRY[node.name] = node.args.args
3✔
1045
        return tfd
3✔
1046

1047
    def visit_Module(self, node: Module) -> TypedModule:
3✔
1048
        self.enter_scope()
3✔
1049
        tm = copy(node)
3✔
1050
        tm.body = self.visit_sequence(node.body)
3✔
1051
        self.exit_scope()
3✔
1052
        return tm
3✔
1053

1054
    def visit_Expr(self, node: Expr) -> TypedExpr:
3✔
1055
        tn = copy(node)
3✔
1056
        tn.value = self.visit(node.value)
3✔
1057
        return tn
3✔
1058

1059
    def visit_BinOp(self, node: BinOp) -> Union[TypedBinOp, TypedCall]:
3✔
1060
        dunder_node = self.dunder_override(node)
3✔
1061
        if dunder_node is not None:
3✔
1062
            return dunder_node
3✔
1063
        tb = copy(node)
3✔
1064
        tb.left = self.visit(node.left)
3✔
1065
        tb.right = self.visit(node.right)
3✔
1066
        binop_fun_typ: FunctionType = tb.left.typ.binop_type(tb.op, tb.right.typ)
3✔
1067
        tb.typ = binop_fun_typ.rettyp
3✔
1068

1069
        return tb
3✔
1070

1071
    def visit_BoolOp(self, node: BoolOp) -> TypedBoolOp:
3✔
1072
        tt = copy(node)
3✔
1073
        if isinstance(node.op, And):
3✔
1074
            values = []
3✔
1075
            prevtyps = {}
3✔
1076
            for e in node.values:
3✔
1077
                e_visited = self.visit(e)
3✔
1078
                values.append(e_visited)
3✔
1079
                typchecks, _ = TypeCheckVisitor(self.allow_isinstance_anything).visit(
3✔
1080
                    e_visited
1081
                )
1082
                # for the time after the shortcut and the variable type to the specialized type
1083
                wrapped = self.implement_typechecks(typchecks)
3✔
1084
                self.wrapped.extend(wrapped.keys())
3✔
1085
                prevtyps.update(wrapped)
3✔
1086
            # Clean up wrapped variables after processing all values
1087
            for var in prevtyps.keys():
3✔
1088
                if var in self.wrapped:
3!
1089
                    self.wrapped.remove(var)
3✔
1090
            self.implement_typechecks(prevtyps)
3✔
1091
            tt.values = values
3✔
1092
        elif isinstance(node.op, Or):
3✔
1093
            values = []
3✔
1094
            prevtyps = {}
3✔
1095
            for e in node.values:
3✔
1096
                values.append(self.visit(e))
3✔
1097
                _, inv_typechecks = TypeCheckVisitor(
3✔
1098
                    self.allow_isinstance_anything
1099
                ).visit(values[-1])
1100
                # for the time after the shortcut or the variable type is *not* the specialized type
1101
                wrapped = self.implement_typechecks(inv_typechecks)
3✔
1102
                self.wrapped.extend(wrapped.keys())
3✔
1103
                prevtyps.update(wrapped)
3✔
1104
            # Clean up wrapped variables after processing all values
1105
            for var in prevtyps.keys():
3✔
1106
                if var in self.wrapped:
3!
1107
                    self.wrapped.remove(var)
3✔
1108
            self.implement_typechecks(prevtyps)
3✔
1109
            tt.values = values
3✔
1110
        else:
1111
            raise NotImplementedError(f"Boolean operator {node.op} not supported")
1112
        tt.typ = BoolInstanceType
3✔
1113
        assert all(
3✔
1114
            BoolInstanceType >= e.typ for e in tt.values
1115
        ), f"All values compared must be bools, found {', '.join(e.typ.python_type() for e in tt.values)}"
1116
        return tt
3✔
1117

1118
    def visit_UnaryOp(self, node: UnaryOp) -> TypedUnaryOp:
3✔
1119
        dunder_node = self.dunder_override(node)
3✔
1120
        if dunder_node is not None:
3✔
1121
            return dunder_node
3✔
1122
        tu = copy(node)
3✔
1123
        tu.operand = self.visit(node.operand)
3✔
1124
        tu.typ = tu.operand.typ.typ.unop_type(node.op).rettyp
3✔
1125
        return tu
3✔
1126

1127
    def visit_Subscript(self, node: Subscript) -> TypedSubscript:
3✔
1128
        ts = copy(node)
3✔
1129
        # special case: Subscript of Union / Dict / List and atomic types
1130
        if isinstance(ts.value, Name) and ts.value.orig_id in [
3✔
1131
            "Union",
1132
            "Dict",
1133
            "List",
1134
        ]:
1135
            ts.value = ts.typ = self.type_from_annotation(ts)
3✔
1136
            return ts
3✔
1137

1138
        ts.value = self.visit(node.value)
3✔
1139
        assert isinstance(ts.value.typ, InstanceType), "Can only subscript instances"
3✔
1140
        if isinstance(ts.value.typ.typ, TupleType):
3✔
1141
            assert (
3✔
1142
                ts.value.typ.typ.typs
1143
            ), "Accessing elements from the empty tuple is not allowed"
1144
            if isinstance(ts.slice, UnaryOp) and isinstance(ts.slice.op, USub):
3✔
1145
                ts.slice = self.visit(Constant(-ts.slice.operand.value))
3✔
1146
            if isinstance(ts.slice, Constant) and isinstance(ts.slice.value, int):
3✔
1147
                assert ts.slice.value < len(
3✔
1148
                    ts.value.typ.typ.typs
1149
                ), f"Subscript index out of bounds for tuple. Accessing index {ts.slice.value} in tuple with {len(ts.value.typ.typ.typs)} elements ({ts.value.typ.python_type()})"
1150
                ts.typ = ts.value.typ.typ.typs[ts.slice.value]
3✔
1151
            else:
1152
                raise TypeInferenceError(
3✔
1153
                    f"Could not infer type of subscript of typ {ts.value.typ.python_type()}"
1154
                )
1155
        elif isinstance(ts.value.typ.typ, PairType):
3✔
1156
            if isinstance(ts.slice, UnaryOp) and isinstance(ts.slice.op, USub):
3✔
1157
                ts.slice = self.visit(Constant(-ts.slice.operand.value))
3✔
1158
            if isinstance(ts.slice, Constant) and isinstance(ts.slice.value, int):
3✔
1159
                assert (
3✔
1160
                    -3 < ts.slice.value < 2
1161
                ), f"Can only access -2, -1, 0 or 1 index in pairs, found {ts.slice.value}"
1162
                ts.typ = (
3✔
1163
                    ts.value.typ.typ.l_typ
1164
                    if ts.slice.value % 2 == 0
1165
                    else ts.value.typ.typ.r_typ
1166
                )
1167
            else:
1168
                raise TypeInferenceError(
3✔
1169
                    f"Could not infer type of subscript of typ {ts.value.typ.python_type()}"
1170
                )
1171
        elif isinstance(ts.value.typ.typ, ListType):
3✔
1172
            if not isinstance(ts.slice, Slice):
3✔
1173
                ts.typ = ts.value.typ.typ.typ
3✔
1174
                ts.slice = self.visit(node.slice)
3✔
1175
                assert (
3✔
1176
                    ts.slice.typ == IntegerInstanceType
1177
                ), f"List indices must be integers, found {ts.slice.typ.python_type()} for list {ts.value.typ.python_type()}"
1178
            else:
1179
                ts.typ = ts.value.typ
3✔
1180
                if ts.slice.lower is None:
3✔
1181
                    ts.slice.lower = Constant(0)
3✔
1182
                ts.slice.lower = self.visit(node.slice.lower)
3✔
1183
                assert (
3✔
1184
                    ts.slice.lower.typ == IntegerInstanceType
1185
                ), f"Lower slice indices for lists must be integers, found {ts.slice.lower.typ.python_type()} for list {ts.value.typ.python_type()}"
1186
                if ts.slice.upper is None:
3✔
1187
                    ts.slice.upper = Call(
3✔
1188
                        func=Name(id="len", ctx=Load()), args=[ts.value], keywords=[]
1189
                    )
1190
                    ts.slice.upper.func.orig_id = "len"
3✔
1191
                ts.slice.upper = self.visit(node.slice.upper)
3✔
1192
                assert (
3✔
1193
                    ts.slice.upper.typ == IntegerInstanceType
1194
                ), f"Upper slice indices for lists must be integers, found {ts.slice.upper.typ.python_type()} for list {ts.value.typ.python_type()}"
1195
        elif isinstance(ts.value.typ.typ, ByteStringType):
3✔
1196
            if not isinstance(ts.slice, Slice):
3✔
1197
                ts.typ = IntegerInstanceType
3✔
1198
                ts.slice = self.visit(node.slice)
3✔
1199
                assert (
3✔
1200
                    ts.slice.typ == IntegerInstanceType
1201
                ), f"Bytes indices must be integers, found {ts.slice.typ.python_type()}."
1202
            else:
1203
                ts.typ = ByteStringInstanceType
3✔
1204
                if ts.slice.lower is None:
3✔
1205
                    ts.slice.lower = Constant(0)
3✔
1206
                ts.slice.lower = self.visit(node.slice.lower)
3✔
1207
                assert (
3✔
1208
                    ts.slice.lower.typ == IntegerInstanceType
1209
                ), f"Lower slice indices for bytes must be integers, found {ts.slice.lower.typ.python_type()}"
1210
                if ts.slice.upper is None:
3✔
1211
                    ts.slice.upper = Call(
3✔
1212
                        func=Name(id="len", ctx=Load()), args=[ts.value], keywords=[]
1213
                    )
1214
                    ts.slice.upper.func.orig_id = "len"
3✔
1215
                ts.slice.upper = self.visit(node.slice.upper)
3✔
1216
                assert (
3✔
1217
                    ts.slice.upper.typ == IntegerInstanceType
1218
                ), f"Upper slice indices for bytes must be integers, found {ts.slice.upper.typ.python_type()}"
1219
        elif isinstance(ts.value.typ.typ, DictType):
3✔
1220
            if not isinstance(ts.slice, Slice):
3✔
1221
                ts.slice = self.visit(node.slice)
3✔
1222
                assert (
3✔
1223
                    ts.value.typ.typ.key_typ >= ts.slice.typ
1224
                ), f"Dict subscript must have dict key type {ts.value.typ.typ.key_typ.python_type()} but has type {ts.slice.typ.python_type()}"
1225
                ts.typ = ts.value.typ.typ.value_typ
3✔
1226
            else:
1227
                raise TypeInferenceError(
3✔
1228
                    f"Could not infer type of subscript of dict with a slice."
1229
                )
1230
        else:
1231
            raise TypeInferenceError(
3✔
1232
                f"Could not infer type of subscript of typ {ts.value.typ.python_type()}"
1233
            )
1234
        return ts
3✔
1235

1236
    def visit_Call(self, node: Call) -> TypedCall:
3✔
1237
        tc = copy(node)
3✔
1238
        if node.keywords:
3✔
1239
            assert (
3✔
1240
                node.func.id in self.FUNCTION_ARGUMENT_REGISTRY
1241
            ), "Keyword arguments can only be used with user defined functions"
1242
            keywords = copy(node.keywords)
3✔
1243
            reg_args = self.FUNCTION_ARGUMENT_REGISTRY[node.func.id]
3✔
1244
            args = []
3✔
1245
            for i, a in enumerate(reg_args):
3✔
1246
                if len(node.args) > i:
3✔
1247
                    args.append(self.visit(node.args[i]))
3✔
1248
                else:
1249
                    candidates = [
3✔
1250
                        (idx, keyword)
1251
                        for idx, keyword in enumerate(keywords)
1252
                        if keyword.arg == a.orig_arg
1253
                    ]
1254
                    assert (
3✔
1255
                        len(candidates) == 1
1256
                    ), f"There should be one keyword or positional argument for the arg {a.orig_arg} but found {len(candidates)}"
1257
                    args.append(self.visit(candidates[0][1].value))
3✔
1258
                    keywords.pop(candidates[0][0])
3✔
1259
            assert (
3✔
1260
                len(keywords) == 0
1261
            ), f"Could not match the keywords {[keyword.arg for keyword in keywords]} to any argument"
1262
            tc.args = args
3✔
1263
            tc.keywords = []
3✔
1264
        else:
1265
            tc.args = [self.visit(a) for a in node.args]
3✔
1266

1267
        # might be isinstance
1268
        # Subscripts are not allowed in isinstance calls
1269
        if (
3✔
1270
            isinstance(tc.func, Name)
1271
            and tc.func.orig_id == "isinstance"
1272
            and isinstance(tc.args[1], Subscript)
1273
        ):
1274
            raise TypeError(
3✔
1275
                "Subscripted generics cannot be used with class and instance checks"
1276
            )
1277

1278
        # Need to handle the presence of PlutusData classes
1279
        if (
3✔
1280
            isinstance(tc.func, Name)
1281
            and tc.func.orig_id == "isinstance"
1282
            and not isinstance(
1283
                tc.args[1].typ, (ByteStringType, IntegerType, ListType, DictType)
1284
            )
1285
        ):
1286
            if (
3✔
1287
                isinstance(tc.args[0].typ, InstanceType)
1288
                and isinstance(tc.args[0].typ.typ, AnyType)
1289
                and not self.allow_isinstance_anything
1290
            ):
1291
                raise AssertionError(
3✔
1292
                    "OpShin does not permit checking the instance of raw Anything/Datum objects as this only checks the equality of the constructor id and nothing more. "
1293
                    "If you are certain of what you are doing, please use the flag '--allow-isinstance-anything'."
1294
                )
1295
            tc.typechecks = TypeCheckVisitor(self.allow_isinstance_anything).visit(tc)
3✔
1296

1297
        # Check for expanded Union funcs
1298
        if isinstance(node.func, ast.Name):
3✔
1299
            expanded_unions = {
3✔
1300
                k: v
1301
                for scope in self.scopes
1302
                for k, v in scope.items()
1303
                if k.startswith(f"{node.func.orig_id}+")
1304
            }
1305
            for k, v in expanded_unions.items():
3✔
1306
                argtyps = v.typ.argtyps
3✔
1307
                if len(tc.args) != len(argtyps):
3!
1308
                    continue
×
1309
                for a, ap in zip(tc.args, argtyps):
3✔
1310
                    if ap != a.typ:
3✔
1311
                        break
3✔
1312
                else:
1313
                    node.func = ast.Name(
3✔
1314
                        id=k, orig_id=f"unknown orig_id for {k}", ctx=ast.Load()
1315
                    )
1316
                    break
3✔
1317

1318
        subbed_method = False
3✔
1319
        if isinstance(tc.func, Attribute):
3✔
1320
            # might be a method, test whether the variable is a record and if the method exists
1321
            accessed_var = self.visit(tc.func.value)
3✔
1322
            if (
3✔
1323
                isinstance(accessed_var.typ, InstanceType)
1324
                and isinstance(accessed_var.typ.typ, RecordType)
1325
                and tc.func.attr != "to_cbor"
1326
            ):
1327
                class_name = accessed_var.typ.typ.record.name
3✔
1328
                method_name = f"{class_name}_+_{tc.func.attr}"
3✔
1329
                # If method_name found then use this.
1330
                if self.is_defined_in_current_scope(method_name):
3!
1331
                    n = ast.Name(id=method_name, ctx=ast.Load())
3✔
1332
                    n.orig_id = node.func.attr
3✔
1333
                    tc.func = self.visit(n)
3✔
1334
                    tc.func.orig_id = node.func.attr
3✔
1335
                    tc.args.insert(0, accessed_var)
3✔
1336
                    subbed_method = True
3✔
1337

1338
        if not subbed_method:
3✔
1339
            tc.func = self.visit(node.func)
3✔
1340

1341
        # might be a class
1342
        if isinstance(tc.func.typ, ClassType):
3✔
1343
            tc.func.typ = tc.func.typ.constr_type()
3✔
1344
        # type might only turn out after the initialization (note the constr could be polymorphic)
1345
        if isinstance(tc.func.typ, InstanceType) and isinstance(
3✔
1346
            tc.func.typ.typ, PolymorphicFunctionType
1347
        ):
1348
            tc.func.typ = PolymorphicFunctionInstanceType(
3✔
1349
                tc.func.typ.typ.polymorphic_function.type_from_args(
1350
                    [a.typ for a in tc.args]
1351
                ),
1352
                tc.func.typ.typ.polymorphic_function,
1353
            )
1354
        if isinstance(tc.func.typ, InstanceType) and isinstance(
3!
1355
            tc.func.typ.typ, FunctionType
1356
        ):
1357
            functyp = tc.func.typ.typ
3✔
1358
            assert len(tc.args) == len(
3✔
1359
                functyp.argtyps
1360
            ), f"Signature of function does not match number of arguments. Expected {len(functyp.argtyps)} arguments but got {len(tc.args)} arguments."
1361
            # all arguments need to be subtypes of the parameter type
1362
            for i, (a, ap) in enumerate(zip(tc.args, functyp.argtyps)):
3✔
1363
                assert (
3✔
1364
                    ap >= a.typ
1365
                ), f"Signature of function does not match arguments in argument {i}. Expected this type: {ap.python_type()} but got {a.typ.python_type()}."
1366
            tc.typ = functyp.rettyp
3✔
1367
            return tc
3✔
1368
        raise TypeInferenceError("Could not infer type of call")
×
1369

1370
    def visit_Pass(self, node: Pass) -> TypedPass:
3✔
1371
        tp = copy(node)
3✔
1372
        return tp
3✔
1373

1374
    def visit_Return(self, node: Return) -> TypedReturn:
3✔
1375
        tp = copy(node)
3✔
1376
        tp.value = self.visit(node.value)
3✔
1377
        tp.typ = tp.value.typ
3✔
1378
        return tp
3✔
1379

1380
    def visit_Attribute(self, node: Attribute) -> TypedAttribute:
3✔
1381
        tp = copy(node)
3✔
1382
        tp.value = self.visit(node.value)
3✔
1383
        owner = tp.value.typ
3✔
1384
        # accesses to field
1385
        tp.typ = owner.attribute_type(node.attr)
3✔
1386
        return tp
3✔
1387

1388
    def visit_Assert(self, node: Assert) -> TypedAssert:
3✔
1389
        ta = copy(node)
3✔
1390
        ta.test = self.visit(node.test)
3✔
1391
        assert (
3✔
1392
            ta.test.typ == BoolInstanceType
1393
        ), "Assertions must result in a boolean type"
1394
        if ta.msg is not None:
3✔
1395
            ta.msg = self.visit(node.msg)
3✔
1396
            assert (
3✔
1397
                ta.msg.typ == StringInstanceType
1398
            ), "Assertions must has a string message (or None)"
1399
        return ta
3✔
1400

1401
    def visit_RawPlutoExpr(self, node: RawPlutoExpr) -> RawPlutoExpr:
3✔
1402
        assert node.typ is not None, "Raw Pluto Expression is missing type annotation"
3✔
1403
        return node
3✔
1404

1405
    def visit_IfExp(self, node: IfExp) -> TypedIfExp:
3✔
1406
        node_cp = copy(node)
3✔
1407
        node_cp.test = self.visit(node.test)
3✔
1408
        assert node_cp.test.typ == BoolInstanceType, "Comparison must have type boolean"
3✔
1409
        typchecks, inv_typchecks = TypeCheckVisitor(
3✔
1410
            self.allow_isinstance_anything
1411
        ).visit(node_cp.test)
1412
        prevtyps = self.implement_typechecks(typchecks)
3✔
1413
        self.wrapped.extend(prevtyps.keys())
3✔
1414
        node_cp.body = self.visit(node.body)
3✔
1415
        self.wrapped = [x for x in self.wrapped if x not in prevtyps.keys()]
3✔
1416

1417
        self.implement_typechecks(prevtyps)
3✔
1418
        prevtyps = self.implement_typechecks(inv_typchecks)
3✔
1419
        self.wrapped.extend(prevtyps.keys())
3✔
1420
        node_cp.orelse = self.visit(node.orelse)
3✔
1421
        self.wrapped = [x for x in self.wrapped if x not in prevtyps.keys()]
3✔
1422
        self.implement_typechecks(prevtyps)
3✔
1423
        if node_cp.body.typ >= node_cp.orelse.typ:
3✔
1424
            node_cp.typ = node_cp.body.typ
3✔
1425
        elif node_cp.orelse.typ >= node_cp.body.typ:
3✔
1426
            node_cp.typ = node_cp.orelse.typ
3✔
1427
        else:
1428
            try:
3✔
1429
                assert isinstance(node_cp.body.typ, InstanceType) and isinstance(
3✔
1430
                    node_cp.orelse.typ, InstanceType
1431
                )
1432
                node_cp.typ = InstanceType(
3✔
1433
                    union_types(node_cp.body.typ.typ, node_cp.orelse.typ.typ)
1434
                )
1435
            except AssertionError:
3✔
1436
                raise TypeInferenceError(
3✔
1437
                    "Branches of if-expression must return compatible types."
1438
                )
1439
        return node_cp
3✔
1440

1441
    def visit_comprehension(self, g: comprehension) -> typedcomprehension:
3✔
1442
        new_g = copy(g)
3✔
1443
        if isinstance(g.target, Tuple):
3✔
1444
            raise NotImplementedError(
1445
                "Type deconstruction in comprehensions is not supported yet"
1446
            )
1447
        new_g.iter = self.visit(g.iter)
3✔
1448
        itertyp = new_g.iter.typ
3✔
1449
        assert isinstance(
3✔
1450
            itertyp, InstanceType
1451
        ), "Can only iterate over instances, not classes"
1452
        if isinstance(itertyp.typ, ListType):
3✔
1453
            vartyp = itertyp.typ.typ
3✔
1454
        else:
1455
            raise NotImplementedError(
1456
                "Iterating over non-list objects is not (yet) supported"
1457
            )
1458
        self.set_variable_type(g.target.id, vartyp)
3✔
1459
        new_g.target = self.visit(g.target)
3✔
1460
        new_g.ifs = [self.visit(i) for i in g.ifs]
3✔
1461
        return new_g
3✔
1462

1463
    def visit_ListComp(self, node: ListComp) -> TypedListComp:
3✔
1464
        typed_listcomp = copy(node)
3✔
1465
        # inside the comprehension is a separate scope
1466
        self.enter_scope()
3✔
1467
        # first evaluate generators for assigned variables
1468
        typed_listcomp.generators = [self.visit(s) for s in node.generators]
3✔
1469

1470
        # collect isinstance type narrowing from all conditions in all generators
1471
        all_typechecks = {}
3✔
1472
        for gen in typed_listcomp.generators:
3✔
1473
            for if_expr in gen.ifs:
3✔
1474
                typchecks, _ = TypeCheckVisitor(self.allow_isinstance_anything).visit(
3✔
1475
                    if_expr
1476
                )
1477
                all_typechecks.update(typchecks)
3✔
1478

1479
        # apply type narrowing before evaluating the element
1480
        wrapped = self.implement_typechecks(all_typechecks)
3✔
1481
        self.wrapped.extend(wrapped.keys())
3✔
1482

1483
        # then evaluate elements with narrowed types
1484
        typed_listcomp.elt = self.visit(node.elt)
3✔
1485

1486
        # clean up wrapped variables
1487
        self.wrapped = [x for x in self.wrapped if x not in wrapped.keys()]
3✔
1488

1489
        self.exit_scope()
3✔
1490
        typed_listcomp.typ = InstanceType(ListType(typed_listcomp.elt.typ))
3✔
1491
        return typed_listcomp
3✔
1492

1493
    def visit_DictComp(self, node: DictComp) -> TypedDictComp:
3✔
1494
        typed_dictcomp = copy(node)
3✔
1495
        # inside the comprehension is a separate scope
1496
        self.enter_scope()
3✔
1497
        # first evaluate generators for assigned variables
1498
        typed_dictcomp.generators = [self.visit(s) for s in node.generators]
3✔
1499

1500
        # collect isinstance type narrowing from all conditions in all generators
1501
        all_typechecks = {}
3✔
1502
        for gen in typed_dictcomp.generators:
3✔
1503
            for if_expr in gen.ifs:
3✔
1504
                typchecks, _ = TypeCheckVisitor(self.allow_isinstance_anything).visit(
3✔
1505
                    if_expr
1506
                )
1507
                all_typechecks.update(typchecks)
3✔
1508

1509
        # apply type narrowing before evaluating the elements
1510
        wrapped = self.implement_typechecks(all_typechecks)
3✔
1511
        self.wrapped.extend(wrapped.keys())
3✔
1512

1513
        # then evaluate elements with narrowed types
1514
        typed_dictcomp.key = self.visit(node.key)
3✔
1515
        typed_dictcomp.value = self.visit(node.value)
3✔
1516

1517
        # clean up wrapped variables
1518
        self.wrapped = [x for x in self.wrapped if x not in wrapped.keys()]
3✔
1519

1520
        self.exit_scope()
3✔
1521
        typed_dictcomp.typ = InstanceType(
3✔
1522
            DictType(typed_dictcomp.key.typ, typed_dictcomp.value.typ)
1523
        )
1524
        return typed_dictcomp
3✔
1525

1526
    def visit_FormattedValue(self, node: FormattedValue) -> TypedFormattedValue:
3✔
1527
        typed_node = copy(node)
3✔
1528
        typed_node.value = self.visit(node.value)
3✔
1529
        assert node.conversion in (
3✔
1530
            -1,
1531
            115,
1532
        ), "Only string formatting is allowed but got repr or ascii formatting."
1533
        assert (
3✔
1534
            node.format_spec is None
1535
        ), "No format specification is allowed but got formatting specifiers (i.e. decimals)."
1536
        typed_node.typ = StringInstanceType
3✔
1537
        return typed_node
3✔
1538

1539
    def visit_JoinedStr(self, node: JoinedStr) -> TypedJoinedStr:
3✔
1540
        typed_node = copy(node)
3✔
1541
        typed_node.values = [self.visit(v) for v in node.values]
3✔
1542
        typed_node.typ = StringInstanceType
3✔
1543
        return typed_node
3✔
1544

1545
    def visit_ImportFrom(self, node: ImportFrom) -> ImportFrom:
3✔
1546
        assert node.module == "opshin.bridge", "Trying to import from invalid location"
3✔
1547
        return node
3✔
1548

1549
    def generic_visit(self, node: AST) -> TypedAST:
3✔
1550
        raise NotImplementedError(
1551
            f"Cannot infer type of non-implemented node {node.__class__}"
1552
        )
1553

1554

1555
class RecordReader(NodeVisitor):
3✔
1556
    name: str
3✔
1557
    orig_name: str
3✔
1558
    constructor: typing.Optional[int]
3✔
1559
    attributes: typing.List[typing.Tuple[str, Type]]
3✔
1560
    _type_inferencer: AggressiveTypeInferencer
3✔
1561

1562
    def __init__(self, type_inferencer: AggressiveTypeInferencer):
3✔
1563
        self.constructor = None
3✔
1564
        self.attributes = []
3✔
1565
        self._type_inferencer = type_inferencer
3✔
1566

1567
    def extract(self, c: ClassDef) -> Record:
3✔
1568
        self.visit(c)
3✔
1569
        if self.constructor is None:
3✔
1570
            det_string = RecordType(
3✔
1571
                Record(self.name, self.orig_name, 0, frozenlist(self.attributes))
1572
            ).pluthon_type(skip_constructor=True)
1573
            det_hash = sha256(str(det_string).encode("utf8")).hexdigest()
3✔
1574
            self.constructor = int(det_hash, 16) % 2**32
3✔
1575
        return Record(
3✔
1576
            self.name, self.orig_name, self.constructor, frozenlist(self.attributes)
1577
        )
1578

1579
    def visit_AnnAssign(self, node: AnnAssign) -> None:
3✔
1580
        assert isinstance(
3✔
1581
            node.target, Name
1582
        ), "Record elements must have named attributes"
1583
        typ = self._type_inferencer.type_from_annotation(node.annotation)
3✔
1584
        if node.target.id != "CONSTR_ID":
3✔
1585
            assert (
3✔
1586
                node.value is None
1587
            ), f"PlutusData attribute {node.target.id} may not have a default value"
1588
            assert not isinstance(
3✔
1589
                typ, TupleType
1590
            ), "Records can currently not hold tuples"
1591
            self.attributes.append(
3✔
1592
                (
1593
                    node.target.id,
1594
                    InstanceType(typ),
1595
                )
1596
            )
1597
            return
3✔
1598
        assert typ == IntegerType(), "CONSTR_ID must be assigned an integer"
3✔
1599
        assert isinstance(
3✔
1600
            node.value, Constant
1601
        ), "CONSTR_ID must be assigned a constant integer"
1602
        assert isinstance(
3✔
1603
            node.value.value, int
1604
        ), "CONSTR_ID must be assigned an integer"
1605
        self.constructor = node.value.value
3✔
1606

1607
    def visit_ClassDef(self, node: ClassDef) -> None:
3✔
1608
        self.name = node.name
3✔
1609
        self.orig_name = node.orig_name
3✔
1610
        for s in node.body:
3✔
1611
            self.visit(s)
3✔
1612

1613
    def visit_Pass(self, node: Pass) -> None:
3✔
1614
        pass
3✔
1615

1616
    def visit_Assign(self, node: Assign) -> None:
3✔
1617
        assert len(node.targets) == 1, "Record elements must be assigned one by one"
3✔
1618
        target = node.targets[0]
3✔
1619
        assert isinstance(target, Name), "Record elements must have named attributes"
3✔
1620
        assert (
3✔
1621
            target.id == "CONSTR_ID"
1622
        ), "Type annotations may only be omitted for CONSTR_ID"
1623
        assert isinstance(
3✔
1624
            node.value, Constant
1625
        ), "CONSTR_ID must be assigned a constant integer"
1626
        assert isinstance(
3✔
1627
            node.value.value, int
1628
        ), "CONSTR_ID must be assigned an integer"
1629
        self.constructor = node.value.value
3✔
1630

1631
    def visit_Expr(self, node: Expr) -> None:
3✔
1632
        assert isinstance(
3✔
1633
            node.value, Constant
1634
        ), "Only comments are allowed inside classes"
1635
        return None
3✔
1636

1637
    def generic_visit(self, node: AST) -> None:
3✔
1638
        raise NotImplementedError(f"Can not compile {ast.dump(node)} inside of a class")
1639

1640

1641
def map_to_orig_name(name: str):
3✔
1642
    return re.sub(r"_\d+$", "", name)
3✔
1643

1644

1645
class ReturnExtractor(TypedNodeVisitor):
3✔
1646
    """
1647
    Utility to check that all paths end in Return statements with the proper type
1648

1649
    Returns whether there is no remaining path
1650
    """
1651

1652
    def __init__(self, func_rettyp: Type):
3✔
1653
        self.func_rettyp = func_rettyp
3✔
1654

1655
    def visit_sequence(self, nodes: typing.List[TypedAST]) -> bool:
3✔
1656
        all_paths_covered = False
3✔
1657
        for node in nodes:
3✔
1658
            all_paths_covered = self.visit(node)
3✔
1659
            if all_paths_covered:
3✔
1660
                break
3✔
1661
        return all_paths_covered
3✔
1662

1663
    def visit_If(self, node: If) -> bool:
3✔
1664
        return self.visit_sequence(node.body) and self.visit_sequence(node.orelse)
3✔
1665

1666
    def visit_For(self, node: For) -> bool:
3✔
1667
        # The body simply has to be checked but has no influence on whether all paths are covered
1668
        # because it might never be visited
1669
        self.visit_sequence(node.body)
3✔
1670
        # the else path is always visited
1671
        return self.visit_sequence(node.orelse)
3✔
1672

1673
    def visit_While(self, node: For) -> bool:
3✔
1674
        # The body simply has to be checked but has no influence on whether all paths are covered
1675
        # because it might never be visited
1676
        self.visit_sequence(node.body)
3✔
1677
        # the else path is always visited
1678
        return self.visit_sequence(node.orelse)
3✔
1679

1680
    def visit_Return(self, node: Return) -> bool:
3✔
1681
        assert (
3✔
1682
            self.func_rettyp >= node.typ
1683
        ), f"Function annotated return type does not match actual return type, expected {self.func_rettyp.python_type()} but got {node.typ.python_type()}"
1684
        return True
3✔
1685

1686
    def check_fulfills(self, node: FunctionDef):
3✔
1687
        all_paths_covered = self.visit_sequence(node.body)
3✔
1688
        if not all_paths_covered:
3✔
1689
            assert (
3✔
1690
                self.func_rettyp >= NoneInstanceType
1691
            ), f"Function '{node.name}' has no return statement but is supposed to return not-None value"
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc