• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

daisytuner / sdfglib / 21112243720

18 Jan 2026 01:08PM UTC coverage: 64.188% (+0.03%) from 64.154%
21112243720

Pull #462

github

web-flow
Merge d6bf7a605 into 92e9cbdc3
Pull Request #462: adds syntax support for multi-assignments and np.empty_like

31 of 33 new or added lines in 2 files covered. (93.94%)

197 existing lines in 5 files now uncovered.

19497 of 30375 relevant lines covered (64.19%)

387.69 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

82.4
/python/docc/ast_parser.py
1
import ast
3✔
2
import copy
3✔
3
from ._sdfg import Scalar, PrimitiveType, Pointer
3✔
4
from .ast_utils import SliceRewriter, get_debug_info
3✔
5
from .expression_visitor import ExpressionVisitor
3✔
6
from .linear_algebra import LinearAlgebraHandler
3✔
7
from .convolution import ConvolutionHandler
3✔
8
from .onnx_ops import ONNXHandler
3✔
9

10

11
class ASTParser(ast.NodeVisitor):
3✔
12
    def __init__(
3✔
13
        self,
14
        builder,
15
        array_info=None,
16
        symbol_table=None,
17
        filename="",
18
        function_name="",
19
        infer_return_type=False,
20
        globals_dict=None,
21
        unique_counter_ref=None,
22
        structure_member_info=None,
23
    ):
24
        self.builder = builder
3✔
25
        self.array_info = array_info if array_info is not None else {}
3✔
26
        self.symbol_table = symbol_table if symbol_table is not None else {}
3✔
27
        self.filename = filename
3✔
28
        self.function_name = function_name
3✔
29
        self.infer_return_type = infer_return_type
3✔
30
        self.globals_dict = globals_dict
3✔
31
        self._unique_counter_ref = (
3✔
32
            unique_counter_ref if unique_counter_ref is not None else [0]
33
        )
34
        self.expr_visitor = ExpressionVisitor(
3✔
35
            self.array_info,
36
            self.builder,
37
            self.symbol_table,
38
            self.globals_dict,
39
            unique_counter_ref=self._unique_counter_ref,
40
            structure_member_info=structure_member_info,
41
        )
42
        self.la_handler = LinearAlgebraHandler(
3✔
43
            self.builder, self.array_info, self.symbol_table, self.expr_visitor
44
        )
45
        self.conv_handler = ConvolutionHandler(
3✔
46
            self.builder, self.array_info, self.symbol_table, self.expr_visitor
47
        )
48
        self.onnx_handler = ONNXHandler(
3✔
49
            self.builder, self.array_info, self.symbol_table, self.expr_visitor
50
        )
51
        self.expr_visitor.la_handler = self.la_handler
3✔
52

53
    def _get_unique_id(self):
3✔
54
        self._unique_counter_ref[0] += 1
3✔
55
        return self._unique_counter_ref[0]
3✔
56

57
    def _parse_expr(self, node):
3✔
58
        return self.expr_visitor.visit(node)
3✔
59

60
    def visit_Return(self, node):
3✔
61
        if node.value is None:
3✔
UNCOV
62
            return
×
63

64
        res = self._parse_expr(node.value)
3✔
65
        debug_info = get_debug_info(node, self.filename, self.function_name)
3✔
66

67
        if isinstance(node.value, ast.Constant):
3✔
68
            val = node.value.value
3✔
69
            if isinstance(val, bool):
3✔
70
                dtype = Scalar(PrimitiveType.Bool)
×
71
            elif isinstance(val, int):
3✔
72
                dtype = Scalar(PrimitiveType.Int64)
3✔
UNCOV
73
            elif isinstance(val, float):
×
UNCOV
74
                dtype = Scalar(PrimitiveType.Double)
×
75
            else:
76
                raise NotImplementedError(
×
77
                    f"Unsupported constant return type: {type(val)}"
78
                )
79
            self.builder.add_constant_return(res, dtype, debug_info)
3✔
80
            if self.infer_return_type:
3✔
81
                self.builder.set_return_type(dtype)
×
82
                self.infer_return_type = False
×
83
        else:
84
            self.builder.add_return(res, debug_info)
3✔
85
            if self.infer_return_type:
3✔
86
                if res in self.symbol_table:
3✔
87
                    self.builder.set_return_type(self.symbol_table[res])
3✔
88
                    if res in self.array_info:
3✔
89
                        shape = self.array_info[res]["shapes"]
3✔
90
                        # Convert shape elements to strings
91
                        shape_str = [str(s) for s in shape]
3✔
92
                        self.builder.set_return_shape(shape_str)
3✔
93
                    self.infer_return_type = False
3✔
94

95
    def visit_AugAssign(self, node):
3✔
96
        if isinstance(node.target, ast.Name) and node.target.id in self.array_info:
×
97
            # Convert to slice assignment: target[:] = target op value
98
            ndim = self.array_info[node.target.id]["ndim"]
×
99

100
            slices = []
×
101
            for _ in range(ndim):
×
102
                slices.append(ast.Slice(lower=None, upper=None, step=None))
×
103

104
            if ndim == 1:
×
105
                slice_arg = slices[0]
×
106
            else:
107
                slice_arg = ast.Tuple(elts=slices, ctx=ast.Load())
×
108

109
            slice_node = ast.Subscript(
×
110
                value=node.target, slice=slice_arg, ctx=ast.Store()
111
            )
112

113
            new_node = ast.Assign(
×
114
                targets=[slice_node],
115
                value=ast.BinOp(left=node.target, op=node.op, right=node.value),
116
            )
117
            self.visit_Assign(new_node)
×
118
        else:
119
            new_node = ast.Assign(
×
120
                targets=[node.target],
121
                value=ast.BinOp(left=node.target, op=node.op, right=node.value),
122
            )
123
            self.visit_Assign(new_node)
×
124

125
    def visit_Assign(self, node):
3✔
126
        if len(node.targets) > 1:
3✔
127
            tmp_name = f"_assign_tmp_{self._get_unique_id()}"
3✔
128
            # Assign value to temporary
129
            val_assign = ast.Assign(
3✔
130
                targets=[ast.Name(id=tmp_name, ctx=ast.Store())], value=node.value
131
            )
132
            ast.copy_location(val_assign, node)
3✔
133
            self.visit_Assign(val_assign)
3✔
134

135
            # Assign temporary to targets
136
            for target in node.targets:
3✔
137
                assign = ast.Assign(
3✔
138
                    targets=[target], value=ast.Name(id=tmp_name, ctx=ast.Load())
139
                )
140
                ast.copy_location(assign, node)
3✔
141
                self.visit_Assign(assign)
3✔
142
            return
3✔
143

144
        target = node.targets[0]
3✔
145

146
        # Special case: linear algebra functions
147
        if self.la_handler.is_gemm(node.value):
3✔
148
            if self.la_handler.handle_gemm(target, node.value):
×
149
                return
×
150
            if self.la_handler.handle_dot(target, node.value):
×
151
                return
×
152

153
        # Special case: convolution
154
        if self.conv_handler.is_conv(node.value):
3✔
155
            if self.conv_handler.handle_conv(target, node.value):
3✔
156
                return
3✔
157

158
        # Special case: ONNX ops (Transpose)
159
        if self.onnx_handler.is_transpose(node.value):
3✔
160
            if self.onnx_handler.handle_transpose(target, node.value):
3✔
161
                return
3✔
162

163
        # Special case:
164
        if isinstance(target, ast.Subscript):
3✔
165
            target_name = self.expr_visitor.visit(target.value)
3✔
166

167
            indices = []
3✔
168
            if isinstance(target.slice, ast.Tuple):
3✔
169
                indices = target.slice.elts
3✔
170
            else:
171
                indices = [target.slice]
3✔
172

173
            has_slice = False
3✔
174
            for idx in indices:
3✔
175
                if isinstance(idx, ast.Slice):
3✔
176
                    has_slice = True
3✔
177
                    break
3✔
178

179
            if has_slice:
3✔
180
                debug_info = get_debug_info(node, self.filename, self.function_name)
3✔
181
                self._handle_slice_assignment(
3✔
182
                    target, node.value, target_name, indices, debug_info
183
                )
184
                return
3✔
185

186
            target_name_full = self._parse_expr(target)
3✔
187
            value_str = self._parse_expr(node.value)
3✔
188
            debug_info = get_debug_info(node, self.filename, self.function_name)
3✔
189

190
            block = self.builder.add_block(debug_info)
3✔
191
            t_src, src_sub = self.expr_visitor._add_read(block, value_str, debug_info)
3✔
192

193
            if "(" in target_name_full and target_name_full.endswith(")"):
3✔
194
                name = target_name_full.split("(")[0]
3✔
195
                subset = target_name_full[target_name_full.find("(") + 1 : -1]
3✔
196
                t_dst = self.builder.add_access(block, name, debug_info)
3✔
197
                dst_sub = subset
3✔
198
            else:
UNCOV
199
                t_dst = self.builder.add_access(block, target_name_full, debug_info)
×
200
                dst_sub = ""
×
201

202
            t_task = self.builder.add_tasklet(
3✔
203
                block, "assign", ["_in"], ["_out"], debug_info
204
            )
205

206
            self.builder.add_memlet(
3✔
207
                block, t_src, "void", t_task, "_in", src_sub, None, debug_info
208
            )
209
            self.builder.add_memlet(
3✔
210
                block, t_task, "_out", t_dst, "void", dst_sub, None, debug_info
211
            )
212
            return
3✔
213

214
        # Variable assignments
215
        if not isinstance(target, ast.Name):
3✔
216
            raise NotImplementedError("Only assignment to variables supported")
×
217

218
        target_name = target.id
3✔
219
        value_str = self._parse_expr(node.value)
3✔
220
        debug_info = get_debug_info(node, self.filename, self.function_name)
3✔
221

222
        if not self.builder.has_container(target_name):
3✔
223
            if isinstance(node.value, ast.Constant):
3✔
224
                val = node.value.value
3✔
225
                if isinstance(val, int):
3✔
226
                    dtype = Scalar(PrimitiveType.Int64)
3✔
UNCOV
227
                elif isinstance(val, float):
×
UNCOV
228
                    dtype = Scalar(PrimitiveType.Double)
×
229
                elif isinstance(val, bool):
×
230
                    dtype = Scalar(PrimitiveType.Bool)
×
231
                else:
232
                    raise NotImplementedError(f"Cannot infer type for {val}")
×
233

234
                self.builder.add_container(target_name, dtype, False)
3✔
235
                self.symbol_table[target_name] = dtype
3✔
236
            else:
237
                assert value_str in self.symbol_table
3✔
238
                self.builder.add_container(
3✔
239
                    target_name, self.symbol_table[value_str], False
240
                )
241
                self.symbol_table[target_name] = self.symbol_table[value_str]
3✔
242

243
        if value_str in self.array_info:
3✔
244
            self.array_info[target_name] = self.array_info[value_str]
3✔
245

246
        # Distinguish assignments: scalar -> tasklet, pointer -> reference_memlet
247
        src_type = self.symbol_table.get(value_str)
3✔
248
        dst_type = self.symbol_table[target_name]
3✔
249
        if src_type and isinstance(src_type, Pointer) and isinstance(dst_type, Pointer):
3✔
250
            block = self.builder.add_block(debug_info)
3✔
251
            t_src = self.builder.add_access(block, value_str, debug_info)
3✔
252
            t_dst = self.builder.add_access(block, target_name, debug_info)
3✔
253
            self.builder.add_reference_memlet(
3✔
254
                block, t_src, t_dst, "0", src_type, debug_info
255
            )
256
            return
3✔
257
        elif (src_type and isinstance(src_type, Scalar)) or isinstance(
3✔
258
            dst_type, Scalar
259
        ):
260
            block = self.builder.add_block(debug_info)
3✔
261
            t_dst = self.builder.add_access(block, target_name, debug_info)
3✔
262
            t_task = self.builder.add_tasklet(
3✔
263
                block, "assign", ["_in"], ["_out"], debug_info
264
            )
265

266
            if src_type:
3✔
267
                t_src = self.builder.add_access(block, value_str, debug_info)
3✔
268
            else:
269
                t_src = self.builder.add_constant(
3✔
270
                    block, value_str, dst_type, debug_info
271
                )
272

273
            self.builder.add_memlet(
3✔
274
                block, t_src, "void", t_task, "_in", "", None, debug_info
275
            )
276
            self.builder.add_memlet(
3✔
277
                block, t_task, "_out", t_dst, "void", "", None, debug_info
278
            )
279

280
            return
3✔
281

282
    def visit_If(self, node):
3✔
283
        cond = self._parse_expr(node.test)
3✔
284
        debug_info = get_debug_info(node, self.filename, self.function_name)
3✔
285
        self.builder.begin_if(f"{cond} != false", debug_info)
3✔
286

287
        for stmt in node.body:
3✔
288
            self.visit(stmt)
3✔
289

290
        if node.orelse:
3✔
291
            self.builder.begin_else(debug_info)
3✔
292
            for stmt in node.orelse:
3✔
293
                self.visit(stmt)
3✔
294

295
        self.builder.end_if()
3✔
296

297
    def visit_While(self, node):
3✔
298
        cond = self._parse_expr(node.test)
3✔
299
        debug_info = get_debug_info(node, self.filename, self.function_name)
3✔
300
        self.builder.begin_while(f"{cond} != false", debug_info)
3✔
301

302
        for stmt in node.body:
3✔
303
            self.visit(stmt)
3✔
304

305
        self.builder.end_while()
3✔
306

307
    def visit_For(self, node):
3✔
308
        if not isinstance(node.target, ast.Name):
3✔
309
            raise NotImplementedError("Only simple for loops supported")
×
310

311
        var = node.target.id
3✔
312

313
        if not isinstance(node.iter, ast.Call) or node.iter.func.id != "range":
3✔
UNCOV
314
            raise NotImplementedError("Only range() loops supported")
×
315

316
        args = node.iter.args
3✔
317
        if len(args) == 1:
3✔
318
            start = "0"
3✔
319
            end = self._parse_expr(args[0])
3✔
320
            step = "1"
3✔
321
        elif len(args) == 2:
3✔
322
            start = self._parse_expr(args[0])
3✔
323
            end = self._parse_expr(args[1])
3✔
324
            step = "1"
3✔
325
        elif len(args) == 3:
3✔
326
            start = self._parse_expr(args[0])
3✔
327
            end = self._parse_expr(args[1])
3✔
328

329
            # Special handling for step to avoid creating tasklets for constants
330
            step_node = args[2]
3✔
331
            if isinstance(step_node, ast.Constant):
3✔
332
                step = str(step_node.value)
3✔
333
            elif (
3✔
334
                isinstance(step_node, ast.UnaryOp)
335
                and isinstance(step_node.op, ast.USub)
336
                and isinstance(step_node.operand, ast.Constant)
337
            ):
338
                step = f"-{step_node.operand.value}"
3✔
339
            else:
UNCOV
340
                step = self._parse_expr(step_node)
×
341
        else:
UNCOV
342
            raise ValueError("Invalid range arguments")
×
343

344
        if not self.builder.has_container(var):
3✔
345
            self.builder.add_container(var, Scalar(PrimitiveType.Int64), False)
3✔
346
            self.symbol_table[var] = Scalar(PrimitiveType.Int64)
3✔
347

348
        debug_info = get_debug_info(node, self.filename, self.function_name)
3✔
349
        self.builder.begin_for(var, start, end, step, debug_info)
3✔
350

351
        for stmt in node.body:
3✔
352
            self.visit(stmt)
3✔
353

354
        self.builder.end_for()
3✔
355

356
    def _handle_slice_assignment(
3✔
357
        self, target, value, target_name, indices, debug_info=None
358
    ):
359
        if debug_info is None:
3✔
UNCOV
360
            debug_info = DebugInfo()
×
361

362
        if target_name in self.array_info:
3✔
363
            ndim = self.array_info[target_name]["ndim"]
3✔
364
            if len(indices) < ndim:
3✔
UNCOV
365
                indices = list(indices)
×
UNCOV
366
                for _ in range(ndim - len(indices)):
×
UNCOV
367
                    indices.append(ast.Slice(lower=None, upper=None, step=None))
×
368

369
        loop_vars = []
3✔
370
        new_target_indices = []
3✔
371

372
        for i, idx in enumerate(indices):
3✔
373
            if isinstance(idx, ast.Slice):
3✔
374
                loop_var = f"_slice_iter_{len(loop_vars)}_{self._get_unique_id()}"
3✔
375
                loop_vars.append(loop_var)
3✔
376

377
                if not self.builder.has_container(loop_var):
3✔
378
                    self.builder.add_container(
3✔
379
                        loop_var, Scalar(PrimitiveType.Int64), False
380
                    )
381
                    self.symbol_table[loop_var] = Scalar(PrimitiveType.Int64)
3✔
382

383
                start_str = "0"
3✔
384
                if idx.lower:
3✔
385
                    start_str = self._parse_expr(idx.lower)
3✔
386
                    if start_str.startswith("-"):
3✔
UNCOV
387
                        shapes = self.array_info[target_name].get("shapes", [])
×
UNCOV
388
                        dim_size = (
×
389
                            shapes[i]
390
                            if i < len(shapes)
391
                            else f"_{target_name}_shape_{i}"
392
                        )
UNCOV
393
                        start_str = f"({dim_size} {start_str})"
×
394

395
                stop_str = ""
3✔
396
                if idx.upper and not (
3✔
397
                    isinstance(idx.upper, ast.Constant) and idx.upper.value is None
398
                ):
399
                    stop_str = self._parse_expr(idx.upper)
3✔
400
                    if stop_str.startswith("-") or stop_str.startswith("(-"):
3✔
UNCOV
401
                        shapes = self.array_info[target_name].get("shapes", [])
×
UNCOV
402
                        dim_size = (
×
403
                            shapes[i]
404
                            if i < len(shapes)
405
                            else f"_{target_name}_shape_{i}"
406
                        )
UNCOV
407
                        stop_str = f"({dim_size} {stop_str})"
×
408
                else:
409
                    shapes = self.array_info[target_name].get("shapes", [])
3✔
410
                    stop_str = (
3✔
411
                        shapes[i] if i < len(shapes) else f"_{target_name}_shape_{i}"
412
                    )
413

414
                step_str = "1"
3✔
415
                if idx.step:
3✔
UNCOV
416
                    step_str = self._parse_expr(idx.step)
×
417

418
                count_str = f"({stop_str} - {start_str})"
3✔
419

420
                self.builder.begin_for(loop_var, "0", count_str, "1", debug_info)
3✔
421
                self.symbol_table[loop_var] = Scalar(PrimitiveType.Int64)
3✔
422

423
                new_target_indices.append(
3✔
424
                    ast.Name(
425
                        id=f"{start_str} + {loop_var} * {step_str}", ctx=ast.Load()
426
                    )
427
                )
428
            else:
429
                new_target_indices.append(idx)
3✔
430

431
        rewriter = SliceRewriter(loop_vars, self.array_info, self.expr_visitor)
3✔
432
        new_value = rewriter.visit(copy.deepcopy(value))
3✔
433

434
        new_target = copy.deepcopy(target)
3✔
435
        if len(new_target_indices) == 1:
3✔
436
            new_target.slice = new_target_indices[0]
3✔
437
        else:
438
            new_target.slice = ast.Tuple(elts=new_target_indices, ctx=ast.Load())
3✔
439

440
        target_str = self._parse_expr(new_target)
3✔
441
        value_str = self._parse_expr(new_value)
3✔
442
        self.builder.add_assignment(target_str, value_str, debug_info)
3✔
443

444
        for _ in loop_vars:
3✔
445
            self.builder.end_for()
3✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc