• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

OpShin / opshin / 806

pending completion
806

push

travis-ci-com

nielstron
Update docs

3519 of 3781 relevant lines covered (93.07%)

3.72 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

92.5
/opshin/compiler.py
1
import logging
4✔
2
from logging import getLogger
4✔
3
from ast import fix_missing_locations
4✔
4

5
from .optimize.optimize_remove_comments import OptimizeRemoveDeadconstants
4✔
6
from .rewrite.rewrite_augassign import RewriteAugAssign
4✔
7
from .rewrite.rewrite_forbidden_overwrites import RewriteForbiddenOverwrites
4✔
8
from .rewrite.rewrite_import import RewriteImport
4✔
9
from .rewrite.rewrite_import_dataclasses import RewriteImportDataclasses
4✔
10
from .rewrite.rewrite_import_hashlib import RewriteImportHashlib
4✔
11
from .rewrite.rewrite_import_plutusdata import RewriteImportPlutusData
4✔
12
from .rewrite.rewrite_import_typing import RewriteImportTyping
4✔
13
from .rewrite.rewrite_inject_builtins import RewriteInjectBuiltins
4✔
14
from .rewrite.rewrite_inject_builtin_constr import RewriteInjectBuiltinsConstr
4✔
15
from .rewrite.rewrite_orig_name import RewriteOrigName
4✔
16
from .rewrite.rewrite_remove_type_stuff import RewriteRemoveTypeStuff
4✔
17
from .rewrite.rewrite_scoping import RewriteScoping
4✔
18
from .rewrite.rewrite_subscript38 import RewriteSubscript38
4✔
19
from .rewrite.rewrite_tuple_assign import RewriteTupleAssign
4✔
20
from .rewrite.rewrite_zero_ary import RewriteZeroAry
4✔
21
from .optimize.optimize_remove_pass import OptimizeRemovePass
4✔
22
from .optimize.optimize_remove_deadvars import OptimizeRemoveDeadvars
4✔
23
from .optimize.optimize_varlen import OptimizeVarlen
4✔
24
from .type_inference import *
4✔
25
from .util import CompilingNodeTransformer, PowImpl
4✔
26
from .typed_ast import transform_ext_params_map, transform_output_map, RawPlutoExpr
4✔
27

28

29
_LOGGER = logging.getLogger(__name__)
4✔
30

31
STATEMONAD = "s"
4✔
32

33

34
BinOpMap = {
4✔
35
    Add: {
36
        IntegerInstanceType: {
37
            IntegerInstanceType: plt.AddInteger,
38
        },
39
        ByteStringInstanceType: {
40
            ByteStringInstanceType: plt.AppendByteString,
41
        },
42
        StringInstanceType: {
43
            StringInstanceType: plt.AppendString,
44
        },
45
    },
46
    Sub: {
47
        IntegerInstanceType: {
48
            IntegerInstanceType: plt.SubtractInteger,
49
        }
50
    },
51
    Mult: {
52
        IntegerInstanceType: {
53
            IntegerInstanceType: plt.MultiplyInteger,
54
        }
55
    },
56
    FloorDiv: {
57
        IntegerInstanceType: {
58
            IntegerInstanceType: plt.DivideInteger,
59
        }
60
    },
61
    Mod: {
62
        IntegerInstanceType: {
63
            IntegerInstanceType: plt.ModInteger,
64
        }
65
    },
66
    Pow: {
67
        IntegerInstanceType: {
68
            IntegerInstanceType: PowImpl,
69
        }
70
    },
71
}
72

73
BoolOpMap = {
4✔
74
    And: plt.And,
75
    Or: plt.Or,
76
}
77

78
UnaryOpMap = {
4✔
79
    Not: {BoolInstanceType: plt.Not},
80
    USub: {IntegerInstanceType: lambda x: plt.SubtractInteger(plt.Integer(0), x)},
81
}
82

83
ConstantMap = {
4✔
84
    str: plt.Text,
85
    bytes: lambda x: plt.ByteString(x),
86
    int: lambda x: plt.Integer(x),
87
    bool: plt.Bool,
88
    type(None): lambda _: plt.Unit(),
89
}
90

91

92
def wrap_validator_double_function(x: plt.AST, pass_through: int = 0):
4✔
93
    """
94
    Wraps the validator function to enable a double function as minting script
95

96
    pass_through defines how many parameters x would normally take and should be passed through to x
97
    """
98
    return plt.Lambda(
4✔
99
        [f"v{i}" for i in range(pass_through)] + ["a0", "a1"],
100
        plt.Let(
101
            [("p", plt.Apply(x, *(plt.Var(f"v{i}") for i in range(pass_through))))],
102
            plt.Ite(
103
                # if the second argument has constructor 0 = script context
104
                plt.DelayedChooseData(
105
                    plt.Var("a1"),
106
                    plt.EqualsInteger(plt.Constructor(plt.Var("a1")), plt.Integer(0)),
107
                    plt.Bool(False),
108
                    plt.Bool(False),
109
                    plt.Bool(False),
110
                    plt.Bool(False),
111
                ),
112
                # call the validator with a0, a1, and plug in Unit for data
113
                plt.Apply(plt.Var("p"), plt.Unit(), plt.Var("a0"), plt.Var("a1")),
114
                # else call the validator with a0, a1 and return (now partially bound)
115
                plt.Apply(plt.Var("p"), plt.Var("a0"), plt.Var("a1")),
116
            ),
117
        ),
118
    )
119

120

121
def extend_statemonad(
4✔
122
    names: typing.List[str],
123
    values: typing.List[plt.AST],
124
    old_statemonad: plt.FunctionalMap,
125
):
126
    """Ensures that the argument is fully evaluated before being passed into the monad (like in imperative languages)"""
127
    assert len(names) == len(values), "Unequal amount of names and values passed in"
4✔
128
    lam_names = [f"a{i}" for i, _ in enumerate(names)]
4✔
129
    return plt.Apply(
4✔
130
        plt.Lambda(
131
            lam_names,
132
            plt.FunctionalMapExtend(
133
                old_statemonad, names, [plt.Var(n) for n in lam_names]
134
            ),
135
        ),
136
        *values,
137
    )
138

139

140
INITIAL_STATE = plt.FunctionalMap()
4✔
141

142

143
class UPLCCompiler(CompilingNodeTransformer):
4✔
144
    """
145
    Expects a TypedAST and returns UPLC/Pluto like code
146
    """
147

148
    step = "Compiling python statements to UPLC"
4✔
149

150
    def __init__(self, force_three_params=False, validator_function_name="validator"):
4✔
151
        self.force_three_params = force_three_params
4✔
152
        self.validator_function_name = validator_function_name
4✔
153

154
    def visit_sequence(self, node_seq: typing.List[typedstmt]) -> plt.AST:
4✔
155
        s = plt.Var(STATEMONAD)
4✔
156
        for n in node_seq:
4✔
157
            compiled_stmt = self.visit(n)
4✔
158
            s = plt.Apply(compiled_stmt, s)
4✔
159
        return plt.Lambda([STATEMONAD], s)
4✔
160

161
    def visit_BinOp(self, node: TypedBinOp) -> plt.AST:
4✔
162
        opmap = BinOpMap.get(type(node.op))
4✔
163
        if opmap is None:
4✔
164
            raise NotImplementedError(f"Operation {node.op} is not implemented")
×
165
        opmap2 = opmap.get(node.left.typ)
4✔
166
        if opmap2 is None:
4✔
167
            raise NotImplementedError(
×
168
                f"Operation {node.op} is not implemented for left type {node.left.typ}"
169
            )
170
        op = opmap2.get(node.right.typ)
4✔
171
        if opmap2 is None:
4✔
172
            raise NotImplementedError(
×
173
                f"Operation {node.op} is not implemented for left type {node.left.typ} and right type {node.right.typ}"
174
            )
175
        return plt.Lambda(
4✔
176
            [STATEMONAD],
177
            op(
178
                plt.Apply(self.visit(node.left), plt.Var(STATEMONAD)),
179
                plt.Apply(self.visit(node.right), plt.Var(STATEMONAD)),
180
            ),
181
        )
182

183
    def visit_BoolOp(self, node: TypedBoolOp) -> plt.AST:
4✔
184
        op = BoolOpMap.get(type(node.op))
4✔
185
        assert len(node.values) >= 2, "Need to compare at least to values"
4✔
186
        ops = op(
4✔
187
            plt.Apply(self.visit(node.values[0]), plt.Var(STATEMONAD)),
188
            plt.Apply(self.visit(node.values[1]), plt.Var(STATEMONAD)),
189
        )
190
        for v in node.values[2:]:
4✔
191
            ops = op(ops, plt.Apply(self.visit(v), plt.Var(STATEMONAD)))
4✔
192
        return plt.Lambda(
4✔
193
            [STATEMONAD],
194
            ops,
195
        )
196

197
    def visit_UnaryOp(self, node: TypedUnaryOp) -> plt.AST:
4✔
198
        opmap = UnaryOpMap.get(type(node.op))
4✔
199
        assert opmap is not None, f"Operator {type(node.op)} is not supported"
4✔
200
        op = opmap.get(node.operand.typ)
4✔
201
        assert (
4✔
202
            op is not None
203
        ), f"Operator {type(node.op)} is not supported for type {node.operand.typ}"
204
        return plt.Lambda(
4✔
205
            [STATEMONAD],
206
            op(plt.Apply(self.visit(node.operand), plt.Var(STATEMONAD))),
207
        )
208

209
    def visit_Compare(self, node: TypedCompare) -> plt.AST:
4✔
210
        assert len(node.ops) == 1, "Only single comparisons are supported"
4✔
211
        assert len(node.comparators) == 1, "Only single comparisons are supported"
4✔
212
        cmpop = node.ops[0]
4✔
213
        comparator = node.comparators[0].typ
4✔
214
        op = node.left.typ.cmp(cmpop, comparator)
4✔
215
        return plt.Lambda(
4✔
216
            [STATEMONAD],
217
            plt.Apply(
218
                op,
219
                plt.Apply(self.visit(node.left), plt.Var(STATEMONAD)),
220
                plt.Apply(self.visit(node.comparators[0]), plt.Var(STATEMONAD)),
221
            ),
222
        )
223

224
    def visit_Module(self, node: TypedModule) -> plt.AST:
4✔
225
        # find main function
226
        # TODO can use more sophisiticated procedure here i.e. functions marked by comment
227
        main_fun: typing.Optional[InstanceType] = None
4✔
228
        for s in node.body:
4✔
229
            if (
4✔
230
                isinstance(s, FunctionDef)
231
                and s.orig_name == self.validator_function_name
232
            ):
233
                main_fun = s
4✔
234
        assert (
4✔
235
            main_fun is not None
236
        ), f"Could not find function named {self.validator_function_name}"
237
        main_fun_typ: FunctionType = main_fun.typ.typ
4✔
238
        assert isinstance(
4✔
239
            main_fun_typ, FunctionType
240
        ), f"Variable named {self.validator_function_name} is not of type function"
241

242
        # check if this is a contract written to double function
243
        enable_double_func_mint_spend = False
4✔
244
        if len(main_fun_typ.argtyps) >= 3 and self.force_three_params:
4✔
245
            # check if is possible
246
            second_last_arg = main_fun_typ.argtyps[-2]
4✔
247
            assert isinstance(
4✔
248
                second_last_arg, InstanceType
249
            ), "Can not pass Class into validator"
250
            if isinstance(second_last_arg.typ, UnionType):
4✔
251
                possible_types = second_last_arg.typ.typs
4✔
252
            else:
253
                possible_types = [second_last_arg.typ]
4✔
254
            if any(isinstance(t, UnitType) for t in possible_types):
4✔
255
                _LOGGER.warning(
×
256
                    "The redeemer is annotated to be 'None'. This value is usually encoded in PlutusData with constructor id 0 and no fields. If you want the script to double function as minting and spending script, annotate the second argument with 'NoRedeemer'."
257
                )
258
            enable_double_func_mint_spend = not any(
4✔
259
                (isinstance(t, RecordType) and t.record.constructor == 0)
260
                or isinstance(t, UnitType)
261
                for t in possible_types
262
            )
263
            if not enable_double_func_mint_spend:
4✔
264
                _LOGGER.warning(
4✔
265
                    "The second argument to the validator function potentially has constructor id 0. The validator will not be able to double function as minting script and spending script."
266
                )
267

268
        validator = plt.Lambda(
4✔
269
            [f"p{i}" for i, _ in enumerate(main_fun_typ.argtyps)],
270
            transform_output_map(main_fun_typ.rettyp)(
271
                plt.Let(
272
                    [
273
                        (
274
                            "s",
275
                            plt.Apply(self.visit_sequence(node.body), INITIAL_STATE),
276
                        ),
277
                        (
278
                            "g",
279
                            plt.FunctionalMapAccess(
280
                                plt.Var("s"),
281
                                plt.ByteString(main_fun.name),
282
                                plt.TraceError(
283
                                    f"NameError: {self.validator_function_name}"
284
                                ),
285
                            ),
286
                        ),
287
                    ],
288
                    plt.Apply(
289
                        plt.Var("g"),
290
                        *[
291
                            transform_ext_params_map(a)(plt.Var(f"p{i}"))
292
                            for i, a in enumerate(main_fun_typ.argtyps)
293
                        ],
294
                        plt.Var("s"),
295
                    ),
296
                ),
297
            ),
298
        )
299
        if enable_double_func_mint_spend:
4✔
300
            validator = wrap_validator_double_function(
4✔
301
                validator, pass_through=len(main_fun_typ.argtyps) - 3
302
            )
303
        elif self.force_three_params:
4✔
304
            # Error if the double function is enforced but not possible
305
            raise RuntimeError(
4✔
306
                "The contract can not always detect if it was passed three or two parameters on-chain."
307
            )
308
        cp = plt.Program((1, 0, 0), validator)
4✔
309
        return cp
4✔
310

311
    def visit_Constant(self, node: TypedConstant) -> plt.AST:
4✔
312
        plt_type = ConstantMap.get(type(node.value))
4✔
313
        if plt_type is None:
4✔
314
            raise NotImplementedError(
×
315
                f"Constants of type {type(node.value)} are not supported"
316
            )
317
        return plt.Lambda([STATEMONAD], plt_type(node.value))
4✔
318

319
    def visit_NoneType(self, _: typing.Optional[typing.Any]) -> plt.AST:
4✔
320
        return plt.Lambda([STATEMONAD], plt.Unit())
×
321

322
    def visit_Assign(self, node: TypedAssign) -> plt.AST:
4✔
323
        assert (
4✔
324
            len(node.targets) == 1
325
        ), "Assignments to more than one variable not supported yet"
326
        assert isinstance(
4✔
327
            node.targets[0], Name
328
        ), "Assignments to other things then names are not supported"
329
        compiled_e = self.visit(node.value)
4✔
330
        # (\{STATEMONAD} -> (\x -> if (x ==b {self.visit(node.targets[0])}) then ({compiled_e} {STATEMONAD}) else ({STATEMONAD} x)))
331
        varname = node.targets[0].id
4✔
332
        return plt.Lambda(
4✔
333
            [STATEMONAD],
334
            extend_statemonad(
335
                [varname],
336
                [plt.Apply(compiled_e, plt.Var(STATEMONAD))],
337
                plt.Var(STATEMONAD),
338
            ),
339
        )
340

341
    def visit_AnnAssign(self, node: AnnAssign) -> plt.AST:
4✔
342
        assert isinstance(
4✔
343
            node.target, Name
344
        ), "Assignments to other things then names are not supported"
345
        assert isinstance(
4✔
346
            node.target.typ, InstanceType
347
        ), "Can only assign instances to instances"
348
        compiled_e = self.visit(node.value)
4✔
349
        # (\{STATEMONAD} -> (\x -> if (x ==b {self.visit(node.targets[0])}) then ({compiled_e} {STATEMONAD}) else ({STATEMONAD} x)))
350
        val = plt.Apply(compiled_e, plt.Var(STATEMONAD))
4✔
351
        if isinstance(node.value.typ, InstanceType) and isinstance(
4✔
352
            node.value.typ.typ, AnyType
353
        ):
354
            # we need to map this as it will originate from PlutusData
355
            # AnyType is the only type other than the builtin itself that can be cast to builtin values
356
            val = transform_ext_params_map(node.target.typ)(val)
4✔
357
        if isinstance(node.target.typ, InstanceType) and isinstance(
4✔
358
            node.target.typ.typ, AnyType
359
        ):
360
            # we need to map this back as it will be treated as PlutusData
361
            # AnyType is the only type other than the builtin itself that can be cast to from builtin values
362
            val = transform_output_map(node.value.typ)(val)
4✔
363
        return plt.Lambda(
4✔
364
            [STATEMONAD],
365
            extend_statemonad(
366
                [node.target.id],
367
                [val],
368
                plt.Var(STATEMONAD),
369
            ),
370
        )
371

372
    def visit_Name(self, node: TypedName) -> plt.AST:
4✔
373
        # depending on load or store context, return the value of the variable or its name
374
        if not isinstance(node.ctx, Load):
4✔
375
            raise NotImplementedError(f"Context {node.ctx} not supported")
×
376
        if isinstance(node.typ, ClassType):
4✔
377
            # if this is not an instance but a class, call the constructor
378
            return plt.Lambda(
4✔
379
                [STATEMONAD],
380
                node.typ.constr(),
381
            )
382
        return plt.Lambda(
4✔
383
            [STATEMONAD],
384
            plt.FunctionalMapAccess(
385
                plt.Var(STATEMONAD),
386
                plt.ByteString(node.id),
387
                plt.TraceError(f"NameError: {node.orig_id}"),
388
            ),
389
        )
390

391
    def visit_Expr(self, node: TypedExpr) -> plt.AST:
4✔
392
        # we exploit UPLCs eager evaluation here
393
        # the expression is computed even though its value is eventually discarded
394
        # Note this really only makes sense for Trace
395
        return plt.Lambda(
4✔
396
            [STATEMONAD],
397
            plt.Apply(
398
                plt.Lambda(["_"], plt.Var(STATEMONAD)),
399
                plt.Apply(self.visit(node.value), plt.Var(STATEMONAD)),
400
            ),
401
        )
402

403
    def visit_Call(self, node: TypedCall) -> plt.AST:
4✔
404
        # compiled_args = " ".join(f"({self.visit(a)} {STATEMONAD})" for a in node.args)
405
        # return rf"(\{STATEMONAD} -> ({self.visit(node.func)} {compiled_args})"
406
        # TODO function is actually not of type polymorphic function type here anymore
407
        if isinstance(node.func.typ, PolymorphicFunctionInstanceType):
4✔
408
            # edge case for weird builtins that are polymorphic
409
            func_plt = node.func.typ.polymorphic_function.impl_from_args(
4✔
410
                node.func.typ.typ.argtyps
411
            )
412
        else:
413
            func_plt = plt.Apply(self.visit(node.func), plt.Var(STATEMONAD))
4✔
414
        args = []
4✔
415
        for a, t in zip(node.args, node.func.typ.typ.argtyps):
4✔
416
            assert isinstance(t, InstanceType)
4✔
417
            # pass in all arguments evaluated with the statemonad
418
            a_int = plt.Apply(self.visit(a), plt.Var(STATEMONAD))
4✔
419
            if isinstance(t.typ, AnyType):
4✔
420
                # if the function expects input of generic type data, wrap data before passing it inside
421
                a_int = transform_output_map(a.typ)(a_int)
4✔
422
            args.append(a_int)
4✔
423
        return plt.Lambda(
4✔
424
            [STATEMONAD],
425
            plt.Apply(
426
                func_plt,
427
                *args,
428
                # eventually pass in the state monad as well
429
                plt.Var(STATEMONAD),
430
            ),
431
        )
432

433
    def visit_FunctionDef(self, node: TypedFunctionDef) -> plt.AST:
4✔
434
        body = node.body.copy()
4✔
435
        if not body or not isinstance(body[-1], Return):
4✔
436
            tr = Return(TypedConstant(None, typ=NoneInstanceType))
4✔
437
            tr.typ = NoneInstanceType
4✔
438
            body.append(tr)
4✔
439
        compiled_body = self.visit_sequence(body[:-1])
4✔
440
        args_state = extend_statemonad(
4✔
441
            # the function can see its argument under the argument names
442
            [a.arg for a in node.args.args],
443
            [plt.Var(f"p{i}") for i in range(len(node.args.args))],
444
            plt.Var(STATEMONAD),
445
        )
446
        compiled_return = plt.Apply(
4✔
447
            self.visit(body[-1].value),
448
            plt.Apply(
449
                compiled_body,
450
                args_state,
451
            ),
452
        )
453
        if isinstance(node.typ.typ.rettyp.typ, AnyType):
4✔
454
            # if the function returns generic data, wrap the function return value
455
            compiled_return = transform_output_map(body[-1].value.typ)(compiled_return)
4✔
456
        return plt.Lambda(
4✔
457
            [STATEMONAD],
458
            extend_statemonad(
459
                [node.name],
460
                [
461
                    plt.Lambda(
462
                        # expect the statemonad again -> this is the basis for internally available values
463
                        [f"p{i}" for i in range(len(node.args.args))] + [STATEMONAD],
464
                        compiled_return,
465
                    )
466
                ],
467
                plt.Var(STATEMONAD),
468
            ),
469
        )
470

471
    def visit_While(self, node: TypedWhile) -> plt.AST:
4✔
472
        compiled_c = self.visit(node.test)
4✔
473
        compiled_s = self.visit_sequence(node.body)
4✔
474
        if node.orelse:
4✔
475
            # If there is orelse, transform it to an appended sequence (TODO check if this is correct)
476
            cn = copy(node)
×
477
            cn.orelse = []
×
478
            return self.visit_sequence([cn] + node.orelse)
×
479
        # return rf"(\{STATEMONAD} -> let g = (\s f -> if ({compiled_c} s) then f ({compiled_s} s) f else s) in (g {STATEMONAD} g))"
480
        return plt.Lambda(
4✔
481
            [STATEMONAD],
482
            plt.Let(
483
                bindings=[
484
                    (
485
                        "g",
486
                        plt.Lambda(
487
                            ["s", "f"],
488
                            plt.Ite(
489
                                plt.Apply(compiled_c, plt.Var("s")),
490
                                plt.Apply(
491
                                    plt.Var("f"),
492
                                    plt.Apply(compiled_s, plt.Var("s")),
493
                                    plt.Var("f"),
494
                                ),
495
                                plt.Var("s"),
496
                            ),
497
                        ),
498
                    ),
499
                ],
500
                term=plt.Apply(plt.Var("g"), plt.Var(STATEMONAD), plt.Var("g")),
501
            ),
502
        )
503

504
    def visit_For(self, node: TypedFor) -> plt.AST:
4✔
505
        if node.orelse:
4✔
506
            # If there is orelse, transform it to an appended sequence (TODO check if this is correct)
507
            cn = copy(node)
×
508
            cn.orelse = []
×
509
            return self.visit_sequence([cn] + node.orelse)
×
510
        assert isinstance(node.iter.typ, InstanceType)
4✔
511
        if isinstance(node.iter.typ.typ, ListType):
4✔
512
            assert isinstance(
4✔
513
                node.target, Name
514
            ), "Can only assign value to singleton element"
515
            return plt.Lambda(
4✔
516
                [STATEMONAD],
517
                plt.FoldList(
518
                    plt.Apply(self.visit(node.iter), plt.Var(STATEMONAD)),
519
                    plt.Lambda(
520
                        [STATEMONAD, "e"],
521
                        plt.Apply(
522
                            self.visit_sequence(node.body),
523
                            extend_statemonad(
524
                                [node.target.id],
525
                                [plt.Var("e")],
526
                                plt.Var(STATEMONAD),
527
                            ),
528
                        ),
529
                    ),
530
                    plt.Var(STATEMONAD),
531
                ),
532
            )
533
        raise NotImplementedError(
×
534
            "Compilation of for statements for anything but lists not implemented yet"
535
        )
536

537
    def visit_If(self, node: TypedIf) -> plt.AST:
4✔
538
        return plt.Lambda(
4✔
539
            [STATEMONAD],
540
            plt.Ite(
541
                plt.Apply(self.visit(node.test), plt.Var(STATEMONAD)),
542
                plt.Apply(self.visit_sequence(node.body), plt.Var(STATEMONAD)),
543
                plt.Apply(self.visit_sequence(node.orelse), plt.Var(STATEMONAD)),
544
            ),
545
        )
546

547
    def visit_Return(self, node: TypedReturn) -> plt.AST:
4✔
548
        raise NotImplementedError(
×
549
            "Compilation of return statements except for last statement in function is not supported."
550
        )
551

552
    def visit_Pass(self, node: TypedPass) -> plt.AST:
4✔
553
        return self.visit_sequence([])
×
554

555
    def visit_Subscript(self, node: TypedSubscript) -> plt.AST:
4✔
556
        assert isinstance(
4✔
557
            node.value.typ, InstanceType
558
        ), "Can only access elements of instances, not classes"
559
        if isinstance(node.value.typ.typ, TupleType):
4✔
560
            assert isinstance(
4✔
561
                node.slice, Constant
562
            ), "Only constant index access for tuples is supported"
563
            assert isinstance(
4✔
564
                node.slice.value, int
565
            ), "Only constant index integer access for tuples is supported"
566
            index = node.slice.value
4✔
567
            if index < 0:
4✔
568
                index += len(node.value.typ.typ.typs)
×
569
            assert isinstance(node.ctx, Load), "Tuples are read-only"
4✔
570
            return plt.Lambda(
4✔
571
                [STATEMONAD],
572
                plt.FunctionalTupleAccess(
573
                    plt.Apply(self.visit(node.value), plt.Var(STATEMONAD)),
574
                    index,
575
                    len(node.value.typ.typ.typs),
576
                ),
577
            )
578
        if isinstance(node.value.typ.typ, PairType):
4✔
579
            assert isinstance(
4✔
580
                node.slice, Constant
581
            ), "Only constant index access for pairs is supported"
582
            assert isinstance(
4✔
583
                node.slice.value, int
584
            ), "Only constant index integer access for pairs is supported"
585
            index = node.slice.value
4✔
586
            if index < 0:
4✔
587
                index += 2
×
588
            assert isinstance(node.ctx, Load), "Pairs are read-only"
4✔
589
            assert (
4✔
590
                0 <= index < 2
591
            ), f"Pairs only have 2 elements, index should be 0 or 1, is {node.slice.value}"
592
            member_func = plt.FstPair if index == 0 else plt.SndPair
4✔
593
            # the content of pairs is always Data, so we need to unwrap
594
            member_typ = node.typ
4✔
595
            return plt.Lambda(
4✔
596
                [STATEMONAD],
597
                transform_ext_params_map(member_typ)(
598
                    member_func(
599
                        plt.Apply(self.visit(node.value), plt.Var(STATEMONAD)),
600
                    ),
601
                ),
602
            )
603
        if isinstance(node.value.typ.typ, ListType):
4✔
604
            assert (
4✔
605
                node.slice.typ == IntegerInstanceType
606
            ), "Only single element list index access supported"
607
            return plt.Lambda(
4✔
608
                [STATEMONAD],
609
                plt.Let(
610
                    [
611
                        ("l", plt.Apply(self.visit(node.value), plt.Var(STATEMONAD))),
612
                        (
613
                            "raw_i",
614
                            plt.Apply(self.visit(node.slice), plt.Var(STATEMONAD)),
615
                        ),
616
                        (
617
                            "i",
618
                            plt.Ite(
619
                                plt.LessThanInteger(plt.Var("raw_i"), plt.Integer(0)),
620
                                plt.AddInteger(
621
                                    plt.Var("raw_i"), plt.LengthList(plt.Var("l"))
622
                                ),
623
                                plt.Var("raw_i"),
624
                            ),
625
                        ),
626
                    ],
627
                    plt.IndexAccessList(plt.Var("l"), plt.Var("i")),
628
                ),
629
            )
630
        elif isinstance(node.value.typ.typ, DictType):
4✔
631
            dict_typ = node.value.typ.typ
4✔
632
            if not isinstance(node.slice, Slice):
4✔
633
                return plt.Lambda(
4✔
634
                    [STATEMONAD],
635
                    plt.Let(
636
                        [
637
                            (
638
                                "key",
639
                                plt.Apply(self.visit(node.slice), plt.Var(STATEMONAD)),
640
                            )
641
                        ],
642
                        transform_ext_params_map(dict_typ.value_typ)(
643
                            plt.SndPair(
644
                                plt.FindList(
645
                                    plt.Apply(
646
                                        self.visit(node.value), plt.Var(STATEMONAD)
647
                                    ),
648
                                    plt.Lambda(
649
                                        ["x"],
650
                                        plt.EqualsData(
651
                                            transform_output_map(dict_typ.key_typ)(
652
                                                plt.Var("key")
653
                                            ),
654
                                            plt.FstPair(plt.Var("x")),
655
                                        ),
656
                                    ),
657
                                    plt.TraceError("KeyError"),
658
                                ),
659
                            ),
660
                        ),
661
                    ),
662
                )
663
        elif isinstance(node.value.typ.typ, ByteStringType):
4✔
664
            if not isinstance(node.slice, Slice):
4✔
665
                return plt.Lambda(
4✔
666
                    [STATEMONAD],
667
                    plt.Let(
668
                        [
669
                            (
670
                                "bs",
671
                                plt.Apply(self.visit(node.value), plt.Var(STATEMONAD)),
672
                            ),
673
                            (
674
                                "raw_ix",
675
                                plt.Apply(self.visit(node.slice), plt.Var(STATEMONAD)),
676
                            ),
677
                            (
678
                                "ix",
679
                                plt.Ite(
680
                                    plt.LessThanInteger(
681
                                        plt.Var("raw_ix"), plt.Integer(0)
682
                                    ),
683
                                    plt.AddInteger(
684
                                        plt.Var("raw_ix"),
685
                                        plt.LengthOfByteString(plt.Var("bs")),
686
                                    ),
687
                                    plt.Var("raw_ix"),
688
                                ),
689
                            ),
690
                        ],
691
                        plt.IndexByteString(plt.Var("bs"), plt.Var("ix")),
692
                    ),
693
                )
694
            elif isinstance(node.slice, Slice):
4✔
695
                return plt.Lambda(
4✔
696
                    [STATEMONAD],
697
                    plt.Let(
698
                        [
699
                            (
700
                                "bs",
701
                                plt.Apply(self.visit(node.value), plt.Var(STATEMONAD)),
702
                            ),
703
                            (
704
                                "raw_i",
705
                                plt.Apply(
706
                                    self.visit(node.slice.lower), plt.Var(STATEMONAD)
707
                                ),
708
                            ),
709
                            (
710
                                "i",
711
                                plt.Ite(
712
                                    plt.LessThanInteger(
713
                                        plt.Var("raw_i"), plt.Integer(0)
714
                                    ),
715
                                    plt.AddInteger(
716
                                        plt.Var("raw_i"),
717
                                        plt.LengthOfByteString(plt.Var("bs")),
718
                                    ),
719
                                    plt.Var("raw_i"),
720
                                ),
721
                            ),
722
                            (
723
                                "raw_j",
724
                                plt.Apply(
725
                                    self.visit(node.slice.upper), plt.Var(STATEMONAD)
726
                                ),
727
                            ),
728
                            (
729
                                "j",
730
                                plt.Ite(
731
                                    plt.LessThanInteger(
732
                                        plt.Var("raw_j"), plt.Integer(0)
733
                                    ),
734
                                    plt.AddInteger(
735
                                        plt.Var("raw_j"),
736
                                        plt.LengthOfByteString(plt.Var("bs")),
737
                                    ),
738
                                    plt.Var("raw_j"),
739
                                ),
740
                            ),
741
                            (
742
                                "drop",
743
                                plt.Ite(
744
                                    plt.LessThanEqualsInteger(
745
                                        plt.Var("i"), plt.Integer(0)
746
                                    ),
747
                                    plt.Integer(0),
748
                                    plt.Var("i"),
749
                                ),
750
                            ),
751
                            (
752
                                "take",
753
                                plt.SubtractInteger(plt.Var("j"), plt.Var("drop")),
754
                            ),
755
                        ],
756
                        plt.Ite(
757
                            plt.LessThanEqualsInteger(plt.Var("j"), plt.Var("i")),
758
                            plt.ByteString(b""),
759
                            plt.SliceByteString(
760
                                plt.Var("drop"),
761
                                plt.Var("take"),
762
                                plt.Var("bs"),
763
                            ),
764
                        ),
765
                    ),
766
                )
767
        raise NotImplementedError(
×
768
            f'Could not implement subscript "{node.slice}" of "{node.value}"'
769
        )
770

771
    def visit_Tuple(self, node: TypedTuple) -> plt.AST:
4✔
772
        return plt.Lambda(
4✔
773
            [STATEMONAD],
774
            plt.FunctionalTuple(
775
                *(plt.Apply(self.visit(e), plt.Var(STATEMONAD)) for e in node.elts)
776
            ),
777
        )
778

779
    def visit_ClassDef(self, node: TypedClassDef) -> plt.AST:
4✔
780
        return plt.Lambda(
4✔
781
            [STATEMONAD],
782
            extend_statemonad(
783
                [node.name],
784
                [node.class_typ.constr()],
785
                plt.Var(STATEMONAD),
786
            ),
787
        )
788

789
    def visit_Attribute(self, node: TypedAttribute) -> plt.AST:
4✔
790
        assert isinstance(
4✔
791
            node.typ, InstanceType
792
        ), "Can only access attributes of instances"
793
        obj = self.visit(node.value)
4✔
794
        attr = node.value.typ.attribute(node.attr)
4✔
795
        return plt.Lambda(
4✔
796
            [STATEMONAD], plt.Apply(attr, plt.Apply(obj, plt.Var(STATEMONAD)))
797
        )
798

799
    def visit_Assert(self, node: TypedAssert) -> plt.AST:
4✔
800
        return plt.Lambda(
4✔
801
            [STATEMONAD],
802
            plt.Ite(
803
                plt.Apply(self.visit(node.test), plt.Var(STATEMONAD)),
804
                plt.Var(STATEMONAD),
805
                plt.Apply(
806
                    plt.Error(),
807
                    plt.Trace(
808
                        plt.Apply(self.visit(node.msg), plt.Var(STATEMONAD)), plt.Unit()
809
                    )
810
                    if node.msg is not None
811
                    else plt.Unit(),
812
                ),
813
            ),
814
        )
815

816
    def visit_RawPlutoExpr(self, node: RawPlutoExpr) -> plt.AST:
4✔
817
        return node.expr
4✔
818

819
    def visit_List(self, node: TypedList) -> plt.AST:
4✔
820
        assert isinstance(node.typ, InstanceType)
4✔
821
        assert isinstance(node.typ.typ, ListType)
4✔
822
        l = empty_list(node.typ.typ.typ)
4✔
823
        for e in reversed(node.elts):
4✔
824
            l = plt.MkCons(plt.Apply(self.visit(e), plt.Var(STATEMONAD)), l)
4✔
825
        return plt.Lambda([STATEMONAD], l)
4✔
826

827
    def visit_Dict(self, node: TypedDict) -> plt.AST:
4✔
828
        assert isinstance(node.typ, InstanceType)
4✔
829
        assert isinstance(node.typ.typ, DictType)
4✔
830
        key_type = node.typ.typ.key_typ
4✔
831
        value_type = node.typ.typ.value_typ
4✔
832
        l = plt.EmptyDataPairList()
4✔
833
        for k, v in zip(node.keys, node.values):
4✔
834
            l = plt.MkCons(
4✔
835
                plt.MkPairData(
836
                    transform_output_map(key_type)(
837
                        plt.Apply(self.visit(k), plt.Var(STATEMONAD))
838
                    ),
839
                    transform_output_map(value_type)(
840
                        plt.Apply(self.visit(v), plt.Var(STATEMONAD))
841
                    ),
842
                ),
843
                l,
844
            )
845
        return plt.Lambda([STATEMONAD], l)
4✔
846

847
    def visit_IfExp(self, node: TypedIfExp) -> plt.AST:
4✔
848
        return plt.Lambda(
4✔
849
            [STATEMONAD],
850
            plt.Ite(
851
                plt.Apply(self.visit(node.test), plt.Var(STATEMONAD)),
852
                plt.Apply(self.visit(node.body), plt.Var(STATEMONAD)),
853
                plt.Apply(self.visit(node.orelse), plt.Var(STATEMONAD)),
854
            ),
855
        )
856

857
    def visit_ListComp(self, node: TypedListComp) -> plt.AST:
4✔
858
        assert len(node.generators) == 1, "Currently only one generator supported"
4✔
859
        gen = node.generators[0]
4✔
860
        assert isinstance(gen.iter.typ, InstanceType), "Only lists are valid generators"
4✔
861
        assert isinstance(gen.iter.typ.typ, ListType), "Only lists are valid generators"
4✔
862
        assert isinstance(
4✔
863
            gen.target, Name
864
        ), "Can only assign value to singleton element"
865
        lst = plt.Apply(self.visit(gen.iter), plt.Var(STATEMONAD))
4✔
866
        ifs = None
4✔
867
        for ifexpr in gen.ifs:
4✔
868
            if ifs is None:
4✔
869
                ifs = self.visit(ifexpr)
4✔
870
            else:
871
                ifs = plt.And(ifs, self.visit(ifexpr))
×
872
        map_fun = plt.Lambda(
4✔
873
            ["x"],
874
            plt.Apply(
875
                self.visit(node.elt),
876
                extend_statemonad([gen.target.id], [plt.Var("x")], plt.Var(STATEMONAD)),
877
            ),
878
        )
879
        empty_list_con = empty_list(node.elt.typ)
4✔
880
        if ifs is not None:
4✔
881
            filter_fun = plt.Lambda(
4✔
882
                ["x"],
883
                plt.Apply(
884
                    ifs,
885
                    extend_statemonad(
886
                        [gen.target.id], [plt.Var("x")], plt.Var(STATEMONAD)
887
                    ),
888
                ),
889
            )
890
            return plt.Lambda(
4✔
891
                [STATEMONAD],
892
                plt.MapFilterList(
893
                    lst,
894
                    filter_fun,
895
                    map_fun,
896
                    empty_list_con,
897
                ),
898
            )
899
        else:
900
            return plt.Lambda(
4✔
901
                [STATEMONAD],
902
                plt.MapList(
903
                    lst,
904
                    map_fun,
905
                    empty_list_con,
906
                ),
907
            )
908

909
    def generic_visit(self, node: AST) -> plt.AST:
4✔
910
        raise NotImplementedError(f"Can not compile {node}")
×
911

912

913
def compile(
4✔
914
    prog: AST,
915
    filename=None,
916
    force_three_params=False,
917
    validator_function_name="validator",
918
):
919
    rewrite_steps = [
4✔
920
        # Important to call this one first - it imports all further files
921
        RewriteImport(filename=filename),
922
        # Rewrites that simplify the python code
923
        RewriteSubscript38(),
924
        RewriteAugAssign(),
925
        RewriteTupleAssign(),
926
        RewriteImportPlutusData(),
927
        RewriteImportHashlib(),
928
        RewriteImportTyping(),
929
        RewriteForbiddenOverwrites(),
930
        RewriteImportDataclasses(),
931
        RewriteInjectBuiltins(),
932
        # The type inference needs to be run after complex python operations were rewritten
933
        AggressiveTypeInferencer(),
934
        # Rewrites that circumvent the type inference or use its results
935
        RewriteZeroAry(),
936
        RewriteInjectBuiltinsConstr(),
937
        RewriteRemoveTypeStuff(),
938
    ]
939
    for s in rewrite_steps:
4✔
940
        prog = s.visit(prog)
4✔
941
        prog = fix_missing_locations(prog)
4✔
942

943
    # from here on raw uplc may occur, so we dont attempt to fix locations
944
    compile_pipeline = [
4✔
945
        # Save the original names of variables
946
        RewriteOrigName(),
947
        RewriteScoping(),
948
        # Apply optimizations
949
        OptimizeRemoveDeadvars(),
950
        OptimizeVarlen(),
951
        OptimizeRemoveDeadconstants(),
952
        OptimizeRemovePass(),
953
        # the compiler runs last
954
        UPLCCompiler(
955
            force_three_params=force_three_params,
956
            validator_function_name=validator_function_name,
957
        ),
958
    ]
959
    for s in compile_pipeline:
4✔
960
        prog = s.visit(prog)
4✔
961

962
    return prog
4✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc