• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

OpShin / opshin / 18409503285

10 Oct 2025 02:25PM UTC coverage: 92.68% (-0.2%) from 92.835%
18409503285

push

github

nielstron
Version bump

1265 of 1480 branches covered (85.47%)

Branch coverage included in aggregate %.

4800 of 5064 relevant lines covered (94.79%)

2.84 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

86.96
/opshin/optimize/optimize_union_expansion.py
1
from _ast import BoolOp, Call, FunctionDef, If, UnaryOp
3✔
2
from ast import *
3✔
3
from typing import Any, List
3✔
4
from ..util import CompilingNodeTransformer
3✔
5
from copy import deepcopy
3✔
6

7
"""
3✔
8
Expand union types
9
"""
10

11

12
def type_to_suffix(typ: expr) -> str:
3✔
13
    try:
3✔
14
        raw = unparse(typ)
3✔
15
    except Exception:
×
16
        return "UnknownType"
×
17
    return (
3✔
18
        raw.replace(" ", "")
19
        .replace("__", "___")
20
        .replace("[", "_l_")
21
        .replace("]", "_r_")
22
        .replace(",", "_c_")
23
        .replace(".", "_d_")
24
    )
25

26

27
class RemoveDeadCode(CompilingNodeTransformer):
3✔
28
    def __init__(self, arg_types: dict[str, str]):
3✔
29
        self.arg_types = arg_types
3✔
30

31
    def visit_FunctionDef(self, node: FunctionDef) -> Any:
3✔
32
        node.body = self.visit_sequence(node.body)
3✔
33
        return node
3✔
34

35
    def visit_sequence(self, stmts):
3✔
36
        new_stmts = []
3✔
37
        for stmt in stmts:
3✔
38
            s = self.visit(stmt)
3✔
39
            if isinstance(s, If) and isinstance(s.test, Constant):
3✔
40
                if s.test.value:
3✔
41
                    new_stmts.extend(s.body)
3✔
42
                else:
43
                    new_stmts.extend(s.orelse)
3✔
44
            else:
45
                new_stmts.append(s)
3✔
46
        return new_stmts
3✔
47

48
    def visit_If(self, node: If) -> Any:
3✔
49
        """
50
        Common types for `ast.If.test`:
51

52
            ast.Name      - `if x:`                     (truthiness of a variable)
53
            ast.Constant  - `if True:`, `if 0:`         (literal truthy/falsy)
54
            ast.Call      - `if func()`, `isinstance()` (function call)
55
            ast.Compare   - `if x > 3:`                 (comparison)
56
            ast.BoolOp    - `if x and y:`               (`and` / `or` logic)
57
            ast.UnaryOp   - `if not x:`                 (negation, e.g. `not`)
58
            ast.BinOp     - `if x + y:`                 (binary operation)
59
            ast.Attribute - `if obj.ready:`             (attribute access)
60
            ast.Subscript - `if arr[0]:`                (indexing)
61
            ast.Lambda    - `if lambda x: x > 0:`       (lambda - rare)
62
            ast.IfExp     - `if a if cond else b:`      (ternary - rare)
63

64
            The most likely to be used are ast.Call (if isinstance(...)), ast.BoolOp (if isinstance(...) and/or isinstance(...)), and ast.UnaryOp (if not isinstance(...))
65
        """
66
        node.test = self.visit(node.test)
3✔
67
        node.body = self.visit_sequence(node.body)
3✔
68
        node.orelse = self.visit_sequence(node.orelse)
3✔
69
        return node
3✔
70

71
    def visit_Call(self, node: Call) -> Any:
3✔
72
        node = self.generic_visit(node)
3✔
73
        # Check if this is an isinstance(x, T) call
74
        if (
3✔
75
            isinstance(node.func, Name)
76
            and node.func.id == "isinstance"
77
            and len(node.args) == 2
78
        ):
79
            arg, typ = node.args
3✔
80
            if isinstance(arg, Name) and isinstance(typ, Name):
3!
81
                known_type = self.arg_types.get(arg.id)
3✔
82
                if known_type is not None:
3✔
83
                    typ_str = getattr(typ, "id", type_to_suffix(typ))
3✔
84
                    return Constant(value=(known_type == typ_str))
3✔
85

86
        return node
3✔
87

88
    def visit_BoolOp(self, node: BoolOp) -> Any:
3✔
89
        node.values = [self.visit(v) for v in node.values]
3✔
90
        # Check if all values are constants
91
        if all(isinstance(v, Constant) for v in node.values):
3!
92
            values = [bool(v.value) for v in node.values]
×
93
            if isinstance(node.op, And):
×
94
                return Constant(value=all(values))
×
95
            elif isinstance(node.op, Or):
×
96
                return Constant(value=any(values))
×
97

98
        # Partial simplification: drop neutral constants
99
        # e.g. in `x or True`, return Constant(True)
100
        # e.g. in `x and False`, return Constant(False)
101
        if isinstance(node.op, And):
3✔
102
            for v in node.values:
3✔
103
                if isinstance(v, Constant) and not v.value:
3✔
104
                    return Constant(value=False)  # short-circuit
3✔
105
            node.values = [
3✔
106
                v for v in node.values if not (isinstance(v, Constant) and v.value)
107
            ]
108
        elif isinstance(node.op, Or):
3!
109
            for v in node.values:
3✔
110
                if isinstance(v, Constant) and v.value:
3✔
111
                    return Constant(value=True)  # short-circuit
3✔
112
            node.values = [
3✔
113
                v for v in node.values if not (isinstance(v, Constant) and not v.value)
114
            ]
115
        # If only one value remains, return it directly
116
        if len(node.values) == 1:
3!
117
            return node.values[0]
3✔
118
        return node
×
119

120
    def visit_UnaryOp(self, node: UnaryOp) -> Any:
3✔
121
        node.operand = self.visit(node.operand)
3✔
122

123
        # Only handle 'not' operations for now
124
        if isinstance(node.op, Not):
3!
125
            # If it's `not <constant>`, simplify it
126
            if isinstance(node.operand, Constant):
3!
127
                return Constant(value=not bool(node.operand.value))
3✔
128

129
        return node
×
130

131
    def visit_IfExp(self, node: IfExp) -> Any:
3✔
132
        node.test = self.visit(node.test)
3✔
133
        node.body = self.visit(node.body)
3✔
134
        node.orelse = self.visit(node.orelse)
3✔
135

136
        # Simplify if the test condition is a constant
137
        if isinstance(node.test, Constant):
3!
138
            if node.test.value:
3✔
139
                return node.body
3✔
140
            else:
141
                return node.orelse
3✔
142

143
        return node
×
144

145

146
class OptimizeUnionExpansion(CompilingNodeTransformer):
3✔
147
    step = "Expanding Unions"
3✔
148

149
    def visit(self, node):
3✔
150
        if hasattr(node, "body") and isinstance(node.body, list):
3✔
151
            node.body = self.visit_sequence(node.body)
3✔
152
        if hasattr(node, "orelse") and isinstance(node.orelse, list):
3✔
153
            node.orelse = self.visit_sequence(node.orelse)
3✔
154
        if hasattr(node, "finalbody") and isinstance(node.finalbody, list):
3!
155
            node.finalbody = self.visit_sequence(node.finalbody)
×
156
        return super().visit(node)
3✔
157

158
    def is_Union_annotation(self, ann: expr):
3✔
159
        if isinstance(ann, Subscript) and isinstance(ann.value, Name):
3✔
160
            if ann.value.id == "Union":
3!
161
                return ann.slice.elts
3✔
162
        return False
3✔
163

164
    def split_functions(
3✔
165
        self, stmt: FunctionDef, args: list, arg_types: dict, naming=""
166
    ) -> List[FunctionDef]:
167
        """
168
                Recursively generate variants of a function with all possible combinations
169
        of expanded union types for its arguments.
170
        """
171
        new_functions = []
3✔
172
        for i, arg in enumerate(args):
3✔
173
            if not arg:
3✔
174
                continue
3✔
175
            n_args = deepcopy(args)
3✔
176
            n_args[i] = False
3✔
177
            for typ in arg:
3✔
178
                new_f = deepcopy(stmt)
3✔
179
                new_f.args.args[i].annotation = typ
3✔
180
                typ_str = getattr(typ, "id", type_to_suffix(typ))
3✔
181
                new_f.name = f"{naming}_{typ_str}"
3✔
182
                new_arg_types = deepcopy(arg_types)
3✔
183
                new_arg_types[stmt.args.args[i].arg] = typ_str
3✔
184
                new_f = RemoveDeadCode(new_arg_types).visit(new_f)
3✔
185
                new_functions.append(new_f)
3✔
186
                new_functions.extend(
3✔
187
                    self.split_functions(new_f, n_args, new_arg_types, new_f.name)
188
                )
189
            # Look for variation where this arg is still Union
190
            new_functions.extend(
3✔
191
                self.split_functions(stmt, n_args, arg_types, f"{naming}_Union")
192
            )
193
            # Handle only one Union per recursion level
194
            break
3✔
195

196
        return new_functions
3✔
197

198
    def visit_sequence(self, body):
3✔
199
        new_body = []
3✔
200
        for stmt in body:
3✔
201
            new_body.append(stmt)
3✔
202
            if isinstance(stmt, FunctionDef):
3✔
203
                args = [
3✔
204
                    self.is_Union_annotation(arg.annotation) for arg in stmt.args.args
205
                ]
206
                # number prefix here should guarantee naming uniqueness
207
                new_funcs = self.split_functions(stmt, args, {}, stmt.name + "+")
3✔
208
                # track variants
209
                new_body[-1].expanded_variants = [f.name for f in new_funcs]
3✔
210
                new_body.extend(new_funcs)
3✔
211
        return new_body
3✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc