• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

ImperatorLang / eopsin / 308

pending completion
308

Pull #58

travis-ci-com

web-flow
Merge 67a87c1ef into cee2aa9bb
Pull Request #58: Feat/parameterization

12 of 12 new or added lines in 2 files covered. (100.0%)

2568 of 2770 relevant lines covered (92.71%)

0.93 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

92.59
/eopsin/compiler.py
1
import logging
1✔
2
from logging import getLogger
1✔
3
from ast import fix_missing_locations
1✔
4

5
from .rewrite.rewrite_augassign import RewriteAugAssign
1✔
6
from .rewrite.rewrite_forbidden_overwrites import RewriteForbiddenOverwrites
1✔
7
from .rewrite.rewrite_import import RewriteImport
1✔
8
from .rewrite.rewrite_import_dataclasses import RewriteImportDataclasses
1✔
9
from .rewrite.rewrite_import_hashlib import RewriteImportHashlib
1✔
10
from .rewrite.rewrite_import_plutusdata import RewriteImportPlutusData
1✔
11
from .rewrite.rewrite_import_typing import RewriteImportTyping
1✔
12
from .rewrite.rewrite_inject_builtins import RewriteInjectBuiltins
1✔
13
from .rewrite.rewrite_inject_builtin_constr import RewriteInjectBuiltinsConstr
1✔
14
from .rewrite.rewrite_remove_type_stuff import RewriteRemoveTypeStuff
1✔
15
from .rewrite.rewrite_tuple_assign import RewriteTupleAssign
1✔
16
from .optimize.optimize_remove_pass import OptimizeRemovePass
1✔
17
from .optimize.optimize_remove_deadvars import OptimizeRemoveDeadvars
1✔
18
from .optimize.optimize_varlen import OptimizeVarlen
1✔
19
from .type_inference import *
1✔
20
from .util import CompilingNodeTransformer, PowImpl
1✔
21
from .typed_ast import transform_ext_params_map, transform_output_map, RawPlutoExpr
1✔
22

23

24
_LOGGER = logging.getLogger(__name__)
1✔
25

26
STATEMONAD = "s"
1✔
27

28

29
BinOpMap = {
1✔
30
    Add: {
31
        IntegerInstanceType: {
32
            IntegerInstanceType: plt.AddInteger,
33
        },
34
        ByteStringInstanceType: {
35
            ByteStringInstanceType: plt.AppendByteString,
36
        },
37
        StringInstanceType: {
38
            StringInstanceType: plt.AppendString,
39
        },
40
    },
41
    Sub: {
42
        IntegerInstanceType: {
43
            IntegerInstanceType: plt.SubtractInteger,
44
        }
45
    },
46
    Mult: {
47
        IntegerInstanceType: {
48
            IntegerInstanceType: plt.MultiplyInteger,
49
        }
50
    },
51
    FloorDiv: {
52
        IntegerInstanceType: {
53
            IntegerInstanceType: plt.DivideInteger,
54
        }
55
    },
56
    Mod: {
57
        IntegerInstanceType: {
58
            IntegerInstanceType: plt.ModInteger,
59
        }
60
    },
61
    Pow: {
62
        IntegerInstanceType: {
63
            IntegerInstanceType: PowImpl,
64
        }
65
    },
66
}
67

68
BoolOpMap = {
1✔
69
    And: plt.And,
70
    Or: plt.Or,
71
}
72

73
UnaryOpMap = {
1✔
74
    Not: {BoolInstanceType: plt.Not},
75
    USub: {IntegerInstanceType: lambda x: plt.SubtractInteger(plt.Integer(0), x)},
76
}
77

78
ConstantMap = {
1✔
79
    str: plt.Text,
80
    bytes: lambda x: plt.ByteString(x),
81
    int: lambda x: plt.Integer(x),
82
    bool: plt.Bool,
83
    type(None): lambda _: plt.Unit(),
84
}
85

86

87
def wrap_validator_double_function(x: plt.AST, pass_through: int = 0):
1✔
88
    """
89
    Wraps the validator function to enable a double function as minting script
90

91
    pass_through defines how many parameters x would normally take and should be passed through to x
92
    """
93
    return plt.Lambda(
1✔
94
        [f"v{i}" for i in range(pass_through)] + ["a0", "a1"],
95
        plt.Let(
96
            [("p", plt.Apply(x, *(plt.Var(f"v{i}") for i in range(pass_through))))],
97
            plt.Ite(
98
                # if the second argument has constructor 0 = script context
99
                plt.DelayedChooseData(
100
                    plt.Var("a1"),
101
                    plt.EqualsInteger(plt.Constructor(plt.Var("a1")), plt.Integer(0)),
102
                    plt.Bool(False),
103
                    plt.Bool(False),
104
                    plt.Bool(False),
105
                    plt.Bool(False),
106
                ),
107
                # call the validator with a0, a1, and plug in Unit for data
108
                plt.Apply(plt.Var("p"), plt.Unit(), plt.Var("a0"), plt.Var("a1")),
109
                # else call the validator with a0, a1 and return (now partially bound)
110
                plt.Apply(plt.Var("p"), plt.Var("a0"), plt.Var("a1")),
111
            ),
112
        ),
113
    )
114

115

116
def extend_statemonad(
1✔
117
    names: typing.List[str],
118
    values: typing.List[plt.AST],
119
    old_statemonad: plt.FunctionalMap,
120
):
121
    """Ensures that the argument is fully evaluated before being passed into the monad (like in imperative languages)"""
122
    assert len(names) == len(values), "Unequal amount of names and values passed in"
1✔
123
    lam_names = [f"a{i}" for i, _ in enumerate(names)]
1✔
124
    return plt.Apply(
1✔
125
        plt.Lambda(
126
            lam_names,
127
            plt.FunctionalMapExtend(
128
                old_statemonad, names, [plt.Var(n) for n in lam_names]
129
            ),
130
        ),
131
        *values,
132
    )
133

134

135
INITIAL_STATE = plt.FunctionalMap()
1✔
136

137

138
class UPLCCompiler(CompilingNodeTransformer):
1✔
139
    """
140
    Expects a TypedAST and returns UPLC/Pluto like code
141
    """
142

143
    step = "Compiling python statements to UPLC"
1✔
144

145
    def visit_sequence(self, node_seq: typing.List[typedstmt]) -> plt.AST:
1✔
146
        s = plt.Var(STATEMONAD)
1✔
147
        for n in node_seq:
1✔
148
            compiled_stmt = self.visit(n)
1✔
149
            s = plt.Apply(compiled_stmt, s)
1✔
150
        return plt.Lambda([STATEMONAD], s)
1✔
151

152
    def visit_BinOp(self, node: TypedBinOp) -> plt.AST:
1✔
153
        opmap = BinOpMap.get(type(node.op))
1✔
154
        if opmap is None:
1✔
155
            raise NotImplementedError(f"Operation {node.op} is not implemented")
×
156
        opmap2 = opmap.get(node.left.typ)
1✔
157
        if opmap2 is None:
1✔
158
            raise NotImplementedError(
×
159
                f"Operation {node.op} is not implemented for left type {node.left.typ}"
160
            )
161
        op = opmap2.get(node.right.typ)
1✔
162
        if opmap2 is None:
1✔
163
            raise NotImplementedError(
×
164
                f"Operation {node.op} is not implemented for left type {node.left.typ} and right type {node.right.typ}"
165
            )
166
        return plt.Lambda(
1✔
167
            [STATEMONAD],
168
            op(
169
                plt.Apply(self.visit(node.left), plt.Var(STATEMONAD)),
170
                plt.Apply(self.visit(node.right), plt.Var(STATEMONAD)),
171
            ),
172
        )
173

174
    def visit_BoolOp(self, node: TypedBoolOp) -> plt.AST:
1✔
175
        op = BoolOpMap.get(type(node.op))
1✔
176
        assert len(node.values) >= 2, "Need to compare at least to values"
1✔
177
        ops = op(
1✔
178
            plt.Apply(self.visit(node.values[0]), plt.Var(STATEMONAD)),
179
            plt.Apply(self.visit(node.values[1]), plt.Var(STATEMONAD)),
180
        )
181
        for v in node.values[2:]:
1✔
182
            ops = op(ops, plt.Apply(self.visit(v), plt.Var(STATEMONAD)))
1✔
183
        return plt.Lambda(
1✔
184
            [STATEMONAD],
185
            ops,
186
        )
187

188
    def visit_UnaryOp(self, node: TypedUnaryOp) -> plt.AST:
1✔
189
        opmap = UnaryOpMap.get(type(node.op))
1✔
190
        assert opmap is not None, f"Operator {type(node.op)} is not supported"
1✔
191
        op = opmap.get(node.operand.typ)
1✔
192
        assert (
1✔
193
            op is not None
194
        ), f"Operator {type(node.op)} is not supported for type {node.operand.typ}"
195
        return plt.Lambda(
1✔
196
            [STATEMONAD],
197
            op(plt.Apply(self.visit(node.operand), plt.Var(STATEMONAD))),
198
        )
199

200
    def visit_Compare(self, node: TypedCompare) -> plt.AST:
1✔
201
        assert len(node.ops) == 1, "Only single comparisons are supported"
1✔
202
        assert len(node.comparators) == 1, "Only single comparisons are supported"
1✔
203
        cmpop = node.ops[0]
1✔
204
        comparator = node.comparators[0].typ
1✔
205
        op = node.left.typ.cmp(cmpop, comparator)
1✔
206
        return plt.Lambda(
1✔
207
            [STATEMONAD],
208
            plt.Apply(
209
                op,
210
                plt.Apply(self.visit(node.left), plt.Var(STATEMONAD)),
211
                plt.Apply(self.visit(node.comparators[0]), plt.Var(STATEMONAD)),
212
            ),
213
        )
214

215
    def visit_Module(self, node: TypedModule) -> plt.AST:
1✔
216
        # find main function
217
        # TODO can use more sophisiticated procedure here i.e. functions marked by comment
218
        main_fun: typing.Optional[InstanceType] = None
1✔
219
        for s in node.body:
1✔
220
            if isinstance(s, FunctionDef) and s.orig_name == "validator":
1✔
221
                main_fun = s
1✔
222
        assert main_fun is not None, "Could not find function named validator"
1✔
223
        main_fun_typ: FunctionType = main_fun.typ.typ
1✔
224
        assert isinstance(
1✔
225
            main_fun_typ, FunctionType
226
        ), "Variable named validator is not of type function"
227

228
        # check if this is a contract written to double function
229
        enable_double_func_mint_spend = False
1✔
230
        if len(main_fun_typ.argtyps) >= 3:
1✔
231
            # check if is possible
232
            second_last_arg = main_fun_typ.argtyps[-2]
1✔
233
            assert isinstance(
1✔
234
                second_last_arg, InstanceType
235
            ), "Can not pass Class into validator"
236
            if isinstance(second_last_arg.typ, UnionType):
1✔
237
                possible_types = second_last_arg.typ.typs
×
238
            else:
239
                possible_types = [second_last_arg.typ]
1✔
240
            if any(isinstance(t, UnitType) for t in possible_types):
1✔
241
                _LOGGER.warning(
1✔
242
                    "The redeemer is annotated to be 'None'. This value is usually encoded in PlutusData with constructor id 0 and no fields. If you want the script to double function as minting and spending script, annotate the second argument with 'NoRedeemer'."
243
                )
244
            enable_double_func_mint_spend = not any(
1✔
245
                (isinstance(t, RecordType) and t.record.constructor == 0)
246
                or isinstance(t, UnitType)
247
                for t in possible_types
248
            )
249
            if not enable_double_func_mint_spend:
1✔
250
                _LOGGER.warning(
1✔
251
                    "The second argument to the validator function potentially has constructor id 0. The validator will not be able to double function as minting script and spending script."
252
                )
253

254
        validator = plt.Lambda(
1✔
255
            [f"p{i}" for i, _ in enumerate(main_fun_typ.argtyps)],
256
            transform_output_map(main_fun_typ.rettyp)(
257
                plt.Let(
258
                    [
259
                        (
260
                            "s",
261
                            plt.Apply(self.visit_sequence(node.body), INITIAL_STATE),
262
                        ),
263
                        (
264
                            "g",
265
                            plt.FunctionalMapAccess(
266
                                plt.Var("s"),
267
                                plt.ByteString(main_fun.name),
268
                                plt.TraceError("NameError: validator"),
269
                            ),
270
                        ),
271
                    ],
272
                    plt.Apply(
273
                        plt.Var("g"),
274
                        *[
275
                            transform_ext_params_map(a)(plt.Var(f"p{i}"))
276
                            for i, a in enumerate(main_fun_typ.argtyps)
277
                        ],
278
                        plt.Var("s"),
279
                    ),
280
                ),
281
            ),
282
        )
283
        if enable_double_func_mint_spend:
1✔
284
            validator = wrap_validator_double_function(
1✔
285
                validator, pass_through=len(main_fun_typ.argtyps) - 3
286
            )
287
        cp = plt.Program("1.0.0", validator)
1✔
288
        return cp
1✔
289

290
    def visit_Constant(self, node: TypedConstant) -> plt.AST:
1✔
291
        plt_type = ConstantMap.get(type(node.value))
1✔
292
        if plt_type is None:
1✔
293
            raise NotImplementedError(
×
294
                f"Constants of type {type(node.value)} are not supported"
295
            )
296
        return plt.Lambda([STATEMONAD], plt_type(node.value))
1✔
297

298
    def visit_NoneType(self, _: typing.Optional[typing.Any]) -> plt.AST:
1✔
299
        return plt.Lambda([STATEMONAD], plt.Unit())
1✔
300

301
    def visit_Assign(self, node: TypedAssign) -> plt.AST:
1✔
302
        assert (
1✔
303
            len(node.targets) == 1
304
        ), "Assignments to more than one variable not supported yet"
305
        assert isinstance(
1✔
306
            node.targets[0], Name
307
        ), "Assignments to other things then names are not supported"
308
        compiled_e = self.visit(node.value)
1✔
309
        # (\{STATEMONAD} -> (\x -> if (x ==b {self.visit(node.targets[0])}) then ({compiled_e} {STATEMONAD}) else ({STATEMONAD} x)))
310
        varname = node.targets[0].id
1✔
311
        return plt.Lambda(
1✔
312
            [STATEMONAD],
313
            extend_statemonad(
314
                [varname],
315
                [plt.Apply(compiled_e, plt.Var(STATEMONAD))],
316
                plt.Var(STATEMONAD),
317
            ),
318
        )
319

320
    def visit_AnnAssign(self, node: AnnAssign) -> plt.AST:
1✔
321
        assert isinstance(
1✔
322
            node.target, Name
323
        ), "Assignments to other things then names are not supported"
324
        assert isinstance(
1✔
325
            node.target.typ, InstanceType
326
        ), "Can only assign instances to instances"
327
        compiled_e = self.visit(node.value)
1✔
328
        # we need to map this as it will originate from PlutusData
329
        # (\{STATEMONAD} -> (\x -> if (x ==b {self.visit(node.targets[0])}) then ({compiled_e} {STATEMONAD}) else ({STATEMONAD} x)))
330
        return plt.Lambda(
1✔
331
            [STATEMONAD],
332
            extend_statemonad(
333
                [node.target.id],
334
                [
335
                    transform_ext_params_map(node.target.typ)(
336
                        plt.Apply(compiled_e, plt.Var(STATEMONAD))
337
                    )
338
                ],
339
                plt.Var(STATEMONAD),
340
            ),
341
        )
342

343
    def visit_Name(self, node: TypedName) -> plt.AST:
1✔
344
        # depending on load or store context, return the value of the variable or its name
345
        if not isinstance(node.ctx, Load):
1✔
346
            raise NotImplementedError(f"Context {node.ctx} not supported")
×
347
        if isinstance(node.typ, ClassType):
1✔
348
            # if this is not an instance but a class, call the constructor
349
            return plt.Lambda(
1✔
350
                [STATEMONAD],
351
                node.typ.constr(),
352
            )
353
        return plt.Lambda(
1✔
354
            [STATEMONAD],
355
            plt.FunctionalMapAccess(
356
                plt.Var(STATEMONAD),
357
                plt.ByteString(node.id),
358
                plt.TraceError(f"NameError: {node.orig_id}"),
359
            ),
360
        )
361

362
    def visit_Expr(self, node: TypedExpr) -> plt.AST:
1✔
363
        # we exploit UPLCs eager evaluation here
364
        # the expression is computed even though its value is eventually discarded
365
        # Note this really only makes sense for Trace
366
        return plt.Lambda(
1✔
367
            [STATEMONAD],
368
            plt.Apply(
369
                plt.Lambda(["_"], plt.Var(STATEMONAD)),
370
                plt.Apply(self.visit(node.value), plt.Var(STATEMONAD)),
371
            ),
372
        )
373

374
    def visit_Call(self, node: TypedCall) -> plt.AST:
1✔
375
        # compiled_args = " ".join(f"({self.visit(a)} {STATEMONAD})" for a in node.args)
376
        # return rf"(\{STATEMONAD} -> ({self.visit(node.func)} {compiled_args})"
377
        # TODO function is actually not of type polymorphic function type here anymore
378
        if isinstance(node.func.typ, PolymorphicFunctionInstanceType):
1✔
379
            # edge case for weird builtins that are polymorphic
380
            func_plt = node.func.typ.polymorphic_function.impl_from_args(
1✔
381
                node.func.typ.typ.argtyps
382
            )
383
        else:
384
            func_plt = plt.Apply(self.visit(node.func), plt.Var(STATEMONAD))
1✔
385
        args = []
1✔
386
        for a, t in zip(node.args, node.func.typ.typ.argtyps):
1✔
387
            assert isinstance(t, InstanceType)
1✔
388
            # pass in all arguments evaluated with the statemonad
389
            a_int = plt.Apply(self.visit(a), plt.Var(STATEMONAD))
1✔
390
            if isinstance(t.typ, AnyType):
1✔
391
                # if the function expects input of generic type data, wrap data before passing it inside
392
                a_int = transform_output_map(a.typ)(a_int)
1✔
393
            args.append(a_int)
1✔
394
        return plt.Lambda(
1✔
395
            [STATEMONAD],
396
            plt.Apply(
397
                func_plt,
398
                *args,
399
                # eventually pass in the state monad as well
400
                plt.Var(STATEMONAD),
401
            ),
402
        )
403

404
    def visit_FunctionDef(self, node: TypedFunctionDef) -> plt.AST:
1✔
405
        body = node.body.copy()
1✔
406
        if not isinstance(body[-1], Return):
1✔
407
            tr = Return(None)
1✔
408
            tr.typ = NoneInstanceType
1✔
409
            assert (
1✔
410
                node.typ.typ.rettyp == NoneInstanceType
411
            ), "Function has no return statement but is supposed to return not-None value"
412
            body.append(tr)
1✔
413
        compiled_body = self.visit_sequence(body[:-1])
1✔
414
        compiled_return = self.visit(body[-1].value)
1✔
415
        args_state = extend_statemonad(
1✔
416
            # the function can see its argument under the argument names
417
            [a.arg for a in node.args.args],
418
            [plt.Var(f"p{i}") for i in range(len(node.args.args))],
419
            plt.Var(STATEMONAD),
420
        )
421
        return plt.Lambda(
1✔
422
            [STATEMONAD],
423
            extend_statemonad(
424
                [node.name],
425
                [
426
                    plt.Lambda(
427
                        # expect the statemonad again -> this is the basis for internally available values
428
                        [f"p{i}" for i in range(len(node.args.args))] + [STATEMONAD],
429
                        plt.Apply(
430
                            compiled_return,
431
                            plt.Apply(
432
                                compiled_body,
433
                                args_state,
434
                            ),
435
                        ),
436
                    )
437
                ],
438
                plt.Var(STATEMONAD),
439
            ),
440
        )
441

442
    def visit_While(self, node: TypedWhile) -> plt.AST:
1✔
443
        compiled_c = self.visit(node.test)
1✔
444
        compiled_s = self.visit_sequence(node.body)
1✔
445
        if node.orelse:
1✔
446
            # If there is orelse, transform it to an appended sequence (TODO check if this is correct)
447
            cn = copy(node)
×
448
            cn.orelse = []
×
449
            return self.visit_sequence([cn] + node.orelse)
×
450
        # return rf"(\{STATEMONAD} -> let g = (\s f -> if ({compiled_c} s) then f ({compiled_s} s) f else s) in (g {STATEMONAD} g))"
451
        return plt.Lambda(
1✔
452
            [STATEMONAD],
453
            plt.Let(
454
                bindings=[
455
                    (
456
                        "g",
457
                        plt.Lambda(
458
                            ["s", "f"],
459
                            plt.Ite(
460
                                plt.Apply(compiled_c, plt.Var("s")),
461
                                plt.Apply(
462
                                    plt.Var("f"),
463
                                    plt.Apply(compiled_s, plt.Var("s")),
464
                                    plt.Var("f"),
465
                                ),
466
                                plt.Var("s"),
467
                            ),
468
                        ),
469
                    ),
470
                ],
471
                term=plt.Apply(plt.Var("g"), plt.Var(STATEMONAD), plt.Var("g")),
472
            ),
473
        )
474

475
    def visit_For(self, node: TypedFor) -> plt.AST:
1✔
476
        if node.orelse:
1✔
477
            # If there is orelse, transform it to an appended sequence (TODO check if this is correct)
478
            cn = copy(node)
×
479
            cn.orelse = []
×
480
            return self.visit_sequence([cn] + node.orelse)
×
481
        assert isinstance(node.iter.typ, InstanceType)
1✔
482
        if isinstance(node.iter.typ.typ, ListType):
1✔
483
            assert isinstance(
1✔
484
                node.target, Name
485
            ), "Can only assign value to singleton element"
486
            return plt.Lambda(
1✔
487
                [STATEMONAD],
488
                plt.FoldList(
489
                    plt.Apply(self.visit(node.iter), plt.Var(STATEMONAD)),
490
                    plt.Lambda(
491
                        [STATEMONAD, "e"],
492
                        plt.Apply(
493
                            self.visit_sequence(node.body),
494
                            extend_statemonad(
495
                                [node.target.id],
496
                                [plt.Var("e")],
497
                                plt.Var(STATEMONAD),
498
                            ),
499
                        ),
500
                    ),
501
                    plt.Var(STATEMONAD),
502
                ),
503
            )
504
        raise NotImplementedError(
×
505
            "Compilation of for statements for anything but lists not implemented yet"
506
        )
507

508
    def visit_If(self, node: TypedIf) -> plt.AST:
1✔
509
        return plt.Lambda(
1✔
510
            [STATEMONAD],
511
            plt.Ite(
512
                plt.Apply(self.visit(node.test), plt.Var(STATEMONAD)),
513
                plt.Apply(self.visit_sequence(node.body), plt.Var(STATEMONAD)),
514
                plt.Apply(self.visit_sequence(node.orelse), plt.Var(STATEMONAD)),
515
            ),
516
        )
517

518
    def visit_Return(self, node: TypedReturn) -> plt.AST:
1✔
519
        raise NotImplementedError(
×
520
            "Compilation of return statements except for last statement in function is not supported."
521
        )
522

523
    def visit_Pass(self, node: TypedPass) -> plt.AST:
1✔
524
        return self.visit_sequence([])
×
525

526
    def visit_Subscript(self, node: TypedSubscript) -> plt.AST:
1✔
527
        assert isinstance(
1✔
528
            node.value.typ, InstanceType
529
        ), "Can only access elements of instances, not classes"
530
        if isinstance(node.value.typ.typ, TupleType):
1✔
531
            assert isinstance(
1✔
532
                node.slice, Index
533
            ), "Only single index slices for tuples are currently supported"
534
            assert isinstance(
1✔
535
                node.slice.value, Constant
536
            ), "Only constant index access for tuples is supported"
537
            assert isinstance(
1✔
538
                node.slice.value.value, int
539
            ), "Only constant index integer access for tuples is supported"
540
            index = node.slice.value.value
1✔
541
            if index < 0:
1✔
542
                index += len(node.value.typ.typ.typs)
×
543
            assert isinstance(node.ctx, Load), "Tuples are read-only"
1✔
544
            return plt.Lambda(
1✔
545
                [STATEMONAD],
546
                plt.FunctionalTupleAccess(
547
                    plt.Apply(self.visit(node.value), plt.Var(STATEMONAD)),
548
                    index,
549
                    len(node.value.typ.typ.typs),
550
                ),
551
            )
552
        if isinstance(node.value.typ.typ, ListType):
1✔
553
            assert isinstance(
1✔
554
                node.slice, Index
555
            ), "Only single index slices for lists are currently supported"
556
            assert (
1✔
557
                node.slice.value.typ == IntegerInstanceType
558
            ), "Only single element list index access supported"
559
            return plt.Lambda(
1✔
560
                [STATEMONAD],
561
                plt.Let(
562
                    [
563
                        ("l", plt.Apply(self.visit(node.value), plt.Var(STATEMONAD))),
564
                        (
565
                            "raw_i",
566
                            plt.Apply(
567
                                self.visit(node.slice.value), plt.Var(STATEMONAD)
568
                            ),
569
                        ),
570
                        (
571
                            "i",
572
                            plt.Ite(
573
                                plt.LessThanInteger(plt.Var("raw_i"), plt.Integer(0)),
574
                                plt.AddInteger(
575
                                    plt.Var("raw_i"), plt.LengthList(plt.Var("l"))
576
                                ),
577
                                plt.Var("raw_i"),
578
                            ),
579
                        ),
580
                    ],
581
                    plt.IndexAccessList(plt.Var("l"), plt.Var("i")),
582
                ),
583
            )
584
        elif isinstance(node.value.typ.typ, ByteStringType):
1✔
585
            if isinstance(node.slice, Index):
1✔
586
                return plt.Lambda(
1✔
587
                    [STATEMONAD],
588
                    plt.Let(
589
                        [
590
                            (
591
                                "bs",
592
                                plt.Apply(self.visit(node.value), plt.Var(STATEMONAD)),
593
                            ),
594
                            (
595
                                "raw_ix",
596
                                plt.Apply(
597
                                    self.visit(node.slice.value), plt.Var(STATEMONAD)
598
                                ),
599
                            ),
600
                            (
601
                                "ix",
602
                                plt.Ite(
603
                                    plt.LessThanInteger(
604
                                        plt.Var("raw_ix"), plt.Integer(0)
605
                                    ),
606
                                    plt.AddInteger(
607
                                        plt.Var("raw_ix"),
608
                                        plt.LengthOfByteString(plt.Var("bs")),
609
                                    ),
610
                                    plt.Var("raw_ix"),
611
                                ),
612
                            ),
613
                        ],
614
                        plt.IndexByteString(plt.Var("bs"), plt.Var("ix")),
615
                    ),
616
                )
617
            elif isinstance(node.slice, Slice):
1✔
618
                return plt.Lambda(
1✔
619
                    [STATEMONAD],
620
                    plt.Let(
621
                        [
622
                            (
623
                                "bs",
624
                                plt.Apply(self.visit(node.value), plt.Var(STATEMONAD)),
625
                            ),
626
                            (
627
                                "raw_i",
628
                                plt.Apply(
629
                                    self.visit(node.slice.lower), plt.Var(STATEMONAD)
630
                                ),
631
                            ),
632
                            (
633
                                "i",
634
                                plt.Ite(
635
                                    plt.LessThanInteger(
636
                                        plt.Var("raw_i"), plt.Integer(0)
637
                                    ),
638
                                    plt.AddInteger(
639
                                        plt.Var("raw_i"),
640
                                        plt.LengthOfByteString(plt.Var("bs")),
641
                                    ),
642
                                    plt.Var("raw_i"),
643
                                ),
644
                            ),
645
                            (
646
                                "raw_j",
647
                                plt.Apply(
648
                                    self.visit(node.slice.upper), plt.Var(STATEMONAD)
649
                                ),
650
                            ),
651
                            (
652
                                "j",
653
                                plt.Ite(
654
                                    plt.LessThanInteger(
655
                                        plt.Var("raw_j"), plt.Integer(0)
656
                                    ),
657
                                    plt.AddInteger(
658
                                        plt.Var("raw_j"),
659
                                        plt.LengthOfByteString(plt.Var("bs")),
660
                                    ),
661
                                    plt.Var("raw_j"),
662
                                ),
663
                            ),
664
                            (
665
                                "drop",
666
                                plt.Ite(
667
                                    plt.LessThanEqualsInteger(
668
                                        plt.Var("i"), plt.Integer(0)
669
                                    ),
670
                                    plt.Integer(0),
671
                                    plt.Var("i"),
672
                                ),
673
                            ),
674
                            (
675
                                "take",
676
                                plt.SubtractInteger(plt.Var("j"), plt.Var("drop")),
677
                            ),
678
                        ],
679
                        plt.Ite(
680
                            plt.LessThanEqualsInteger(plt.Var("j"), plt.Var("i")),
681
                            plt.ByteString(b""),
682
                            plt.SliceByteString(
683
                                plt.Var("drop"),
684
                                plt.Var("take"),
685
                                plt.Var("bs"),
686
                            ),
687
                        ),
688
                    ),
689
                )
690
        raise NotImplementedError(f"Could not implement subscript of {node}")
×
691

692
    def visit_Tuple(self, node: TypedTuple) -> plt.AST:
1✔
693
        return plt.Lambda(
1✔
694
            [STATEMONAD],
695
            plt.FunctionalTuple(
696
                *(plt.Apply(self.visit(e), plt.Var(STATEMONAD)) for e in node.elts)
697
            ),
698
        )
699

700
    def visit_ClassDef(self, node: TypedClassDef) -> plt.AST:
1✔
701
        return plt.Lambda(
1✔
702
            [STATEMONAD],
703
            extend_statemonad(
704
                [node.name],
705
                [node.class_typ.constr()],
706
                plt.Var(STATEMONAD),
707
            ),
708
        )
709

710
    def visit_Attribute(self, node: TypedAttribute) -> plt.AST:
1✔
711
        assert isinstance(
1✔
712
            node.typ, InstanceType
713
        ), "Can only access attributes of instances"
714
        obj = self.visit(node.value)
1✔
715
        attr = node.value.typ.attribute(node.attr)
1✔
716
        return plt.Lambda(
1✔
717
            [STATEMONAD], plt.Apply(attr, plt.Apply(obj, plt.Var(STATEMONAD)))
718
        )
719

720
    def visit_Assert(self, node: TypedAssert) -> plt.AST:
1✔
721
        return plt.Lambda(
1✔
722
            [STATEMONAD],
723
            plt.Ite(
724
                plt.Apply(self.visit(node.test), plt.Var(STATEMONAD)),
725
                plt.Var(STATEMONAD),
726
                plt.Apply(
727
                    plt.Error(),
728
                    plt.Trace(
729
                        plt.Apply(self.visit(node.msg), plt.Var(STATEMONAD)), plt.Unit()
730
                    )
731
                    if node.msg is not None
732
                    else plt.Unit(),
733
                ),
734
            ),
735
        )
736

737
    def visit_RawPlutoExpr(self, node: RawPlutoExpr) -> plt.AST:
1✔
738
        return node.expr
1✔
739

740
    def visit_List(self, node: TypedList) -> plt.AST:
1✔
741
        assert isinstance(node.typ, InstanceType)
1✔
742
        assert isinstance(node.typ.typ, ListType)
1✔
743
        l = empty_list(node.typ.typ.typ)
1✔
744
        for e in reversed(node.elts):
1✔
745
            l = plt.MkCons(plt.Apply(self.visit(e), plt.Var(STATEMONAD)), l)
1✔
746
        return plt.Lambda([STATEMONAD], l)
1✔
747

748
    def visit_Dict(self, node: TypedDict) -> plt.AST:
1✔
749
        assert isinstance(node.typ, InstanceType)
1✔
750
        assert isinstance(node.typ.typ, DictType)
1✔
751
        key_type = node.typ.typ.key_typ
1✔
752
        value_type = node.typ.typ.value_typ
1✔
753
        l = plt.EmptyDataPairList()
1✔
754
        for k, v in zip(node.keys, node.values):
1✔
755
            l = plt.MkCons(
1✔
756
                plt.MkPairData(
757
                    transform_output_map(key_type)(
758
                        plt.Apply(self.visit(k), plt.Var(STATEMONAD))
759
                    ),
760
                    transform_output_map(value_type)(
761
                        plt.Apply(self.visit(v), plt.Var(STATEMONAD))
762
                    ),
763
                ),
764
                l,
765
            )
766
        return plt.Lambda([STATEMONAD], l)
1✔
767

768
    def visit_IfExp(self, node: TypedIfExp) -> plt.AST:
1✔
769
        return plt.Lambda(
1✔
770
            [STATEMONAD],
771
            plt.Ite(
772
                plt.Apply(self.visit(node.test), plt.Var(STATEMONAD)),
773
                plt.Apply(self.visit(node.body), plt.Var(STATEMONAD)),
774
                plt.Apply(self.visit(node.orelse), plt.Var(STATEMONAD)),
775
            ),
776
        )
777

778
    def visit_ListComp(self, node: TypedListComp) -> plt.AST:
1✔
779
        assert len(node.generators) == 1, "Currently only one generator supported"
1✔
780
        gen = node.generators[0]
1✔
781
        assert isinstance(gen.iter.typ, InstanceType), "Only lists are valid generators"
1✔
782
        assert isinstance(gen.iter.typ.typ, ListType), "Only lists are valid generators"
1✔
783
        assert isinstance(
1✔
784
            gen.target, Name
785
        ), "Can only assign value to singleton element"
786
        lst = plt.Apply(self.visit(gen.iter), plt.Var(STATEMONAD))
1✔
787
        for ifexpr in gen.ifs:
1✔
788
            lst = plt.FilterList(
1✔
789
                lst,
790
                plt.Lambda(
791
                    ["x"],
792
                    plt.Apply(
793
                        self.visit(ifexpr),
794
                        extend_statemonad(
795
                            [gen.target.id], [plt.Var("x")], plt.Var(STATEMONAD)
796
                        ),
797
                    ),
798
                ),
799
                empty_list(gen.iter.typ.typ.typ),
800
            )
801
        return plt.Lambda(
1✔
802
            [STATEMONAD],
803
            plt.MapList(
804
                lst,
805
                plt.Lambda(
806
                    ["x"],
807
                    plt.Apply(
808
                        self.visit(node.elt),
809
                        extend_statemonad(
810
                            [gen.target.id], [plt.Var("x")], plt.Var(STATEMONAD)
811
                        ),
812
                    ),
813
                ),
814
                empty_list(node.elt.typ),
815
            ),
816
        )
817

818
    def generic_visit(self, node: AST) -> plt.AST:
1✔
819
        raise NotImplementedError(f"Can not compile {node}")
×
820

821

822
def compile(prog: AST):
1✔
823
    rewrite_steps = [
1✔
824
        # Important to call this one first - it imports all further files
825
        RewriteImport,
826
        # Rewrites that simplify the python code
827
        RewriteAugAssign,
828
        RewriteTupleAssign,
829
        RewriteImportPlutusData,
830
        RewriteImportHashlib,
831
        RewriteImportTyping,
832
        RewriteForbiddenOverwrites,
833
        RewriteImportDataclasses,
834
        RewriteInjectBuiltins,
835
        # The type inference needs to be run after complex python operations were rewritten
836
        AggressiveTypeInferencer,
837
        # Rewrites that circumvent the type inference or use its results
838
        RewriteInjectBuiltinsConstr,
839
        RewriteRemoveTypeStuff,
840
    ]
841
    for s in rewrite_steps:
1✔
842
        prog = s().visit(prog)
1✔
843
        prog = fix_missing_locations(prog)
1✔
844

845
    # from here on raw uplc may occur, so we dont attempt to fix locations
846
    compile_pipeline = [
1✔
847
        # Apply optimizations
848
        OptimizeRemoveDeadvars,
849
        OptimizeVarlen,
850
        OptimizeRemovePass,
851
        # the compiler runs last
852
        UPLCCompiler,
853
    ]
854
    for s in compile_pipeline:
1✔
855
        prog = s().visit(prog)
1✔
856

857
    return prog
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc