• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

daisytuner / sdfglib / 21113623600

18 Jan 2026 02:50PM UTC coverage: 64.425% (+0.3%) from 64.154%
21113623600

Pull #462

github

web-flow
Merge d503e5691 into 92e9cbdc3
Pull Request #462: adds syntax support for multi-assignments and np.empty_like

221 of 258 new or added lines in 5 files covered. (85.66%)

21 existing lines in 4 files now uncovered.

19678 of 30544 relevant lines covered (64.43%)

385.56 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

83.23
/python/docc/ast_parser.py
1
import ast
3✔
2
import copy
3✔
3
from ._sdfg import Scalar, PrimitiveType, Pointer
3✔
4
from .ast_utils import SliceRewriter, get_debug_info
3✔
5
from .expression_visitor import ExpressionVisitor
3✔
6
from .linear_algebra import LinearAlgebraHandler
3✔
7
from .convolution import ConvolutionHandler
3✔
8
from .onnx_ops import ONNXHandler
3✔
9

10

11
class ASTParser(ast.NodeVisitor):
3✔
12
    def __init__(
3✔
13
        self,
14
        builder,
15
        array_info=None,
16
        symbol_table=None,
17
        filename="",
18
        function_name="",
19
        infer_return_type=False,
20
        globals_dict=None,
21
        unique_counter_ref=None,
22
        structure_member_info=None,
23
    ):
24
        self.builder = builder
3✔
25
        self.array_info = array_info if array_info is not None else {}
3✔
26
        self.symbol_table = symbol_table if symbol_table is not None else {}
3✔
27
        self.filename = filename
3✔
28
        self.function_name = function_name
3✔
29
        self.infer_return_type = infer_return_type
3✔
30
        self.globals_dict = globals_dict
3✔
31
        self._unique_counter_ref = (
3✔
32
            unique_counter_ref if unique_counter_ref is not None else [0]
33
        )
34
        self.expr_visitor = ExpressionVisitor(
3✔
35
            self.array_info,
36
            self.builder,
37
            self.symbol_table,
38
            self.globals_dict,
39
            unique_counter_ref=self._unique_counter_ref,
40
            structure_member_info=structure_member_info,
41
        )
42
        self.la_handler = LinearAlgebraHandler(
3✔
43
            self.builder, self.array_info, self.symbol_table, self.expr_visitor
44
        )
45
        self.conv_handler = ConvolutionHandler(
3✔
46
            self.builder, self.array_info, self.symbol_table, self.expr_visitor
47
        )
48
        self.onnx_handler = ONNXHandler(
3✔
49
            self.builder, self.array_info, self.symbol_table, self.expr_visitor
50
        )
51
        self.expr_visitor.la_handler = self.la_handler
3✔
52
        self.captured_return_shapes = {}  # Map param name to shape string list
3✔
53

54
    def _get_unique_id(self):
3✔
55
        self._unique_counter_ref[0] += 1
3✔
56
        return self._unique_counter_ref[0]
3✔
57

58
    def _parse_expr(self, node):
3✔
59
        return self.expr_visitor.visit(node)
3✔
60

61
    def visit_Return(self, node):
3✔
62
        if node.value is None:
3✔
NEW
63
            debug_info = get_debug_info(node, self.filename, self.function_name)
×
NEW
64
            self.builder.add_return("", debug_info)
×
UNCOV
65
            return
×
66

67
        if isinstance(node.value, ast.Tuple):
3✔
68
            values = node.value.elts
3✔
69
        else:
70
            values = [node.value]
3✔
71

72
        parsed_values = [self._parse_expr(v) for v in values]
3✔
73
        debug_info = get_debug_info(node, self.filename, self.function_name)
3✔
74

75
        if self.infer_return_type:
3✔
76
            for i, res in enumerate(parsed_values):
3✔
77
                ret_name = f"_docc_ret_{i}"
3✔
78
                if not self.builder.has_container(ret_name):
3✔
79
                    dtype = Scalar(PrimitiveType.Double)
3✔
80
                    if res in self.symbol_table:
3✔
81
                        dtype = self.symbol_table[res]
3✔
NEW
82
                    elif isinstance(values[i], ast.Constant):
×
NEW
83
                        val = values[i].value
×
NEW
84
                        if isinstance(val, int):
×
NEW
85
                            dtype = Scalar(PrimitiveType.Int64)
×
NEW
86
                        elif isinstance(val, float):
×
NEW
87
                            dtype = Scalar(PrimitiveType.Double)
×
NEW
88
                        elif isinstance(val, bool):
×
NEW
89
                            dtype = Scalar(PrimitiveType.Bool)
×
90

91
                    # Wrap Scalar in Pointer. Keep Arrays/Pointers as is.
92
                    arg_type = dtype
3✔
93
                    if isinstance(dtype, Scalar):
3✔
94
                        arg_type = Pointer(dtype)
3✔
95

96
                    self.builder.add_container(ret_name, arg_type, is_argument=True)
3✔
97
                    self.symbol_table[ret_name] = arg_type
3✔
98

99
                    if res in self.array_info:
3✔
100
                        self.array_info[ret_name] = self.array_info[res]
3✔
101

102
            self.infer_return_type = False
3✔
103

104
        for i, res in enumerate(parsed_values):
3✔
105
            ret_name = f"_docc_ret_{i}"
3✔
106
            typ = self.symbol_table.get(ret_name)
3✔
107

108
            is_array_return = False
3✔
109
            if res in self.array_info:
3✔
110
                # Only treat as array return if it has dimensions
111
                # 0-d arrays (scalars) should be handled by scalar assignment
112
                if self.array_info[res]["ndim"] > 0:
3✔
113
                    is_array_return = True
3✔
114
            elif res in self.symbol_table:
3✔
115
                if isinstance(self.symbol_table[res], Pointer):
3✔
NEW
116
                    is_array_return = True
×
117

118
            # Simple Scalar Assignment
119
            if not is_array_return:
3✔
120
                block = self.builder.add_block(debug_info)
3✔
121
                t_dst = self.builder.add_access(block, ret_name, debug_info)
3✔
122

123
                t_src, src_sub = self.expr_visitor._add_read(block, res, debug_info)
3✔
124

125
                t_task = self.builder.add_tasklet(
3✔
126
                    block, "assign", ["_in"], ["_out"], debug_info
127
                )
128
                self.builder.add_memlet(
3✔
129
                    block, t_src, "void", t_task, "_in", src_sub, None, debug_info
130
                )
131
                self.builder.add_memlet(
3✔
132
                    block, t_task, "_out", t_dst, "void", "0", None, debug_info
133
                )
134

135
            # Array Assignment (Copy)
136
            else:
137
                # Record shape for metadata
138
                if res in self.array_info:
3✔
139
                    shape = self.array_info[res]["shapes"]
3✔
140
                    # Convert to string expressions
141
                    self.captured_return_shapes[ret_name] = [str(s) for s in shape]
3✔
142

143
                    # Ensure destination array info exists
144
                    if ret_name not in self.array_info:
3✔
145
                        self.array_info[ret_name] = self.array_info[res]
3✔
146

147
                # Copy Logic using visit_Assign
148
                ndim = 1
3✔
149
                if ret_name in self.array_info:
3✔
150
                    ndim = self.array_info[ret_name]["ndim"]
3✔
151

152
                slice_node = ast.Slice(lower=None, upper=None, step=None)
3✔
153
                if ndim > 1:
3✔
154
                    target_slice = ast.Tuple(elts=[slice_node] * ndim, ctx=ast.Load())
3✔
155
                else:
156
                    target_slice = slice_node
3✔
157

158
                target_sub = ast.Subscript(
3✔
159
                    value=ast.Name(id=ret_name, ctx=ast.Load()),
160
                    slice=target_slice,
161
                    ctx=ast.Store(),
162
                )
163

164
                # Value node reconstruction
165
                if isinstance(values[i], ast.Name):
3✔
166
                    val_node = values[i]
3✔
167
                else:
168
                    val_node = ast.Name(id=res, ctx=ast.Load())
3✔
169

170
                assign_node = ast.Assign(targets=[target_sub], value=val_node)
3✔
171
                self.visit_Assign(assign_node)
3✔
172

173
        # Add control flow return to exit the function/path
174
        self.builder.add_return("", debug_info)
3✔
175

176
    def visit_AugAssign(self, node):
3✔
177
        if isinstance(node.target, ast.Name) and node.target.id in self.array_info:
×
178
            # Convert to slice assignment: target[:] = target op value
179
            ndim = self.array_info[node.target.id]["ndim"]
×
180

181
            slices = []
×
182
            for _ in range(ndim):
×
183
                slices.append(ast.Slice(lower=None, upper=None, step=None))
×
184

185
            if ndim == 1:
×
186
                slice_arg = slices[0]
×
187
            else:
188
                slice_arg = ast.Tuple(elts=slices, ctx=ast.Load())
×
189

190
            slice_node = ast.Subscript(
×
191
                value=node.target, slice=slice_arg, ctx=ast.Store()
192
            )
193

194
            new_node = ast.Assign(
×
195
                targets=[slice_node],
196
                value=ast.BinOp(left=node.target, op=node.op, right=node.value),
197
            )
198
            self.visit_Assign(new_node)
×
199
        else:
200
            new_node = ast.Assign(
×
201
                targets=[node.target],
202
                value=ast.BinOp(left=node.target, op=node.op, right=node.value),
203
            )
204
            self.visit_Assign(new_node)
×
205

206
    def visit_Assign(self, node):
3✔
207
        if len(node.targets) > 1:
3✔
208
            tmp_name = f"_assign_tmp_{self._get_unique_id()}"
3✔
209
            # Assign value to temporary
210
            val_assign = ast.Assign(
3✔
211
                targets=[ast.Name(id=tmp_name, ctx=ast.Store())], value=node.value
212
            )
213
            ast.copy_location(val_assign, node)
3✔
214
            self.visit_Assign(val_assign)
3✔
215

216
            # Assign temporary to targets
217
            for target in node.targets:
3✔
218
                assign = ast.Assign(
3✔
219
                    targets=[target], value=ast.Name(id=tmp_name, ctx=ast.Load())
220
                )
221
                ast.copy_location(assign, node)
3✔
222
                self.visit_Assign(assign)
3✔
223
            return
3✔
224

225
        target = node.targets[0]
3✔
226

227
        # Special case: linear algebra functions
228
        if self.la_handler.is_gemm(node.value):
3✔
229
            if self.la_handler.handle_gemm(target, node.value):
×
230
                return
×
231
            if self.la_handler.handle_dot(target, node.value):
×
232
                return
×
233

234
        # Special case: convolution
235
        if self.conv_handler.is_conv(node.value):
3✔
236
            if self.conv_handler.handle_conv(target, node.value):
3✔
237
                return
3✔
238

239
        # Special case: ONNX ops (Transpose)
240
        if self.onnx_handler.is_transpose(node.value):
3✔
241
            if self.onnx_handler.handle_transpose(target, node.value):
3✔
242
                return
3✔
243

244
        # Special case:
245
        if isinstance(target, ast.Subscript):
3✔
246
            target_name = self.expr_visitor.visit(target.value)
3✔
247

248
            indices = []
3✔
249
            if isinstance(target.slice, ast.Tuple):
3✔
250
                indices = target.slice.elts
3✔
251
            else:
252
                indices = [target.slice]
3✔
253

254
            has_slice = False
3✔
255
            for idx in indices:
3✔
256
                if isinstance(idx, ast.Slice):
3✔
257
                    has_slice = True
3✔
258
                    break
3✔
259

260
            if has_slice:
3✔
261
                debug_info = get_debug_info(node, self.filename, self.function_name)
3✔
262
                self._handle_slice_assignment(
3✔
263
                    target, node.value, target_name, indices, debug_info
264
                )
265
                return
3✔
266

267
            target_name_full = self._parse_expr(target)
3✔
268
            value_str = self._parse_expr(node.value)
3✔
269
            debug_info = get_debug_info(node, self.filename, self.function_name)
3✔
270

271
            block = self.builder.add_block(debug_info)
3✔
272
            t_src, src_sub = self.expr_visitor._add_read(block, value_str, debug_info)
3✔
273

274
            if "(" in target_name_full and target_name_full.endswith(")"):
3✔
275
                name = target_name_full.split("(")[0]
3✔
276
                subset = target_name_full[target_name_full.find("(") + 1 : -1]
3✔
277
                t_dst = self.builder.add_access(block, name, debug_info)
3✔
278
                dst_sub = subset
3✔
279
            else:
280
                t_dst = self.builder.add_access(block, target_name_full, debug_info)
×
281
                dst_sub = ""
×
282

283
            t_task = self.builder.add_tasklet(
3✔
284
                block, "assign", ["_in"], ["_out"], debug_info
285
            )
286

287
            self.builder.add_memlet(
3✔
288
                block, t_src, "void", t_task, "_in", src_sub, None, debug_info
289
            )
290
            self.builder.add_memlet(
3✔
291
                block, t_task, "_out", t_dst, "void", dst_sub, None, debug_info
292
            )
293
            return
3✔
294

295
        # Variable assignments
296
        if not isinstance(target, ast.Name):
3✔
297
            raise NotImplementedError("Only assignment to variables supported")
×
298

299
        target_name = target.id
3✔
300
        value_str = self._parse_expr(node.value)
3✔
301
        debug_info = get_debug_info(node, self.filename, self.function_name)
3✔
302

303
        if not self.builder.has_container(target_name):
3✔
304
            if isinstance(node.value, ast.Constant):
3✔
305
                val = node.value.value
3✔
306
                if isinstance(val, int):
3✔
307
                    dtype = Scalar(PrimitiveType.Int64)
3✔
308
                elif isinstance(val, float):
×
309
                    dtype = Scalar(PrimitiveType.Double)
×
310
                elif isinstance(val, bool):
×
311
                    dtype = Scalar(PrimitiveType.Bool)
×
312
                else:
313
                    raise NotImplementedError(f"Cannot infer type for {val}")
×
314

315
                self.builder.add_container(target_name, dtype, False)
3✔
316
                self.symbol_table[target_name] = dtype
3✔
317
            else:
318
                assert value_str in self.symbol_table
3✔
319
                self.builder.add_container(
3✔
320
                    target_name, self.symbol_table[value_str], False
321
                )
322
                self.symbol_table[target_name] = self.symbol_table[value_str]
3✔
323

324
        if value_str in self.array_info:
3✔
325
            self.array_info[target_name] = self.array_info[value_str]
3✔
326

327
        # Distinguish assignments: scalar -> tasklet, pointer -> reference_memlet
328
        src_type = self.symbol_table.get(value_str)
3✔
329
        dst_type = self.symbol_table[target_name]
3✔
330
        if src_type and isinstance(src_type, Pointer) and isinstance(dst_type, Pointer):
3✔
331
            block = self.builder.add_block(debug_info)
3✔
332
            t_src = self.builder.add_access(block, value_str, debug_info)
3✔
333
            t_dst = self.builder.add_access(block, target_name, debug_info)
3✔
334
            self.builder.add_reference_memlet(
3✔
335
                block, t_src, t_dst, "0", src_type, debug_info
336
            )
337
            return
3✔
338
        elif (src_type and isinstance(src_type, Scalar)) or isinstance(
3✔
339
            dst_type, Scalar
340
        ):
341
            block = self.builder.add_block(debug_info)
3✔
342
            t_dst = self.builder.add_access(block, target_name, debug_info)
3✔
343
            t_task = self.builder.add_tasklet(
3✔
344
                block, "assign", ["_in"], ["_out"], debug_info
345
            )
346

347
            if src_type:
3✔
348
                t_src = self.builder.add_access(block, value_str, debug_info)
3✔
349
            else:
350
                t_src = self.builder.add_constant(
3✔
351
                    block, value_str, dst_type, debug_info
352
                )
353

354
            self.builder.add_memlet(
3✔
355
                block, t_src, "void", t_task, "_in", "", None, debug_info
356
            )
357
            self.builder.add_memlet(
3✔
358
                block, t_task, "_out", t_dst, "void", "", None, debug_info
359
            )
360

361
            return
3✔
362

363
    def visit_If(self, node):
3✔
364
        cond = self._parse_expr(node.test)
3✔
365
        debug_info = get_debug_info(node, self.filename, self.function_name)
3✔
366
        self.builder.begin_if(f"{cond} != false", debug_info)
3✔
367

368
        for stmt in node.body:
3✔
369
            self.visit(stmt)
3✔
370

371
        if node.orelse:
3✔
372
            self.builder.begin_else(debug_info)
3✔
373
            for stmt in node.orelse:
3✔
374
                self.visit(stmt)
3✔
375

376
        self.builder.end_if()
3✔
377

378
    def visit_While(self, node):
3✔
379
        cond = self._parse_expr(node.test)
3✔
380
        debug_info = get_debug_info(node, self.filename, self.function_name)
3✔
381
        self.builder.begin_while(f"{cond} != false", debug_info)
3✔
382

383
        for stmt in node.body:
3✔
384
            self.visit(stmt)
3✔
385

386
        self.builder.end_while()
3✔
387

388
    def visit_For(self, node):
3✔
389
        if not isinstance(node.target, ast.Name):
3✔
390
            raise NotImplementedError("Only simple for loops supported")
×
391

392
        var = node.target.id
3✔
393

394
        if not isinstance(node.iter, ast.Call) or node.iter.func.id != "range":
3✔
395
            raise NotImplementedError("Only range() loops supported")
×
396

397
        args = node.iter.args
3✔
398
        if len(args) == 1:
3✔
399
            start = "0"
3✔
400
            end = self._parse_expr(args[0])
3✔
401
            step = "1"
3✔
402
        elif len(args) == 2:
3✔
403
            start = self._parse_expr(args[0])
3✔
404
            end = self._parse_expr(args[1])
3✔
405
            step = "1"
3✔
406
        elif len(args) == 3:
3✔
407
            start = self._parse_expr(args[0])
3✔
408
            end = self._parse_expr(args[1])
3✔
409

410
            # Special handling for step to avoid creating tasklets for constants
411
            step_node = args[2]
3✔
412
            if isinstance(step_node, ast.Constant):
3✔
413
                step = str(step_node.value)
3✔
414
            elif (
3✔
415
                isinstance(step_node, ast.UnaryOp)
416
                and isinstance(step_node.op, ast.USub)
417
                and isinstance(step_node.operand, ast.Constant)
418
            ):
419
                step = f"-{step_node.operand.value}"
3✔
420
            else:
421
                step = self._parse_expr(step_node)
×
422
        else:
423
            raise ValueError("Invalid range arguments")
×
424

425
        if not self.builder.has_container(var):
3✔
426
            self.builder.add_container(var, Scalar(PrimitiveType.Int64), False)
3✔
427
            self.symbol_table[var] = Scalar(PrimitiveType.Int64)
3✔
428

429
        debug_info = get_debug_info(node, self.filename, self.function_name)
3✔
430
        self.builder.begin_for(var, start, end, step, debug_info)
3✔
431

432
        for stmt in node.body:
3✔
433
            self.visit(stmt)
3✔
434

435
        self.builder.end_for()
3✔
436

437
    def _handle_slice_assignment(
3✔
438
        self, target, value, target_name, indices, debug_info=None
439
    ):
440
        if debug_info is None:
3✔
441
            debug_info = DebugInfo()
×
442

443
        if target_name in self.array_info:
3✔
444
            ndim = self.array_info[target_name]["ndim"]
3✔
445
            if len(indices) < ndim:
3✔
446
                indices = list(indices)
×
447
                for _ in range(ndim - len(indices)):
×
448
                    indices.append(ast.Slice(lower=None, upper=None, step=None))
×
449

450
        loop_vars = []
3✔
451
        new_target_indices = []
3✔
452

453
        for i, idx in enumerate(indices):
3✔
454
            if isinstance(idx, ast.Slice):
3✔
455
                loop_var = f"_slice_iter_{len(loop_vars)}_{self._get_unique_id()}"
3✔
456
                loop_vars.append(loop_var)
3✔
457

458
                if not self.builder.has_container(loop_var):
3✔
459
                    self.builder.add_container(
3✔
460
                        loop_var, Scalar(PrimitiveType.Int64), False
461
                    )
462
                    self.symbol_table[loop_var] = Scalar(PrimitiveType.Int64)
3✔
463

464
                start_str = "0"
3✔
465
                if idx.lower:
3✔
466
                    start_str = self._parse_expr(idx.lower)
3✔
467
                    if start_str.startswith("-"):
3✔
468
                        shapes = self.array_info[target_name].get("shapes", [])
×
469
                        dim_size = (
×
470
                            shapes[i]
471
                            if i < len(shapes)
472
                            else f"_{target_name}_shape_{i}"
473
                        )
474
                        start_str = f"({dim_size} {start_str})"
×
475

476
                stop_str = ""
3✔
477
                if idx.upper and not (
3✔
478
                    isinstance(idx.upper, ast.Constant) and idx.upper.value is None
479
                ):
480
                    stop_str = self._parse_expr(idx.upper)
3✔
481
                    if stop_str.startswith("-") or stop_str.startswith("(-"):
3✔
482
                        shapes = self.array_info[target_name].get("shapes", [])
×
483
                        dim_size = (
×
484
                            shapes[i]
485
                            if i < len(shapes)
486
                            else f"_{target_name}_shape_{i}"
487
                        )
488
                        stop_str = f"({dim_size} {stop_str})"
×
489
                else:
490
                    shapes = self.array_info[target_name].get("shapes", [])
3✔
491
                    stop_str = (
3✔
492
                        shapes[i] if i < len(shapes) else f"_{target_name}_shape_{i}"
493
                    )
494

495
                step_str = "1"
3✔
496
                if idx.step:
3✔
497
                    step_str = self._parse_expr(idx.step)
×
498

499
                count_str = f"({stop_str} - {start_str})"
3✔
500

501
                self.builder.begin_for(loop_var, "0", count_str, "1", debug_info)
3✔
502
                self.symbol_table[loop_var] = Scalar(PrimitiveType.Int64)
3✔
503

504
                new_target_indices.append(
3✔
505
                    ast.Name(
506
                        id=f"{start_str} + {loop_var} * {step_str}", ctx=ast.Load()
507
                    )
508
                )
509
            else:
510
                new_target_indices.append(idx)
3✔
511

512
        rewriter = SliceRewriter(loop_vars, self.array_info, self.expr_visitor)
3✔
513
        new_value = rewriter.visit(copy.deepcopy(value))
3✔
514

515
        new_target = copy.deepcopy(target)
3✔
516
        if len(new_target_indices) == 1:
3✔
517
            new_target.slice = new_target_indices[0]
3✔
518
        else:
519
            new_target.slice = ast.Tuple(elts=new_target_indices, ctx=ast.Load())
3✔
520

521
        target_str = self._parse_expr(new_target)
3✔
522
        value_str = self._parse_expr(new_value)
3✔
523
        self.builder.add_assignment(target_str, value_str, debug_info)
3✔
524

525
        for _ in loop_vars:
3✔
526
            self.builder.end_for()
3✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc