• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

OpShin / opshin / 18375768062

09 Oct 2025 12:07PM UTC coverage: 92.132% (-0.7%) from 92.835%
18375768062

Pull #549

github

nielstron
Fixes for library exports
Pull Request #549: Plutus V3 support

1242 of 1458 branches covered (85.19%)

Branch coverage included in aggregate %.

38 of 47 new or added lines in 8 files covered. (80.85%)

27 existing lines in 4 files now uncovered.

4566 of 4846 relevant lines covered (94.22%)

4.71 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

86.14
/opshin/type_impls.py
1
import typing
5✔
2
from dataclasses import dataclass, field
5✔
3
from typing import Callable
5✔
4

5
import itertools
5✔
6
import ast
5✔
7

8
from frozendict import frozendict
5✔
9
from frozenlist2 import frozenlist
5✔
10
from ordered_set import OrderedSet
5✔
11

12
import uplc.ast as uplc
5✔
13
import pluthon as plt
5✔
14

15
from .util import patternize, OVar, OLet, OLambda, OPSHIN_LOGGER, SafeOLambda, distinct
5✔
16

17
if typing.TYPE_CHECKING:
18
    from .typed_ast import TypedAST
19

20

21
class TypeInferenceError(AssertionError):
5✔
22
    pass
5✔
23

24

25
class Type:
5✔
26
    def __new__(meta, *args, **kwargs):
5✔
27
        klass = super().__new__(meta)
5✔
28

29
        for key in ["constr", "attribute", "cmp", "stringify", "copy_only_attributes"]:
5✔
30
            value = getattr(klass, key)
5✔
31
            wrapped = patternize(value)
5✔
32
            object.__setattr__(klass, key, wrapped)
5✔
33

34
        return klass
5✔
35

36
    def __ge__(self, other: "Type"):
5✔
37
        """
38
        Returns whether other can be substituted for this type.
39
        In other words this returns whether the interface of this type is a subset of the interface of other.
40
        Note that this is usually <= and not >=, but this needs to be fixed later.
41
        Produces a partial order on types.
42
        The top element is the most generic type and can not substitute for anything.
43
        The bottom element is the most specific type and can be substituted for anything.
44
        """
45
        raise NotImplementedError("Comparison between raw types impossible")
46

47
    def constr_type(self) -> "InstanceType":
5✔
48
        """The type of the constructor for this class"""
49
        raise TypeInferenceError(
5✔
50
            f"Object of type {self.__class__} does not have a constructor"
51
        )
52

53
    def constr(self) -> plt.AST:
5✔
54
        """The constructor for this class"""
55
        raise NotImplementedError(
56
            f"Constructor of {self.python_type()} not implemented"
57
        )
58

59
    def attribute_type(self, attr) -> "Type":
5✔
60
        """The types of the named attributes of this class"""
61
        raise TypeInferenceError(
5✔
62
            f"Object of type {self.python_type()} does not have attribute '{attr}'"
63
        )
64

65
    def attribute(self, attr) -> plt.AST:
5✔
66
        """The attributes of this class. Needs to be a lambda that expects as first argument the object itself"""
67
        raise NotImplementedError(f"Attribute {attr} not implemented for type {self}")
68

69
    def cmp(self, op: ast.cmpop, o: "Type") -> plt.AST:
5✔
70
        """The implementation of comparing this type to type o via operator op. Returns a lambda that expects as first argument the object itself and as second the comparison."""
71
        raise NotImplementedError(
72
            f"Comparison {type(op).__name__} for {self.python_type()} and {o.python_type()} is not implemented. This is likely intended because it would always evaluate to False."
73
        )
74

75
    def stringify(self, recursive: bool = False) -> plt.AST:
5✔
76
        """
77
        Returns a stringified version of the object
78

79
        The recursive parameter informs the method whether it was invoked recursively from another invocation
80
        """
81
        raise NotImplementedError(f"{self.python_type()} can not be stringified")
82

83
    def copy_only_attributes(self) -> plt.AST:
5✔
84
        """
85
        Pluthon function that returns a copy of only the attributes of the object
86
        This can only be called for UnionType and RecordType, as such the input data is always in PlutusData format and the output should be as well.
87
        """
88
        raise NotImplementedError(f"{self.python_type()} can not be copied")
89

90
    def binop_type(self, binop: ast.operator, other: "Type") -> "Type":
5✔
91
        """
92
        Type of a binary operation between self and other.
93
        """
94
        return FunctionType(
5✔
95
            [InstanceType(self), InstanceType(other)],
96
            InstanceType(self._binop_return_type(binop, other)),
97
        )
98

99
    def _binop_return_type(self, binop: ast.operator, other: "Type") -> "Type":
5✔
100
        """
101
        Return the type of a binary operation between self and other
102
        """
103
        raise NotImplementedError(
104
            f"{self.python_type()} does not implement {binop.__class__.__name__} with {other.python_type()}"
105
        )
106

107
    def binop(self, binop: ast.operator, other: "TypedAST") -> plt.AST:
5✔
108
        """
109
        Implements a binary operation between self and other
110
        """
111
        return OLambda(
5✔
112
            ["self", "other"],
113
            self._binop_bin_fun(binop, other)(OVar("self"), OVar("other")),
114
        )
115

116
    def _binop_bin_fun(
5✔
117
        self, binop: ast.operator, other: "TypedAST"
118
    ) -> Callable[[plt.AST, plt.AST], plt.AST]:
119
        """
120
        Returns a binary function that implements the binary operation between self and other.
121
        """
122
        raise NotImplementedError(
123
            f"{self.python_type()} can not be used with operation {binop.__class__.__name__} with {other.type.python_type()}"
124
        )
125

126
    def unop_type(self, unop: ast.unaryop) -> "Type":
5✔
127
        """
128
        Type of a unary operation on self.
129
        """
130
        return FunctionType(
5✔
131
            [InstanceType(self)],
132
            InstanceType(self._unop_return_type(unop)),
133
        )
134

135
    def _unop_return_type(self, unop: ast.unaryop) -> "Type":
5✔
136
        """
137
        Return the type of a binary operation between self and other
138
        """
139
        raise NotImplementedError(
140
            f"{self.python_type()} does not implement {unop.__class__.__name__}"
141
        )
142

143
    def unop(self, unop: ast.unaryop) -> plt.AST:
5✔
144
        """
145
        Implements a unary operation on self
146
        """
147
        return OLambda(
5✔
148
            ["self"],
149
            self._unop_fun(unop)(OVar("self")),
150
        )
151

152
    def _unop_fun(self, unop: ast.unaryop) -> Callable[[plt.AST], plt.AST]:
5✔
153
        """
154
        Returns a unary function that implements the unary operation on self.
155
        """
156
        raise NotImplementedError(
157
            f"{self.python_type()} can not be used with operation {unop.__class__.__name__}"
158
        )
159

160
    def pluthon_type(self, skip_constructor: bool = False) -> str:
5✔
161
        """
162
        Returns a representation of the type in pluthon.
163
        """
164
        raise NotImplementedError(
165
            f"Type {self.python_type()} does not have a pluthon representation"
166
        )
167

168
    def python_type(self):
5✔
169
        """
170
        Returns a representation of the type in python.
171
        """
172
        raise NotImplementedError(
173
            f"Type {type(self).__name__} does not have a python type representation"
174
        )
175

176

177
@dataclass(frozen=True, unsafe_hash=True)
5✔
178
class Record:
5✔
179
    name: str
5✔
180
    orig_name: str
5✔
181
    constructor: int
5✔
182
    fields: typing.Union[typing.List[typing.Tuple[str, Type]], frozenlist]
5✔
183

184
    def __post_init__(self):
5✔
185
        object.__setattr__(self, "fields", frozenlist(self.fields))
5✔
186

187
    def __ge__(self, other):
5✔
188
        assert isinstance(other, Record), "Can only compare Records to Records"
5✔
189
        return (
5✔
190
            self.constructor == other.constructor
191
            and len(self.fields) == len(other.fields)
192
            and all(a >= b for a, b in zip(self.fields, other.fields))
193
        )
194

195

196
@dataclass(frozen=True, unsafe_hash=True)
5✔
197
class ClassType(Type):
5✔
198
    def __ge__(self, other):
5✔
199
        """
200
        Returns whether other can be substituted for this type.
201
        In other words this returns whether the interface of this type is a subset of the interface of other.
202
        Note that this is usually <= and not >=, but this needs to be fixed later.
203
        Produces a partial order on types.
204
        The top element is the most generic type and can not substitute for anything.
205
        The bottom element is the most specific type and can be substituted for anything.
206
        """
207
        raise NotImplementedError("Comparison between raw classtypes impossible")
208

209

210
@dataclass(frozen=True, unsafe_hash=True)
5✔
211
class AnyType(ClassType):
5✔
212
    """The top element in the partial order on types (excluding FunctionTypes, which do not compare to anything)"""
213

214
    def pluthon_type(self, skip_constructor: bool = False) -> str:
5✔
215
        return "any"
5✔
216

217
    def python_type(self):
5✔
218
        return "Any"
5✔
219

220
    def attribute_type(self, attr: str) -> Type:
5✔
221
        """The types of the named attributes of this class"""
222
        return super().attribute_type(attr)
5✔
223

224
    def attribute(self, attr: str) -> plt.AST:
5✔
225
        """The attributes of this class. Need to be a lambda that expects as first argument the object itself"""
226
        return super().attribute(attr)
×
227

228
    def __ge__(self, other):
5✔
229
        return (
5✔
230
            isinstance(other, ClassType)
231
            and not isinstance(other, FunctionType)
232
            and not isinstance(other, PolymorphicFunctionType)
233
        )
234

235
    def cmp(self, op: ast.cmpop, o: "Type") -> plt.AST:
5✔
236
        """The implementation of comparing this type to type o via operator op. Returns a lambda that expects as first argument the object itself and as second the comparison."""
237
        # this will reject comparisons that will always be false - most likely due to faults during programming
238
        if (
5✔
239
            (isinstance(o, RecordType))
240
            or isinstance(o, UnionType)
241
            or isinstance(o, AnyType)
242
        ):
243
            # Note that comparison with Record and UnionType is actually fine because both are Data
244
            if isinstance(op, ast.Eq):
5✔
245
                return plt.BuiltIn(uplc.BuiltInFun.EqualsData)
5✔
246
            if isinstance(op, ast.NotEq):
5!
247
                return OLambda(
5✔
248
                    ["x", "y"],
249
                    plt.Not(
250
                        plt.Apply(
251
                            plt.BuiltIn(uplc.BuiltInFun.EqualsData),
252
                            OVar("x"),
253
                            OVar("y"),
254
                        )
255
                    ),
256
                )
257
        if (
5!
258
            isinstance(o, ListType)
259
            and isinstance(o.typ, InstanceType)
260
            and (o.typ.typ >= self or self >= o.typ.typ)
261
        ):
262
            if isinstance(op, ast.In):
5✔
263
                return OLambda(
5✔
264
                    ["x", "y"],
265
                    plt.AnyList(
266
                        OVar("y"),
267
                        plt.Apply(
268
                            plt.BuiltIn(uplc.BuiltInFun.EqualsData),
269
                            OVar("x"),
270
                        ),
271
                    ),
272
                )
273
            if isinstance(op, ast.NotIn):
5!
274
                return OLambda(
5✔
275
                    ["x", "y"],
276
                    plt.Not(
277
                        plt.AnyList(
278
                            OVar("y"),
279
                            plt.Apply(
280
                                plt.BuiltIn(uplc.BuiltInFun.EqualsData),
281
                                OVar("x"),
282
                            ),
283
                        ),
284
                    ),
285
                )
286
        return super().cmp(op, o)
×
287

288
    def stringify(self, recursive: bool = False) -> plt.AST:
5✔
289
        OPSHIN_LOGGER.warning(
5✔
290
            "Serializing AnyType will result in RawPlutusData (CBOR representation) to be printed without additional type information. Annotate types where possible to avoid this warning."
291
        )
292
        return OLambda(
5✔
293
            ["self"],
294
            OLet(
295
                [
296
                    (
297
                        "joinMapList",
298
                        OLambda(
299
                            ["m", "l", "start", "end"],
300
                            OLet(
301
                                [
302
                                    (
303
                                        "g",
304
                                        plt.RecFun(
305
                                            OLambda(
306
                                                ["f", "l"],
307
                                                plt.AppendString(
308
                                                    plt.Apply(
309
                                                        OVar("m"),
310
                                                        plt.HeadList(OVar("l")),
311
                                                    ),
312
                                                    OLet(
313
                                                        [
314
                                                            (
315
                                                                "t",
316
                                                                plt.TailList(OVar("l")),
317
                                                            )
318
                                                        ],
319
                                                        plt.IteNullList(
320
                                                            OVar("t"),
321
                                                            OVar("end"),
322
                                                            plt.AppendString(
323
                                                                plt.Text(", "),
324
                                                                plt.Apply(
325
                                                                    OVar("f"),
326
                                                                    OVar("f"),
327
                                                                    OVar("t"),
328
                                                                ),
329
                                                            ),
330
                                                        ),
331
                                                    ),
332
                                                ),
333
                                            )
334
                                        ),
335
                                    )
336
                                ],
337
                                plt.AppendString(
338
                                    OVar("start"),
339
                                    plt.IteNullList(
340
                                        OVar("l"),
341
                                        OVar("end"),
342
                                        plt.Apply(
343
                                            OVar("g"),
344
                                            OVar("l"),
345
                                        ),
346
                                    ),
347
                                ),
348
                            ),
349
                        ),
350
                    ),
351
                    (
352
                        "stringifyPlutusData",
353
                        plt.RecFun(
354
                            OLambda(
355
                                ["f", "d"],
356
                                plt.DelayedChooseData(
357
                                    OVar("d"),
358
                                    OLet(
359
                                        [
360
                                            (
361
                                                "constructor",
362
                                                plt.FstPair(
363
                                                    plt.UnConstrData(OVar("d"))
364
                                                ),
365
                                            )
366
                                        ],
367
                                        plt.Ite(
368
                                            plt.LessThanInteger(
369
                                                OVar("constructor"),
370
                                                plt.Integer(128),
371
                                            ),
372
                                            plt.ConcatString(
373
                                                plt.Text("CBORTag("),
374
                                                plt.Apply(
375
                                                    OVar("f"),
376
                                                    OVar("f"),
377
                                                    plt.IData(
378
                                                        plt.AddInteger(
379
                                                            OVar("constructor"),
380
                                                            plt.Ite(
381
                                                                plt.LessThanInteger(
382
                                                                    OVar("constructor"),
383
                                                                    plt.Integer(7),
384
                                                                ),
385
                                                                plt.Integer(121),
386
                                                                plt.Integer(1280 - 7),
387
                                                            ),
388
                                                        )
389
                                                    ),
390
                                                ),
391
                                                plt.Text(", "),
392
                                                plt.Apply(
393
                                                    OVar("f"),
394
                                                    OVar("f"),
395
                                                    plt.ListData(
396
                                                        plt.SndPair(
397
                                                            plt.UnConstrData(OVar("d"))
398
                                                        )
399
                                                    ),
400
                                                ),
401
                                                plt.Text(")"),
402
                                            ),
403
                                            plt.ConcatString(
404
                                                plt.Text("CBORTag(102, "),
405
                                                plt.Apply(
406
                                                    OVar("f"),
407
                                                    OVar("f"),
408
                                                    plt.ListData(
409
                                                        plt.MkCons(
410
                                                            plt.IData(
411
                                                                OVar("constructor")
412
                                                            ),
413
                                                            plt.MkCons(
414
                                                                plt.ListData(
415
                                                                    plt.SndPair(
416
                                                                        plt.UnConstrData(
417
                                                                            OVar("d")
418
                                                                        )
419
                                                                    )
420
                                                                ),
421
                                                                plt.EmptyDataList(),
422
                                                            ),
423
                                                        )
424
                                                    ),
425
                                                ),
426
                                                plt.Text(")"),
427
                                            ),
428
                                        ),
429
                                    ),
430
                                    plt.Apply(
431
                                        OVar("joinMapList"),
432
                                        OLambda(
433
                                            ["x"],
434
                                            plt.ConcatString(
435
                                                plt.Apply(
436
                                                    OVar("f"),
437
                                                    OVar("f"),
438
                                                    plt.FstPair(OVar("x")),
439
                                                ),
440
                                                plt.Text(": "),
441
                                                plt.Apply(
442
                                                    OVar("f"),
443
                                                    OVar("f"),
444
                                                    plt.SndPair(OVar("x")),
445
                                                ),
446
                                            ),
447
                                        ),
448
                                        plt.UnMapData(OVar("d")),
449
                                        plt.Text("{"),
450
                                        plt.Text("}"),
451
                                    ),
452
                                    plt.Apply(
453
                                        OVar("joinMapList"),
454
                                        OLambda(
455
                                            ["x"],
456
                                            plt.Apply(
457
                                                OVar("f"),
458
                                                OVar("f"),
459
                                                OVar("x"),
460
                                            ),
461
                                        ),
462
                                        plt.UnListData(OVar("d")),
463
                                        plt.Text("["),
464
                                        plt.Text("]"),
465
                                    ),
466
                                    plt.Apply(
467
                                        IntegerInstanceType.stringify(recursive=True),
468
                                        plt.UnIData(OVar("d")),
469
                                    ),
470
                                    plt.Apply(
471
                                        ByteStringInstanceType.stringify(
472
                                            recursive=True
473
                                        ),
474
                                        plt.UnBData(OVar("d")),
475
                                    ),
476
                                ),
477
                            )
478
                        ),
479
                    ),
480
                ],
481
                plt.ConcatString(
482
                    plt.Text("RawPlutusData(data="),
483
                    plt.Apply(OVar("stringifyPlutusData"), OVar("self")),
484
                    plt.Text(")"),
485
                ),
486
            ),
487
        )
488

489
    def copy_only_attributes(self) -> plt.AST:
5✔
490
        """Any is always valid, just returns"""
491
        return OLambda(["self"], OVar("self"))
5✔
492

493

494
@dataclass(frozen=True, unsafe_hash=True)
5✔
495
class AtomicType(ClassType):
5✔
496
    def __ge__(self, other):
5✔
497
        # Can only substitute for its own type (also subtypes)
498
        return isinstance(other, self.__class__)
5✔
499

500

501
@dataclass(frozen=True, unsafe_hash=True)
5✔
502
class RecordType(ClassType):
5✔
503
    record: Record
5✔
504

505
    def pluthon_type(self, skip_constructor: bool = False) -> str:
5✔
506
        return (
5✔
507
            "cons["
508
            + self.record.orig_name
509
            + "]("
510
            + (str(self.record.constructor) if not skip_constructor else "_")
511
            + ";"
512
            + ",".join(
513
                name + ":" + type.pluthon_type() for name, type in self.record.fields
514
            )
515
            + ")"
516
        )
517

518
    def python_type(self):
5✔
519
        return f"{self.record.orig_name}(CONSTR_ID={self.record.constructor}) "
5✔
520

521
    def constr_type(self) -> "InstanceType":
5✔
522
        return InstanceType(
5✔
523
            FunctionType(
524
                frozenlist([f[1] for f in self.record.fields]), InstanceType(self)
525
            )
526
        )
527

528
    def constr(self) -> plt.AST:
5✔
529
        # wrap all constructor values to PlutusData
530
        build_constr_params = plt.EmptyDataList()
5✔
531
        for n, t in reversed(self.record.fields):
5✔
532
            build_constr_params = plt.MkCons(
5✔
533
                transform_output_map(t)(plt.Force(OVar(n))), build_constr_params
534
            )
535
        # then build a constr type with this PlutusData
536
        return SafeOLambda(
5✔
537
            [n for n, _ in self.record.fields],
538
            plt.ConstrData(plt.Integer(self.record.constructor), build_constr_params),
539
        )
540

541
    def attribute_type(self, attr: str) -> Type:
5✔
542
        """The types of the named attributes of this class"""
543
        if attr == "CONSTR_ID":
5✔
544
            return IntegerInstanceType
5✔
545
        for n, t in self.record.fields:
5✔
546
            if n == attr:
5✔
547
                return t
5✔
548
        if attr == "to_cbor":
5✔
549
            return InstanceType(FunctionType(frozenlist([]), ByteStringInstanceType))
5✔
550
        super().attribute_type(attr)
5✔
551

552
    def attribute(self, attr: str) -> plt.AST:
5✔
553
        """The attributes of this class. Need to be a lambda that expects as first argument the object itself"""
554
        if attr == "CONSTR_ID":
5✔
555
            # access to constructor
556
            return OLambda(
5✔
557
                ["self"],
558
                plt.Constructor(OVar("self")),
559
            )
560
        if attr in (n for n, t in self.record.fields):
5✔
561
            attr_typ = self.attribute_type(attr)
5✔
562
            pos = next(i for i, (n, _) in enumerate(self.record.fields) if n == attr)
5✔
563
            # access to normal fields
564
            return OLambda(
5✔
565
                ["self"],
566
                transform_ext_params_map(attr_typ)(
567
                    plt.ConstantNthFieldFast(
568
                        OVar("self"),
569
                        pos,
570
                    ),
571
                ),
572
            )
573
        if attr == "to_cbor":
5✔
574
            return OLambda(
5✔
575
                ["self", "_"],
576
                plt.SerialiseData(
577
                    OVar("self"),
578
                ),
579
            )
580
        raise NotImplementedError(f"Attribute {attr} not implemented for type {self}")
581

582
    def cmp(self, op: ast.cmpop, o: "Type") -> plt.AST:
5✔
583
        """The implementation of comparing this type to type o via operator op. Returns a lambda that expects as first argument the object itself and as second the comparison."""
584
        # this will reject comparisons that will always be false - most likely due to faults during programming
585
        if (
5✔
586
            (
587
                isinstance(o, RecordType)
588
                and (self.record >= o.record or o.record >= self.record)
589
            )
590
            or (
591
                isinstance(o, UnionType) and any(self >= o or self >= o for o in o.typs)
592
            )
593
            or isinstance(o, AnyType)
594
        ):
595
            # Note that comparison with AnyType is actually fine because both are Data
596
            if isinstance(op, ast.Eq):
5!
597
                return plt.BuiltIn(uplc.BuiltInFun.EqualsData)
5✔
598
            if isinstance(op, ast.NotEq):
×
599
                return OLambda(
×
600
                    ["x", "y"],
601
                    plt.Not(
602
                        plt.Apply(
603
                            plt.BuiltIn(uplc.BuiltInFun.EqualsData),
604
                            OVar("x"),
605
                            OVar("y"),
606
                        )
607
                    ),
608
                )
609
        if (
5!
610
            isinstance(o, ListType)
611
            and isinstance(o.typ, InstanceType)
612
            and (o.typ.typ >= self or self >= o.typ.typ)
613
        ):
614
            if isinstance(op, ast.In):
5!
615
                return OLambda(
5✔
616
                    ["x", "y"],
617
                    plt.AnyList(
618
                        OVar("y"),
619
                        plt.Apply(
620
                            plt.BuiltIn(uplc.BuiltInFun.EqualsData),
621
                            OVar("x"),
622
                        ),
623
                    ),
624
                )
625
            if isinstance(op, ast.NotIn):
×
626
                return OLambda(
×
627
                    ["x", "y"],
628
                    plt.Not(
629
                        plt.AnyList(
630
                            OVar("y"),
631
                            plt.Apply(
632
                                plt.BuiltIn(uplc.BuiltInFun.EqualsData),
633
                                OVar("x"),
634
                            ),
635
                        ),
636
                    ),
637
                )
638
        return super().cmp(op, o)
×
639

640
    def __ge__(self, other):
5✔
641
        # Can only substitute for its own type, records need to be equal
642
        # if someone wants to be funny, they can implement <= to be true if all fields match up to some point
643
        return isinstance(other, self.__class__) and self.record >= other.record
5✔
644

645
    def stringify(self, recursive: bool = False) -> plt.AST:
5✔
646
        """Returns a stringified version of the object"""
647
        map_fields = plt.Text(")")
5✔
648
        if self.record.fields:
5✔
649
            # TODO access to fields is a bit inefficient but this is debugging stuff only anyways
650
            pos = len(self.record.fields) - 1
5✔
651
            for field_name, field_type in reversed(self.record.fields[1:]):
5✔
652
                map_fields = plt.ConcatString(
5✔
653
                    plt.Text(f", {field_name}="),
654
                    plt.Apply(
655
                        field_type.stringify(recursive=True),
656
                        transform_ext_params_map(field_type)(
657
                            plt.ConstantNthFieldFast(OVar("self"), pos)
658
                        ),
659
                    ),
660
                    map_fields,
661
                )
662
                pos -= 1
5✔
663
            map_fields = plt.ConcatString(
5✔
664
                plt.Text(f"{self.record.fields[0][0]}="),
665
                plt.Apply(
666
                    self.record.fields[0][1].stringify(recursive=True),
667
                    transform_ext_params_map(self.record.fields[0][1])(
668
                        plt.ConstantNthFieldFast(OVar("self"), pos)
669
                    ),
670
                ),
671
                map_fields,
672
            )
673
        return OLambda(
5✔
674
            ["self"],
675
            plt.AppendString(plt.Text(f"{self.record.orig_name}("), map_fields),
676
        )
677

678
    def copy_only_attributes(self) -> plt.AST:
5✔
679
        copied_attributes = plt.EmptyDataList()
5✔
680
        for attr_name, attr_type in reversed(self.record.fields):
5✔
681
            copied_attributes = OLet(
5✔
682
                [
683
                    ("f", plt.HeadList(OVar("fs"))),
684
                    ("fs", plt.TailList(OVar("fs"))),
685
                ],
686
                plt.MkCons(
687
                    plt.Apply(
688
                        attr_type.copy_only_attributes(),
689
                        OVar("f"),
690
                    ),
691
                    copied_attributes,
692
                ),
693
            )
694
        copied_attributes = OLet(
5✔
695
            [("fs", plt.Fields(OVar("self")))],
696
            copied_attributes,
697
        )
698
        return OLambda(
5✔
699
            ["self"],
700
            plt.DelayedChooseData(
701
                OVar("self"),
702
                plt.ConstrData(
703
                    plt.Integer(self.record.constructor),
704
                    copied_attributes,
705
                ),
706
                plt.TraceError(
707
                    "IntegrityError: Expected a PlutusMap, but got PlutusDict"
708
                ),
709
                plt.TraceError(
710
                    "IntegrityError: Expected a PlutusMap, but got PlutusList"
711
                ),
712
                plt.TraceError(
713
                    "IntegrityError: Expected a PlutusMap, but got PlutusInteger"
714
                ),
715
                plt.TraceError(
716
                    "IntegrityError: Expected a PlutusMap, but got PlutusByteString"
717
                ),
718
            ),
719
        )
720

721

722
@dataclass(frozen=True, unsafe_hash=True)
5✔
723
class UnionType(ClassType):
5✔
724
    typs: typing.List[Type]
5✔
725

726
    def __post_init__(self):
5✔
727
        object.__setattr__(self, "typs", frozenlist(self.typs))
5✔
728

729
    def pluthon_type(self, skip_constructor: bool = False) -> str:
5✔
730
        return "union<" + ",".join(t.pluthon_type() for t in self.typs) + ">"
5✔
731

732
    def python_type(self):
5✔
733
        return "Union[" + ", ".join(t.python_type() for t in self.typs) + "]"
5✔
734

735
    def attribute_type(self, attr) -> "Type":
5✔
736
        record_only = all(isinstance(x, RecordType) for x in self.typs)
5✔
737
        if attr == "CONSTR_ID" and record_only:
5✔
738
            # constructor is only guaranteed to be present if all types are record types
739
            return IntegerInstanceType
5✔
740
        # need to have a common field with the same name
741
        if record_only and all(
5✔
742
            attr in (n for n, t in x.record.fields) for x in self.typs
743
        ):
744
            attr_types = OrderedSet(
5✔
745
                t for x in self.typs for n, t in x.record.fields if n == attr
746
            )
747
            for at in attr_types:
5✔
748
                # return the maximum element if there is one
749
                if all(at >= at2 for at2 in attr_types):
5✔
750
                    return at
5✔
751
            # return the union type of all possible instantiations if all possible values are record types
752
            if all(
5✔
753
                isinstance(at, InstanceType) and isinstance(at.typ, RecordType)
754
                for at in attr_types
755
            ) and distinct([at.typ.record.constructor for at in attr_types]):
756
                return InstanceType(
5✔
757
                    UnionType(frozenlist([at.typ for at in attr_types]))
758
                )
759
            # return Anytype
760
            return InstanceType(AnyType())
5✔
761
        if attr == "to_cbor":
5✔
762
            return InstanceType(FunctionType(frozenlist([]), ByteStringInstanceType))
5✔
763
        raise TypeInferenceError(
5✔
764
            f"Can not access attribute {attr} of Union type. Cast to desired type with an 'if isinstance(_, _):' branch."
765
        )
766

767
    def attribute(self, attr: str) -> plt.AST:
5✔
768
        if attr == "CONSTR_ID":
5✔
769
            # access to constructor
770
            return OLambda(
5✔
771
                ["self"],
772
                plt.Constructor(OVar("self")),
773
            )
774
        # iterate through all names/types of the unioned records by position
775
        if any(attr in (n for n, t in r.record.fields) for r in self.typs):
5✔
776
            attr_typ = self.attribute_type(attr)
5✔
777
            pos_constrs = [
5✔
778
                (i, x.record.constructor)
779
                for x in self.typs
780
                for i, (n, t) in enumerate(x.record.fields)
781
                if n == attr
782
            ]
783
            pos_constrs = sorted(pos_constrs, key=lambda x: x[0])
5✔
784
            pos_constrs = [
5✔
785
                (pos, [c[1] for c in constrs])
786
                for (pos, constrs) in itertools.groupby(pos_constrs, key=lambda x: x[0])
787
            ]
788
            # largest group last so we save the comparisons for that
789
            pos_constrs = sorted(pos_constrs, key=lambda x: len(x[1]))
5✔
790
            # access to normal fields
791
            if not pos_constrs:
5!
792
                pos_decisor = plt.TraceError("Invalid constructor")
×
793
            else:
794
                pos_decisor = plt.ConstantNthFieldFast(OVar("self"), pos_constrs[-1][0])
5✔
795
                pos_constrs = pos_constrs[:-1]
5✔
796
            # constr is not needed when there is only one position for all constructors
797
            if not pos_constrs:
5✔
798
                return OLambda(
5✔
799
                    ["self"],
800
                    transform_ext_params_map(attr_typ)(
801
                        pos_decisor,
802
                    ),
803
                )
804
            for pos, constrs in pos_constrs:
5✔
805
                assert constrs, "Found empty constructors for a position"
5✔
806
                constr_check = plt.EqualsInteger(
5✔
807
                    OVar("constr"), plt.Integer(constrs[0])
808
                )
809
                for constr in constrs[1:]:
5!
810
                    constr_check = plt.Or(
×
811
                        plt.EqualsInteger(OVar("constr"), plt.Integer(constr)),
812
                        constr_check,
813
                    )
814
                pos_decisor = plt.Ite(
5✔
815
                    constr_check,
816
                    plt.ConstantNthFieldFast(OVar("self"), pos),
817
                    pos_decisor,
818
                )
819
            return OLambda(
5✔
820
                ["self"],
821
                transform_ext_params_map(attr_typ)(
822
                    OLet(
823
                        [("constr", plt.Constructor(OVar("self")))],
824
                        pos_decisor,
825
                    ),
826
                ),
827
            )
828
        if attr == "to_cbor":
5✔
829
            return OLambda(
5✔
830
                ["self", "_"],
831
                plt.SerialiseData(
832
                    OVar("self"),
833
                ),
834
            )
835
        raise NotImplementedError(f"Attribute {attr} not implemented for type {self}")
836

837
    def __ge__(self, other):
5✔
838
        if isinstance(other, UnionType):
5✔
839
            return all(self >= ot for ot in other.typs)
5✔
840
        return any(t >= other for t in self.typs)
5✔
841

842
    def cmp(self, op: ast.cmpop, o: "Type") -> plt.AST:
5✔
843
        """The implementation of comparing this type to type o via operator op. Returns a lambda that expects as first argument the object itself and as second the comparison."""
844
        # this will reject comparisons that will always be false - most likely due to faults during programming
845
        # note we require that there is an overlapt between the possible types for unions
846
        if (isinstance(o, RecordType) and any(t >= o or o >= t for t in self.typs)) or (
5✔
847
            isinstance(o, UnionType)
848
            and any(t >= ot or t >= ot for t in self.typs for ot in o.typs)
849
        ):
850
            if isinstance(op, ast.Eq):
5✔
851
                return plt.BuiltIn(uplc.BuiltInFun.EqualsData)
5✔
852
            if isinstance(op, ast.NotEq):
5✔
853
                return OLambda(
5✔
854
                    ["x", "y"],
855
                    plt.Not(
856
                        plt.Apply(
857
                            plt.BuiltIn(uplc.BuiltInFun.EqualsData),
858
                            OVar("x"),
859
                            OVar("y"),
860
                        )
861
                    ),
862
                )
863
        if (
5✔
864
            isinstance(o, ListType)
865
            and isinstance(o.typ, InstanceType)
866
            and any(o.typ.typ >= t or t >= o.typ.typ for t in self.typs)
867
        ):
868
            if isinstance(op, ast.In):
5!
869
                return OLambda(
×
870
                    ["x", "y"],
871
                    plt.AnyList(
872
                        OVar("y"),
873
                        plt.Apply(
874
                            plt.BuiltIn(uplc.BuiltInFun.EqualsData),
875
                            OVar("x"),
876
                        ),
877
                    ),
878
                )
879
            if isinstance(op, ast.NotIn):
5✔
880
                return OLambda(
5✔
881
                    ["x", "y"],
882
                    plt.Not(
883
                        plt.AnyList(
884
                            OVar("y"),
885
                            plt.Apply(
886
                                plt.BuiltIn(uplc.BuiltInFun.EqualsData),
887
                                OVar("x"),
888
                            ),
889
                        ),
890
                    ),
891
                )
892
        raise NotImplementedError(
893
            f"Can not compare {o.python_type()} and {self.python_type()} with operation {op.__class__.__name__}. Note that comparisons that always return false are also rejected."
894
        )
895

896
    def stringify(self, recursive: bool = False) -> plt.AST:
5✔
897
        decide_string_func = plt.TraceError("Invalid constructor id in Union")
5✔
898
        contains_non_record = False
5✔
899
        for t in self.typs:
5✔
900
            if not isinstance(t, RecordType):
5✔
901
                contains_non_record = True
5✔
902
                continue
5✔
903
            decide_string_func = plt.Ite(
5✔
904
                plt.EqualsInteger(OVar("constr"), plt.Integer(t.record.constructor)),
905
                t.stringify(recursive=True),
906
                decide_string_func,
907
            )
908
        decide_string_func = OLet(
5✔
909
            [("constr", plt.Constructor(OVar("self")))],
910
            plt.Apply(decide_string_func, OVar("self")),
911
        )
912
        if contains_non_record:
5✔
913
            decide_string_func = plt.DelayedChooseData(
5✔
914
                OVar("self"),
915
                decide_string_func,
916
                plt.Apply(
917
                    DictType(
918
                        InstanceType(AnyType()), InstanceType(AnyType())
919
                    ).stringify(recursive=True),
920
                    plt.UnMapData(OVar("self")),
921
                ),
922
                plt.Apply(
923
                    ListType(InstanceType(AnyType())).stringify(recursive=True),
924
                    plt.UnListData(OVar("self")),
925
                ),
926
                plt.Apply(
927
                    IntegerType().stringify(recursive=True), plt.UnIData(OVar("self"))
928
                ),
929
                plt.Apply(
930
                    ByteStringType().stringify(recursive=True),
931
                    plt.UnBData(OVar("self")),
932
                ),
933
            )
934
        return OLambda(
5✔
935
            ["self"],
936
            decide_string_func,
937
        )
938

939
    def copy_only_attributes(self) -> plt.AST:
5✔
940
        copied_attributes = plt.TraceError(
5✔
941
            f"Invalid CONSTR_ID (no matching type in Union)"
942
        )
943
        for typ in self.typs:
5✔
944
            if not isinstance(typ, RecordType):
5✔
945
                continue
5✔
946
            copied_attributes = plt.Ite(
5✔
947
                plt.EqualsInteger(OVar("constr"), plt.Integer(typ.record.constructor)),
948
                plt.Apply(typ.copy_only_attributes(), OVar("self")),
949
                copied_attributes,
950
            )
951
        record_copier = OLambda(
5✔
952
            ["self"],
953
            OLet(
954
                [("constr", plt.Constructor(OVar("self")))],
955
                copied_attributes,
956
            ),
957
        )
958

959
        def lambda_false(x: str):
5✔
960
            return OLambda(
5✔
961
                ["_"],
962
                plt.TraceError("Invalid datatype not in Union, got " + x),
963
            )
964

965
        return OLambda(
5✔
966
            ["self"],
967
            plt.Apply(
968
                plt.DelayedChooseData(
969
                    OVar("self"),
970
                    record_copier,
971
                    (
972
                        DictType(AnyType(), AnyType()).copy_only_attributes()
973
                        if any(isinstance(x, DictType) for x in self.typs)
974
                        else lambda_false("dict")
975
                    ),
976
                    (
977
                        ListType(AnyType()).copy_only_attributes()
978
                        if any(isinstance(x, ListType) for x in self.typs)
979
                        else lambda_false("list")
980
                    ),
981
                    (
982
                        IntegerType().copy_only_attributes()
983
                        if IntegerType() in self.typs
984
                        else lambda_false("int")
985
                    ),
986
                    (
987
                        ByteStringType().copy_only_attributes()
988
                        if ByteStringType() in self.typs
989
                        else lambda_false("bytes")
990
                    ),
991
                ),
992
                OVar("self"),
993
            ),
994
        )
995

996

997
@dataclass(frozen=True, unsafe_hash=True)
5✔
998
class TupleType(ClassType):
5✔
999
    typs: typing.List[Type]
5✔
1000

1001
    def __ge__(self, other):
5✔
1002
        return (
5✔
1003
            isinstance(other, TupleType)
1004
            and len(self.typs) <= len(other.typs)
1005
            and all(t >= ot for t, ot in zip(self.typs, other.typs))
1006
        )
1007

1008
    def stringify(self, recursive: bool = False) -> plt.AST:
5✔
1009
        if not self.typs:
5✔
1010
            return OLambda(
5✔
1011
                ["self"],
1012
                plt.Text("()"),
1013
            )
1014
        elif len(self.typs) == 1:
5✔
1015
            tuple_content = plt.ConcatString(
5✔
1016
                plt.Apply(
1017
                    self.typs[0].stringify(recursive=True),
1018
                    plt.FunctionalTupleAccess(OVar("self"), 0, len(self.typs)),
1019
                ),
1020
                plt.Text(","),
1021
            )
1022
        else:
1023
            tuple_content = plt.ConcatString(
5✔
1024
                plt.Apply(
1025
                    self.typs[0].stringify(recursive=True),
1026
                    plt.FunctionalTupleAccess(OVar("self"), 0, len(self.typs)),
1027
                ),
1028
            )
1029
            for i, t in enumerate(self.typs[1:], start=1):
5✔
1030
                tuple_content = plt.ConcatString(
5✔
1031
                    tuple_content,
1032
                    plt.Text(", "),
1033
                    plt.Apply(
1034
                        t.stringify(recursive=True),
1035
                        plt.FunctionalTupleAccess(OVar("self"), i, len(self.typs)),
1036
                    ),
1037
                )
1038
        return OLambda(
5✔
1039
            ["self"],
1040
            plt.ConcatString(plt.Text("("), tuple_content, plt.Text(")")),
1041
        )
1042

1043
    def _binop_return_type(self, binop: ast.operator, other: "Type") -> "Type":
5✔
1044
        if isinstance(binop, ast.Add):
×
1045
            if isinstance(other, TupleType):
×
1046
                return TupleType(self.typs + other.typs)
×
1047
        return super()._binop_return_type(binop, other)
×
1048

1049
    def python_type(self) -> str:
5✔
1050
        return f"Tuple[{', '.join(t.python_type() for t in self.typs)}]"
5✔
1051

1052

1053
@dataclass(frozen=True, unsafe_hash=True)
5✔
1054
class PairType(ClassType):
5✔
1055
    """An internal type representing built-in PlutusData pairs"""
1056

1057
    l_typ: Type
5✔
1058
    r_typ: Type
5✔
1059

1060
    def __ge__(self, other):
5✔
1061
        return isinstance(other, PairType) and all(
×
1062
            t >= ot
1063
            for t, ot in zip((self.l_typ, self.r_typ), (other.l_typ, other.r_typ))
1064
        )
1065

1066
    def stringify(self, recursive: bool = False) -> plt.AST:
5✔
1067
        tuple_content = plt.ConcatString(
5✔
1068
            plt.Apply(
1069
                self.l_typ.stringify(recursive=True),
1070
                transform_ext_params_map(self.l_typ)(plt.FstPair(OVar("self"))),
1071
            ),
1072
            plt.Text(", "),
1073
            plt.Apply(
1074
                self.r_typ.stringify(recursive=True),
1075
                transform_ext_params_map(self.r_typ)(plt.SndPair(OVar("self"))),
1076
            ),
1077
        )
1078
        return OLambda(
5✔
1079
            ["self"],
1080
            plt.ConcatString(plt.Text("("), tuple_content, plt.Text(")")),
1081
        )
1082

1083
    def python_type(self):
5✔
1084
        return f"Tuple[{self.l_typ.python_type()}, {self.r_typ.python_type()}]"
5✔
1085

1086

1087
@dataclass(frozen=True, unsafe_hash=True)
5✔
1088
class ListType(ClassType):
5✔
1089
    typ: Type
5✔
1090

1091
    def __ge__(self, other):
5✔
1092
        return isinstance(other, ListType) and self.typ >= other.typ
5✔
1093

1094
    def pluthon_type(self, skip_constructor: bool = False) -> str:
5✔
1095
        return "list<" + self.typ.pluthon_type() + ">"
5✔
1096

1097
    def python_type(self):
5✔
1098
        return f"List[{self.typ.python_type()}]"
5✔
1099

1100
    def attribute_type(self, attr) -> "Type":
5✔
1101
        if attr == "index":
5!
1102
            return InstanceType(
5✔
1103
                FunctionType(frozenlist([self.typ]), IntegerInstanceType)
1104
            )
1105
        super().attribute_type(attr)
×
1106

1107
    def attribute(self, attr) -> plt.AST:
5✔
1108
        if attr == "index":
5!
1109
            return OLambda(
5✔
1110
                ["self", "x"],
1111
                OLet(
1112
                    [("x", plt.Force(OVar("x")))],
1113
                    plt.Apply(
1114
                        plt.RecFun(
1115
                            OLambda(
1116
                                ["index", "xs", "a"],
1117
                                plt.IteNullList(
1118
                                    OVar("xs"),
1119
                                    plt.TraceError(
1120
                                        "ValueError: Did not find element in list"
1121
                                    ),
1122
                                    plt.Ite(
1123
                                        # the parameter x must have the same type as the list elements
1124
                                        plt.Apply(
1125
                                            self.typ.cmp(ast.Eq(), self.typ),
1126
                                            OVar("x"),
1127
                                            plt.HeadList(OVar("xs")),
1128
                                        ),
1129
                                        OVar("a"),
1130
                                        plt.Apply(
1131
                                            OVar("index"),
1132
                                            OVar("index"),
1133
                                            plt.TailList(OVar("xs")),
1134
                                            plt.AddInteger(OVar("a"), plt.Integer(1)),
1135
                                        ),
1136
                                    ),
1137
                                ),
1138
                            ),
1139
                        ),
1140
                        OVar("self"),
1141
                        plt.Integer(0),
1142
                    ),
1143
                ),
1144
            )
1145
        super().attribute(attr)
×
1146

1147
    def stringify(self, recursive: bool = False) -> plt.AST:
5✔
1148
        return OLambda(
5✔
1149
            ["self"],
1150
            OLet(
1151
                [
1152
                    (
1153
                        "g",
1154
                        plt.RecFun(
1155
                            OLambda(
1156
                                ["f", "l"],
1157
                                plt.AppendString(
1158
                                    plt.Apply(
1159
                                        self.typ.stringify(recursive=True),
1160
                                        plt.HeadList(OVar("l")),
1161
                                    ),
1162
                                    OLet(
1163
                                        [("t", plt.TailList(OVar("l")))],
1164
                                        plt.IteNullList(
1165
                                            OVar("t"),
1166
                                            plt.Text("]"),
1167
                                            plt.AppendString(
1168
                                                plt.Text(", "),
1169
                                                plt.Apply(
1170
                                                    OVar("f"),
1171
                                                    OVar("f"),
1172
                                                    OVar("t"),
1173
                                                ),
1174
                                            ),
1175
                                        ),
1176
                                    ),
1177
                                ),
1178
                            )
1179
                        ),
1180
                    )
1181
                ],
1182
                plt.AppendString(
1183
                    plt.Text("["),
1184
                    plt.IteNullList(
1185
                        OVar("self"),
1186
                        plt.Text("]"),
1187
                        plt.Apply(
1188
                            OVar("g"),
1189
                            OVar("self"),
1190
                        ),
1191
                    ),
1192
                ),
1193
            ),
1194
        )
1195

1196
    def copy_only_attributes(self) -> plt.AST:
5✔
1197
        mapped_attrs = plt.MapList(
5✔
1198
            plt.UnListData(OVar("self")),
1199
            OLambda(
1200
                ["v"],
1201
                plt.Apply(
1202
                    self.typ.copy_only_attributes(),
1203
                    OVar("v"),
1204
                ),
1205
            ),
1206
            plt.EmptyDataList(),
1207
        )
1208
        return OLambda(
5✔
1209
            ["self"],
1210
            plt.DelayedChooseData(
1211
                OVar("self"),
1212
                plt.TraceError(
1213
                    "IntegrityError: Expected a PlutusList, but got PlutusData"
1214
                ),
1215
                plt.TraceError(
1216
                    "IntegrityError: Expected a PlutusList, but got PlutusMap"
1217
                ),
1218
                plt.ListData(mapped_attrs),
1219
                plt.TraceError(
1220
                    "IntegrityError: Expected a PlutusList, but got PlutusInteger"
1221
                ),
1222
                plt.TraceError(
1223
                    "IntegrityError: Expected a PlutusList, but got PlutusByteString"
1224
                ),
1225
            ),
1226
        )
1227

1228
    def cmp(self, op: ast.cmpop, o: "Type") -> plt.AST:
5✔
1229
        """The implementation of comparing this type to type o via operator op. Returns a lambda that expects as first argument the object itself and as second the comparison."""
1230
        if isinstance(o, ListType) and (self.typ >= o.typ or o.typ >= self.typ):
5!
1231
            if isinstance(op, ast.Eq):
5✔
1232
                # Implement list equality comparison
1233
                # This is expensive (linear in the size of the list) as noted in the feature request
1234
                return OLambda(
5✔
1235
                    ["x", "y"],
1236
                    plt.Apply(
1237
                        plt.RecFun(
1238
                            OLambda(
1239
                                ["f", "xs", "ys"],
1240
                                plt.IteNullList(
1241
                                    OVar("xs"),
1242
                                    # If first list is empty, check if second is also empty
1243
                                    plt.NullList(OVar("ys")),
1244
                                    plt.IteNullList(
1245
                                        OVar("ys"),
1246
                                        # If second list is empty but first is not, they're not equal
1247
                                        plt.Bool(False),
1248
                                        # Both lists have elements, compare heads and recurse on tails
1249
                                        plt.And(
1250
                                            plt.Apply(
1251
                                                self.typ.cmp(op, o.typ),
1252
                                                plt.HeadList(OVar("xs")),
1253
                                                plt.HeadList(OVar("ys")),
1254
                                            ),
1255
                                            plt.Apply(
1256
                                                OVar("f"),
1257
                                                OVar("f"),
1258
                                                plt.TailList(OVar("xs")),
1259
                                                plt.TailList(OVar("ys")),
1260
                                            ),
1261
                                        ),
1262
                                    ),
1263
                                ),
1264
                            )
1265
                        ),
1266
                        OVar("x"),
1267
                        OVar("y"),
1268
                    ),
1269
                )
1270
            if isinstance(op, ast.NotEq):
5!
1271
                # Implement list inequality comparison as negation of equality
1272
                return OLambda(
5✔
1273
                    ["x", "y"],
1274
                    plt.Not(
1275
                        plt.Apply(
1276
                            self.cmp(ast.Eq(), o),
1277
                            OVar("x"),
1278
                            OVar("y"),
1279
                        )
1280
                    ),
1281
                )
1282
        return super().cmp(op, o)
×
1283

1284
    def _binop_return_type(self, binop: ast.operator, other: "Type") -> "Type":
5✔
1285
        if isinstance(binop, ast.Add):
5!
1286
            if isinstance(other, InstanceType) and isinstance(other.typ, ListType):
5!
1287
                other_typ = other.typ
5✔
1288
                assert (
5✔
1289
                    self.typ >= other_typ.typ or other_typ.typ >= self.typ
1290
                ), f"Types of lists {self.typ} and {other_typ.typ} are not compatible"
1291
                return ListType(
5✔
1292
                    self.typ if self.typ >= other_typ.typ else other_typ.typ
1293
                )
1294
        return super()._binop_return_type(binop, other)
×
1295

1296
    def _binop_bin_fun(self, binop: ast.operator, other: "TypedAST"):
5✔
1297
        if isinstance(binop, ast.Add):
5!
1298
            if isinstance(other.typ, InstanceType) and isinstance(
5!
1299
                other.typ.typ, ListType
1300
            ):
1301
                return plt.AppendList
5✔
1302
        return super()._binop_bin_fun(binop, other)
×
1303

1304
    def _unop_return_type(self, unop: ast.unaryop) -> "Type":
5✔
1305
        if isinstance(unop, ast.Not):
5!
1306
            return BoolType()
5✔
1307
        return super()._unop_return_type(unop)
×
1308

1309
    def _unop_fun(self, unop: ast.unaryop) -> Callable[[plt.AST], plt.AST]:
5✔
1310
        if isinstance(unop, ast.Not):
5!
1311
            return lambda x: plt.IteNullList(x, plt.Bool(True), plt.Bool(False))
5✔
1312
        return super()._unop_fun(unop)
×
1313

1314

1315
@dataclass(frozen=True, unsafe_hash=True)
5✔
1316
class DictType(ClassType):
5✔
1317
    key_typ: Type
5✔
1318
    value_typ: Type
5✔
1319

1320
    def pluthon_type(self, skip_constructor: bool = False) -> str:
5✔
1321
        return (
5✔
1322
            "map<"
1323
            + self.key_typ.pluthon_type()
1324
            + ","
1325
            + self.value_typ.pluthon_type()
1326
            + ">"
1327
        )
1328

1329
    def python_type(self):
5✔
1330
        return f"Dict[{self.key_typ.python_type()}, {self.value_typ.python_type()}]"
×
1331

1332
    def attribute_type(self, attr) -> "Type":
5✔
1333
        if attr == "get":
5✔
1334
            return InstanceType(
5✔
1335
                FunctionType(frozenlist([self.key_typ, self.value_typ]), self.value_typ)
1336
            )
1337
        if attr == "keys":
5✔
1338
            return InstanceType(
5✔
1339
                FunctionType(frozenlist([]), InstanceType(ListType(self.key_typ)))
1340
            )
1341
        if attr == "values":
5✔
1342
            return InstanceType(
5✔
1343
                FunctionType(frozenlist([]), InstanceType(ListType(self.value_typ)))
1344
            )
1345
        if attr == "items":
5!
1346
            return InstanceType(
5✔
1347
                FunctionType(
1348
                    frozenlist([]),
1349
                    InstanceType(
1350
                        ListType(InstanceType(PairType(self.key_typ, self.value_typ)))
1351
                    ),
1352
                )
1353
            )
1354
        raise TypeInferenceError(
×
1355
            f"Type of attribute '{attr}' is unknown for type Dict."
1356
        )
1357

1358
    def attribute(self, attr) -> plt.AST:
5✔
1359
        if attr == "get":
5✔
1360
            return OLambda(
5✔
1361
                ["self", "key", "default"],
1362
                transform_ext_params_map(self.value_typ)(
1363
                    OLet(
1364
                        [
1365
                            (
1366
                                "key_mapped",
1367
                                transform_output_map(self.key_typ)(
1368
                                    plt.Force(OVar("key"))
1369
                                ),
1370
                            )
1371
                        ],
1372
                        plt.SndPair(
1373
                            plt.FindList(
1374
                                OVar("self"),
1375
                                OLambda(
1376
                                    ["x"],
1377
                                    plt.EqualsData(
1378
                                        OVar("key_mapped"),
1379
                                        plt.FstPair(OVar("x")),
1380
                                    ),
1381
                                ),
1382
                                # this is a bit ugly... we wrap - only to later unwrap again
1383
                                plt.MkPairData(
1384
                                    OVar("key_mapped"),
1385
                                    transform_output_map(self.value_typ)(
1386
                                        plt.Force(OVar("default"))
1387
                                    ),
1388
                                ),
1389
                            ),
1390
                        ),
1391
                    ),
1392
                ),
1393
            )
1394
        if attr == "keys":
5✔
1395
            return OLambda(
5✔
1396
                ["self", "_"],
1397
                plt.MapList(
1398
                    OVar("self"),
1399
                    OLambda(
1400
                        ["x"],
1401
                        transform_ext_params_map(self.key_typ)(plt.FstPair(OVar("x"))),
1402
                    ),
1403
                    empty_list(self.key_typ),
1404
                ),
1405
            )
1406
        if attr == "values":
5✔
1407
            return OLambda(
5✔
1408
                ["self", "_"],
1409
                plt.MapList(
1410
                    OVar("self"),
1411
                    OLambda(
1412
                        ["x"],
1413
                        transform_ext_params_map(self.value_typ)(
1414
                            plt.SndPair(OVar("x"))
1415
                        ),
1416
                    ),
1417
                    empty_list(self.value_typ),
1418
                ),
1419
            )
1420
        if attr == "items":
5✔
1421
            return OLambda(
5✔
1422
                ["self", "_"],
1423
                OVar("self"),
1424
            )
1425
        raise NotImplementedError(f"Attribute '{attr}' of Dict is unknown.")
1426

1427
    def __ge__(self, other):
5✔
1428
        return (
5✔
1429
            isinstance(other, DictType)
1430
            and self.key_typ >= other.key_typ
1431
            and self.value_typ >= other.value_typ
1432
        )
1433

1434
    def stringify(self, recursive: bool = False) -> plt.AST:
5✔
1435
        return OLambda(
5✔
1436
            ["self"],
1437
            OLet(
1438
                [
1439
                    (
1440
                        "g",
1441
                        plt.RecFun(
1442
                            OLambda(
1443
                                ["f", "l"],
1444
                                OLet(
1445
                                    [
1446
                                        ("h", plt.HeadList(OVar("l"))),
1447
                                        ("t", plt.TailList(OVar("l"))),
1448
                                    ],
1449
                                    plt.ConcatString(
1450
                                        plt.Apply(
1451
                                            self.key_typ.stringify(recursive=True),
1452
                                            transform_ext_params_map(self.key_typ)(
1453
                                                plt.FstPair(OVar("h"))
1454
                                            ),
1455
                                        ),
1456
                                        plt.Text(": "),
1457
                                        plt.Apply(
1458
                                            self.value_typ.stringify(recursive=True),
1459
                                            transform_ext_params_map(self.value_typ)(
1460
                                                plt.SndPair(OVar("h"))
1461
                                            ),
1462
                                        ),
1463
                                        plt.IteNullList(
1464
                                            OVar("t"),
1465
                                            plt.Text("}"),
1466
                                            plt.AppendString(
1467
                                                plt.Text(", "),
1468
                                                plt.Apply(
1469
                                                    OVar("f"),
1470
                                                    OVar("f"),
1471
                                                    OVar("t"),
1472
                                                ),
1473
                                            ),
1474
                                        ),
1475
                                    ),
1476
                                ),
1477
                            )
1478
                        ),
1479
                    )
1480
                ],
1481
                plt.AppendString(
1482
                    plt.Text("{"),
1483
                    plt.IteNullList(
1484
                        OVar("self"),
1485
                        plt.Text("}"),
1486
                        plt.Apply(
1487
                            OVar("g"),
1488
                            OVar("self"),
1489
                        ),
1490
                    ),
1491
                ),
1492
            ),
1493
        )
1494

1495
    def copy_only_attributes(self) -> plt.AST:
5✔
1496
        def CustomMapFilterList(
5✔
1497
            l: plt.AST,
1498
            filter_op: plt.AST,
1499
            map_op: plt.AST,
1500
            empty_list=plt.EmptyDataList(),
1501
        ):
1502
            from pluthon import (
5✔
1503
                Apply,
1504
                Lambda as PLambda,
1505
                RecFun,
1506
                IteNullList,
1507
                Var as PVar,
1508
                HeadList,
1509
                Ite,
1510
                TailList,
1511
                PrependList,
1512
                Let as PLet,
1513
            )
1514

1515
            """
4✔
1516
            Apply a filter and a map function on each element in a list (throws out all that evaluate to false)
1517
            Performs only a single pass and is hence much more efficient than filter + map
1518
            """
1519
            return Apply(
5✔
1520
                PLambda(
1521
                    ["filter", "map"],
1522
                    RecFun(
1523
                        PLambda(
1524
                            ["filtermap", "xs"],
1525
                            IteNullList(
1526
                                PVar("xs"),
1527
                                empty_list,
1528
                                PLet(
1529
                                    [
1530
                                        ("head", HeadList(PVar("xs"))),
1531
                                        ("tail", TailList(PVar("xs"))),
1532
                                    ],
1533
                                    Ite(
1534
                                        Apply(
1535
                                            PVar("filter"), PVar("head"), PVar("tail")
1536
                                        ),
1537
                                        PrependList(
1538
                                            Apply(PVar("map"), PVar("head")),
1539
                                            Apply(
1540
                                                PVar("filtermap"),
1541
                                                PVar("filtermap"),
1542
                                                PVar("tail"),
1543
                                            ),
1544
                                        ),
1545
                                        Apply(
1546
                                            PVar("filtermap"),
1547
                                            PVar("filtermap"),
1548
                                            PVar("tail"),
1549
                                        ),
1550
                                    ),
1551
                                ),
1552
                            ),
1553
                        ),
1554
                    ),
1555
                ),
1556
                filter_op,
1557
                map_op,
1558
                l,
1559
            )
1560

1561
        mapped_attrs = CustomMapFilterList(
5✔
1562
            plt.UnMapData(OVar("self")),
1563
            OLambda(
1564
                ["h", "t"],
1565
                OLet(
1566
                    [
1567
                        ("hfst", plt.FstPair(OVar("h"))),
1568
                    ],
1569
                    plt.Not(
1570
                        plt.AnyList(
1571
                            OVar("t"),
1572
                            OLambda(
1573
                                ["e"],
1574
                                plt.EqualsData(OVar("hfst"), plt.FstPair(OVar("e"))),
1575
                            ),
1576
                        )
1577
                    ),
1578
                ),
1579
            ),
1580
            OLambda(
1581
                ["v"],
1582
                plt.MkPairData(
1583
                    plt.Apply(
1584
                        self.key_typ.copy_only_attributes(), plt.FstPair(OVar("v"))
1585
                    ),
1586
                    plt.Apply(
1587
                        self.value_typ.copy_only_attributes(), plt.SndPair(OVar("v"))
1588
                    ),
1589
                ),
1590
            ),
1591
            plt.EmptyDataPairList(),
1592
        )
1593
        return OLambda(
5✔
1594
            ["self"],
1595
            plt.DelayedChooseData(
1596
                OVar("self"),
1597
                plt.TraceError(
1598
                    "IntegrityError: Expected a PlutusMap, but got PlutusData"
1599
                ),
1600
                plt.MapData(mapped_attrs),
1601
                plt.TraceError(
1602
                    "IntegrityError: Expected a PlutusMap, but got PlutusList"
1603
                ),
1604
                plt.TraceError(
1605
                    "IntegrityError: Expected a PlutusMap, but got PlutusInteger"
1606
                ),
1607
                plt.TraceError(
1608
                    "IntegrityError: Expected a PlutusMap, but got PlutusByteString"
1609
                ),
1610
            ),
1611
        )
1612

1613
    def _unop_return_type(self, unop: ast.unaryop) -> "Type":
5✔
1614
        if isinstance(unop, ast.Not):
5!
1615
            return BoolType()
5✔
1616
        return super()._unop_return_type(unop)
×
1617

1618
    def _unop_fun(self, unop: ast.unaryop) -> Callable[[plt.AST], plt.AST]:
5✔
1619
        if isinstance(unop, ast.Not):
5!
1620
            return lambda x: plt.IteNullList(x, plt.Bool(True), plt.Bool(False))
5✔
1621
        return super()._unop_fun(unop)
×
1622

1623

1624
@dataclass(frozen=True, unsafe_hash=True)
5✔
1625
class FunctionType(ClassType):
5✔
1626
    argtyps: typing.List[Type]
5✔
1627
    rettyp: Type
5✔
1628
    # A map from external variable names to their types when the function is defined
1629
    bound_vars: typing.Dict[str, Type] = field(default_factory=frozendict)
5✔
1630
    # Whether and under which name the function binds itself
1631
    # The type of this variable is "self"
1632
    bind_self: typing.Optional[str] = None
5✔
1633

1634
    def __post_init__(self):
5✔
1635
        object.__setattr__(self, "argtyps", frozenlist(self.argtyps))
5✔
1636
        object.__setattr__(self, "bound_vars", frozendict(self.bound_vars))
5✔
1637

1638
    def __ge__(self, other):
5✔
1639
        return (
5✔
1640
            isinstance(other, FunctionType)
1641
            and len(self.argtyps) == len(other.argtyps)
1642
            and all(a >= oa for a, oa in zip(self.argtyps, other.argtyps))
1643
            and self.bound_vars.keys() == other.bound_vars.keys()
1644
            and all(sbv >= other.bound_vars[k] for k, sbv in self.bound_vars.items())
1645
            and self.bind_self == other.bind_self
1646
            and other.rettyp >= self.rettyp
1647
        )
1648

1649
    def stringify(self, recursive: bool = False) -> plt.AST:
5✔
1650
        return OLambda(["x"], plt.Text("<function>"))
×
1651

1652
    def python_type(self):
5✔
1653
        arg_types = ", ".join(t.python_type() for t in self.argtyps)
5✔
1654
        return f"Callable[[{arg_types}], {self.rettyp.python_type()}]"
5✔
1655

1656

1657
@dataclass(frozen=True, unsafe_hash=True)
5✔
1658
class InstanceType(Type):
5✔
1659
    typ: ClassType
5✔
1660

1661
    def pluthon_type(self, skip_constructor: bool = False) -> str:
5✔
1662
        return self.typ.pluthon_type(skip_constructor=skip_constructor)
5✔
1663

1664
    def constr_type(self) -> FunctionType:
5✔
1665
        raise TypeInferenceError(f"Can not construct an instance {self}")
×
1666

1667
    def constr(self) -> plt.AST:
5✔
1668
        raise NotImplementedError(f"Can not construct an instance {self}")
1669

1670
    def attribute_type(self, attr) -> Type:
5✔
1671
        return self.typ.attribute_type(attr)
5✔
1672

1673
    def attribute(self, attr) -> plt.AST:
5✔
1674
        return self.typ.attribute(attr)
5✔
1675

1676
    def cmp(self, op: ast.cmpop, o: "Type") -> plt.AST:
5✔
1677
        """The implementation of comparing this type to type o via operator op. Returns a lambda that expects as first argument the object itself and as second the comparison."""
1678
        if isinstance(o, InstanceType):
5!
1679
            return self.typ.cmp(op, o.typ)
5✔
1680
        return super().cmp(op, o)
×
1681

1682
    def __ge__(self, other):
5✔
1683
        return isinstance(other, InstanceType) and self.typ >= other.typ
5✔
1684

1685
    def stringify(self, recursive: bool = False) -> plt.AST:
5✔
1686
        return self.typ.stringify(recursive=recursive)
5✔
1687

1688
    def copy_only_attributes(self) -> plt.AST:
5✔
1689
        return self.typ.copy_only_attributes()
5✔
1690

1691
    def binop_type(self, binop: ast.operator, other: "Type") -> "Type":
5✔
1692
        return self.typ.binop_type(binop, other)
5✔
1693

1694
    def binop(self, binop: ast.operator, other: "TypedAST") -> plt.AST:
5✔
1695
        return self.typ.binop(binop, other)
5✔
1696

1697
    def unop_type(self, unop: ast.unaryop) -> "Type":
5✔
1698
        return self.typ.unop_type(unop)
×
1699

1700
    def unop(self, unop: ast.unaryop) -> plt.AST:
5✔
1701
        return self.typ.unop(unop)
5✔
1702

1703
    def python_type(self):
5✔
1704
        return self.typ.python_type()
5✔
1705

1706

1707
@dataclass(frozen=True, unsafe_hash=True)
5✔
1708
class IntegerType(AtomicType):
5✔
1709
    def pluthon_type(self, skip_constructor: bool = False) -> str:
5✔
1710
        return "int"
5✔
1711

1712
    def python_type(self):
5✔
1713
        return "int"
5✔
1714

1715
    def constr_type(self) -> InstanceType:
5✔
1716
        return InstanceType(PolymorphicFunctionType(IntImpl()))
5✔
1717

1718
    def cmp(self, op: ast.cmpop, o: "Type") -> plt.AST:
5✔
1719
        """The implementation of comparing this type to type o via operator op. Returns a lambda that expects as first argument the object itself and as second the comparison."""
1720
        if isinstance(o, BoolType):
5!
UNCOV
1721
            if isinstance(op, ast.Eq):
×
1722
                # 1 == True
1723
                # 0 == False
1724
                # all other comparisons are False
UNCOV
1725
                return OLambda(
×
1726
                    ["x", "y"],
1727
                    plt.Ite(
1728
                        OVar("y"),
1729
                        plt.EqualsInteger(OVar("x"), plt.Integer(1)),
1730
                        plt.EqualsInteger(OVar("x"), plt.Integer(0)),
1731
                    ),
1732
                )
1733
        if isinstance(o, IntegerType):
5✔
1734
            if isinstance(op, ast.Eq):
5✔
1735
                return plt.BuiltIn(uplc.BuiltInFun.EqualsInteger)
5✔
1736
            if isinstance(op, ast.NotEq):
5✔
1737
                return OLambda(
5✔
1738
                    ["x", "y"],
1739
                    plt.Not(
1740
                        plt.Apply(
1741
                            plt.BuiltIn(uplc.BuiltInFun.EqualsInteger),
1742
                            OVar("y"),
1743
                            OVar("x"),
1744
                        )
1745
                    ),
1746
                )
1747
            if isinstance(op, ast.LtE):
5✔
1748
                return plt.BuiltIn(uplc.BuiltInFun.LessThanEqualsInteger)
5✔
1749
            if isinstance(op, ast.Lt):
5✔
1750
                return plt.BuiltIn(uplc.BuiltInFun.LessThanInteger)
5✔
1751
            if isinstance(op, ast.Gt):
5✔
1752
                return OLambda(
5✔
1753
                    ["x", "y"],
1754
                    plt.Apply(
1755
                        plt.BuiltIn(uplc.BuiltInFun.LessThanInteger),
1756
                        OVar("y"),
1757
                        OVar("x"),
1758
                    ),
1759
                )
1760
            if isinstance(op, ast.GtE):
5!
1761
                return OLambda(
5✔
1762
                    ["x", "y"],
1763
                    plt.Apply(
1764
                        plt.BuiltIn(uplc.BuiltInFun.LessThanEqualsInteger),
1765
                        OVar("y"),
1766
                        OVar("x"),
1767
                    ),
1768
                )
1769
        if (
5!
1770
            isinstance(o, ListType)
1771
            and isinstance(o.typ, InstanceType)
1772
            and isinstance(o.typ.typ, IntegerType)
1773
        ):
1774
            if isinstance(op, ast.In):
5✔
1775
                return OLambda(
5✔
1776
                    ["x", "y"],
1777
                    plt.AnyList(
1778
                        OVar("y"),
1779
                        plt.Apply(
1780
                            plt.BuiltIn(uplc.BuiltInFun.EqualsInteger), OVar("x")
1781
                        ),
1782
                    ),
1783
                )
1784
            if isinstance(op, ast.NotIn):
5!
1785
                return OLambda(
5✔
1786
                    ["x", "y"],
1787
                    plt.Not(
1788
                        plt.AnyList(
1789
                            OVar("y"),
1790
                            plt.Apply(
1791
                                plt.BuiltIn(uplc.BuiltInFun.EqualsInteger), OVar("x")
1792
                            ),
1793
                        ),
1794
                    ),
1795
                )
1796
        return super().cmp(op, o)
×
1797

1798
    def stringify(self, recursive: bool = False) -> plt.AST:
5✔
1799
        return OLambda(
5✔
1800
            ["x"],
1801
            plt.DecodeUtf8(
1802
                OLet(
1803
                    [
1804
                        (
1805
                            "strlist",
1806
                            plt.RecFun(
1807
                                OLambda(
1808
                                    ["f", "i"],
1809
                                    plt.Ite(
1810
                                        plt.LessThanEqualsInteger(
1811
                                            OVar("i"), plt.Integer(0)
1812
                                        ),
1813
                                        plt.EmptyIntegerList(),
1814
                                        plt.MkCons(
1815
                                            plt.AddInteger(
1816
                                                plt.ModInteger(
1817
                                                    OVar("i"), plt.Integer(10)
1818
                                                ),
1819
                                                plt.Integer(ord("0")),
1820
                                            ),
1821
                                            plt.Apply(
1822
                                                OVar("f"),
1823
                                                OVar("f"),
1824
                                                plt.DivideInteger(
1825
                                                    OVar("i"), plt.Integer(10)
1826
                                                ),
1827
                                            ),
1828
                                        ),
1829
                                    ),
1830
                                ),
1831
                            ),
1832
                        ),
1833
                        (
1834
                            "mkstr",
1835
                            OLambda(
1836
                                ["i"],
1837
                                plt.FoldList(
1838
                                    plt.Apply(OVar("strlist"), OVar("i")),
1839
                                    OLambda(
1840
                                        ["b", "i"],
1841
                                        plt.ConsByteString(OVar("i"), OVar("b")),
1842
                                    ),
1843
                                    plt.ByteString(b""),
1844
                                ),
1845
                            ),
1846
                        ),
1847
                    ],
1848
                    plt.Ite(
1849
                        plt.EqualsInteger(OVar("x"), plt.Integer(0)),
1850
                        plt.ByteString(b"0"),
1851
                        plt.Ite(
1852
                            plt.LessThanInteger(OVar("x"), plt.Integer(0)),
1853
                            plt.ConsByteString(
1854
                                plt.Integer(ord("-")),
1855
                                plt.Apply(OVar("mkstr"), plt.Negate(OVar("x"))),
1856
                            ),
1857
                            plt.Apply(OVar("mkstr"), OVar("x")),
1858
                        ),
1859
                    ),
1860
                )
1861
            ),
1862
        )
1863

1864
    def _binop_return_type(self, binop: ast.operator, other: "Type") -> "Type":
5✔
1865
        if isinstance(other, InstanceType) and isinstance(other.typ, RecordType):
5!
1866
            print("Ha")
×
1867
        if (
5✔
1868
            isinstance(binop, ast.Add)
1869
            or isinstance(binop, ast.Sub)
1870
            or isinstance(binop, ast.FloorDiv)
1871
            or isinstance(binop, ast.Mod)
1872
            or isinstance(binop, ast.Div)
1873
            or isinstance(binop, ast.Pow)
1874
        ):
1875
            if other == IntegerInstanceType:
5✔
1876
                return IntegerType()
5✔
1877
            elif other == BoolInstanceType:
5!
1878
                # cast to integer
1879
                return IntegerType()
5✔
1880
        if isinstance(binop, ast.Mult):
5!
1881
            if other == IntegerInstanceType:
5✔
1882
                return IntegerType()
5✔
1883
            elif other == ByteStringInstanceType:
5✔
1884
                return ByteStringType()
5✔
1885
            elif other == StringInstanceType:
5!
1886
                return StringType()
5✔
1887
        return super()._binop_return_type(binop, other)
×
1888

1889
    def _binop_bin_fun(self, binop: ast.operator, other: "TypedAST"):
5✔
1890
        if other.typ == IntegerInstanceType:
5✔
1891
            if isinstance(binop, ast.Add):
5✔
1892
                return plt.AddInteger
5✔
1893
            elif isinstance(binop, ast.Sub):
5✔
1894
                return plt.SubtractInteger
5✔
1895
            elif isinstance(binop, ast.FloorDiv):
5✔
1896
                return plt.DivideInteger
5✔
1897
            elif isinstance(binop, ast.Mod):
5✔
1898
                return plt.ModInteger
5✔
1899
            elif isinstance(binop, ast.Pow):
5✔
1900
                return lambda x, y: OLet(
5✔
1901
                    [("y", y)],
1902
                    plt.Ite(
1903
                        plt.LessThanInteger(OVar("y"), plt.Integer(0)),
1904
                        plt.TraceError("Negative exponentiation is not supported"),
1905
                        PowImpl(x, OVar("y")),
1906
                    ),
1907
                )
1908
        if other.typ == BoolInstanceType:
5✔
1909
            if isinstance(binop, ast.Add):
5✔
1910
                return lambda x, y: OLet(
5✔
1911
                    [("x", x), ("y", y)],
1912
                    plt.Ite(
1913
                        OVar("y"), plt.AddInteger(OVar("x"), plt.Integer(1)), OVar("x")
1914
                    ),
1915
                )
1916
            elif isinstance(binop, ast.Sub):
5!
1917
                return lambda x, y: OLet(
5✔
1918
                    [("x", x), ("y", y)],
1919
                    plt.Ite(
1920
                        OVar("y"),
1921
                        plt.SubtractInteger(OVar("x"), plt.Integer(1)),
1922
                        OVar("x"),
1923
                    ),
1924
                )
1925

1926
        if isinstance(binop, ast.Mult):
5!
1927
            if other.typ == IntegerInstanceType:
5✔
1928
                return plt.MultiplyInteger
5✔
1929
            elif other.typ == ByteStringInstanceType:
5✔
1930
                return lambda x, y: ByteStrIntMulImpl(y, x)
5✔
1931
            elif other.typ == StringInstanceType:
5!
1932
                return lambda x, y: StrIntMulImpl(y, x)
5✔
1933
        return super()._binop_bin_fun(binop, other)
×
1934

1935
    def _unop_return_type(self, unop: ast.unaryop) -> "Type":
5✔
1936
        if isinstance(unop, ast.USub):
5✔
1937
            return IntegerType()
5✔
1938
        elif isinstance(unop, ast.UAdd):
5✔
1939
            return IntegerType()
5✔
1940
        elif isinstance(unop, ast.Not):
5!
1941
            return BoolType()
5✔
1942
        return super()._unop_return_type(unop)
×
1943

1944
    def _unop_fun(self, unop: ast.unaryop) -> Callable[[plt.AST], plt.AST]:
5✔
1945
        if isinstance(unop, ast.USub):
5✔
1946
            return lambda x: plt.SubtractInteger(plt.Integer(0), x)
5✔
1947
        if isinstance(unop, ast.UAdd):
5✔
1948
            return lambda x: x
5✔
1949
        if isinstance(unop, ast.Not):
5!
1950
            return lambda x: plt.EqualsInteger(x, plt.Integer(0))
5✔
1951
        return super()._unop_fun(unop)
×
1952

1953
    def copy_only_attributes(self) -> plt.AST:
5✔
1954
        return OLambda(
5✔
1955
            ["self"],
1956
            plt.DelayedChooseData(
1957
                OVar("self"),
1958
                plt.TraceError(
1959
                    f"IntegrityError: Expected PlutusInteger but got PlutusData"
1960
                ),
1961
                plt.TraceError(
1962
                    f"IntegrityError: Expected PlutusInteger but got PlutusMap"
1963
                ),
1964
                plt.TraceError(
1965
                    f"IntegrityError: Expected PlutusInteger but got PlutusList"
1966
                ),
1967
                OVar("self"),
1968
                plt.TraceError(
1969
                    f"IntegrityError: Expected PlutusInteger but got PlutusByteString"
1970
                ),
1971
            ),
1972
        )
1973

1974

1975
@dataclass(frozen=True, unsafe_hash=True)
5✔
1976
class StringType(AtomicType):
5✔
1977

1978
    def python_type(self):
5✔
1979
        return "str"
5✔
1980

1981
    def constr_type(self) -> InstanceType:
5✔
1982
        return InstanceType(PolymorphicFunctionType(StrImpl()))
5✔
1983

1984
    def attribute_type(self, attr) -> Type:
5✔
1985
        if attr == "encode":
5!
1986
            return InstanceType(FunctionType(frozenlist([]), ByteStringInstanceType))
5✔
1987
        return super().attribute_type(attr)
×
1988

1989
    def attribute(self, attr) -> plt.AST:
5✔
1990
        if attr == "encode":
5!
1991
            # No codec -> only the default (utf8) is allowed
1992
            return OLambda(["x", "_"], plt.EncodeUtf8(OVar("x")))
5✔
1993
        return super().attribute(attr)
×
1994

1995
    def cmp(self, op: ast.cmpop, o: "Type") -> plt.AST:
5✔
1996
        if isinstance(o, StringType):
5!
1997
            if isinstance(op, ast.Eq):
5✔
1998
                return plt.BuiltIn(uplc.BuiltInFun.EqualsString)
5✔
1999
            if isinstance(op, ast.NotEq):
5!
2000
                return OLambda(
5✔
2001
                    ["x", "y"], plt.Not(plt.EqualsString(OVar("x"), OVar("y")))
2002
                )
UNCOV
2003
        return super().cmp(op, o)
×
2004

2005
    def stringify(self, recursive: bool = False) -> plt.AST:
5✔
2006
        if recursive:
5✔
2007
            # TODO this is not correct, as the string is not properly escaped
2008
            return OLambda(
5✔
2009
                ["self"],
2010
                plt.ConcatString(plt.Text("'"), OVar("self"), plt.Text("'")),
2011
            )
2012
        else:
2013
            return OLambda(["self"], OVar("self"))
5✔
2014

2015
    def _binop_return_type(self, binop: ast.operator, other: "Type") -> "Type":
5✔
2016
        if isinstance(binop, ast.Add):
5✔
2017
            if other == StringInstanceType:
5!
2018
                return StringType()
5✔
2019
        if isinstance(binop, ast.Mult):
5!
2020
            if other == IntegerInstanceType:
5!
2021
                return StringType()
5✔
2022
        return super()._binop_return_type(binop, other)
×
2023

2024
    def _binop_bin_fun(self, binop: ast.operator, other: "TypedAST"):
5✔
2025
        if isinstance(binop, ast.Add):
5✔
2026
            if other.typ == StringInstanceType:
5!
2027
                return plt.AppendString
5✔
2028
        if isinstance(binop, ast.Mult):
5!
2029
            if other.typ == IntegerInstanceType:
5!
2030
                return StrIntMulImpl
5✔
2031
        return super()._binop_bin_fun(binop, other)
×
2032

2033
    def _unop_return_type(self, unop: ast.unaryop) -> "Type":
5✔
2034
        if isinstance(unop, ast.Not):
5!
2035
            return BoolType()
5✔
2036
        return super()._unop_return_type(unop)
×
2037

2038
    def _unop_fun(self, unop: ast.unaryop) -> Callable[[plt.AST], plt.AST]:
5✔
2039
        if isinstance(unop, ast.Not):
5!
2040
            return lambda x: plt.EqualsInteger(
5✔
2041
                plt.LengthOfByteString(plt.EncodeUtf8(x)), plt.Integer(0)
2042
            )
2043
        return super()._unop_fun(unop)
×
2044

2045
    def copy_only_attributes(self) -> plt.AST:
5✔
2046
        return OLambda(
5✔
2047
            ["self"],
2048
            plt.DelayedChooseData(
2049
                OVar("self"),
2050
                plt.TraceError(
2051
                    f"IntegrityError: Expected PlutusByteString but got PlutusData"
2052
                ),
2053
                plt.TraceError(
2054
                    f"IntegrityError: Expected PlutusByteString but got PlutusMap"
2055
                ),
2056
                plt.TraceError(
2057
                    f"IntegrityError: Expected PlutusByteString but got PlutusList"
2058
                ),
2059
                plt.TraceError(
2060
                    f"IntegrityError: Expected PlutusByteString but got PlutusInteger"
2061
                ),
2062
                OVar("self"),
2063
            ),
2064
        )
2065

2066

2067
@dataclass(frozen=True, unsafe_hash=True)
5✔
2068
class ByteStringType(AtomicType):
5✔
2069
    def pluthon_type(self, skip_constructor: bool = False) -> str:
5✔
2070
        return "bytes"
5✔
2071

2072
    def python_type(self):
5✔
2073
        return "bytes"
5✔
2074

2075
    def constr_type(self) -> InstanceType:
5✔
2076
        return InstanceType(PolymorphicFunctionType(BytesImpl()))
5✔
2077

2078
    def attribute_type(self, attr) -> Type:
5✔
2079
        if attr == "decode":
5✔
2080
            return InstanceType(FunctionType(frozenlist([]), StringInstanceType))
5✔
2081
        if attr == "hex":
5✔
2082
            return InstanceType(FunctionType(frozenlist([]), StringInstanceType))
5✔
2083
        if attr == "fromhex":
5✔
2084
            return InstanceType(
5✔
2085
                FunctionType(
2086
                    frozenlist([StringInstanceType]),
2087
                    ByteStringInstanceType,
2088
                )
2089
            )
2090
        if attr == "ljust" or attr == "rjust":
5!
2091
            return InstanceType(
5✔
2092
                FunctionType(
2093
                    frozenlist([IntegerInstanceType, InstanceType(ByteStringType())]),
2094
                    ByteStringInstanceType,
2095
                )
2096
            )
2097
        return super().attribute_type(attr)
×
2098

2099
    def attribute(self, attr) -> plt.AST:
5✔
2100
        if attr == "decode":
5✔
2101
            # No codec -> only the default (utf8) is allowed
2102
            return OLambda(["x", "_"], plt.DecodeUtf8(OVar("x")))
5✔
2103
        if attr == "hex":
5✔
2104
            return OLambda(
5✔
2105
                ["x", "_"],
2106
                plt.DecodeUtf8(
2107
                    OLet(
2108
                        [
2109
                            (
2110
                                "hexlist",
2111
                                plt.RecFun(
2112
                                    OLambda(
2113
                                        ["f", "i"],
2114
                                        plt.Ite(
2115
                                            plt.LessThanInteger(
2116
                                                OVar("i"), plt.Integer(0)
2117
                                            ),
2118
                                            plt.EmptyIntegerList(),
2119
                                            plt.MkCons(
2120
                                                plt.IndexByteString(
2121
                                                    OVar("x"), OVar("i")
2122
                                                ),
2123
                                                plt.Apply(
2124
                                                    OVar("f"),
2125
                                                    OVar("f"),
2126
                                                    plt.SubtractInteger(
2127
                                                        OVar("i"), plt.Integer(1)
2128
                                                    ),
2129
                                                ),
2130
                                            ),
2131
                                        ),
2132
                                    ),
2133
                                ),
2134
                            ),
2135
                            (
2136
                                "map_str",
2137
                                OLambda(
2138
                                    ["i"],
2139
                                    plt.AddInteger(
2140
                                        OVar("i"),
2141
                                        plt.IfThenElse(
2142
                                            plt.LessThanInteger(
2143
                                                OVar("i"), plt.Integer(10)
2144
                                            ),
2145
                                            plt.Integer(ord("0")),
2146
                                            plt.Integer(ord("a") - 10),
2147
                                        ),
2148
                                    ),
2149
                                ),
2150
                            ),
2151
                            (
2152
                                "mkstr",
2153
                                OLambda(
2154
                                    ["i"],
2155
                                    plt.FoldList(
2156
                                        plt.Apply(OVar("hexlist"), OVar("i")),
2157
                                        OLambda(
2158
                                            ["b", "i"],
2159
                                            plt.ConsByteString(
2160
                                                plt.Apply(
2161
                                                    OVar("map_str"),
2162
                                                    plt.DivideInteger(
2163
                                                        OVar("i"), plt.Integer(16)
2164
                                                    ),
2165
                                                ),
2166
                                                plt.ConsByteString(
2167
                                                    plt.Apply(
2168
                                                        OVar("map_str"),
2169
                                                        plt.ModInteger(
2170
                                                            OVar("i"),
2171
                                                            plt.Integer(16),
2172
                                                        ),
2173
                                                    ),
2174
                                                    OVar("b"),
2175
                                                ),
2176
                                            ),
2177
                                        ),
2178
                                        plt.ByteString(b""),
2179
                                    ),
2180
                                ),
2181
                            ),
2182
                        ],
2183
                        plt.Apply(
2184
                            OVar("mkstr"),
2185
                            plt.SubtractInteger(
2186
                                plt.LengthOfByteString(OVar("x")), plt.Integer(1)
2187
                            ),
2188
                        ),
2189
                    ),
2190
                ),
2191
            )
2192
        if attr == "fromhex":
5✔
2193
            return OLambda(
5✔
2194
                ["_", "x"],
2195
                OLet(
2196
                    [
2197
                        (
2198
                            "bytestr",
2199
                            plt.EncodeUtf8(plt.Force(OVar("x"))),
2200
                        ),
2201
                        (
2202
                            "bytestr_len",
2203
                            plt.LengthOfByteString(OVar("bytestr")),
2204
                        ),
2205
                        (
2206
                            "char_to_int",
2207
                            OLambda(
2208
                                ["c"],
2209
                                plt.Ite(
2210
                                    plt.And(
2211
                                        plt.LessThanEqualsInteger(
2212
                                            plt.Integer(ord("a")), OVar("c")
2213
                                        ),
2214
                                        plt.LessThanEqualsInteger(
2215
                                            OVar("c"), plt.Integer(ord("f"))
2216
                                        ),
2217
                                    ),
2218
                                    plt.AddInteger(
2219
                                        plt.SubtractInteger(
2220
                                            OVar("c"), plt.Integer(ord("a"))
2221
                                        ),
2222
                                        plt.Integer(10),
2223
                                    ),
2224
                                    plt.Ite(
2225
                                        plt.And(
2226
                                            plt.LessThanEqualsInteger(
2227
                                                plt.Integer(ord("0")), OVar("c")
2228
                                            ),
2229
                                            plt.LessThanEqualsInteger(
2230
                                                OVar("c"), plt.Integer(ord("9"))
2231
                                            ),
2232
                                        ),
2233
                                        plt.SubtractInteger(
2234
                                            OVar("c"), plt.Integer(ord("0"))
2235
                                        ),
2236
                                        plt.Ite(
2237
                                            plt.And(
2238
                                                plt.LessThanEqualsInteger(
2239
                                                    plt.Integer(ord("A")), OVar("c")
2240
                                                ),
2241
                                                plt.LessThanEqualsInteger(
2242
                                                    OVar("c"), plt.Integer(ord("F"))
2243
                                                ),
2244
                                            ),
2245
                                            plt.AddInteger(
2246
                                                plt.SubtractInteger(
2247
                                                    OVar("c"), plt.Integer(ord("A"))
2248
                                                ),
2249
                                                plt.Integer(10),
2250
                                            ),
2251
                                            plt.TraceError("Invalid hex character"),
2252
                                        ),
2253
                                    ),
2254
                                ),
2255
                            ),
2256
                        ),
2257
                        (
2258
                            "splitlist",
2259
                            plt.RecFun(
2260
                                OLambda(
2261
                                    ["f", "i"],
2262
                                    plt.Ite(
2263
                                        plt.LessThanInteger(
2264
                                            OVar("bytestr_len"),
2265
                                            plt.AddInteger(OVar("i"), plt.Integer(1)),
2266
                                        ),
2267
                                        plt.ByteString(b""),
2268
                                        plt.Ite(
2269
                                            plt.LessThanInteger(
2270
                                                OVar("bytestr_len"),
2271
                                                plt.AddInteger(
2272
                                                    OVar("i"), plt.Integer(2)
2273
                                                ),
2274
                                            ),
2275
                                            plt.TraceError("Invalid hex string"),
2276
                                            OLet(
2277
                                                [
2278
                                                    (
2279
                                                        "char_at_i",
2280
                                                        plt.IndexByteString(
2281
                                                            OVar("bytestr"),
2282
                                                            OVar("i"),
2283
                                                        ),
2284
                                                    ),
2285
                                                    (
2286
                                                        "char_at_ip1",
2287
                                                        plt.IndexByteString(
2288
                                                            OVar("bytestr"),
2289
                                                            plt.AddInteger(
2290
                                                                OVar("i"),
2291
                                                                plt.Integer(1),
2292
                                                            ),
2293
                                                        ),
2294
                                                    ),
2295
                                                ],
2296
                                                plt.ConsByteString(
2297
                                                    plt.AddInteger(
2298
                                                        plt.MultiplyInteger(
2299
                                                            plt.Apply(
2300
                                                                OVar("char_to_int"),
2301
                                                                OVar("char_at_i"),
2302
                                                            ),
2303
                                                            plt.Integer(16),
2304
                                                        ),
2305
                                                        plt.Apply(
2306
                                                            OVar("char_to_int"),
2307
                                                            OVar("char_at_ip1"),
2308
                                                        ),
2309
                                                    ),
2310
                                                    plt.Apply(
2311
                                                        OVar("f"),
2312
                                                        OVar("f"),
2313
                                                        plt.AddInteger(
2314
                                                            OVar("i"),
2315
                                                            plt.Integer(2),
2316
                                                        ),
2317
                                                    ),
2318
                                                ),
2319
                                            ),
2320
                                        ),
2321
                                    ),
2322
                                )
2323
                            ),
2324
                        ),
2325
                    ],
2326
                    plt.Apply(OVar("splitlist"), plt.Integer(0)),
2327
                ),
2328
            )
2329
        if attr == "ljust":
5✔
2330
            return OLambda(
5✔
2331
                ["x", "width", "fillchar"],
2332
                OLet(
2333
                    [
2334
                        ("fillchar", plt.Force(OVar("fillchar"))),
2335
                        ("width", plt.Force(OVar("width"))),
2336
                    ],
2337
                    plt.Ite(
2338
                        plt.NotEqualsInteger(
2339
                            plt.LengthOfByteString(OVar("fillchar")), plt.Integer(1)
2340
                        ),
2341
                        plt.TraceError("fillchar must be a single byte"),
2342
                        plt.Ite(
2343
                            plt.LessThanInteger(
2344
                                plt.LengthOfByteString(OVar("x")), OVar("width")
2345
                            ),
2346
                            plt.AppendByteString(
2347
                                OVar("x"),
2348
                                ByteStrIntMulImpl(
2349
                                    OVar("fillchar"),
2350
                                    plt.SubtractInteger(
2351
                                        OVar("width"), plt.LengthOfByteString(OVar("x"))
2352
                                    ),
2353
                                ),
2354
                            ),
2355
                            OVar("x"),
2356
                        ),
2357
                    ),
2358
                ),
2359
            )
2360
        if attr == "rjust":
5!
2361
            return OLambda(
5✔
2362
                ["x", "width", "fillchar"],
2363
                OLet(
2364
                    [
2365
                        ("fillchar", plt.Force(OVar("fillchar"))),
2366
                        ("width", plt.Force(OVar("width"))),
2367
                    ],
2368
                    plt.Ite(
2369
                        plt.NotEqualsInteger(
2370
                            plt.LengthOfByteString(OVar("fillchar")), plt.Integer(1)
2371
                        ),
2372
                        plt.TraceError("fillchar must be a single byte"),
2373
                        plt.Ite(
2374
                            plt.LessThanInteger(
2375
                                plt.LengthOfByteString(OVar("x")), OVar("width")
2376
                            ),
2377
                            plt.AppendByteString(
2378
                                ByteStrIntMulImpl(
2379
                                    OVar("fillchar"),
2380
                                    plt.SubtractInteger(
2381
                                        OVar("width"), plt.LengthOfByteString(OVar("x"))
2382
                                    ),
2383
                                ),
2384
                                OVar("x"),
2385
                            ),
2386
                            OVar("x"),
2387
                        ),
2388
                    ),
2389
                ),
2390
            )
2391

2392
        return super().attribute(attr)
×
2393

2394
    def cmp(self, op: ast.cmpop, o: "Type") -> plt.AST:
5✔
2395
        if isinstance(o, ByteStringType):
5✔
2396
            if isinstance(op, ast.Eq):
5✔
2397
                return plt.BuiltIn(uplc.BuiltInFun.EqualsByteString)
5✔
2398
            if isinstance(op, ast.NotEq):
5!
2399
                return OLambda(
5✔
2400
                    ["x", "y"],
2401
                    plt.Not(
2402
                        plt.Apply(
2403
                            plt.BuiltIn(uplc.BuiltInFun.EqualsByteString),
2404
                            OVar("y"),
2405
                            OVar("x"),
2406
                        )
2407
                    ),
2408
                )
2409
            if isinstance(op, ast.Lt):
×
2410
                return plt.BuiltIn(uplc.BuiltInFun.LessThanByteString)
×
2411
            if isinstance(op, ast.LtE):
×
2412
                return plt.BuiltIn(uplc.BuiltInFun.LessThanEqualsByteString)
×
2413
            if isinstance(op, ast.Gt):
×
2414
                return OLambda(
×
2415
                    ["x", "y"],
2416
                    plt.Apply(
2417
                        plt.BuiltIn(uplc.BuiltInFun.LessThanByteString),
2418
                        OVar("y"),
2419
                        OVar("x"),
2420
                    ),
2421
                )
2422
            if isinstance(op, ast.GtE):
×
2423
                return OLambda(
×
2424
                    ["x", "y"],
2425
                    plt.Apply(
2426
                        plt.BuiltIn(uplc.BuiltInFun.LessThanEqualsByteString),
2427
                        OVar("y"),
2428
                        OVar("x"),
2429
                    ),
2430
                )
2431
        if (
5!
2432
            isinstance(o, ListType)
2433
            and isinstance(o.typ, InstanceType)
2434
            and isinstance(o.typ.typ, ByteStringType)
2435
        ):
2436
            if isinstance(op, ast.In):
5✔
2437
                return OLambda(
5✔
2438
                    ["x", "y"],
2439
                    plt.AnyList(
2440
                        OVar("y"),
2441
                        plt.Apply(
2442
                            plt.BuiltIn(uplc.BuiltInFun.EqualsByteString),
2443
                            OVar("x"),
2444
                        ),
2445
                    ),
2446
                )
2447
            if isinstance(op, ast.NotIn):
5!
2448
                return OLambda(
5✔
2449
                    ["x", "y"],
2450
                    plt.Not(
2451
                        plt.AnyList(
2452
                            OVar("y"),
2453
                            plt.Apply(
2454
                                plt.BuiltIn(uplc.BuiltInFun.EqualsByteString),
2455
                                OVar("x"),
2456
                            ),
2457
                        ),
2458
                    ),
2459
                )
2460
        return super().cmp(op, o)
×
2461

2462
    def stringify(self, recursive: bool = False) -> plt.AST:
5✔
2463
        return OLambda(
5✔
2464
            ["x"],
2465
            plt.DecodeUtf8(
2466
                OLet(
2467
                    [
2468
                        (
2469
                            "hexlist",
2470
                            plt.RecFun(
2471
                                OLambda(
2472
                                    ["f", "i"],
2473
                                    plt.Ite(
2474
                                        plt.LessThanInteger(OVar("i"), plt.Integer(0)),
2475
                                        plt.EmptyIntegerList(),
2476
                                        plt.MkCons(
2477
                                            plt.IndexByteString(OVar("x"), OVar("i")),
2478
                                            plt.Apply(
2479
                                                OVar("f"),
2480
                                                OVar("f"),
2481
                                                plt.SubtractInteger(
2482
                                                    OVar("i"), plt.Integer(1)
2483
                                                ),
2484
                                            ),
2485
                                        ),
2486
                                    ),
2487
                                ),
2488
                            ),
2489
                        ),
2490
                        (
2491
                            "map_str",
2492
                            OLambda(
2493
                                ["i"],
2494
                                plt.AddInteger(
2495
                                    OVar("i"),
2496
                                    plt.IfThenElse(
2497
                                        plt.LessThanInteger(OVar("i"), plt.Integer(10)),
2498
                                        plt.Integer(ord("0")),
2499
                                        plt.Integer(ord("a") - 10),
2500
                                    ),
2501
                                ),
2502
                            ),
2503
                        ),
2504
                        (
2505
                            "mkstr",
2506
                            OLambda(
2507
                                ["i"],
2508
                                plt.FoldList(
2509
                                    plt.Apply(OVar("hexlist"), OVar("i")),
2510
                                    OLambda(
2511
                                        ["b", "i"],
2512
                                        plt.Ite(
2513
                                            # ascii printable characters are kept unmodified
2514
                                            plt.And(
2515
                                                plt.LessThanEqualsInteger(
2516
                                                    plt.Integer(0x20), OVar("i")
2517
                                                ),
2518
                                                plt.LessThanEqualsInteger(
2519
                                                    OVar("i"), plt.Integer(0x7E)
2520
                                                ),
2521
                                            ),
2522
                                            plt.Ite(
2523
                                                plt.EqualsInteger(
2524
                                                    OVar("i"),
2525
                                                    plt.Integer(ord("\\")),
2526
                                                ),
2527
                                                plt.AppendByteString(
2528
                                                    plt.ByteString(b"\\\\"),
2529
                                                    OVar("b"),
2530
                                                ),
2531
                                                plt.Ite(
2532
                                                    plt.EqualsInteger(
2533
                                                        OVar("i"),
2534
                                                        plt.Integer(ord("'")),
2535
                                                    ),
2536
                                                    plt.AppendByteString(
2537
                                                        plt.ByteString(b"\\'"),
2538
                                                        OVar("b"),
2539
                                                    ),
2540
                                                    plt.ConsByteString(
2541
                                                        OVar("i"), OVar("b")
2542
                                                    ),
2543
                                                ),
2544
                                            ),
2545
                                            plt.Ite(
2546
                                                plt.EqualsInteger(
2547
                                                    OVar("i"), plt.Integer(ord("\t"))
2548
                                                ),
2549
                                                plt.AppendByteString(
2550
                                                    plt.ByteString(b"\\t"), OVar("b")
2551
                                                ),
2552
                                                plt.Ite(
2553
                                                    plt.EqualsInteger(
2554
                                                        OVar("i"),
2555
                                                        plt.Integer(ord("\n")),
2556
                                                    ),
2557
                                                    plt.AppendByteString(
2558
                                                        plt.ByteString(b"\\n"),
2559
                                                        OVar("b"),
2560
                                                    ),
2561
                                                    plt.Ite(
2562
                                                        plt.EqualsInteger(
2563
                                                            OVar("i"),
2564
                                                            plt.Integer(ord("\r")),
2565
                                                        ),
2566
                                                        plt.AppendByteString(
2567
                                                            plt.ByteString(b"\\r"),
2568
                                                            OVar("b"),
2569
                                                        ),
2570
                                                        plt.AppendByteString(
2571
                                                            plt.ByteString(b"\\x"),
2572
                                                            plt.ConsByteString(
2573
                                                                plt.Apply(
2574
                                                                    OVar("map_str"),
2575
                                                                    plt.DivideInteger(
2576
                                                                        OVar("i"),
2577
                                                                        plt.Integer(16),
2578
                                                                    ),
2579
                                                                ),
2580
                                                                plt.ConsByteString(
2581
                                                                    plt.Apply(
2582
                                                                        OVar("map_str"),
2583
                                                                        plt.ModInteger(
2584
                                                                            OVar("i"),
2585
                                                                            plt.Integer(
2586
                                                                                16
2587
                                                                            ),
2588
                                                                        ),
2589
                                                                    ),
2590
                                                                    OVar("b"),
2591
                                                                ),
2592
                                                            ),
2593
                                                        ),
2594
                                                    ),
2595
                                                ),
2596
                                            ),
2597
                                        ),
2598
                                    ),
2599
                                    plt.ByteString(b""),
2600
                                ),
2601
                            ),
2602
                        ),
2603
                    ],
2604
                    plt.ConcatByteString(
2605
                        plt.ByteString(b"b'"),
2606
                        plt.Apply(
2607
                            OVar("mkstr"),
2608
                            plt.SubtractInteger(
2609
                                plt.LengthOfByteString(OVar("x")), plt.Integer(1)
2610
                            ),
2611
                        ),
2612
                        plt.ByteString(b"'"),
2613
                    ),
2614
                ),
2615
            ),
2616
        )
2617

2618
    def _binop_return_type(self, binop: ast.operator, other: "Type") -> "Type":
5✔
2619
        if isinstance(binop, ast.Add):
5✔
2620
            if other == ByteStringInstanceType:
5!
2621
                return ByteStringType()
5✔
2622
        if isinstance(binop, ast.Mult):
5!
2623
            if other == IntegerInstanceType:
5!
2624
                return ByteStringType()
5✔
2625
        return super()._binop_return_type(binop, other)
×
2626

2627
    def _binop_bin_fun(self, binop: ast.operator, other: "TypedAST"):
5✔
2628
        if isinstance(binop, ast.Add):
5✔
2629
            if other.typ == ByteStringInstanceType:
5!
2630
                return plt.AppendByteString
5✔
2631
        if isinstance(binop, ast.Mult):
5!
2632
            if other.typ == IntegerInstanceType:
5!
2633
                return ByteStrIntMulImpl
5✔
2634
        return super()._binop_bin_fun(binop, other)
×
2635

2636
    def _unop_return_type(self, unop: ast.unaryop) -> "Type":
5✔
2637
        if isinstance(unop, ast.Not):
5!
2638
            return BoolType()
5✔
2639
        return super()._unop_return_type(unop)
×
2640

2641
    def _unop_fun(self, unop: ast.unaryop) -> Callable[[plt.AST], plt.AST]:
5✔
2642
        if isinstance(unop, ast.Not):
5!
2643
            return lambda x: plt.EqualsInteger(
5✔
2644
                plt.LengthOfByteString(x), plt.Integer(0)
2645
            )
2646
        return super()._unop_fun(unop)
×
2647

2648
    def copy_only_attributes(self) -> plt.AST:
5✔
2649
        return OLambda(
5✔
2650
            ["self"],
2651
            plt.DelayedChooseData(
2652
                OVar("self"),
2653
                plt.TraceError(
2654
                    f"IntegrityError: Expected PlutusByteString but got PlutusData"
2655
                ),
2656
                plt.TraceError(
2657
                    f"IntegrityError: Expected PlutusByteString but got PlutusMap"
2658
                ),
2659
                plt.TraceError(
2660
                    f"IntegrityError: Expected PlutusByteString but got PlutusList"
2661
                ),
2662
                plt.TraceError(
2663
                    f"IntegrityError: Expected PlutusByteString but got PlutusInteger"
2664
                ),
2665
                OVar("self"),
2666
            ),
2667
        )
2668

2669

2670
@dataclass(frozen=True, unsafe_hash=True)
5✔
2671
class BoolType(AtomicType):
5✔
2672

2673
    def python_type(self):
5✔
2674
        return "bool"
5✔
2675

2676
    def constr_type(self) -> "InstanceType":
5✔
2677
        return InstanceType(PolymorphicFunctionType(BoolImpl()))
5✔
2678

2679
    def cmp(self, op: ast.cmpop, o: "Type") -> plt.AST:
5✔
2680
        if isinstance(o, IntegerType):
5✔
2681
            if isinstance(op, ast.Eq):
5✔
2682
                # 1 == True
2683
                # 0 == False
2684
                # all other comparisons are False
2685
                return OLambda(
5✔
2686
                    ["y", "x"],
2687
                    plt.Ite(
2688
                        OVar("y"),
2689
                        plt.EqualsInteger(OVar("x"), plt.Integer(1)),
2690
                        plt.EqualsInteger(OVar("x"), plt.Integer(0)),
2691
                    ),
2692
                )
2693
            if isinstance(op, ast.NotEq):
5!
2694
                return OLambda(
5✔
2695
                    ["y", "x"],
2696
                    plt.Ite(
2697
                        OVar("y"),
2698
                        plt.NotEqualsInteger(OVar("x"), plt.Integer(1)),
2699
                        plt.NotEqualsInteger(OVar("x"), plt.Integer(0)),
2700
                    ),
2701
                )
2702
        if isinstance(o, BoolType):
5!
2703
            if isinstance(op, ast.Eq):
5✔
2704
                return OLambda(["x", "y"], plt.Iff(OVar("x"), OVar("y")))
5✔
2705
            if isinstance(op, ast.NotEq):
5✔
2706
                return OLambda(["x", "y"], plt.Not(plt.Iff(OVar("x"), OVar("y"))))
5✔
2707
            if isinstance(op, ast.Lt):
5✔
2708
                return OLambda(["x", "y"], plt.And(plt.Not(OVar("x")), OVar("y")))
5✔
2709
            if isinstance(op, ast.Gt):
5!
2710
                return OLambda(["x", "y"], plt.And(OVar("x"), plt.Not(OVar("y"))))
5✔
2711
        return super().cmp(op, o)
×
2712

2713
    def stringify(self, recursive: bool = False) -> plt.AST:
5✔
2714
        return OLambda(
5✔
2715
            ["self"],
2716
            plt.Ite(
2717
                OVar("self"),
2718
                plt.Text("True"),
2719
                plt.Text("False"),
2720
            ),
2721
        )
2722

2723
    def _unop_return_type(self, unop: ast.unaryop) -> "Type":
5✔
2724
        if isinstance(unop, ast.Not):
5!
2725
            return BoolType()
5✔
2726
        return super()._unop_return_type(unop)
×
2727

2728
    def _unop_fun(self, unop: ast.unaryop) -> Callable[[plt.AST], plt.AST]:
5✔
2729
        if isinstance(unop, ast.Not):
5!
2730
            return plt.Not
5✔
2731
        return super()._unop_fun(unop)
×
2732

2733
    def copy_only_attributes(self) -> plt.AST:
5✔
2734
        return OLambda(
×
2735
            ["self"],
2736
            plt.DelayedChooseData(
2737
                OVar("self"),
2738
                plt.TraceError(
2739
                    f"IntegrityError: Expected PlutusByteInteger but got PlutusData"
2740
                ),
2741
                plt.TraceError(
2742
                    f"IntegrityError: Expected PlutusByteInteger but got PlutusMap"
2743
                ),
2744
                plt.TraceError(
2745
                    f"IntegrityError: Expected PlutusByteInteger but got PlutusList"
2746
                ),
2747
                OVar("self"),
2748
                plt.TraceError(
2749
                    f"IntegrityError: Expected PlutusByteInteger but got PlutusByteString"
2750
                ),
2751
            ),
2752
        )
2753

2754

2755
@dataclass(frozen=True, unsafe_hash=True)
5✔
2756
class UnitType(AtomicType):
5✔
2757
    def cmp(self, op: ast.cmpop, o: "Type") -> plt.AST:
5✔
2758
        if isinstance(o, UnitType):
×
2759
            if isinstance(op, ast.Eq):
×
2760
                return OLambda(["x", "y"], plt.Bool(True))
×
2761
            if isinstance(op, ast.NotEq):
×
2762
                return OLambda(["x", "y"], plt.Bool(False))
×
2763
        return super().cmp(op, o)
×
2764

2765
    def stringify(self, recursive: bool = False) -> plt.AST:
5✔
2766
        return OLambda(["self"], plt.Text("None"))
5✔
2767

2768
    def _unop_return_type(self, unop: ast.unaryop) -> "Type":
5✔
2769
        if isinstance(unop, ast.Not):
5!
2770
            return BoolType()
5✔
2771
        return super()._unop_return_type(unop)
×
2772

2773
    def _unop_fun(self, unop: ast.unaryop) -> Callable[[plt.AST], plt.AST]:
5✔
2774
        if isinstance(unop, ast.Not):
5!
2775
            return lambda x: plt.Bool(True)
5✔
2776
        return super()._unop_fun(unop)
×
2777

2778
    def python_type(self):
5✔
2779
        return "None"
×
2780

2781

2782
IntegerInstanceType = InstanceType(IntegerType())
5✔
2783
StringInstanceType = InstanceType(StringType())
5✔
2784
ByteStringInstanceType = InstanceType(ByteStringType())
5✔
2785
BoolInstanceType = InstanceType(BoolType())
5✔
2786
UnitInstanceType = InstanceType(UnitType())
5✔
2787

2788
ATOMIC_TYPES = {
5✔
2789
    int.__name__: IntegerType(),
2790
    str.__name__: StringType(),
2791
    bytes.__name__: ByteStringType(),
2792
    bytearray.__name__: ByteStringType(),
2793
    type(None).__name__: UnitType(),
2794
    bool.__name__: BoolType(),
2795
}
2796

2797

2798
NoneInstanceType = UnitInstanceType
5✔
2799

2800

2801
class InaccessibleType(ClassType):
5✔
2802
    """A type that blocks overwriting of a function"""
2803

2804
    pass
4✔
2805

2806
    def python_type(self):
5✔
2807
        return "<forbidden>"
×
2808

2809

2810
def repeated_addition(zero, add):
5✔
2811
    # this is optimized for logarithmic complexity by exponentiation by squaring
2812
    # it follows the implementation described here: https://en.wikipedia.org/wiki/Exponentiation_by_squaring#With_constant_auxiliary_memory
2813
    def RepeatedAdd(x: plt.AST, y: plt.AST):
5✔
2814
        return plt.Apply(
5✔
2815
            plt.RecFun(
2816
                OLambda(
2817
                    ["f", "y", "x", "n"],
2818
                    plt.Ite(
2819
                        plt.LessThanEqualsInteger(OVar("n"), plt.Integer(0)),
2820
                        OVar("y"),
2821
                        OLet(
2822
                            [
2823
                                (
2824
                                    "n_half",
2825
                                    plt.DivideInteger(OVar("n"), plt.Integer(2)),
2826
                                )
2827
                            ],
2828
                            plt.Ite(
2829
                                # tests whether (x//2)*2 == x which is True iff x is even
2830
                                plt.EqualsInteger(
2831
                                    plt.AddInteger(OVar("n_half"), OVar("n_half")),
2832
                                    OVar("n"),
2833
                                ),
2834
                                plt.Apply(
2835
                                    OVar("f"),
2836
                                    OVar("f"),
2837
                                    OVar("y"),
2838
                                    add(OVar("x"), OVar("x")),
2839
                                    OVar("n_half"),
2840
                                ),
2841
                                plt.Apply(
2842
                                    OVar("f"),
2843
                                    OVar("f"),
2844
                                    add(OVar("y"), OVar("x")),
2845
                                    add(OVar("x"), OVar("x")),
2846
                                    OVar("n_half"),
2847
                                ),
2848
                            ),
2849
                        ),
2850
                    ),
2851
                ),
2852
            ),
2853
            zero,
2854
            x,
2855
            y,
2856
        )
2857

2858
    return RepeatedAdd
5✔
2859

2860

2861
PowImpl = repeated_addition(plt.Integer(1), plt.MultiplyInteger)
5✔
2862
ByteStrIntMulImpl = repeated_addition(plt.ByteString(b""), plt.AppendByteString)
5✔
2863
StrIntMulImpl = repeated_addition(plt.Text(""), plt.AppendString)
5✔
2864

2865

2866
class PolymorphicFunction:
5✔
2867
    def __new__(meta, *args, **kwargs):
5✔
2868
        klass = super().__new__(meta)
5✔
2869

2870
        for key in ["impl_from_args"]:
5✔
2871
            value = getattr(klass, key)
5✔
2872
            wrapped = patternize(value)
5✔
2873
            object.__setattr__(klass, key, wrapped)
5✔
2874

2875
        return klass
5✔
2876

2877
    def type_from_args(self, args: typing.List[Type]) -> FunctionType:
5✔
2878
        raise NotImplementedError()
2879

2880
    def impl_from_args(self, args: typing.List[Type]) -> plt.AST:
5✔
2881
        raise NotImplementedError()
2882

2883

2884
class StrImpl(PolymorphicFunction):
5✔
2885
    def type_from_args(self, args: typing.List[Type]) -> FunctionType:
5✔
2886
        assert (
5✔
2887
            len(args) == 1
2888
        ), f"'str' takes only one argument, but {len(args)} were given"
2889
        typ = args[0]
5✔
2890
        assert isinstance(typ, InstanceType), "Can only stringify instances"
5✔
2891
        return FunctionType(args, StringInstanceType)
5✔
2892

2893
    def impl_from_args(self, args: typing.List[Type]) -> plt.AST:
5✔
2894
        arg = args[0]
5✔
2895
        assert isinstance(arg, InstanceType), "Can only stringify instances"
5✔
2896
        return arg.typ.stringify()
5✔
2897

2898

2899
class IntImpl(PolymorphicFunction):
5✔
2900
    def type_from_args(self, args: typing.List[Type]) -> FunctionType:
5✔
2901
        assert (
5✔
2902
            len(args) == 1
2903
        ), f"'int' takes only one argument, but {len(args)} were given"
2904
        typ = args[0]
5✔
2905
        assert isinstance(typ, InstanceType), "Can only create ints from instances"
5✔
2906
        assert any(
5✔
2907
            isinstance(typ.typ, t) for t in (IntegerType, StringType, BoolType)
2908
        ), "Can only create integers from int, str or bool"
2909
        return FunctionType(args, IntegerInstanceType)
5✔
2910

2911
    def impl_from_args(self, args: typing.List[Type]) -> plt.AST:
5✔
2912
        arg = args[0]
5✔
2913
        assert isinstance(arg, InstanceType), "Can only create ints from instances"
5✔
2914
        if isinstance(arg.typ, IntegerType):
5✔
2915
            return OLambda(["x"], OVar("x"))
5✔
2916
        elif isinstance(arg.typ, BoolType):
5✔
2917
            return OLambda(
5✔
2918
                ["x"], plt.IfThenElse(OVar("x"), plt.Integer(1), plt.Integer(0))
2919
            )
2920
        elif isinstance(arg.typ, StringType):
5✔
2921
            return OLambda(
5✔
2922
                ["x"],
2923
                OLet(
2924
                    [
2925
                        ("e", plt.EncodeUtf8(OVar("x"))),
2926
                        ("len", plt.LengthOfByteString(OVar("e"))),
2927
                        (
2928
                            "first_int",
2929
                            plt.Ite(
2930
                                plt.LessThanInteger(plt.Integer(0), OVar("len")),
2931
                                plt.IndexByteString(OVar("e"), plt.Integer(0)),
2932
                                plt.Integer(ord("_")),
2933
                            ),
2934
                        ),
2935
                        (
2936
                            "last_int",
2937
                            plt.IndexByteString(
2938
                                OVar("e"),
2939
                                plt.SubtractInteger(OVar("len"), plt.Integer(1)),
2940
                            ),
2941
                        ),
2942
                        (
2943
                            "fold_start",
2944
                            OLambda(
2945
                                ["start"],
2946
                                plt.FoldList(
2947
                                    plt.Range(OVar("len"), OVar("start")),
2948
                                    OLambda(
2949
                                        ["s", "i"],
2950
                                        OLet(
2951
                                            [
2952
                                                (
2953
                                                    "b",
2954
                                                    plt.IndexByteString(
2955
                                                        OVar("e"), OVar("i")
2956
                                                    ),
2957
                                                )
2958
                                            ],
2959
                                            plt.Ite(
2960
                                                plt.EqualsInteger(
2961
                                                    OVar("b"), plt.Integer(ord("_"))
2962
                                                ),
2963
                                                OVar("s"),
2964
                                                plt.Ite(
2965
                                                    plt.Or(
2966
                                                        plt.LessThanInteger(
2967
                                                            OVar("b"),
2968
                                                            plt.Integer(ord("0")),
2969
                                                        ),
2970
                                                        plt.LessThanInteger(
2971
                                                            plt.Integer(ord("9")),
2972
                                                            OVar("b"),
2973
                                                        ),
2974
                                                    ),
2975
                                                    plt.TraceError(
2976
                                                        "ValueError: invalid literal for int() with base 10"
2977
                                                    ),
2978
                                                    plt.AddInteger(
2979
                                                        plt.SubtractInteger(
2980
                                                            OVar("b"),
2981
                                                            plt.Integer(ord("0")),
2982
                                                        ),
2983
                                                        plt.MultiplyInteger(
2984
                                                            OVar("s"),
2985
                                                            plt.Integer(10),
2986
                                                        ),
2987
                                                    ),
2988
                                                ),
2989
                                            ),
2990
                                        ),
2991
                                    ),
2992
                                    plt.Integer(0),
2993
                                ),
2994
                            ),
2995
                        ),
2996
                    ],
2997
                    plt.Ite(
2998
                        plt.Or(
2999
                            plt.Or(
3000
                                plt.EqualsInteger(
3001
                                    OVar("first_int"),
3002
                                    plt.Integer(ord("_")),
3003
                                ),
3004
                                plt.EqualsInteger(
3005
                                    OVar("last_int"),
3006
                                    plt.Integer(ord("_")),
3007
                                ),
3008
                            ),
3009
                            plt.And(
3010
                                plt.EqualsInteger(OVar("len"), plt.Integer(1)),
3011
                                plt.Or(
3012
                                    plt.EqualsInteger(
3013
                                        OVar("first_int"),
3014
                                        plt.Integer(ord("-")),
3015
                                    ),
3016
                                    plt.EqualsInteger(
3017
                                        OVar("first_int"),
3018
                                        plt.Integer(ord("+")),
3019
                                    ),
3020
                                ),
3021
                            ),
3022
                        ),
3023
                        plt.TraceError(
3024
                            "ValueError: invalid literal for int() with base 10"
3025
                        ),
3026
                        plt.Ite(
3027
                            plt.EqualsInteger(
3028
                                OVar("first_int"),
3029
                                plt.Integer(ord("-")),
3030
                            ),
3031
                            plt.Negate(
3032
                                plt.Apply(OVar("fold_start"), plt.Integer(1)),
3033
                            ),
3034
                            plt.Ite(
3035
                                plt.EqualsInteger(
3036
                                    OVar("first_int"),
3037
                                    plt.Integer(ord("+")),
3038
                                ),
3039
                                plt.Apply(OVar("fold_start"), plt.Integer(1)),
3040
                                plt.Apply(OVar("fold_start"), plt.Integer(0)),
3041
                            ),
3042
                        ),
3043
                    ),
3044
                ),
3045
            )
3046
        else:
3047
            raise NotImplementedError(
3048
                f"Can not derive integer from type {arg.typ.python_type()}"
3049
            )
3050

3051

3052
class BoolImpl(PolymorphicFunction):
5✔
3053
    def type_from_args(self, args: typing.List[Type]) -> FunctionType:
5✔
3054
        assert (
5✔
3055
            len(args) == 1
3056
        ), f"'bool' takes only one argument, but {len(args)} were given"
3057
        typ = args[0]
5✔
3058
        assert isinstance(typ, InstanceType), "Can only create bools from instances"
5✔
3059
        assert any(
5✔
3060
            isinstance(typ.typ, t)
3061
            for t in (
3062
                IntegerType,
3063
                StringType,
3064
                ByteStringType,
3065
                BoolType,
3066
                UnitType,
3067
                ListType,
3068
                DictType,
3069
            )
3070
        ), "Can only create bools from int, str, bool, bytes, None, list or dict"
3071
        return FunctionType(args, BoolInstanceType)
5✔
3072

3073
    def impl_from_args(self, args: typing.List[Type]) -> plt.AST:
5✔
3074
        arg = args[0]
5✔
3075
        assert isinstance(arg, InstanceType), "Can only create bools from instances"
5✔
3076
        if isinstance(arg.typ, BoolType):
5✔
3077
            return OLambda(["x"], OVar("x"))
5✔
3078
        elif isinstance(arg.typ, IntegerType):
5✔
3079
            return OLambda(["x"], plt.NotEqualsInteger(OVar("x"), plt.Integer(0)))
5✔
3080
        elif isinstance(arg.typ, StringType):
5✔
3081
            return OLambda(
5✔
3082
                ["x"],
3083
                plt.NotEqualsInteger(
3084
                    plt.LengthOfByteString(plt.EncodeUtf8(OVar("x"))), plt.Integer(0)
3085
                ),
3086
            )
3087
        elif isinstance(arg.typ, ByteStringType):
5✔
3088
            return OLambda(
5✔
3089
                ["x"],
3090
                plt.NotEqualsInteger(plt.LengthOfByteString(OVar("x")), plt.Integer(0)),
3091
            )
3092
        elif isinstance(arg.typ, ListType) or isinstance(arg.typ, DictType):
5✔
3093
            return OLambda(["x"], plt.Not(plt.NullList(OVar("x"))))
5✔
3094
        elif isinstance(arg.typ, UnitType):
5✔
3095
            return OLambda(["x"], plt.Bool(False))
5✔
3096
        else:
3097
            raise NotImplementedError(
3098
                f"Can not derive bool from type {arg.typ.python_type()}"
3099
            )
3100

3101

3102
class BytesImpl(PolymorphicFunction):
5✔
3103
    def type_from_args(self, args: typing.List[Type]) -> FunctionType:
5✔
3104
        assert (
5✔
3105
            len(args) == 1
3106
        ), f"'bytes' takes only one argument, but {len(args)} were given"
3107
        typ = args[0]
5✔
3108
        assert isinstance(
5✔
3109
            typ, InstanceType
3110
        ), "Can only create bytes from instances, got ClassType"
3111
        assert any(
5✔
3112
            isinstance(typ.typ, t)
3113
            for t in (
3114
                IntegerType,
3115
                ByteStringType,
3116
                ListType,
3117
            )
3118
        ), f"Can only create bytes from int, bytes or integer lists, got {typ.python_type()}"
3119
        if isinstance(typ.typ, ListType):
5✔
3120
            assert (
5✔
3121
                typ.typ.typ == IntegerInstanceType
3122
            ), f"Can only create bytes from integer lists but got a list with another type {typ.python_type()}"
3123
        return FunctionType(args, ByteStringInstanceType)
5✔
3124

3125
    def impl_from_args(self, args: typing.List[Type]) -> plt.AST:
5✔
3126
        arg = args[0]
5✔
3127
        assert isinstance(
5✔
3128
            arg, InstanceType
3129
        ), "Can only create bytes from instances, got ClassType"
3130
        if isinstance(arg.typ, ByteStringType):
5✔
3131
            return OLambda(["x"], OVar("x"))
5✔
3132
        elif isinstance(arg.typ, IntegerType):
5✔
3133
            return OLambda(
5✔
3134
                ["x"],
3135
                plt.Ite(
3136
                    plt.LessThanInteger(OVar("x"), plt.Integer(0)),
3137
                    plt.TraceError("ValueError: negative count"),
3138
                    ByteStrIntMulImpl(plt.ByteString(b"\x00"), OVar("x")),
3139
                ),
3140
            )
3141
        elif isinstance(arg.typ, ListType):
5✔
3142
            return OLambda(
5✔
3143
                ["xs"],
3144
                plt.RFoldList(
3145
                    OVar("xs"),
3146
                    OLambda(["a", "x"], plt.ConsByteString(OVar("x"), OVar("a"))),
3147
                    plt.ByteString(b""),
3148
                ),
3149
            )
3150
        else:
3151
            raise NotImplementedError(
3152
                f"Can not derive bytes from type {arg.typ.python_type()}"
3153
            )
3154

3155

3156
@dataclass(frozen=True, unsafe_hash=True)
5✔
3157
class PolymorphicFunctionType(ClassType):
5✔
3158
    """A special type of builtin that may act differently on different parameters"""
3159

3160
    polymorphic_function: PolymorphicFunction
5✔
3161

3162
    def __ge__(self, other):
5✔
3163
        return (
5✔
3164
            isinstance(other, PolymorphicFunctionType)
3165
            and self.polymorphic_function == other.polymorphic_function
3166
        )
3167

3168
    def python_type(self):
5✔
3169
        return (
×
3170
            f"PolymorphicFunctionType({self.polymorphic_function.__class__.__name__})"
3171
        )
3172

3173

3174
@dataclass(frozen=True, unsafe_hash=True)
5✔
3175
class PolymorphicFunctionInstanceType(InstanceType):
5✔
3176
    typ: FunctionType
5✔
3177
    polymorphic_function: PolymorphicFunction
5✔
3178

3179
    def python_type(self):
5✔
3180
        return self.typ.python_type()
×
3181

3182

3183
EmptyListMap = {
5✔
3184
    IntegerInstanceType: plt.EmptyIntegerList(),
3185
    ByteStringInstanceType: plt.EmptyByteStringList(),
3186
    StringInstanceType: plt.EmptyTextList(),
3187
    UnitInstanceType: plt.EmptyUnitList(),
3188
    BoolInstanceType: plt.EmptyBoolList(),
3189
}
3190

3191

3192
def empty_list(p: Type):
5✔
3193
    if p in EmptyListMap:
5✔
3194
        return EmptyListMap[p]
5✔
3195
    assert isinstance(p, InstanceType), "Can only create lists of instances"
5✔
3196
    if isinstance(p.typ, ListType):
5✔
3197
        el = empty_list(p.typ.typ)
5✔
3198
        return plt.EmptyListList(uplc.BuiltinList([], el.sample_value))
5✔
3199
    if isinstance(p.typ, DictType):
5✔
3200
        return plt.EmptyListList(
5✔
3201
            uplc.BuiltinList(
3202
                [],
3203
                uplc.BuiltinPair(
3204
                    uplc.PlutusConstr(0, frozenlist([])),
3205
                    uplc.PlutusConstr(0, frozenlist([])),
3206
                ),
3207
            )
3208
        )
3209
    if (
5✔
3210
        isinstance(p.typ, RecordType)
3211
        or isinstance(p.typ, AnyType)
3212
        or isinstance(p.typ, UnionType)
3213
    ):
3214
        return plt.EmptyDataList()
5✔
3215
    raise NotImplementedError(f"Empty lists of type {p} can't be constructed yet")
3216

3217

3218
TransformExtParamsMap = {
5✔
3219
    IntegerInstanceType: lambda x: plt.UnIData(x),
3220
    ByteStringInstanceType: lambda x: plt.UnBData(x),
3221
    StringInstanceType: lambda x: plt.DecodeUtf8(plt.UnBData(x)),
3222
    UnitInstanceType: lambda x: plt.Apply(OLambda(["_"], plt.Unit())),
3223
    BoolInstanceType: lambda x: plt.NotEqualsInteger(plt.UnIData(x), plt.Integer(0)),
3224
}
3225

3226

3227
def transform_ext_params_map(p: Type):
5✔
3228
    assert isinstance(
5✔
3229
        p, InstanceType
3230
    ), "Can only transform instances, not classes as input"
3231
    if p in TransformExtParamsMap:
5✔
3232
        return TransformExtParamsMap[p]
5✔
3233
    if isinstance(p.typ, ListType):
5✔
3234
        list_int_typ = p.typ.typ
5✔
3235
        return lambda x: plt.MapList(
5✔
3236
            plt.UnListData(x),
3237
            OLambda(["x"], transform_ext_params_map(list_int_typ)(OVar("x"))),
3238
            empty_list(p.typ.typ),
3239
        )
3240
    if isinstance(p.typ, DictType):
5✔
3241
        # there doesn't appear to be a constructor function to make Pair a b for any types
3242
        # so pairs will always contain Data
3243
        return lambda x: plt.UnMapData(x)
5✔
3244
    return lambda x: x
5✔
3245

3246

3247
OUnit = plt.ConstrData(plt.Integer(0), plt.EmptyDataList())
5✔
3248

3249
TransformOutputMap = {
5✔
3250
    StringInstanceType: lambda x: plt.BData(plt.EncodeUtf8(x)),
3251
    IntegerInstanceType: lambda x: plt.IData(x),
3252
    ByteStringInstanceType: lambda x: plt.BData(x),
3253
    UnitInstanceType: lambda x: plt.Apply(OLambda(["_"], OUnit), x),
3254
    BoolInstanceType: lambda x: plt.IData(
3255
        plt.IfThenElse(x, plt.Integer(1), plt.Integer(0))
3256
    ),
3257
}
3258

3259

3260
def transform_output_map(p: Type):
5✔
3261
    assert isinstance(
5✔
3262
        p, InstanceType
3263
    ), "Can only transform instances, not classes as input"
3264
    if isinstance(p.typ, FunctionType) or isinstance(p.typ, PolymorphicFunction):
5✔
3265
        raise NotImplementedError(
3266
            "Can not map functions into PlutusData and hence not return them from a function as Anything"
3267
        )
3268
    if p in TransformOutputMap:
5✔
3269
        return TransformOutputMap[p]
5✔
3270
    if isinstance(p.typ, ListType):
5✔
3271
        list_int_typ = p.typ.typ
5✔
3272
        return lambda x: plt.ListData(
5✔
3273
            plt.MapList(
3274
                x,
3275
                OLambda(["x"], transform_output_map(list_int_typ)(OVar("x"))),
3276
            ),
3277
        )
3278
    if isinstance(p.typ, DictType):
5✔
3279
        # there doesn't appear to be a constructor function to make Pair a b for any types
3280
        # so pairs will always contain Data
3281
        return lambda x: plt.MapData(x)
5✔
3282
    return lambda x: x
5✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc