• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

OpShin / opshin / 18375768062

09 Oct 2025 12:07PM UTC coverage: 92.132% (-0.7%) from 92.835%
18375768062

Pull #549

github

nielstron
Fixes for library exports
Pull Request #549: Plutus V3 support

1242 of 1458 branches covered (85.19%)

Branch coverage included in aggregate %.

38 of 47 new or added lines in 8 files covered. (80.85%)

27 existing lines in 4 files now uncovered.

4566 of 4846 relevant lines covered (94.22%)

4.71 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

96.19
/opshin/compiler.py
1
import ast
5✔
2
import copy
5✔
3
import typing
5✔
4
from ast import Load, Name, Constant, Slice
5✔
5

6
import pluthon as plt
5✔
7
import uplc.ast as uplc
5✔
8
from pycardano import PlutusData
5✔
9
from uplc.ast import data_from_cbor
5✔
10

11
from .bridge import to_uplc_builtin
5✔
12
from .optimize.optimize_remove_trace import OptimizeRemoveTrace
5✔
13
from .prelude import Nothing
5✔
14
from .type_impls import (
5✔
15
    InstanceType,
16
    UnionType,
17
    UnitType,
18
    RecordType,
19
    transform_ext_params_map,
20
    AnyType,
21
    transform_output_map,
22
    ClassType,
23
    PolymorphicFunctionInstanceType,
24
    ListType,
25
    TupleType,
26
    PairType,
27
    IntegerInstanceType,
28
    empty_list,
29
    DictType,
30
    ByteStringType,
31
    FunctionType,
32
    OUnit,
33
    UnitInstanceType,
34
)
35
from .type_inference import map_to_orig_name, AggressiveTypeInferencer
5✔
36
from .typed_ast import *
5✔
37

38
from .compiler_config import DEFAULT_CONFIG
5✔
39
from .optimize.optimize_const_folding import OptimizeConstantFolding
5✔
40
from .optimize.optimize_remove_deadconstants import OptimizeRemoveDeadconstants
5✔
41
from .optimize.optimize_union_expansion import OptimizeUnionExpansion
5✔
42

43
from .rewrite.rewrite_assert_none import RewriteAssertNone
5✔
44
from .rewrite.rewrite_augassign import RewriteAugAssign
5✔
45
from .rewrite.rewrite_cast_condition import RewriteConditions
5✔
46
from .rewrite.rewrite_comparison_chaining import RewriteComparisonChaining
5✔
47
from .rewrite.rewrite_empty_dicts import RewriteEmptyDicts
5✔
48
from .rewrite.rewrite_empty_lists import RewriteEmptyLists
5✔
49
from .rewrite.rewrite_forbidden_overwrites import RewriteForbiddenOverwrites
5✔
50
from .rewrite.rewrite_forbidden_return import RewriteForbiddenReturn
5✔
51
from .rewrite.rewrite_import import RewriteImport
5✔
52
from .rewrite.rewrite_import_dataclasses import RewriteImportDataclasses
5✔
53
from .rewrite.rewrite_import_hashlib import RewriteImportHashlib
5✔
54
from .rewrite.rewrite_import_integrity_check import RewriteImportIntegrityCheck
5✔
55
from .rewrite.rewrite_import_plutusdata import RewriteImportPlutusData
5✔
56
from .rewrite.rewrite_import_typing import RewriteImportTyping
5✔
57
from .rewrite.rewrite_import_uplc_builtins import RewriteImportUPLCBuiltins
5✔
58
from .rewrite.rewrite_inject_builtins import RewriteInjectBuiltins
5✔
59
from .rewrite.rewrite_orig_name import RewriteOrigName
5✔
60
from .rewrite.rewrite_remove_type_stuff import RewriteRemoveTypeStuff
5✔
61
from .rewrite.rewrite_scoping import RewriteScoping
5✔
62
from .rewrite.rewrite_subscript38 import RewriteSubscript38
5✔
63
from .rewrite.rewrite_tuple_assign import RewriteTupleAssign
5✔
64
from .optimize.optimize_remove_pass import OptimizeRemovePass
5✔
65
from .optimize.optimize_remove_deadvars import OptimizeRemoveDeadvars, NameLoadCollector
5✔
66
from .util import (
5✔
67
    CompilingNodeTransformer,
68
    NoOp,
69
    OVar,
70
    OLambda,
71
    OLet,
72
    OPSHIN_LOGGER,
73
    all_vars,
74
    SafeOLambda,
75
    opshin_name_scheme_compatible_varname,
76
    force_params,
77
    SafeApply,
78
    SafeLambda,
79
    written_vars,
80
    custom_fix_missing_locations,
81
)
82

83

84
BoolOpMap = {
5✔
85
    ast.And: plt.And,
86
    ast.Or: plt.Or,
87
}
88

89

90
def rec_constant_map_data(c):
5✔
91
    if isinstance(c, bool):
5✔
92
        return uplc.PlutusInteger(int(c))
5✔
93
    if isinstance(c, int):
5✔
94
        return uplc.PlutusInteger(c)
5✔
95
    if isinstance(c, type(None)):
5!
96
        return uplc.PlutusConstr(0, [])
×
97
    if isinstance(c, bytes):
5✔
98
        return uplc.PlutusByteString(c)
5✔
99
    if isinstance(c, str):
5✔
100
        return uplc.PlutusByteString(c.encode())
5✔
101
    if isinstance(c, list):
5✔
102
        return uplc.PlutusList([rec_constant_map_data(ce) for ce in c])
5✔
103
    if isinstance(c, dict):
5!
104
        return uplc.PlutusMap(
5✔
105
            dict(
106
                zip(
107
                    (rec_constant_map_data(ce) for ce in c.keys()),
108
                    (rec_constant_map_data(ce) for ce in c.values()),
109
                )
110
            )
111
        )
112
    # This can occur when PlutusData is generated during constant folding
113
    if isinstance(c, PlutusData):
×
114
        return data_from_cbor(c.to_cbor())
×
115
    raise NotImplementedError(f"Unsupported constant type {type(c)}")
116

117

118
def rec_constant_map(c):
5✔
119
    if isinstance(c, bool):
5✔
120
        return uplc.BuiltinBool(c)
5✔
121
    if isinstance(c, int):
5✔
122
        return uplc.BuiltinInteger(c)
5✔
123
    if isinstance(c, type(None)):
5✔
124
        return uplc.BuiltinUnit()
5✔
125
    if isinstance(c, bytes):
5✔
126
        return uplc.BuiltinByteString(c)
5✔
127
    if isinstance(c, str):
5✔
128
        return uplc.BuiltinString(c)
5✔
129
    if isinstance(c, list):
5✔
130
        return uplc.BuiltinList([rec_constant_map(ce) for ce in c])
5✔
131
    if isinstance(c, dict):
5✔
132
        return uplc.BuiltinList(
5✔
133
            [
134
                uplc.BuiltinPair(*p)
135
                for p in zip(
136
                    (rec_constant_map_data(ce) for ce in c.keys()),
137
                    (rec_constant_map_data(ce) for ce in c.values()),
138
                )
139
            ]
140
        )
141
    # This can occur when PlutusData is generated during constant folding
142
    if isinstance(c, PlutusData):
5✔
143
        return data_from_cbor(c.to_cbor())
5✔
144
    raise NotImplementedError(f"Unsupported constant type {type(c)}")
145

146

147
def wrap_validator_double_function(x: plt.AST, pass_through: int = 0):
5✔
148
    """
149
    Wraps the validator function to enable a double function as minting script
150

151
    pass_through defines how many parameters x would normally take and should be passed through to x
152
    """
UNCOV
153
    return OLambda(
×
154
        [f"v{i}" for i in range(pass_through)] + ["a0", "a1"],
155
        OLet(
156
            [("p", plt.Apply(x, *(OVar(f"v{i}") for i in range(pass_through))))],
157
            plt.Ite(
158
                # if the second argument has constructor 0 = script context
159
                plt.DelayedChooseData(
160
                    OVar("a1"),
161
                    plt.EqualsInteger(plt.Constructor(OVar("a1")), plt.Integer(0)),
162
                    plt.Bool(False),
163
                    plt.Bool(False),
164
                    plt.Bool(False),
165
                    plt.Bool(False),
166
                ),
167
                # call the validator with a0, a1, and plug in "Nothing" for data
168
                plt.Apply(
169
                    OVar("p"),
170
                    plt.UPLCConstant(to_uplc_builtin(Nothing())),
171
                    OVar("a0"),
172
                    OVar("a1"),
173
                ),
174
                # else call the validator with a0, a1 and return (now partially bound)
175
                plt.Apply(OVar("p"), OVar("a0"), OVar("a1")),
176
            ),
177
        ),
178
    )
179

180

181
CallAST = typing.Callable[[plt.AST], plt.AST]
5✔
182

183

184
class PlutoCompiler(CompilingNodeTransformer):
5✔
185
    """
186
    Expects a TypedAST and returns UPLC/Pluto like code
187
    """
188

189
    step = "Compiling python statements to UPLC"
5✔
190

191
    def __init__(
5✔
192
        self,
193
        validator_function_name="validator",
194
        config=DEFAULT_CONFIG,
195
    ):
196
        # parameters
197
        self.validator_function_name = validator_function_name
5✔
198
        self.config = config
5✔
199
        assert (
5✔
200
            self.config.fast_access_skip is None or self.config.fast_access_skip > 1
201
        ), "Parameter fast-access-skip needs to be greater than 1 or omitted"
202
        # marked knowledge during compilation
203
        self.current_function_typ: typing.List[FunctionType] = []
5✔
204

205
    def visit_sequence(self, node_seq: typing.List[typedstmt]) -> CallAST:
5✔
206
        def g(s: plt.AST):
5✔
207
            for n in reversed(node_seq):
5✔
208
                compiled_stmt = self.visit(n)
5✔
209
                s = compiled_stmt(s)
5✔
210
            return s
5✔
211

212
        return g
5✔
213

214
    def visit_BinOp(self, node: TypedBinOp) -> plt.AST:
5✔
215
        op = node.left.typ.binop(node.op, node.right)
5✔
216
        return plt.Apply(
5✔
217
            op,
218
            self.visit(node.left),
219
            self.visit(node.right),
220
        )
221

222
    def visit_BoolOp(self, node: TypedBoolOp) -> plt.AST:
5✔
223
        op = BoolOpMap.get(type(node.op))
5✔
224
        assert len(node.values) >= 2, "Need to compare at least to values"
5✔
225
        ops = op(
5✔
226
            self.visit(node.values[0]),
227
            self.visit(node.values[1]),
228
        )
229
        for v in node.values[2:]:
5✔
230
            ops = op(ops, self.visit(v))
5✔
231
        return ops
5✔
232

233
    def visit_UnaryOp(self, node: TypedUnaryOp) -> plt.AST:
5✔
234
        op = node.operand.typ.unop(node.op)
5✔
235
        return plt.Apply(
5✔
236
            op,
237
            self.visit(node.operand),
238
        )
239

240
    def visit_Compare(self, node: TypedCompare) -> plt.AST:
5✔
241
        assert len(node.ops) == 1, "Only single comparisons are supported"
5✔
242
        assert len(node.comparators) == 1, "Only single comparisons are supported"
5✔
243
        cmpop = node.ops[0]
5✔
244
        comparator = node.comparators[0].typ
5✔
245
        op = node.left.typ.cmp(cmpop, comparator)
5✔
246
        return plt.Apply(
5✔
247
            op,
248
            self.visit(node.left),
249
            self.visit(node.comparators[0]),
250
        )
251

252
    def visit_Module(self, node: TypedModule) -> plt.AST:
5✔
253
        # extract actually read variables by each function
254
        if self.validator_function_name is not None:
5!
255
            # for validators find main function
256
            # TODO can use more sophisiticated procedure here i.e. functions marked by comment
257
            main_fun: typing.Optional[InstanceType] = None
5✔
258
            for s in node.body:
5✔
259
                if (
5✔
260
                    isinstance(s, ast.FunctionDef)
261
                    and s.orig_name == self.validator_function_name
262
                ):
263
                    main_fun = s
5✔
264
            assert (
5✔
265
                main_fun is not None
266
            ), f"Could not find function named {self.validator_function_name}"
267
            main_fun_typ: FunctionType = main_fun.typ.typ
5✔
268
            assert isinstance(
5✔
269
                main_fun_typ, FunctionType
270
            ), f"Variable named {self.validator_function_name} is not of type function"
271

272
            body = node.body + (
5✔
273
                [
274
                    TypedReturn(
275
                        TypedCall(
276
                            func=ast.Name(
277
                                id=main_fun.name,
278
                                typ=InstanceType(main_fun_typ),
279
                                ctx=ast.Load(),
280
                            ),
281
                            typ=main_fun_typ.rettyp,
282
                            args=[
283
                                RawPlutoExpr(
284
                                    expr=(
285
                                        transform_ext_params_map(a)(
286
                                            OVar(f"val_param{i}")
287
                                        )
288
                                        if self.config.unwrap_input
289
                                        else OVar(f"val_param{i}")
290
                                    ),
291
                                    typ=a,
292
                                )
293
                                for i, a in enumerate(main_fun_typ.argtyps)
294
                            ],
295
                        )
296
                    ),
297
                ]
298
            )
299
            # TODO probably need to handle here when user wants to return something specific
300
            self.current_function_typ.append(
5✔
301
                FunctionType(
302
                    [],
303
                    InstanceType(
304
                        UnitType() if not self.config.wrap_output else AnyType()
305
                    ),
306
                )
307
            )
308
            name_load_visitor = NameLoadCollector()
5✔
309
            name_load_visitor.visit(node)
5✔
310
            all_vs = sorted(set(all_vars(node)) | set(name_load_visitor.loaded.keys()))
5✔
311

312
            # write all variables that are ever read
313
            # once at the beginning so that we can always access them (only potentially causing a nameerror at runtime)
314
            validator = SafeOLambda(
5✔
315
                [f"val_param{i}" for i, _ in enumerate(main_fun_typ.argtyps)],
316
                plt.Let(
317
                    [
318
                        (
319
                            x,
320
                            plt.Delay(
321
                                plt.TraceError(f"NameError: {map_to_orig_name(x)}")
322
                            ),
323
                        )
324
                        for x in all_vs
325
                    ],
326
                    self.visit_sequence(body)(plt.Unit()),
327
                ),
328
            )
329
            self.current_function_typ.pop()
5✔
330
        else:
UNCOV
331
            name_load_visitor = NameLoadCollector()
×
UNCOV
332
            name_load_visitor.visit(node)
×
UNCOV
333
            all_vs = sorted(set(all_vars(node)) | set(name_load_visitor.loaded.keys()))
×
334

UNCOV
335
            body = node.body
×
336
            # write all variables that are ever read
337
            # once at the beginning so that we can always access them (only potentially causing a nameerror at runtime)
UNCOV
338
            validator = plt.Let(
×
339
                [
340
                    (
341
                        x,
342
                        plt.Delay(plt.TraceError(f"NameError: {map_to_orig_name(x)}")),
343
                    )
344
                    for x in all_vs
345
                ],
346
                self.visit_sequence(body)(OUnit),
347
            )
348

349
        cp = plt.Program((1, 0, 0), validator)
5✔
350
        return cp
5✔
351

352
    def visit_Constant(self, node: Constant) -> plt.AST:
5✔
353
        if isinstance(node.value, bytes) and node.value != b"":
5✔
354
            try:
5✔
355
                bytes.fromhex(node.value.decode())
5✔
356
            except ValueError:
5✔
357
                pass
5✔
358
            else:
359
                OPSHIN_LOGGER.warning(
5✔
360
                    f"The string {node.value} looks like it is supposed to be a hex-encoded bytestring but is actually utf8-encoded. Try using `bytes.fromhex('{node.value.decode()}')` instead."
361
                )
362
        plt_val = plt.UPLCConstant(rec_constant_map(node.value))
5✔
363
        return plt_val
5✔
364

365
    def visit_NoneType(self, _: typing.Optional[typing.Any]) -> plt.AST:
5✔
366
        return plt.Unit()
×
367

368
    def visit_Assign(self, node: TypedAssign) -> CallAST:
5✔
369
        assert (
5✔
370
            len(node.targets) == 1
371
        ), "Assignments to more than one variable not supported yet"
372
        assert isinstance(
5✔
373
            node.targets[0], Name
374
        ), "Assignments to other things then names are not supported"
375
        compiled_e = self.visit(node.value)
5✔
376
        varname = node.targets[0].id
5✔
377
        if (hasattr(node.targets[0], "is_wrapped") and node.targets[0].is_wrapped) or (
5✔
378
            isinstance(node.targets[0].typ, InstanceType)
379
            and (
380
                isinstance(node.targets[0].typ.typ, AnyType)
381
                or isinstance(node.targets[0].typ.typ, UnionType)
382
            )
383
        ):
384
            # if this is a wrapped variable or Union/Any, we need to map it back to the external parameter type
385
            # TODO this is terribly inefficient. we would rather want to cast once when entering the body and cast back when leaving
386
            compiled_e = transform_output_map(node.value.typ)(compiled_e)
5✔
387
        # first evaluate the term, then wrap in a delay
388
        return lambda x: plt.Let(
5✔
389
            [
390
                (opshin_name_scheme_compatible_varname(varname), compiled_e),
391
                (varname, plt.Delay(OVar(varname))),
392
            ],
393
            x,
394
        )
395

396
    def visit_AnnAssign(self, node: TypedAnnAssign) -> CallAST:
5✔
397
        assert isinstance(
5✔
398
            node.target, Name
399
        ), "Assignments to other things then names are not supported"
400
        assert isinstance(
5✔
401
            node.target.typ, InstanceType
402
        ), "Can only assign instances to instances"
403
        val = self.visit(node.value)
5✔
404
        if isinstance(node.value.typ, InstanceType) and (
5✔
405
            isinstance(node.value.typ.typ, AnyType)
406
            or isinstance(node.value.typ.typ, UnionType)
407
        ):
408
            # we need to map this as it will originate from PlutusData
409
            # AnyType is the only type other than the builtin itself that can be cast to builtin values
410
            val = transform_ext_params_map(node.target.typ)(val)
5✔
411
        if isinstance(node.target.typ, InstanceType) and (
5✔
412
            isinstance(node.target.typ.typ, AnyType)
413
            or isinstance(node.target.typ.typ, UnionType)
414
        ):
415
            # we need to map this back as it will be treated as PlutusData
416
            # AnyType is the only type other than the builtin itself that can be cast to from builtin values
417
            val = transform_output_map(node.value.typ)(val)
5✔
418
        return lambda x: plt.Let(
5✔
419
            [
420
                (opshin_name_scheme_compatible_varname(node.target.id), val),
421
                (node.target.id, plt.Delay(OVar(node.target.id))),
422
            ],
423
            x,
424
        )
425

426
    def visit_Name(self, node: Name) -> plt.AST:
5✔
427
        # depending on load or store context, return the value of the variable or its name
428
        if not isinstance(node.ctx, Load):
5✔
429
            raise NotImplementedError(f"Context {node.ctx} not supported")
430
        if isinstance(node.typ, ClassType):
5✔
431
            # if this is not an instance but a class, call the constructor
432
            return node.typ.constr()
5✔
433
        if hasattr(node, "is_wrapped") and node.is_wrapped:
5✔
434
            return transform_ext_params_map(node.typ)(plt.Force(plt.Var(node.id)))
5✔
435
        return plt.Force(plt.Var(node.id))
5✔
436

437
    def visit_Expr(self, node: TypedExpr) -> CallAST:
5✔
438
        # we exploit UPLCs eager evaluation here
439
        # the expression is computed even though its value is eventually discarded
440
        # Note this really only makes sense for Trace
441
        # we use an invalid name here to avoid conflicts
442
        return lambda x: plt.Apply(OLambda(["0"], x), self.visit(node.value))
5✔
443

444
    def visit_Call(self, node: TypedCall) -> plt.AST:
5✔
445
        # compiled_args = " ".join(f"({self.visit(a)} {STATEMONAD})" for a in node.args)
446
        # return rf"(\{STATEMONAD} -> ({self.visit(node.func)} {compiled_args})"
447
        # TODO function is actually not of type polymorphic function type here anymore
448
        if isinstance(node.func.typ, PolymorphicFunctionInstanceType):
5✔
449
            # edge case for weird builtins that are polymorphic
450
            func_plt = force_params(
5✔
451
                node.func.typ.polymorphic_function.impl_from_args(
452
                    node.func.typ.typ.argtyps
453
                )
454
            )
455
            bind_self = None
5✔
456
        else:
457
            assert isinstance(node.func.typ, InstanceType) and isinstance(
5✔
458
                node.func.typ.typ, FunctionType
459
            ), "Can only call instances of functions"
460
            func_plt = self.visit(node.func)
5✔
461
            bind_self = node.func.typ.typ.bind_self
5✔
462
        bound_vs = sorted(list(node.func.typ.typ.bound_vars.keys()))
5✔
463
        args = []
5✔
464
        for i, (a, t) in enumerate(zip(node.args, node.func.typ.typ.argtyps)):
5✔
465
            # now impl_from_args has been chosen, skip type arg
466
            if (
5✔
467
                hasattr(node.func, "orig_id")
468
                and node.func.orig_id == "isinstance"
469
                and i == 1
470
            ):
471
                continue
5✔
472
            assert isinstance(t, InstanceType)
5✔
473
            # pass in all arguments evaluated with the statemonad
474
            a_int = self.visit(a)
5✔
475
            if isinstance(t.typ, AnyType) or isinstance(t.typ, UnionType):
5✔
476
                # if the function expects input of generic type data, wrap data before passing it inside
477
                a_int = transform_output_map(a.typ)(a_int)
5✔
478
            args.append(a_int)
5✔
479
        # First assign to let to ensure that the arguments are evaluated before the call, but need to delay
480
        # as this is a variable assignment
481
        # Also bring all states of variables read inside the function into scope / update with value in current state
482
        # before call to simulate statemonad with current state being passed in
483
        return OLet(
5✔
484
            [(f"p{i}", a) for i, a in enumerate(args)],
485
            SafeApply(
486
                func_plt,
487
                *([plt.Var(bind_self)] if bind_self is not None else []),
488
                *[plt.Var(n) for n in bound_vs],
489
                *[plt.Delay(OVar(f"p{i}")) for i in range(len(args))],
490
            ),
491
        )
492

493
    def visit_FunctionDef(self, node: TypedFunctionDef) -> CallAST:
5✔
494
        body = node.body.copy()
5✔
495
        # defaults to returning None if there is no return statement
496
        if node.typ.typ.rettyp.typ == AnyType():
5✔
497
            ret_val = OUnit
5✔
498
        else:
499
            ret_val = plt.Unit()
5✔
500
        read_vs = sorted(list(node.typ.typ.bound_vars.keys()))
5✔
501
        if node.typ.typ.bind_self is not None:
5✔
502
            read_vs.insert(0, node.typ.typ.bind_self)
5✔
503
        self.current_function_typ.append(node.typ.typ)
5✔
504
        compiled_body = self.visit_sequence(body)(ret_val)
5✔
505
        self.current_function_typ.pop()
5✔
506
        return lambda x: plt.Let(
5✔
507
            [
508
                (
509
                    node.name,
510
                    plt.Delay(
511
                        SafeLambda(
512
                            read_vs + [a.arg for a in node.args.args],
513
                            compiled_body,
514
                        )
515
                    ),
516
                )
517
            ],
518
            x,
519
        )
520

521
    def visit_While(self, node: TypedWhile) -> CallAST:
5✔
522
        # the while loop calls itself, updating the values at overwritten names
523
        # by overwriting them with arguments to its self-recall
524
        if node.orelse:
5✔
525
            # If there is orelse, transform it to an appended sequence (TODO check if this is correct)
526
            cn = copy.copy(node)
5✔
527
            cn.orelse = []
5✔
528
            return self.visit_sequence([cn] + node.orelse)
5✔
529
        compiled_c = self.visit(node.test)
5✔
530
        compiled_s = self.visit_sequence(node.body)
5✔
531
        written_vs = written_vars(node)
5✔
532
        pwritten_vs = [plt.Var(x) for x in written_vs]
5✔
533
        s_fun = lambda x: plt.Lambda(
5✔
534
            [opshin_name_scheme_compatible_varname("while")] + written_vs,
535
            plt.Ite(
536
                compiled_c,
537
                compiled_s(
538
                    plt.Apply(
539
                        OVar("while"),
540
                        OVar("while"),
541
                        *copy.deepcopy(pwritten_vs),
542
                    )
543
                ),
544
                x,
545
            ),
546
        )
547

548
        return lambda x: OLet(
5✔
549
            [
550
                ("adjusted_next", SafeLambda(written_vs, x)),
551
                (
552
                    "while",
553
                    s_fun(
554
                        SafeApply(OVar("adjusted_next"), *copy.deepcopy(pwritten_vs))
555
                    ),
556
                ),
557
            ],
558
            plt.Apply(OVar("while"), OVar("while"), *copy.deepcopy(pwritten_vs)),
559
        )
560

561
    def visit_For(self, node: TypedFor) -> CallAST:
5✔
562
        if node.orelse:
5✔
563
            # If there is orelse, transform it to an appended sequence (TODO check if this is correct)
564
            cn = copy.copy(node)
5✔
565
            cn.orelse = []
5✔
566
            return self.visit_sequence([cn] + node.orelse)
5✔
567
        assert isinstance(node.iter.typ, InstanceType)
5✔
568
        if isinstance(node.iter.typ.typ, ListType):
5✔
569
            assert isinstance(
5✔
570
                node.target, Name
571
            ), "Can only assign value to singleton element"
572
            compiled_s = self.visit_sequence(node.body)
5✔
573
            compiled_iter = self.visit(node.iter)
5✔
574
            written_vs = written_vars(node)
5✔
575
            pwritten_vs = [plt.Var(x) for x in written_vs]
5✔
576
            s_fun = lambda x: plt.Lambda(
5✔
577
                [
578
                    opshin_name_scheme_compatible_varname("for"),
579
                    opshin_name_scheme_compatible_varname("iter"),
580
                ]
581
                + written_vs,
582
                plt.IteNullList(
583
                    OVar("iter"),
584
                    x,
585
                    plt.Let(
586
                        [(node.target.id, plt.Delay(plt.HeadList(OVar("iter"))))],
587
                        compiled_s(
588
                            plt.Apply(
589
                                OVar("for"),
590
                                OVar("for"),
591
                                plt.TailList(OVar("iter")),
592
                                *copy.deepcopy(pwritten_vs),
593
                            )
594
                        ),
595
                    ),
596
                ),
597
            )
598
            return lambda x: OLet(
5✔
599
                [
600
                    ("adjusted_next", plt.Lambda([node.target.id] + written_vs, x)),
601
                    (
602
                        "for",
603
                        s_fun(
604
                            plt.Apply(
605
                                OVar("adjusted_next"),
606
                                plt.Var(node.target.id),
607
                                *copy.deepcopy(pwritten_vs),
608
                            )
609
                        ),
610
                    ),
611
                ],
612
                plt.Apply(
613
                    OVar("for"),
614
                    OVar("for"),
615
                    compiled_iter,
616
                    *copy.deepcopy(pwritten_vs),
617
                ),
618
            )
619
        raise NotImplementedError(
620
            "Compilation of for statements for anything but lists not implemented yet"
621
        )
622

623
    def visit_If(self, node: TypedIf) -> CallAST:
5✔
624
        written_vs = written_vars(node)
5✔
625
        pwritten_vs = [plt.Var(x) for x in written_vs]
5✔
626
        return lambda x: OLet(
5✔
627
            [("adjusted_next", SafeLambda(written_vs, x))],
628
            plt.Ite(
629
                self.visit(node.test),
630
                self.visit_sequence(node.body)(
631
                    SafeApply(OVar("adjusted_next"), *copy.deepcopy(pwritten_vs))
632
                ),
633
                self.visit_sequence(node.orelse)(
634
                    SafeApply(OVar("adjusted_next"), *copy.deepcopy(pwritten_vs))
635
                ),
636
            ),
637
        )
638

639
    def visit_Return(self, node: TypedReturn) -> CallAST:
5✔
640
        value_plt = self.visit(node.value)
5✔
641
        assert self.current_function_typ, "Can not handle Return outside of a function"
5✔
642
        if isinstance(self.current_function_typ[-1].rettyp.typ, AnyType) or isinstance(
5✔
643
            self.current_function_typ[-1].rettyp.typ, UnionType
644
        ):
645
            value_plt = transform_output_map(node.value.typ)(value_plt)
5✔
646
        return lambda _: value_plt
5✔
647

648
    def visit_Subscript(self, node: TypedSubscript) -> plt.AST:
5✔
649
        assert isinstance(
5✔
650
            node.value.typ, InstanceType
651
        ), "Can only access elements of instances, not classes"
652
        if isinstance(node.value.typ.typ, TupleType):
5✔
653
            assert isinstance(
5✔
654
                node.slice, Constant
655
            ), "Only constant index access for tuples is supported"
656
            assert isinstance(
5✔
657
                node.slice.value, int
658
            ), "Only constant index integer access for tuples is supported"
659
            index = node.slice.value
5✔
660
            if index < 0:
5✔
661
                index += len(node.value.typ.typ.typs)
5✔
662
            assert isinstance(node.ctx, Load), "Tuples are read-only"
5✔
663
            return plt.FunctionalTupleAccess(
5✔
664
                self.visit(node.value),
665
                index,
666
                len(node.value.typ.typ.typs),
667
            )
668
        if isinstance(node.value.typ.typ, PairType):
5✔
669
            assert isinstance(
5✔
670
                node.slice, Constant
671
            ), "Only constant index access for pairs is supported"
672
            assert isinstance(
5✔
673
                node.slice.value, int
674
            ), "Only constant index integer access for pairs is supported"
675
            index = node.slice.value
5✔
676
            if index < 0:
5✔
677
                index += 2
5✔
678
            assert isinstance(node.ctx, Load), "Pairs are read-only"
5✔
679
            assert (
5✔
680
                0 <= index < 2
681
            ), f"Pairs only have 2 elements, index should be -2, -1, 0 or 1, found {node.slice.value}"
682
            member_func = plt.FstPair if index == 0 else plt.SndPair
5✔
683
            # the content of pairs is always Data, so we need to unwrap
684
            member_typ = node.typ
5✔
685
            return transform_ext_params_map(member_typ)(
5✔
686
                member_func(
687
                    self.visit(node.value),
688
                ),
689
            )
690
        if isinstance(node.value.typ.typ, ListType):
5✔
691
            if not isinstance(node.slice, Slice):
5✔
692
                assert (
5✔
693
                    node.slice.typ == IntegerInstanceType
694
                ), "Only single element list index access supported"
695
                if isinstance(node.slice, Constant) and node.slice.value >= 0:
5✔
696
                    index = node.slice.value
5✔
697
                    return plt.ConstantIndexAccessListFast(
5✔
698
                        self.visit(node.value),
699
                        index,
700
                    )
701
                return OLet(
5✔
702
                    [
703
                        (
704
                            "l",
705
                            self.visit(node.value),
706
                        ),
707
                        (
708
                            "raw_i",
709
                            self.visit(node.slice),
710
                        ),
711
                        (
712
                            "i",
713
                            plt.Ite(
714
                                plt.LessThanInteger(OVar("raw_i"), plt.Integer(0)),
715
                                plt.AddInteger(
716
                                    OVar("raw_i"), plt.LengthList(OVar("l"))
717
                                ),
718
                                OVar("raw_i"),
719
                            ),
720
                        ),
721
                    ],
722
                    (
723
                        plt.IndexAccessListFast(self.config.fast_access_skip)(
724
                            OVar("l"), OVar("i")
725
                        )
726
                        if self.config.fast_access_skip is not None
727
                        else plt.IndexAccessList(OVar("l"), OVar("i"))
728
                    ),
729
                )
730
            else:
731
                assert (
5✔
732
                    node.slice.upper is not None
733
                ), "Only slices with upper bound supported"
734
                assert (
5✔
735
                    node.slice.lower is not None
736
                ), "Only slices with lower bound supported"
737
                return OLet(
5✔
738
                    [
739
                        (
740
                            "xs",
741
                            self.visit(node.value),
742
                        ),
743
                        (
744
                            "raw_i",
745
                            self.visit(node.slice.lower),
746
                        ),
747
                        (
748
                            "i",
749
                            plt.Ite(
750
                                plt.LessThanInteger(OVar("raw_i"), plt.Integer(0)),
751
                                plt.AddInteger(
752
                                    OVar("raw_i"),
753
                                    plt.LengthList(OVar("xs")),
754
                                ),
755
                                OVar("raw_i"),
756
                            ),
757
                        ),
758
                        (
759
                            "raw_j",
760
                            self.visit(node.slice.upper),
761
                        ),
762
                        (
763
                            "j",
764
                            plt.Ite(
765
                                plt.LessThanInteger(OVar("raw_j"), plt.Integer(0)),
766
                                plt.AddInteger(
767
                                    OVar("raw_j"),
768
                                    plt.LengthList(OVar("xs")),
769
                                ),
770
                                OVar("raw_j"),
771
                            ),
772
                        ),
773
                        (
774
                            "drop",
775
                            plt.Ite(
776
                                plt.LessThanEqualsInteger(OVar("i"), plt.Integer(0)),
777
                                plt.Integer(0),
778
                                OVar("i"),
779
                            ),
780
                        ),
781
                        (
782
                            "take",
783
                            plt.SubtractInteger(OVar("j"), OVar("drop")),
784
                        ),
785
                    ],
786
                    plt.Ite(
787
                        plt.LessThanEqualsInteger(OVar("j"), OVar("i")),
788
                        empty_list(node.value.typ.typ.typ),
789
                        plt.SliceList(
790
                            OVar("drop"),
791
                            OVar("take"),
792
                            OVar("xs"),
793
                            empty_list(node.value.typ.typ.typ),
794
                        ),
795
                    ),
796
                )
797
        elif isinstance(node.value.typ.typ, DictType):
5✔
798
            dict_typ = node.value.typ.typ
5✔
799
            if not isinstance(node.slice, Slice):
5✔
800
                return OLet(
5✔
801
                    [
802
                        (
803
                            "key",
804
                            transform_output_map(node.slice.typ)(
805
                                self.visit(node.slice),
806
                            ),
807
                        )
808
                    ],
809
                    transform_ext_params_map(dict_typ.value_typ)(
810
                        plt.SndPair(
811
                            plt.FindList(
812
                                self.visit(node.value),
813
                                OLambda(
814
                                    ["x"],
815
                                    plt.EqualsData(
816
                                        OVar("key"),
817
                                        plt.FstPair(OVar("x")),
818
                                    ),
819
                                ),
820
                                plt.TraceError("KeyError"),
821
                            )
822
                        ),
823
                    ),
824
                )
825
        elif isinstance(node.value.typ.typ, ByteStringType):
5✔
826
            if not isinstance(node.slice, Slice):
5✔
827
                if isinstance(node.slice, Constant) and node.slice.value >= 0:
5✔
828
                    return plt.IndexByteString(
5✔
829
                        self.visit(node.value),
830
                        self.visit(node.slice),
831
                    )
832
                elif isinstance(node.slice, Constant) and node.slice.value < 0:
5✔
833
                    return plt.IndexByteString(
5✔
834
                        self.visit(node.value),
835
                        plt.AddInteger(
836
                            self.visit(node.slice),
837
                            plt.LengthOfByteString(self.visit(node.value)),
838
                        ),
839
                    )
840
                return OLet(
5✔
841
                    [
842
                        (
843
                            "bs",
844
                            self.visit(node.value),
845
                        ),
846
                        (
847
                            "raw_ix",
848
                            self.visit(node.slice),
849
                        ),
850
                        (
851
                            "ix",
852
                            plt.Ite(
853
                                plt.LessThanInteger(OVar("raw_ix"), plt.Integer(0)),
854
                                plt.AddInteger(
855
                                    OVar("raw_ix"),
856
                                    plt.LengthOfByteString(OVar("bs")),
857
                                ),
858
                                OVar("raw_ix"),
859
                            ),
860
                        ),
861
                    ],
862
                    plt.IndexByteString(OVar("bs"), OVar("ix")),
863
                )
864
            elif isinstance(node.slice, Slice):
5✔
865
                return OLet(
5✔
866
                    [
867
                        (
868
                            "bs",
869
                            self.visit(node.value),
870
                        ),
871
                        (
872
                            "raw_i",
873
                            self.visit(node.slice.lower),
874
                        ),
875
                        (
876
                            "i",
877
                            plt.Ite(
878
                                plt.LessThanInteger(OVar("raw_i"), plt.Integer(0)),
879
                                plt.AddInteger(
880
                                    OVar("raw_i"),
881
                                    plt.LengthOfByteString(OVar("bs")),
882
                                ),
883
                                OVar("raw_i"),
884
                            ),
885
                        ),
886
                        (
887
                            "raw_j",
888
                            self.visit(node.slice.upper),
889
                        ),
890
                        (
891
                            "j",
892
                            plt.Ite(
893
                                plt.LessThanInteger(OVar("raw_j"), plt.Integer(0)),
894
                                plt.AddInteger(
895
                                    OVar("raw_j"),
896
                                    plt.LengthOfByteString(OVar("bs")),
897
                                ),
898
                                OVar("raw_j"),
899
                            ),
900
                        ),
901
                        (
902
                            "drop",
903
                            plt.Ite(
904
                                plt.LessThanEqualsInteger(OVar("i"), plt.Integer(0)),
905
                                plt.Integer(0),
906
                                OVar("i"),
907
                            ),
908
                        ),
909
                        (
910
                            "take",
911
                            plt.SubtractInteger(OVar("j"), OVar("drop")),
912
                        ),
913
                    ],
914
                    plt.Ite(
915
                        plt.LessThanEqualsInteger(OVar("j"), OVar("i")),
916
                        plt.ByteString(b""),
917
                        plt.SliceByteString(
918
                            OVar("drop"),
919
                            OVar("take"),
920
                            OVar("bs"),
921
                        ),
922
                    ),
923
                )
924
        raise NotImplementedError(
925
            f'Could not implement subscript "{node.slice}" of "{node.value}"'
926
        )
927

928
    def visit_Tuple(self, node: TypedTuple) -> plt.AST:
5✔
929
        return plt.FunctionalTuple(*(self.visit(e) for e in node.elts))
5✔
930

931
    def visit_ClassDef(self, node: TypedClassDef) -> CallAST:
5✔
932
        return lambda x: plt.Let([(node.name, plt.Delay(node.class_typ.constr()))], x)
5✔
933

934
    def visit_Attribute(self, node: TypedAttribute) -> plt.AST:
5✔
935
        assert isinstance(
5✔
936
            node.value.typ, InstanceType
937
        ), "Can only access attributes of instances"
938
        obj = self.visit(node.value)
5✔
939
        attr = node.value.typ.attribute(node.attr)
5✔
940
        return plt.Apply(attr, obj)
5✔
941

942
    def visit_Assert(self, node: TypedAssert) -> CallAST:
5✔
943
        return lambda x: plt.Ite(
5✔
944
            self.visit(node.test),
945
            x,
946
            plt.Apply(
947
                plt.Error(),
948
                (
949
                    plt.Trace(self.visit(node.msg), plt.Unit())
950
                    if node.msg is not None
951
                    else plt.Unit()
952
                ),
953
            ),
954
        )
955

956
    def visit_RawPlutoExpr(self, node: RawPlutoExpr) -> plt.AST:
5✔
957
        return node.expr
5✔
958

959
    def visit_List(self, node: TypedList) -> plt.AST:
5✔
960
        assert isinstance(node.typ, InstanceType)
5✔
961
        assert isinstance(node.typ.typ, ListType)
5✔
962
        el_typ = node.typ.typ.typ
5✔
963
        l = empty_list(el_typ)
5✔
964
        for e in reversed(node.elts):
5✔
965
            element = self.visit(e)
5✔
966
            if isinstance(el_typ.typ, AnyType) or isinstance(el_typ.typ, UnionType):
5✔
967
                # if the function expects input of generic type data, wrap data before passing it inside
968
                element = transform_output_map(e.typ)(element)
5✔
969
            l = plt.MkCons(element, l)
5✔
970
        return l
5✔
971

972
    def visit_Dict(self, node: TypedDict) -> plt.AST:
5✔
973
        assert isinstance(node.typ, InstanceType)
5✔
974
        assert isinstance(node.typ.typ, DictType)
5✔
975
        key_type = node.typ.typ.key_typ
5✔
976
        value_type = node.typ.typ.value_typ
5✔
977
        l = plt.EmptyDataPairList()
5✔
978
        for k, v in zip(node.keys, node.values):
5✔
979
            l = plt.MkCons(
5✔
980
                plt.MkPairData(
981
                    transform_output_map(k.typ)(
982
                        self.visit(k),
983
                    ),
984
                    transform_output_map(v.typ)(
985
                        self.visit(v),
986
                    ),
987
                ),
988
                l,
989
            )
990
        return l
5✔
991

992
    def visit_IfExp(self, node: TypedIfExp) -> plt.AST:
5✔
993
        if isinstance(node.typ.typ, UnionType):
5✔
994
            body = self.visit(node.body)
5✔
995
            orelse = self.visit(node.orelse)
5✔
996
            if not isinstance(node.body.typ, UnionType):
5!
997
                body = transform_output_map(node.body.typ)(body)
5✔
998
            if not isinstance(node.orelse.typ, UnionType):
5!
999
                orelse = transform_output_map(node.orelse.typ)(orelse)
5✔
1000
            return plt.Ite(self.visit(node.test), body, orelse)
5✔
1001
        return plt.Ite(
5✔
1002
            self.visit(node.test),
1003
            self.visit(node.body),
1004
            self.visit(node.orelse),
1005
        )
1006

1007
    def visit_ListComp(self, node: TypedListComp) -> plt.AST:
5✔
1008
        assert len(node.generators) == 1, "Currently only one generator supported"
5✔
1009
        gen = node.generators[0]
5✔
1010
        assert isinstance(gen.iter.typ, InstanceType), "Only lists are valid generators"
5✔
1011
        assert isinstance(gen.iter.typ.typ, ListType), "Only lists are valid generators"
5✔
1012
        assert isinstance(
5✔
1013
            gen.target, Name
1014
        ), "Can only assign value to singleton element"
1015
        lst = self.visit(gen.iter)
5✔
1016
        ifs = None
5✔
1017
        for ifexpr in gen.ifs:
5✔
1018
            if ifs is None:
5!
1019
                ifs = self.visit(ifexpr)
5✔
1020
            else:
1021
                ifs = plt.And(ifs, self.visit(ifexpr))
×
1022
        map_fun = OLambda(
5✔
1023
            ["x"],
1024
            plt.Let(
1025
                [(gen.target.id, plt.Delay(OVar("x")))],
1026
                self.visit(node.elt),
1027
            ),
1028
        )
1029
        empty_list_con = empty_list(node.elt.typ)
5✔
1030
        if ifs is not None:
5✔
1031
            filter_fun = OLambda(
5✔
1032
                ["x"],
1033
                plt.Let(
1034
                    [(gen.target.id, plt.Delay(OVar("x")))],
1035
                    ifs,
1036
                ),
1037
            )
1038
            return plt.MapFilterList(
5✔
1039
                lst,
1040
                filter_fun,
1041
                map_fun,
1042
                empty_list_con,
1043
            )
1044
        else:
1045
            return plt.MapList(
5✔
1046
                lst,
1047
                map_fun,
1048
                empty_list_con,
1049
            )
1050

1051
    def visit_DictComp(self, node: TypedDictComp) -> plt.AST:
5✔
1052
        assert len(node.generators) == 1, "Currently only one generator supported"
5✔
1053
        gen = node.generators[0]
5✔
1054
        assert isinstance(gen.iter.typ, InstanceType), "Only lists are valid generators"
5✔
1055
        assert isinstance(gen.iter.typ.typ, ListType), "Only lists are valid generators"
5✔
1056
        assert isinstance(
5✔
1057
            gen.target, Name
1058
        ), "Can only assign value to singleton element"
1059
        lst = self.visit(gen.iter)
5✔
1060
        ifs = None
5✔
1061
        for ifexpr in gen.ifs:
5✔
1062
            if ifs is None:
5!
1063
                ifs = self.visit(ifexpr)
5✔
1064
            else:
1065
                ifs = plt.And(ifs, self.visit(ifexpr))
×
1066
        map_fun = OLambda(
5✔
1067
            ["x"],
1068
            plt.Let(
1069
                [(gen.target.id, plt.Delay(OVar("x")))],
1070
                plt.MkPairData(
1071
                    transform_output_map(node.key.typ)(
1072
                        self.visit(node.key),
1073
                    ),
1074
                    transform_output_map(node.value.typ)(
1075
                        self.visit(node.value),
1076
                    ),
1077
                ),
1078
            ),
1079
        )
1080
        empty_list_con = plt.EmptyDataPairList()
5✔
1081
        if ifs is not None:
5✔
1082
            filter_fun = OLambda(
5✔
1083
                ["x"],
1084
                plt.Let(
1085
                    [(gen.target.id, plt.Delay(OVar("x")))],
1086
                    ifs,
1087
                ),
1088
            )
1089
            return plt.MapFilterList(
5✔
1090
                lst,
1091
                filter_fun,
1092
                map_fun,
1093
                empty_list_con,
1094
            )
1095
        else:
1096
            return plt.MapList(
5✔
1097
                lst,
1098
                map_fun,
1099
                empty_list_con,
1100
            )
1101

1102
    def visit_FormattedValue(self, node: TypedFormattedValue) -> plt.AST:
5✔
1103
        return plt.Apply(
5✔
1104
            node.value.typ.stringify(),
1105
            self.visit(node.value),
1106
        )
1107

1108
    def visit_JoinedStr(self, node: TypedJoinedStr) -> plt.AST:
5✔
1109
        joined_str = plt.Text("")
5✔
1110
        for v in reversed(node.values):
5✔
1111
            joined_str = plt.AppendString(self.visit(v), joined_str)
5✔
1112
        return joined_str
5✔
1113

1114
    def generic_visit(self, node: TypedAST) -> plt.AST:
5✔
1115
        raise NotImplementedError(f"Can not compile {node}")
1116

1117

1118
def parse(
5✔
1119
    source: str,
1120
    filename=None,
1121
) -> ast.AST:
1122
    """
1123
    Parse source code into an AST
1124

1125
    Currently passes everything through Python's ast module.
1126
    """
1127
    tree = ast.parse(source, filename=filename)
5✔
1128
    return tree
5✔
1129

1130

1131
def compile(
5✔
1132
    prog: ast.AST,
1133
    filename=None,
1134
    validator_function_name="validator",
1135
    config=DEFAULT_CONFIG,
1136
    wrap_output=False,
1137
) -> plt.Program:
1138
    compile_pipeline = [
5✔
1139
        # Important to call this one first - it imports all further files
1140
        RewriteImport(filename=filename),
1141
        # Rewrites that simplify the python code
1142
        RewriteForbiddenReturn(),
1143
        OptimizeConstantFolding() if config.constant_folding else NoOp(),
1144
        OptimizeUnionExpansion() if config.expand_union_types else NoOp(),
1145
        RewriteSubscript38(),
1146
        RewriteAugAssign(),
1147
        RewriteComparisonChaining(),
1148
        RewriteTupleAssign(),
1149
        RewriteImportIntegrityCheck(),
1150
        RewriteImportPlutusData(),
1151
        RewriteImportHashlib(),
1152
        RewriteImportTyping(),
1153
        RewriteForbiddenOverwrites(),
1154
        RewriteImportDataclasses(),
1155
        RewriteInjectBuiltins(),
1156
        RewriteConditions(),
1157
        # Save the original names of variables
1158
        RewriteOrigName(),
1159
        RewriteScoping(),
1160
        # The type inference needs to be run after complex python operations were rewritten
1161
        AggressiveTypeInferencer(config.allow_isinstance_anything),
1162
        # Rewrites that circumvent the type inference or use its results
1163
        RewriteAssertNone(),
1164
        RewriteEmptyLists(),
1165
        RewriteEmptyDicts(),
1166
        RewriteImportUPLCBuiltins(),
1167
        RewriteRemoveTypeStuff(),
1168
        # Apply optimizations
1169
        OptimizeRemoveTrace() if config.remove_trace else NoOp(),
1170
        (
1171
            OptimizeRemoveDeadvars(validator_function_name=validator_function_name)
1172
            if config.remove_dead_code
1173
            else NoOp()
1174
        ),
1175
        OptimizeRemoveDeadconstants() if config.remove_dead_code else NoOp(),
1176
        OptimizeRemovePass(),
1177
    ]
1178
    for s in compile_pipeline:
5✔
1179
        prog = s.visit(prog)
5✔
1180
        prog = custom_fix_missing_locations(prog)
5✔
1181

1182
    # the compiler runs last
1183
    s = PlutoCompiler(
5✔
1184
        validator_function_name=validator_function_name,
1185
        config=config,
1186
    )
1187
    prog = s.visit(prog)
5✔
1188

1189
    return prog
5✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc