• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

OpShin / opshin / 806

pending completion
806

push

travis-ci-com

nielstron
Update docs

3519 of 3781 relevant lines covered (93.07%)

3.72 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

86.1
/opshin/typed_ast.py
1
import typing
4✔
2
from ast import *
4✔
3
from dataclasses import dataclass
4✔
4

5
from frozenlist import FrozenList
4✔
6

7
import pluthon as plt
4✔
8
import uplc.ast as uplc
4✔
9

10

11
def distinct(xs: list):
4✔
12
    """Returns true iff the list consists of distinct elements"""
13
    return len(xs) == len(set(xs))
4✔
14

15

16
def FrozenFrozenList(l: list):
4✔
17
    fl = FrozenList(l)
4✔
18
    fl.freeze()
4✔
19
    return fl
4✔
20

21

22
class Type:
4✔
23
    def constr_type(self) -> "InstanceType":
4✔
24
        """The type of the constructor for this class"""
25
        raise TypeInferenceError(
×
26
            f"Object of type {self.__class__} does not have a constructor"
27
        )
28

29
    def constr(self) -> plt.AST:
4✔
30
        """The constructor for this class"""
31
        raise NotImplementedError(f"Constructor of {self.__class__} not implemented")
4✔
32

33
    def attribute_type(self, attr) -> "Type":
4✔
34
        """The types of the named attributes of this class"""
35
        raise TypeInferenceError(
×
36
            f"Object of type {self.__class__} does not have attribute {attr}"
37
        )
38

39
    def attribute(self, attr) -> plt.AST:
4✔
40
        """The attributes of this class. Needs to be a lambda that expects as first argument the object itself"""
41
        raise NotImplementedError(f"Attribute {attr} not implemented for type {self}")
×
42

43
    def cmp(self, op: cmpop, o: "Type") -> plt.AST:
4✔
44
        """The implementation of comparing this type to type o via operator op. Returns a lambda that expects as first argument the object itself and as second the comparison."""
45
        raise NotImplementedError(
×
46
            f"Comparison {type(op).__name__} for {self.__class__.__name__} and {o.__class__.__name__} is not implemented. This is likely intended because it would always evaluate to False."
47
        )
48

49

50
@dataclass(frozen=True, unsafe_hash=True)
4✔
51
class Record:
4✔
52
    name: str
4✔
53
    constructor: int
4✔
54
    fields: typing.Union[typing.List[typing.Tuple[str, Type]], FrozenList]
4✔
55

56

57
@dataclass(frozen=True, unsafe_hash=True)
4✔
58
class ClassType(Type):
4✔
59
    def __ge__(self, other):
4✔
60
        raise NotImplementedError("Comparison between raw classtypes impossible")
×
61

62

63
@dataclass(frozen=True, unsafe_hash=True)
4✔
64
class AnyType(ClassType):
4✔
65
    """The top element in the partial order on types (excluding FunctionTypes, which do not compare to anything)"""
66

67
    def __ge__(self, other):
4✔
68
        return (
4✔
69
            isinstance(other, ClassType)
70
            and not isinstance(other, FunctionType)
71
            and not isinstance(other, PolymorphicFunctionType)
72
        )
73

74

75
@dataclass(frozen=True, unsafe_hash=True)
4✔
76
class AtomicType(ClassType):
4✔
77
    def __ge__(self, other):
4✔
78
        # Can only substitute for its own type (also subtypes)
79
        return isinstance(other, self.__class__)
4✔
80

81

82
@dataclass(frozen=True, unsafe_hash=True)
4✔
83
class RecordType(ClassType):
4✔
84
    record: Record
4✔
85

86
    def constr_type(self) -> "InstanceType":
4✔
87
        return InstanceType(
4✔
88
            FunctionType([f[1] for f in self.record.fields], InstanceType(self))
89
        )
90

91
    def constr(self) -> plt.AST:
4✔
92
        # wrap all constructor values to PlutusData
93
        build_constr_params = plt.EmptyDataList()
4✔
94
        for n, t in reversed(self.record.fields):
4✔
95
            build_constr_params = plt.MkCons(
4✔
96
                transform_output_map(t)(plt.Var(n)), build_constr_params
97
            )
98
        # then build a constr type with this PlutusData
99
        return plt.Lambda(
4✔
100
            [n for n, _ in self.record.fields] + ["_"],
101
            plt.ConstrData(plt.Integer(self.record.constructor), build_constr_params),
102
        )
103

104
    def attribute_type(self, attr: str) -> Type:
4✔
105
        """The types of the named attributes of this class"""
106
        if attr == "CONSTR_ID":
4✔
107
            return IntegerInstanceType
×
108
        for n, t in self.record.fields:
4✔
109
            if n == attr:
4✔
110
                return t
4✔
111
        raise TypeInferenceError(
×
112
            f"Type {self.record.name} does not have attribute {attr}"
113
        )
114

115
    def attribute(self, attr: str) -> plt.AST:
4✔
116
        """The attributes of this class. Need to be a lambda that expects as first argument the object itself"""
117
        if attr == "CONSTR_ID":
4✔
118
            # access to constructor
119
            return plt.Lambda(
×
120
                ["self"],
121
                plt.Constructor(plt.Var("self")),
122
            )
123
        attr_typ = self.attribute_type(attr)
4✔
124
        pos = next(i for i, (n, _) in enumerate(self.record.fields) if n == attr)
4✔
125
        # access to normal fields
126
        return plt.Lambda(
4✔
127
            ["self"],
128
            transform_ext_params_map(attr_typ)(
129
                plt.NthField(
130
                    plt.Var("self"),
131
                    plt.Integer(pos),
132
                ),
133
            ),
134
        )
135

136
    def cmp(self, op: cmpop, o: "Type") -> plt.AST:
4✔
137
        """The implementation of comparing this type to type o via operator op. Returns a lambda that expects as first argument the object itself and as second the comparison."""
138
        # this will reject comparisons that will always be false - most likely due to faults during programming
139
        if (isinstance(o, RecordType) and o.record == self.record) or (
4✔
140
            isinstance(o, UnionType) and self in o.typs
141
        ):
142
            if isinstance(op, Eq):
4✔
143
                return plt.BuiltIn(uplc.BuiltInFun.EqualsData)
4✔
144
            if isinstance(op, NotEq):
×
145
                return plt.Lambda(
×
146
                    ["x", "y"],
147
                    plt.Not(
148
                        plt.Apply(
149
                            plt.BuiltIn(uplc.BuiltInFun.EqualsData),
150
                            plt.Var("x"),
151
                            plt.Var("y"),
152
                        )
153
                    ),
154
                )
155
        if (
4✔
156
            isinstance(o, ListType)
157
            and isinstance(o.typ, InstanceType)
158
            and o.typ.typ >= self
159
        ):
160
            if isinstance(op, In):
4✔
161
                return plt.Lambda(
4✔
162
                    ["x", "y"],
163
                    plt.EqualsData(
164
                        plt.Var("x"),
165
                        plt.FindList(
166
                            plt.Var("y"),
167
                            plt.Apply(
168
                                plt.BuiltIn(uplc.BuiltInFun.EqualsData), plt.Var("x")
169
                            ),
170
                            # this simply ensures the default is always unequal to the searched value
171
                            plt.ConstrData(
172
                                plt.AddInteger(
173
                                    plt.Constructor(plt.Var("x")), plt.Integer(1)
174
                                ),
175
                                plt.MkNilData(plt.Unit()),
176
                            ),
177
                        ),
178
                    ),
179
                )
180
        return super().cmp(op, o)
×
181

182
    def __ge__(self, other):
4✔
183
        # Can only substitute for its own type, records need to be equal
184
        # if someone wants to be funny, they can implement <= to be true if all fields match up to some point
185
        return isinstance(other, self.__class__) and other.record == self.record
4✔
186

187

188
@dataclass(frozen=True, unsafe_hash=True)
4✔
189
class UnionType(ClassType):
4✔
190
    typs: typing.List[RecordType]
4✔
191

192
    def attribute_type(self, attr) -> "Type":
4✔
193
        if attr == "CONSTR_ID":
4✔
194
            return IntegerInstanceType
4✔
195
        # iterate through all names/types of the unioned records by position
196
        for attr_names, attr_types in map(
4✔
197
            lambda x: zip(*x), zip(*(t.record.fields for t in self.typs))
198
        ):
199
            # need to have a common field with the same name, in the same position!
200
            if any(attr_name != attr for attr_name in attr_names):
4✔
201
                continue
×
202
            for at in attr_types:
4✔
203
                # return the maximum element if there is one
204
                if all(at >= at2 for at2 in attr_types):
4✔
205
                    return at
4✔
206
            # return the union type of all possible instantiations if all possible values are record types
207
            if all(
4✔
208
                isinstance(at, InstanceType) and isinstance(at.typ, RecordType)
209
                for at in attr_types
210
            ) and distinct([at.typ.record.constructor for at in attr_types]):
211
                return InstanceType(
4✔
212
                    UnionType(FrozenFrozenList([at.typ for at in attr_types]))
213
                )
214
            # return Anytype
215
            return InstanceType(AnyType())
4✔
216
        raise TypeInferenceError(
×
217
            f"Can not access attribute {attr} of Union type. Cast to desired type with an 'if isinstance(_, _):' branch."
218
        )
219

220
    def attribute(self, attr: str) -> plt.AST:
4✔
221
        if attr == "CONSTR_ID":
4✔
222
            # access to constructor
223
            return plt.Lambda(
4✔
224
                ["self"],
225
                plt.Constructor(plt.Var("self")),
226
            )
227
        # iterate through all names/types of the unioned records by position
228
        attr_typ = self.attribute_type(attr)
4✔
229
        pos = next(
4✔
230
            i
231
            for i, (ns, _) in enumerate(
232
                map(lambda x: zip(*x), zip(*(t.record.fields for t in self.typs)))
233
            )
234
            if all(n == attr for n in ns)
235
        )
236
        # access to normal fields
237
        return plt.Lambda(
4✔
238
            ["self"],
239
            transform_ext_params_map(attr_typ)(
240
                plt.NthField(
241
                    plt.Var("self"),
242
                    plt.Integer(pos),
243
                ),
244
            ),
245
        )
246

247
    def __ge__(self, other):
4✔
248
        if isinstance(other, UnionType):
4✔
249
            return all(any(t >= ot for ot in other.typs) for t in self.typs)
4✔
250
        return any(t >= other for t in self.typs)
4✔
251

252
    def cmp(self, op: cmpop, o: "Type") -> plt.AST:
4✔
253
        """The implementation of comparing this type to type o via operator op. Returns a lambda that expects as first argument the object itself and as second the comparison."""
254
        # this will reject comparisons that will always be false - most likely due to faults during programming
255
        # note we require that there is an overlapt between the possible types for unions
256
        if (isinstance(o, RecordType) and o in self.typs) or (
4✔
257
            isinstance(o, UnionType) and set(self.typs).intersection(o.typs)
258
        ):
259
            if isinstance(op, Eq):
4✔
260
                return plt.BuiltIn(uplc.BuiltInFun.EqualsData)
4✔
261
            if isinstance(op, NotEq):
×
262
                return plt.Lambda(
×
263
                    ["x", "y"],
264
                    plt.Not(
265
                        plt.Apply(
266
                            plt.BuiltIn(uplc.BuiltInFun.EqualsData),
267
                            plt.Var("x"),
268
                            plt.Var("y"),
269
                        )
270
                    ),
271
                )
272
        raise NotImplementedError(
×
273
            f"Can not compare {o} and {self} with operation {op.__class__}. Note that comparisons that always return false are also rejected."
274
        )
275

276

277
@dataclass(frozen=True, unsafe_hash=True)
4✔
278
class TupleType(ClassType):
4✔
279
    typs: typing.List[Type]
4✔
280

281
    def __ge__(self, other):
4✔
282
        return isinstance(other, TupleType) and all(
×
283
            t >= ot for t, ot in zip(self.typs, other.typs)
284
        )
285

286

287
@dataclass(frozen=True, unsafe_hash=True)
4✔
288
class PairType(ClassType):
4✔
289
    """An internal type representing built-in PlutusData pairs"""
290

291
    l_typ: Type
4✔
292
    r_typ: Type
4✔
293

294
    def __ge__(self, other):
4✔
295
        return isinstance(other, PairType) and all(
×
296
            t >= ot
297
            for t, ot in zip((self.l_typ, self.r_typ), (other.l_typ, other.r_typ))
298
        )
299

300

301
@dataclass(frozen=True, unsafe_hash=True)
4✔
302
class ListType(ClassType):
4✔
303
    typ: Type
4✔
304

305
    def __ge__(self, other):
4✔
306
        return isinstance(other, ListType) and self.typ >= other.typ
4✔
307

308

309
@dataclass(frozen=True, unsafe_hash=True)
4✔
310
class DictType(ClassType):
4✔
311
    key_typ: Type
4✔
312
    value_typ: Type
4✔
313

314
    def attribute_type(self, attr) -> "Type":
4✔
315
        if attr == "get":
4✔
316
            return InstanceType(
4✔
317
                FunctionType([self.key_typ, self.value_typ], self.value_typ)
318
            )
319
        if attr == "keys":
4✔
320
            return InstanceType(FunctionType([], InstanceType(ListType(self.key_typ))))
4✔
321
        if attr == "values":
4✔
322
            return InstanceType(
4✔
323
                FunctionType([], InstanceType(ListType(self.value_typ)))
324
            )
325
        if attr == "items":
4✔
326
            return InstanceType(
4✔
327
                FunctionType(
328
                    [],
329
                    InstanceType(
330
                        ListType(InstanceType(PairType(self.key_typ, self.value_typ)))
331
                    ),
332
                )
333
            )
334
        raise TypeInferenceError(
×
335
            f"Type of attribute '{attr}' is unknown for type Dict."
336
        )
337

338
    def attribute(self, attr) -> plt.AST:
4✔
339
        if attr == "get":
4✔
340
            return plt.Lambda(
4✔
341
                ["self", "key", "default", "_"],
342
                transform_ext_params_map(self.value_typ)(
343
                    plt.SndPair(
344
                        plt.FindList(
345
                            plt.Var("self"),
346
                            plt.Lambda(
347
                                ["x"],
348
                                plt.EqualsData(
349
                                    transform_output_map(self.key_typ)(plt.Var("key")),
350
                                    plt.FstPair(plt.Var("x")),
351
                                ),
352
                            ),
353
                            # this is a bit ugly... we wrap - only to later unwrap again
354
                            plt.MkPairData(
355
                                transform_output_map(self.key_typ)(plt.Var("key")),
356
                                transform_output_map(self.value_typ)(
357
                                    plt.Var("default")
358
                                ),
359
                            ),
360
                        ),
361
                    ),
362
                ),
363
            )
364
        if attr == "keys":
4✔
365
            return plt.Lambda(
4✔
366
                ["self", "_"],
367
                plt.MapList(
368
                    plt.Var("self"),
369
                    plt.Lambda(
370
                        ["x"],
371
                        transform_ext_params_map(self.key_typ)(
372
                            plt.FstPair(plt.Var("x"))
373
                        ),
374
                    ),
375
                    empty_list(self.key_typ),
376
                ),
377
            )
378
        if attr == "values":
4✔
379
            return plt.Lambda(
4✔
380
                ["self", "_"],
381
                plt.MapList(
382
                    plt.Var("self"),
383
                    plt.Lambda(
384
                        ["x"],
385
                        transform_ext_params_map(self.value_typ)(
386
                            plt.SndPair(plt.Var("x"))
387
                        ),
388
                    ),
389
                    empty_list(self.value_typ),
390
                ),
391
            )
392
        if attr == "items":
4✔
393
            return plt.Lambda(
4✔
394
                ["self", "_"],
395
                plt.Var("self"),
396
            )
397
        raise NotImplementedError(f"Attribute '{attr}' of Dict is unknown.")
×
398

399
    def __ge__(self, other):
4✔
400
        return (
4✔
401
            isinstance(other, DictType)
402
            and self.key_typ >= other.key_typ
403
            and self.value_typ >= other.value_typ
404
        )
405

406

407
@dataclass(frozen=True, unsafe_hash=True)
4✔
408
class FunctionType(ClassType):
4✔
409
    argtyps: typing.List[Type]
4✔
410
    rettyp: Type
4✔
411

412
    def __ge__(self, other):
4✔
413
        return (
×
414
            isinstance(other, FunctionType)
415
            and all(a >= oa for a, oa in zip(self.argtyps, other.argtyps))
416
            and other.rettyp >= self.rettyp
417
        )
418

419

420
@dataclass(frozen=True, unsafe_hash=True)
4✔
421
class InstanceType(Type):
4✔
422
    typ: ClassType
4✔
423

424
    def constr_type(self) -> FunctionType:
4✔
425
        raise TypeInferenceError(f"Can not construct an instance {self}")
×
426

427
    def constr(self) -> plt.AST:
4✔
428
        raise NotImplementedError(f"Can not construct an instance {self}")
×
429

430
    def attribute_type(self, attr) -> Type:
4✔
431
        return self.typ.attribute_type(attr)
4✔
432

433
    def attribute(self, attr) -> plt.AST:
4✔
434
        return self.typ.attribute(attr)
4✔
435

436
    def cmp(self, op: cmpop, o: "Type") -> plt.AST:
4✔
437
        """The implementation of comparing this type to type o via operator op. Returns a lambda that expects as first argument the object itself and as second the comparison."""
438
        if isinstance(o, InstanceType):
4✔
439
            return self.typ.cmp(op, o.typ)
4✔
440
        return super().cmp(op, o)
×
441

442
    def __ge__(self, other):
4✔
443
        return isinstance(other, InstanceType) and self.typ >= other.typ
4✔
444

445

446
@dataclass(frozen=True, unsafe_hash=True)
4✔
447
class IntegerType(AtomicType):
4✔
448
    def constr_type(self) -> InstanceType:
4✔
449
        return InstanceType(FunctionType([StringInstanceType], InstanceType(self)))
4✔
450

451
    def constr(self) -> plt.AST:
4✔
452
        # TODO we need to strip the string implicitely before parsing it
453
        return plt.Lambda(
4✔
454
            ["x", "_"],
455
            plt.Let(
456
                [
457
                    ("e", plt.EncodeUtf8(plt.Var("x"))),
458
                    ("first_int", plt.IndexByteString(plt.Var("e"), plt.Integer(0))),
459
                    ("len", plt.LengthOfByteString(plt.Var("e"))),
460
                    (
461
                        "fold_start",
462
                        plt.Lambda(
463
                            ["start"],
464
                            plt.FoldList(
465
                                plt.Range(plt.Var("len"), plt.Var("start")),
466
                                plt.Lambda(
467
                                    ["s", "i"],
468
                                    plt.Let(
469
                                        [
470
                                            (
471
                                                "b",
472
                                                plt.IndexByteString(
473
                                                    plt.Var("e"), plt.Var("i")
474
                                                ),
475
                                            )
476
                                        ],
477
                                        plt.Ite(
478
                                            plt.EqualsInteger(
479
                                                plt.Var("b"), plt.Integer(ord("_"))
480
                                            ),
481
                                            plt.Var("s"),
482
                                            plt.Ite(
483
                                                plt.Or(
484
                                                    plt.LessThanInteger(
485
                                                        plt.Var("b"),
486
                                                        plt.Integer(ord("0")),
487
                                                    ),
488
                                                    plt.LessThanInteger(
489
                                                        plt.Integer(ord("9")),
490
                                                        plt.Var("b"),
491
                                                    ),
492
                                                ),
493
                                                plt.TraceError(
494
                                                    "ValueError: invalid literal for int() with base 10"
495
                                                ),
496
                                                plt.AddInteger(
497
                                                    plt.SubtractInteger(
498
                                                        plt.Var("b"),
499
                                                        plt.Integer(ord("0")),
500
                                                    ),
501
                                                    plt.MultiplyInteger(
502
                                                        plt.Var("s"), plt.Integer(10)
503
                                                    ),
504
                                                ),
505
                                            ),
506
                                        ),
507
                                    ),
508
                                ),
509
                                plt.Integer(0),
510
                            ),
511
                        ),
512
                    ),
513
                ],
514
                plt.Ite(
515
                    plt.Or(
516
                        plt.EqualsInteger(plt.Var("len"), plt.Integer(0)),
517
                        plt.EqualsInteger(
518
                            plt.Var("first_int"),
519
                            plt.Integer(ord("_")),
520
                        ),
521
                    ),
522
                    plt.TraceError(
523
                        "ValueError: invalid literal for int() with base 10"
524
                    ),
525
                    plt.Ite(
526
                        plt.EqualsInteger(
527
                            plt.Var("first_int"),
528
                            plt.Integer(ord("-")),
529
                        ),
530
                        plt.Negate(
531
                            plt.Apply(plt.Var("fold_start"), plt.Integer(1)),
532
                        ),
533
                        plt.Apply(plt.Var("fold_start"), plt.Integer(0)),
534
                    ),
535
                ),
536
            ),
537
        )
538

539
    def cmp(self, op: cmpop, o: "Type") -> plt.AST:
4✔
540
        """The implementation of comparing this type to type o via operator op. Returns a lambda that expects as first argument the object itself and as second the comparison."""
541
        if isinstance(o, BoolType):
4✔
542
            if isinstance(op, Eq):
×
543
                # 1 == True
544
                # 0 == False
545
                # all other comparisons are False
546
                return plt.Lambda(
×
547
                    ["x", "y"],
548
                    plt.Ite(
549
                        plt.Var("y"),
550
                        plt.EqualsInteger(plt.Var("x"), plt.Integer(1)),
551
                        plt.EqualsInteger(plt.Var("x"), plt.Integer(0)),
552
                    ),
553
                )
554
        if isinstance(o, IntegerType):
4✔
555
            if isinstance(op, Eq):
4✔
556
                return plt.BuiltIn(uplc.BuiltInFun.EqualsInteger)
4✔
557
            if isinstance(op, NotEq):
4✔
558
                return plt.Lambda(
×
559
                    ["x", "y"],
560
                    plt.Not(
561
                        plt.Apply(
562
                            plt.BuiltIn(uplc.BuiltInFun.EqualsInteger),
563
                            plt.Var("y"),
564
                            plt.Var("x"),
565
                        )
566
                    ),
567
                )
568
            if isinstance(op, LtE):
4✔
569
                return plt.BuiltIn(uplc.BuiltInFun.LessThanEqualsInteger)
×
570
            if isinstance(op, Lt):
4✔
571
                return plt.BuiltIn(uplc.BuiltInFun.LessThanInteger)
4✔
572
            if isinstance(op, Gt):
4✔
573
                return plt.Lambda(
×
574
                    ["x", "y"],
575
                    plt.Apply(
576
                        plt.BuiltIn(uplc.BuiltInFun.LessThanInteger),
577
                        plt.Var("y"),
578
                        plt.Var("x"),
579
                    ),
580
                )
581
            if isinstance(op, GtE):
4✔
582
                return plt.Lambda(
4✔
583
                    ["x", "y"],
584
                    plt.Apply(
585
                        plt.BuiltIn(uplc.BuiltInFun.LessThanEqualsInteger),
586
                        plt.Var("y"),
587
                        plt.Var("x"),
588
                    ),
589
                )
590
        if (
4✔
591
            isinstance(o, ListType)
592
            and isinstance(o.typ, InstanceType)
593
            and isinstance(o.typ.typ, IntegerType)
594
        ):
595
            if isinstance(op, In):
4✔
596
                return plt.Lambda(
4✔
597
                    ["x", "y"],
598
                    plt.EqualsInteger(
599
                        plt.Var("x"),
600
                        plt.FindList(
601
                            plt.Var("y"),
602
                            plt.Apply(
603
                                plt.BuiltIn(uplc.BuiltInFun.EqualsInteger), plt.Var("x")
604
                            ),
605
                            # this simply ensures the default is always unequal to the searched value
606
                            plt.AddInteger(plt.Var("x"), plt.Integer(1)),
607
                        ),
608
                    ),
609
                )
610
        return super().cmp(op, o)
×
611

612

613
@dataclass(frozen=True, unsafe_hash=True)
4✔
614
class StringType(AtomicType):
4✔
615
    def constr_type(self) -> InstanceType:
4✔
616
        return InstanceType(FunctionType([IntegerInstanceType], InstanceType(self)))
4✔
617

618
    def constr(self) -> plt.AST:
4✔
619
        # constructs a string representation of an integer
620
        return plt.Lambda(
4✔
621
            ["x", "_"],
622
            plt.DecodeUtf8(
623
                plt.Let(
624
                    [
625
                        (
626
                            "strlist",
627
                            plt.RecFun(
628
                                plt.Lambda(
629
                                    ["f", "i"],
630
                                    plt.Ite(
631
                                        plt.LessThanEqualsInteger(
632
                                            plt.Var("i"), plt.Integer(0)
633
                                        ),
634
                                        plt.EmptyIntegerList(),
635
                                        plt.MkCons(
636
                                            plt.AddInteger(
637
                                                plt.ModInteger(
638
                                                    plt.Var("i"), plt.Integer(10)
639
                                                ),
640
                                                plt.Integer(ord("0")),
641
                                            ),
642
                                            plt.Apply(
643
                                                plt.Var("f"),
644
                                                plt.Var("f"),
645
                                                plt.DivideInteger(
646
                                                    plt.Var("i"), plt.Integer(10)
647
                                                ),
648
                                            ),
649
                                        ),
650
                                    ),
651
                                ),
652
                            ),
653
                        ),
654
                        (
655
                            "mkstr",
656
                            plt.Lambda(
657
                                ["i"],
658
                                plt.FoldList(
659
                                    plt.Apply(plt.Var("strlist"), plt.Var("i")),
660
                                    plt.Lambda(
661
                                        ["b", "i"],
662
                                        plt.ConsByteString(plt.Var("i"), plt.Var("b")),
663
                                    ),
664
                                    plt.ByteString(b""),
665
                                ),
666
                            ),
667
                        ),
668
                    ],
669
                    plt.Ite(
670
                        plt.EqualsInteger(plt.Var("x"), plt.Integer(0)),
671
                        plt.ByteString(b"0"),
672
                        plt.Ite(
673
                            plt.LessThanInteger(plt.Var("x"), plt.Integer(0)),
674
                            plt.ConsByteString(
675
                                plt.Integer(ord("-")),
676
                                plt.Apply(plt.Var("mkstr"), plt.Negate(plt.Var("x"))),
677
                            ),
678
                            plt.Apply(plt.Var("mkstr"), plt.Var("x")),
679
                        ),
680
                    ),
681
                )
682
            ),
683
        )
684

685
    def attribute_type(self, attr) -> Type:
4✔
686
        if attr == "encode":
4✔
687
            return InstanceType(FunctionType([], ByteStringInstanceType))
4✔
688
        return super().attribute_type(attr)
×
689

690
    def attribute(self, attr) -> plt.AST:
4✔
691
        if attr == "encode":
4✔
692
            # No codec -> only the default (utf8) is allowed
693
            return plt.Lambda(["x", "_"], plt.EncodeUtf8(plt.Var("x")))
4✔
694
        return super().attribute(attr)
×
695

696
    def cmp(self, op: cmpop, o: "Type") -> plt.AST:
4✔
697
        if isinstance(o, StringType):
4✔
698
            if isinstance(op, Eq):
4✔
699
                return plt.BuiltIn(uplc.BuiltInFun.EqualsString)
4✔
700
        return super().cmp(op, o)
×
701

702

703
@dataclass(frozen=True, unsafe_hash=True)
4✔
704
class ByteStringType(AtomicType):
4✔
705
    def constr_type(self) -> InstanceType:
4✔
706
        return InstanceType(
4✔
707
            FunctionType(
708
                [InstanceType(ListType(IntegerInstanceType))], InstanceType(self)
709
            )
710
        )
711

712
    def constr(self) -> plt.AST:
4✔
713
        return plt.Lambda(
4✔
714
            ["xs", "_"],
715
            plt.RFoldList(
716
                plt.Var("xs"),
717
                plt.Lambda(["a", "x"], plt.ConsByteString(plt.Var("x"), plt.Var("a"))),
718
                plt.ByteString(b""),
719
            ),
720
        )
721

722
    def attribute_type(self, attr) -> Type:
4✔
723
        if attr == "decode":
4✔
724
            return InstanceType(FunctionType([], StringInstanceType))
4✔
725
        return super().attribute_type(attr)
×
726

727
    def attribute(self, attr) -> plt.AST:
4✔
728
        if attr == "decode":
4✔
729
            # No codec -> only the default (utf8) is allowed
730
            return plt.Lambda(["x", "_"], plt.DecodeUtf8(plt.Var("x")))
4✔
731
        return super().attribute(attr)
×
732

733
    def cmp(self, op: cmpop, o: "Type") -> plt.AST:
4✔
734
        if isinstance(o, ByteStringType):
4✔
735
            if isinstance(op, Eq):
4✔
736
                return plt.BuiltIn(uplc.BuiltInFun.EqualsByteString)
4✔
737
            if isinstance(op, NotEq):
×
738
                return plt.Lambda(
×
739
                    ["x", "y"],
740
                    plt.Not(
741
                        plt.Apply(
742
                            plt.BuiltIn(uplc.BuiltInFun.EqualsByteString),
743
                            plt.Var("y"),
744
                            plt.Var("x"),
745
                        )
746
                    ),
747
                )
748
            if isinstance(op, Lt):
×
749
                return plt.BuiltIn(uplc.BuiltInFun.LessThanByteString)
×
750
            if isinstance(op, LtE):
×
751
                return plt.BuiltIn(uplc.BuiltInFun.LessThanEqualsByteString)
×
752
            if isinstance(op, Gt):
×
753
                return plt.Lambda(
×
754
                    ["x", "y"],
755
                    plt.Apply(
756
                        plt.BuiltIn(uplc.BuiltInFun.LessThanByteString),
757
                        plt.Var("y"),
758
                        plt.Var("x"),
759
                    ),
760
                )
761
            if isinstance(op, GtE):
×
762
                return plt.Lambda(
×
763
                    ["x", "y"],
764
                    plt.Apply(
765
                        plt.BuiltIn(uplc.BuiltInFun.LessThanEqualsByteString),
766
                        plt.Var("y"),
767
                        plt.Var("x"),
768
                    ),
769
                )
770
        if (
4✔
771
            isinstance(o, ListType)
772
            and isinstance(o.typ, InstanceType)
773
            and isinstance(o.typ.typ, ByteStringType)
774
        ):
775
            if isinstance(op, In):
4✔
776
                return plt.Lambda(
4✔
777
                    ["x", "y"],
778
                    plt.EqualsByteString(
779
                        plt.Var("x"),
780
                        plt.FindList(
781
                            plt.Var("y"),
782
                            plt.Apply(
783
                                plt.BuiltIn(uplc.BuiltInFun.EqualsByteString),
784
                                plt.Var("x"),
785
                            ),
786
                            # this simply ensures the default is always unequal to the searched value
787
                            plt.ConsByteString(plt.Integer(0), plt.Var("x")),
788
                        ),
789
                    ),
790
                )
791
        return super().cmp(op, o)
×
792

793

794
@dataclass(frozen=True, unsafe_hash=True)
4✔
795
class BoolType(AtomicType):
4✔
796
    def constr_type(self) -> "InstanceType":
4✔
797
        return InstanceType(FunctionType([IntegerInstanceType], BoolInstanceType))
4✔
798

799
    def constr(self) -> plt.AST:
4✔
800
        # constructs a boolean from an integer
801
        return plt.Lambda(
4✔
802
            ["x", "_"], plt.Not(plt.EqualsInteger(plt.Var("x"), plt.Integer(0)))
803
        )
804

805
    def cmp(self, op: cmpop, o: "Type") -> plt.AST:
4✔
806
        if isinstance(o, IntegerType):
4✔
807
            if isinstance(op, Eq):
×
808
                # 1 == True
809
                # 0 == False
810
                # all other comparisons are False
811
                return plt.Lambda(
×
812
                    ["y", "x"],
813
                    plt.Ite(
814
                        plt.Var("y"),
815
                        plt.EqualsInteger(plt.Var("x"), plt.Integer(1)),
816
                        plt.EqualsInteger(plt.Var("x"), plt.Integer(0)),
817
                    ),
818
                )
819
        if isinstance(o, BoolType):
4✔
820
            if isinstance(op, Eq):
4✔
821
                return plt.Lambda(["x", "y"], plt.Iff(plt.Var("x"), plt.Var("y")))
4✔
822
        return super().cmp(op, o)
×
823

824

825
@dataclass(frozen=True, unsafe_hash=True)
4✔
826
class UnitType(AtomicType):
4✔
827
    def cmp(self, op: cmpop, o: "Type") -> plt.AST:
4✔
828
        if isinstance(o, UnitType):
×
829
            if isinstance(op, Eq):
×
830
                return plt.Lambda(["x", "y"], plt.Bool(True))
×
831
            if isinstance(op, NotEq):
×
832
                return plt.Lambda(["x", "y"], plt.Bool(False))
×
833
        return super().cmp(op, o)
×
834

835

836
IntegerInstanceType = InstanceType(IntegerType())
4✔
837
StringInstanceType = InstanceType(StringType())
4✔
838
ByteStringInstanceType = InstanceType(ByteStringType())
4✔
839
BoolInstanceType = InstanceType(BoolType())
4✔
840
UnitInstanceType = InstanceType(UnitType())
4✔
841

842
ATOMIC_TYPES = {
4✔
843
    int.__name__: IntegerType(),
844
    str.__name__: StringType(),
845
    bytes.__name__: ByteStringType(),
846
    "Unit": UnitType(),
847
    bool.__name__: BoolType(),
848
}
849

850

851
NoneInstanceType = UnitInstanceType
4✔
852

853

854
class InaccessibleType(ClassType):
4✔
855
    """A type that blocks overwriting of a function"""
856

857
    pass
4✔
858

859

860
class PolymorphicFunction:
4✔
861
    def type_from_args(self, args: typing.List[Type]) -> FunctionType:
4✔
862
        raise NotImplementedError()
×
863

864
    def impl_from_args(self, args: typing.List[Type]) -> plt.AST:
4✔
865
        raise NotImplementedError()
×
866

867

868
@dataclass(frozen=True, unsafe_hash=True)
4✔
869
class PolymorphicFunctionType(ClassType):
4✔
870
    """A special type of builtin that may act differently on different parameters"""
871

872
    polymorphic_function: PolymorphicFunction
4✔
873

874

875
@dataclass(frozen=True, unsafe_hash=True)
4✔
876
class PolymorphicFunctionInstanceType(InstanceType):
4✔
877
    typ: FunctionType
4✔
878
    polymorphic_function: PolymorphicFunction
4✔
879

880

881
class TypedAST(AST):
4✔
882
    typ: Type
4✔
883

884

885
class typedexpr(TypedAST, expr):
4✔
886
    pass
4✔
887

888

889
class typedstmt(TypedAST, stmt):
4✔
890
    # Statements always have type None
891
    typ = NoneInstanceType
4✔
892

893

894
class typedarg(TypedAST, arg):
4✔
895
    pass
4✔
896

897

898
class typedarguments(TypedAST, arguments):
4✔
899
    args: typing.List[typedarg]
4✔
900
    vararg: typing.Union[typedarg, None]
4✔
901
    kwonlyargs: typing.List[typedarg]
4✔
902
    kw_defaults: typing.List[typing.Union[typedexpr, None]]
4✔
903
    kwarg: typing.Union[typedarg, None]
4✔
904
    defaults: typing.List[typedexpr]
4✔
905

906

907
class TypedModule(typedstmt, Module):
4✔
908
    body: typing.List[typedstmt]
4✔
909

910

911
class TypedFunctionDef(typedstmt, FunctionDef):
4✔
912
    body: typing.List[typedstmt]
4✔
913
    args: arguments
4✔
914

915

916
class TypedIf(typedstmt, If):
4✔
917
    test: typedexpr
4✔
918
    body: typing.List[typedstmt]
4✔
919
    orelse: typing.List[typedstmt]
4✔
920

921

922
class TypedReturn(typedstmt, Return):
4✔
923
    value: typedexpr
4✔
924

925

926
class TypedExpression(typedexpr, Expression):
4✔
927
    body: typedexpr
4✔
928

929

930
class TypedCall(typedexpr, Call):
4✔
931
    func: typedexpr
4✔
932
    args: typing.List[typedexpr]
4✔
933

934

935
class TypedExpr(typedstmt, Expr):
4✔
936
    value: typedexpr
4✔
937

938

939
class TypedAssign(typedstmt, Assign):
4✔
940
    targets: typing.List[typedexpr]
4✔
941
    value: typedexpr
4✔
942

943

944
class TypedClassDef(typedstmt, ClassDef):
4✔
945
    class_typ: Type
4✔
946

947

948
class TypedAnnAssign(typedstmt, AnnAssign):
4✔
949
    target: typedexpr
4✔
950
    annotation: Type
4✔
951
    value: typedexpr
4✔
952

953

954
class TypedWhile(typedstmt, While):
4✔
955
    test: typedexpr
4✔
956
    body: typing.List[typedstmt]
4✔
957
    orelse: typing.List[typedstmt]
4✔
958

959

960
class TypedFor(typedstmt, For):
4✔
961
    target: typedexpr
4✔
962
    iter: typedexpr
4✔
963
    body: typing.List[typedstmt]
4✔
964
    orelse: typing.List[typedstmt]
4✔
965

966

967
class TypedPass(typedstmt, Pass):
4✔
968
    pass
4✔
969

970

971
class TypedName(typedexpr, Name):
4✔
972
    pass
4✔
973

974

975
class TypedConstant(TypedAST, Constant):
4✔
976
    pass
4✔
977

978

979
class TypedTuple(typedexpr, Tuple):
4✔
980
    pass
4✔
981

982

983
class TypedList(typedexpr, List):
4✔
984
    pass
4✔
985

986

987
class typedcomprehension(typedexpr, comprehension):
4✔
988
    target: typedexpr
4✔
989
    iter: typedexpr
4✔
990
    ifs: typing.List[typedexpr]
4✔
991

992

993
class TypedListComp(typedexpr, ListComp):
4✔
994
    generators: typing.List[typedcomprehension]
4✔
995
    elt: typedexpr
4✔
996

997

998
class TypedDict(typedexpr, Dict):
4✔
999
    pass
4✔
1000

1001

1002
class TypedIfExp(typedstmt, IfExp):
4✔
1003
    test: typedexpr
4✔
1004
    body: typedexpr
4✔
1005
    orelse: typedexpr
4✔
1006

1007

1008
class TypedCompare(typedexpr, Compare):
4✔
1009
    left: typedexpr
4✔
1010
    ops: typing.List[cmpop]
4✔
1011
    comparators: typing.List[typedexpr]
4✔
1012

1013

1014
class TypedBinOp(typedexpr, BinOp):
4✔
1015
    left: typedexpr
4✔
1016
    right: typedexpr
4✔
1017

1018

1019
class TypedBoolOp(typedexpr, BoolOp):
4✔
1020
    values: typing.List[typedexpr]
4✔
1021

1022

1023
class TypedUnaryOp(typedexpr, UnaryOp):
4✔
1024
    operand: typedexpr
4✔
1025

1026

1027
class TypedSubscript(typedexpr, Subscript):
4✔
1028
    value: typedexpr
4✔
1029

1030

1031
class TypedAttribute(typedexpr, Attribute):
4✔
1032
    value: typedexpr
4✔
1033
    pos: int
4✔
1034

1035

1036
class TypedAssert(typedstmt, Assert):
4✔
1037
    test: typedexpr
4✔
1038
    msg: typedexpr
4✔
1039

1040

1041
class RawPlutoExpr(typedexpr):
4✔
1042
    typ: Type
4✔
1043
    expr: plt.AST
4✔
1044

1045

1046
class TypeInferenceError(AssertionError):
4✔
1047
    pass
4✔
1048

1049

1050
EmptyListMap = {
4✔
1051
    IntegerInstanceType: plt.EmptyIntegerList(),
1052
    ByteStringInstanceType: plt.EmptyByteStringList(),
1053
    StringInstanceType: plt.EmptyTextList(),
1054
    UnitInstanceType: plt.EmptyUnitList(),
1055
    BoolInstanceType: plt.EmptyBoolList(),
1056
}
1057

1058

1059
def empty_list(p: Type):
4✔
1060
    if p in EmptyListMap:
4✔
1061
        return EmptyListMap[p]
4✔
1062
    assert isinstance(p, InstanceType), "Can only create lists of instances"
4✔
1063
    if isinstance(p.typ, ListType):
4✔
1064
        el = empty_list(p.typ.typ)
×
1065
        return plt.EmptyListList(uplc.BuiltinList([], el.sample_value))
×
1066
    if isinstance(p.typ, DictType):
4✔
1067
        return plt.EmptyListList(
×
1068
            uplc.BuiltinList(
1069
                [],
1070
                uplc.BuiltinPair(
1071
                    uplc.PlutusConstr(0, FrozenList([])),
1072
                    uplc.PlutusConstr(0, FrozenList([])),
1073
                ),
1074
            )
1075
        )
1076
    if isinstance(p.typ, RecordType) or isinstance(p.typ, AnyType):
4✔
1077
        return plt.EmptyDataList()
4✔
1078
    raise NotImplementedError(f"Empty lists of type {p} can't be constructed yet")
×
1079

1080

1081
TransformExtParamsMap = {
4✔
1082
    IntegerInstanceType: lambda x: plt.UnIData(x),
1083
    ByteStringInstanceType: lambda x: plt.UnBData(x),
1084
    StringInstanceType: lambda x: plt.DecodeUtf8(plt.UnBData(x)),
1085
    UnitInstanceType: lambda x: plt.Apply(plt.Lambda(["_"], plt.Unit())),
1086
    BoolInstanceType: lambda x: plt.NotEqualsInteger(plt.UnIData(x), plt.Integer(0)),
1087
}
1088

1089

1090
def transform_ext_params_map(p: Type):
4✔
1091
    assert isinstance(
4✔
1092
        p, InstanceType
1093
    ), "Can only transform instances, not classes as input"
1094
    if p in TransformExtParamsMap:
4✔
1095
        return TransformExtParamsMap[p]
4✔
1096
    if isinstance(p.typ, ListType):
4✔
1097
        list_int_typ = p.typ.typ
4✔
1098
        return lambda x: plt.MapList(
4✔
1099
            plt.UnListData(x),
1100
            plt.Lambda(["x"], transform_ext_params_map(list_int_typ)(plt.Var("x"))),
1101
            empty_list(p.typ.typ),
1102
        )
1103
    if isinstance(p.typ, DictType):
4✔
1104
        # there doesn't appear to be a constructor function to make Pair a b for any types
1105
        # so pairs will always contain Data
1106
        return lambda x: plt.UnMapData(x)
4✔
1107
    return lambda x: x
4✔
1108

1109

1110
TransformOutputMap = {
4✔
1111
    StringInstanceType: lambda x: plt.BData(plt.EncodeUtf8(x)),
1112
    IntegerInstanceType: lambda x: plt.IData(x),
1113
    ByteStringInstanceType: lambda x: plt.BData(x),
1114
    UnitInstanceType: lambda x: plt.Apply(
1115
        plt.Lambda(["_"], plt.ConstrData(plt.Integer(0), plt.EmptyDataList())), x
1116
    ),
1117
    BoolInstanceType: lambda x: plt.IData(
1118
        plt.IfThenElse(x, plt.Integer(1), plt.Integer(0))
1119
    ),
1120
}
1121

1122

1123
def transform_output_map(p: Type):
4✔
1124
    assert isinstance(
4✔
1125
        p, InstanceType
1126
    ), "Can only transform instances, not classes as input"
1127
    if isinstance(p.typ, FunctionType) or isinstance(p.typ, PolymorphicFunction):
4✔
1128
        raise NotImplementedError(
×
1129
            "Can not map functions into PlutusData and hence not return them from a function as Anything"
1130
        )
1131
    if p in TransformOutputMap:
4✔
1132
        return TransformOutputMap[p]
4✔
1133
    if isinstance(p.typ, ListType):
4✔
1134
        list_int_typ = p.typ.typ
4✔
1135
        return lambda x: plt.ListData(
4✔
1136
            plt.MapList(
1137
                x,
1138
                plt.Lambda(["x"], transform_output_map(list_int_typ)(plt.Var("x"))),
1139
            ),
1140
        )
1141
    if isinstance(p.typ, DictType):
4✔
1142
        # there doesn't appear to be a constructor function to make Pair a b for any types
1143
        # so pairs will always contain Data
1144
        return lambda x: plt.MapData(x)
4✔
1145
    return lambda x: x
4✔
1146

1147

1148
class TypedNodeTransformer(NodeTransformer):
4✔
1149
    def visit(self, node):
4✔
1150
        """Visit a node."""
1151
        node_class_name = node.__class__.__name__
4✔
1152
        if node_class_name.startswith("Typed"):
4✔
1153
            node_class_name = node_class_name[len("Typed") :]
4✔
1154
        method = "visit_" + node_class_name
4✔
1155
        visitor = getattr(self, method, self.generic_visit)
4✔
1156
        return visitor(node)
4✔
1157

1158

1159
class TypedNodeVisitor(NodeVisitor):
4✔
1160
    def visit(self, node):
4✔
1161
        """Visit a node."""
1162
        node_class_name = node.__class__.__name__
4✔
1163
        if node_class_name.startswith("Typed"):
4✔
1164
            node_class_name = node_class_name[len("Typed") :]
4✔
1165
        method = "visit_" + node_class_name
4✔
1166
        visitor = getattr(self, method, self.generic_visit)
4✔
1167
        return visitor(node)
4✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc