• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

daisytuner / docc / 21746009216

06 Feb 2026 09:43AM UTC coverage: 66.359% (-0.1%) from 66.484%
21746009216

push

github

web-flow
Merge pull request #506 from daisytuner/npbench-crc16

adds crc16 npbench benchmark

53 of 130 new or added lines in 3 files covered. (40.77%)

1 existing line in 1 file now uncovered.

23114 of 34832 relevant lines covered (66.36%)

375.27 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

69.79
/python/docc/python/ast_parser.py
1
import ast
4✔
2
import copy
4✔
3
from docc.sdfg import Scalar, PrimitiveType, Pointer, TaskletCode
4✔
4
from docc.python.ast_utils import (
4✔
5
    SliceRewriter,
6
    get_debug_info,
7
    contains_ufunc_outer,
8
    normalize_negative_index,
9
)
10
from docc.python.expression_visitor import ExpressionVisitor
4✔
11
from docc.python.linear_algebra import LinearAlgebraHandler
4✔
12
from docc.python.convolution import ConvolutionHandler
4✔
13
from docc.python.onnx_ops import ONNXHandler
4✔
14

15

16
class ASTParser(ast.NodeVisitor):
4✔
17
    def __init__(
4✔
18
        self,
19
        builder,
20
        array_info=None,
21
        symbol_table=None,
22
        filename="",
23
        function_name="",
24
        infer_return_type=False,
25
        globals_dict=None,
26
        unique_counter_ref=None,
27
        structure_member_info=None,
28
    ):
29
        self.builder = builder
4✔
30
        self.array_info = array_info if array_info is not None else {}
4✔
31
        self.symbol_table = symbol_table if symbol_table is not None else {}
4✔
32
        self.filename = filename
4✔
33
        self.function_name = function_name
4✔
34
        self.infer_return_type = infer_return_type
4✔
35
        self.globals_dict = globals_dict
4✔
36
        self._unique_counter_ref = (
4✔
37
            unique_counter_ref if unique_counter_ref is not None else [0]
38
        )
39
        self.expr_visitor = ExpressionVisitor(
4✔
40
            self.array_info,
41
            self.builder,
42
            self.symbol_table,
43
            self.globals_dict,
44
            unique_counter_ref=self._unique_counter_ref,
45
            structure_member_info=structure_member_info,
46
        )
47
        self.la_handler = LinearAlgebraHandler(
4✔
48
            self.builder, self.array_info, self.symbol_table, self.expr_visitor
49
        )
50
        self.conv_handler = ConvolutionHandler(
4✔
51
            self.builder, self.array_info, self.symbol_table, self.expr_visitor
52
        )
53
        self.onnx_handler = ONNXHandler(
4✔
54
            self.builder, self.array_info, self.symbol_table, self.expr_visitor
55
        )
56
        self.expr_visitor.la_handler = self.la_handler
4✔
57
        self.captured_return_shapes = {}  # Map param name to shape string list
4✔
58

59
    def _get_unique_id(self):
4✔
60
        self._unique_counter_ref[0] += 1
4✔
61
        return self._unique_counter_ref[0]
4✔
62

63
    def _parse_expr(self, node):
4✔
64
        return self.expr_visitor.visit(node)
4✔
65

66
    def visit_Return(self, node):
4✔
67
        if node.value is None:
4✔
68
            debug_info = get_debug_info(node, self.filename, self.function_name)
×
69
            self.builder.add_return("", debug_info)
×
70
            return
×
71

72
        if isinstance(node.value, ast.Tuple):
4✔
73
            values = node.value.elts
4✔
74
        else:
75
            values = [node.value]
4✔
76

77
        parsed_values = [self._parse_expr(v) for v in values]
4✔
78
        debug_info = get_debug_info(node, self.filename, self.function_name)
4✔
79

80
        if self.infer_return_type:
4✔
81
            for i, res in enumerate(parsed_values):
4✔
82
                ret_name = f"_docc_ret_{i}"
4✔
83
                if not self.builder.exists(ret_name):
4✔
84
                    dtype = Scalar(PrimitiveType.Double)
4✔
85
                    if res in self.symbol_table:
4✔
86
                        dtype = self.symbol_table[res]
4✔
87
                    elif isinstance(values[i], ast.Constant):
×
88
                        val = values[i].value
×
89
                        if isinstance(val, int):
×
90
                            dtype = Scalar(PrimitiveType.Int64)
×
91
                        elif isinstance(val, float):
×
92
                            dtype = Scalar(PrimitiveType.Double)
×
93
                        elif isinstance(val, bool):
×
94
                            dtype = Scalar(PrimitiveType.Bool)
×
95

96
                    # Wrap Scalar in Pointer. Keep Arrays/Pointers as is.
97
                    arg_type = dtype
4✔
98
                    if isinstance(dtype, Scalar):
4✔
99
                        arg_type = Pointer(dtype)
4✔
100

101
                    self.builder.add_container(ret_name, arg_type, is_argument=True)
4✔
102
                    self.symbol_table[ret_name] = arg_type
4✔
103

104
                    if res in self.array_info:
4✔
105
                        self.array_info[ret_name] = self.array_info[res]
4✔
106

107
            self.infer_return_type = False
4✔
108

109
        for i, res in enumerate(parsed_values):
4✔
110
            ret_name = f"_docc_ret_{i}"
4✔
111
            typ = self.symbol_table.get(ret_name)
4✔
112

113
            is_array_return = False
4✔
114
            if res in self.array_info:
4✔
115
                # Only treat as array return if it has dimensions
116
                # 0-d arrays (scalars) should be handled by scalar assignment
117
                if self.array_info[res]["ndim"] > 0:
4✔
118
                    is_array_return = True
4✔
119
            elif res in self.symbol_table:
4✔
120
                if isinstance(self.symbol_table[res], Pointer):
4✔
121
                    is_array_return = True
×
122

123
            # Simple Scalar Assignment
124
            if not is_array_return:
4✔
125
                block = self.builder.add_block(debug_info)
4✔
126
                t_dst = self.builder.add_access(block, ret_name, debug_info)
4✔
127

128
                t_src, src_sub = self.expr_visitor._add_read(block, res, debug_info)
4✔
129

130
                t_task = self.builder.add_tasklet(
4✔
131
                    block, TaskletCode.assign, ["_in"], ["_out"], debug_info
132
                )
133
                self.builder.add_memlet(
4✔
134
                    block, t_src, "void", t_task, "_in", src_sub, None, debug_info
135
                )
136
                self.builder.add_memlet(
4✔
137
                    block, t_task, "_out", t_dst, "void", "0", None, debug_info
138
                )
139

140
            # Array Assignment (Copy)
141
            else:
142
                # Record shape for metadata
143
                if res in self.array_info:
4✔
144
                    # Prefer runtime shapes if available (for indirect access patterns)
145
                    # Fall back to regular shapes otherwise
146
                    if "shapes_runtime" in self.array_info[res]:
4✔
147
                        shape = self.array_info[res]["shapes_runtime"]
4✔
148
                    else:
149
                        shape = self.array_info[res]["shapes"]
4✔
150
                    # Convert to string expressions
151
                    self.captured_return_shapes[ret_name] = [str(s) for s in shape]
4✔
152

153
                    # Ensure destination array info exists
154
                    if ret_name not in self.array_info:
4✔
155
                        self.array_info[ret_name] = self.array_info[res]
4✔
156

157
                # Copy Logic using visit_Assign
158
                ndim = 1
4✔
159
                if ret_name in self.array_info:
4✔
160
                    ndim = self.array_info[ret_name]["ndim"]
4✔
161

162
                slice_node = ast.Slice(lower=None, upper=None, step=None)
4✔
163
                if ndim > 1:
4✔
164
                    target_slice = ast.Tuple(elts=[slice_node] * ndim, ctx=ast.Load())
4✔
165
                else:
166
                    target_slice = slice_node
4✔
167

168
                target_sub = ast.Subscript(
4✔
169
                    value=ast.Name(id=ret_name, ctx=ast.Load()),
170
                    slice=target_slice,
171
                    ctx=ast.Store(),
172
                )
173

174
                # Value node reconstruction
175
                if isinstance(values[i], ast.Name):
4✔
176
                    val_node = values[i]
4✔
177
                else:
178
                    val_node = ast.Name(id=res, ctx=ast.Load())
4✔
179

180
                assign_node = ast.Assign(targets=[target_sub], value=val_node)
4✔
181
                self.visit_Assign(assign_node)
4✔
182

183
        # Add control flow return to exit the function/path
184
        self.builder.add_return("", debug_info)
4✔
185

186
    def visit_Expr(self, node):
4✔
187
        """Handle expression statements (e.g., bare function calls)."""
188
        # Expression statements are typically function calls without assignment
189
        # Like: build_up_b(...) or pressure_poisson(...)
190
        if isinstance(node.value, ast.Call):
×
191
            # This is a function call statement - handle it through expression visitor
192
            # which will inline the function
193
            self._parse_expr(node.value)
×
194
        else:
195
            # For other expression statements, just evaluate them
196
            self._parse_expr(node.value)
×
197

198
    def visit_AugAssign(self, node):
4✔
199
        if isinstance(node.target, ast.Name) and node.target.id in self.array_info:
4✔
200
            # Convert to slice assignment: target[:] = target op value
201
            ndim = self.array_info[node.target.id]["ndim"]
4✔
202

203
            slices = []
4✔
204
            for _ in range(ndim):
4✔
205
                slices.append(ast.Slice(lower=None, upper=None, step=None))
4✔
206

207
            if ndim == 1:
4✔
208
                slice_arg = slices[0]
×
209
            else:
210
                slice_arg = ast.Tuple(elts=slices, ctx=ast.Load())
4✔
211

212
            slice_node = ast.Subscript(
4✔
213
                value=node.target, slice=slice_arg, ctx=ast.Store()
214
            )
215

216
            new_node = ast.Assign(
4✔
217
                targets=[slice_node],
218
                value=ast.BinOp(left=node.target, op=node.op, right=node.value),
219
            )
220
            self.visit_Assign(new_node)
4✔
221
        else:
222
            new_node = ast.Assign(
4✔
223
                targets=[node.target],
224
                value=ast.BinOp(left=node.target, op=node.op, right=node.value),
225
            )
226
            self.visit_Assign(new_node)
4✔
227

228
    def visit_Assign(self, node):
4✔
229
        if len(node.targets) > 1:
4✔
230
            tmp_name = f"_assign_tmp_{self._get_unique_id()}"
4✔
231
            # Assign value to temporary
232
            val_assign = ast.Assign(
4✔
233
                targets=[ast.Name(id=tmp_name, ctx=ast.Store())], value=node.value
234
            )
235
            ast.copy_location(val_assign, node)
4✔
236
            self.visit_Assign(val_assign)
4✔
237

238
            # Assign temporary to targets
239
            for target in node.targets:
4✔
240
                assign = ast.Assign(
4✔
241
                    targets=[target], value=ast.Name(id=tmp_name, ctx=ast.Load())
242
                )
243
                ast.copy_location(assign, node)
4✔
244
                self.visit_Assign(assign)
4✔
245
            return
4✔
246

247
        target = node.targets[0]
4✔
248

249
        # Handle tuple unpacking: I, J, K = expr1, expr2, expr3
250
        if isinstance(target, ast.Tuple):
4✔
251
            if isinstance(node.value, ast.Tuple):
4✔
252
                # Unpacking tuple to tuple: a, b, c = x, y, z
253
                if len(target.elts) != len(node.value.elts):
4✔
254
                    raise ValueError("Tuple unpacking size mismatch")
×
255
                for tgt, val in zip(target.elts, node.value.elts):
4✔
256
                    assign = ast.Assign(targets=[tgt], value=val)
4✔
257
                    ast.copy_location(assign, node)
4✔
258
                    self.visit_Assign(assign)
4✔
259
            else:
260
                raise NotImplementedError(
×
261
                    "Tuple unpacking from non-tuple values not supported"
262
                )
263
            return
4✔
264

265
        # Special case: linear algebra functions
266
        if self.la_handler.is_gemm(node.value):
4✔
267
            if self.la_handler.handle_gemm(target, node.value):
4✔
268
                return
×
269
            if self.la_handler.handle_dot(target, node.value):
4✔
270
                return
×
271

272
        # Special case: outer product
273
        if self.la_handler.is_outer(node.value):
4✔
274
            if self.la_handler.handle_outer(target, node.value):
4✔
275
                return
4✔
276

277
        # Special case: convolution
278
        if self.conv_handler.is_conv(node.value):
4✔
279
            if self.conv_handler.handle_conv(target, node.value):
4✔
280
                return
4✔
281

282
        # Special case: ONNX ops (Transpose)
283
        if self.onnx_handler.is_transpose(node.value):
4✔
284
            if self.onnx_handler.handle_transpose(target, node.value):
4✔
285
                return
4✔
286

287
        # Special case:
288
        if isinstance(target, ast.Subscript):
4✔
289
            target_name = self.expr_visitor.visit(target.value)
4✔
290

291
            indices = []
4✔
292
            if isinstance(target.slice, ast.Tuple):
4✔
293
                indices = target.slice.elts
4✔
294
            else:
295
                indices = [target.slice]
4✔
296

297
            has_slice = False
4✔
298
            for idx in indices:
4✔
299
                if isinstance(idx, ast.Slice):
4✔
300
                    has_slice = True
4✔
301
                    break
4✔
302

303
            if has_slice:
4✔
304
                debug_info = get_debug_info(node, self.filename, self.function_name)
4✔
305
                self._handle_slice_assignment(
4✔
306
                    target, node.value, target_name, indices, debug_info
307
                )
308
                return
4✔
309

310
            target_name_full = self._parse_expr(target)
4✔
311
            value_str = self._parse_expr(node.value)
4✔
312
            debug_info = get_debug_info(node, self.filename, self.function_name)
4✔
313

314
            block = self.builder.add_block(debug_info)
4✔
315
            t_src, src_sub = self.expr_visitor._add_read(block, value_str, debug_info)
4✔
316

317
            if "(" in target_name_full and target_name_full.endswith(")"):
4✔
318
                name = target_name_full.split("(")[0]
4✔
319
                subset = target_name_full[target_name_full.find("(") + 1 : -1]
4✔
320
                t_dst = self.builder.add_access(block, name, debug_info)
4✔
321
                dst_sub = subset
4✔
322
            else:
323
                t_dst = self.builder.add_access(block, target_name_full, debug_info)
×
324
                dst_sub = ""
×
325

326
            t_task = self.builder.add_tasklet(
4✔
327
                block, TaskletCode.assign, ["_in"], ["_out"], debug_info
328
            )
329

330
            self.builder.add_memlet(
4✔
331
                block, t_src, "void", t_task, "_in", src_sub, None, debug_info
332
            )
333
            self.builder.add_memlet(
4✔
334
                block, t_task, "_out", t_dst, "void", dst_sub, None, debug_info
335
            )
336
            return
4✔
337

338
        # Variable assignments
339
        if not isinstance(target, ast.Name):
4✔
340
            raise NotImplementedError("Only assignment to variables supported")
×
341

342
        target_name = target.id
4✔
343
        value_str = self._parse_expr(node.value)
4✔
344
        debug_info = get_debug_info(node, self.filename, self.function_name)
4✔
345

346
        if not self.builder.exists(target_name):
4✔
347
            if isinstance(node.value, ast.Constant):
4✔
348
                val = node.value.value
4✔
349
                if isinstance(val, int):
4✔
350
                    dtype = Scalar(PrimitiveType.Int64)
4✔
351
                elif isinstance(val, float):
4✔
352
                    dtype = Scalar(PrimitiveType.Double)
4✔
353
                elif isinstance(val, bool):
×
354
                    dtype = Scalar(PrimitiveType.Bool)
×
355
                else:
356
                    raise NotImplementedError(f"Cannot infer type for {val}")
×
357

358
                self.builder.add_container(target_name, dtype, False)
4✔
359
                self.symbol_table[target_name] = dtype
4✔
360
            else:
361
                assert value_str in self.symbol_table
4✔
362
                self.builder.add_container(
4✔
363
                    target_name, self.symbol_table[value_str], False
364
                )
365
                self.symbol_table[target_name] = self.symbol_table[value_str]
4✔
366

367
        if value_str in self.array_info:
4✔
368
            self.array_info[target_name] = self.array_info[value_str]
4✔
369

370
        # Distinguish assignments: scalar -> tasklet, pointer -> reference_memlet
371
        src_type = self.symbol_table.get(value_str)
4✔
372
        dst_type = self.symbol_table[target_name]
4✔
373
        if src_type and isinstance(src_type, Pointer) and isinstance(dst_type, Pointer):
4✔
374
            block = self.builder.add_block(debug_info)
4✔
375
            t_src = self.builder.add_access(block, value_str, debug_info)
4✔
376
            t_dst = self.builder.add_access(block, target_name, debug_info)
4✔
377
            self.builder.add_reference_memlet(
4✔
378
                block, t_src, t_dst, "0", src_type, debug_info
379
            )
380
            return
4✔
381
        elif (src_type and isinstance(src_type, Scalar)) or isinstance(
4✔
382
            dst_type, Scalar
383
        ):
384
            block = self.builder.add_block(debug_info)
4✔
385
            t_dst = self.builder.add_access(block, target_name, debug_info)
4✔
386
            t_task = self.builder.add_tasklet(
4✔
387
                block, TaskletCode.assign, ["_in"], ["_out"], debug_info
388
            )
389

390
            if src_type:
4✔
391
                t_src = self.builder.add_access(block, value_str, debug_info)
4✔
392
            else:
393
                t_src = self.builder.add_constant(
4✔
394
                    block, value_str, dst_type, debug_info
395
                )
396

397
            self.builder.add_memlet(
4✔
398
                block, t_src, "void", t_task, "_in", "", None, debug_info
399
            )
400
            self.builder.add_memlet(
4✔
401
                block, t_task, "_out", t_dst, "void", "", None, debug_info
402
            )
403

404
            return
4✔
405

406
    def visit_If(self, node):
4✔
407
        cond = self._parse_expr(node.test)
4✔
408
        debug_info = get_debug_info(node, self.filename, self.function_name)
4✔
409
        self.builder.begin_if(f"{cond} != false", debug_info)
4✔
410

411
        for stmt in node.body:
4✔
412
            self.visit(stmt)
4✔
413

414
        if node.orelse:
4✔
415
            self.builder.begin_else(debug_info)
4✔
416
            for stmt in node.orelse:
4✔
417
                self.visit(stmt)
4✔
418

419
        self.builder.end_if()
4✔
420

421
    def visit_While(self, node):
4✔
422
        cond = self._parse_expr(node.test)
4✔
423
        debug_info = get_debug_info(node, self.filename, self.function_name)
4✔
424
        self.builder.begin_while(f"{cond} != false", debug_info)
4✔
425

426
        for stmt in node.body:
4✔
427
            self.visit(stmt)
4✔
428

429
        self.builder.end_while()
4✔
430

431
    def visit_For(self, node):
4✔
432
        if not isinstance(node.target, ast.Name):
4✔
433
            raise NotImplementedError("Only simple for loops supported")
×
434

435
        var = node.target.id
4✔
436
        debug_info = get_debug_info(node, self.filename, self.function_name)
4✔
437

438
        # Check if iterating over a range() call
439
        if (
4✔
440
            isinstance(node.iter, ast.Call)
441
            and isinstance(node.iter.func, ast.Name)
442
            and node.iter.func.id == "range"
443
        ):
444
            args = node.iter.args
4✔
445
            if len(args) == 1:
4✔
446
                start = "0"
4✔
447
                end = self._parse_expr(args[0])
4✔
448
                step = "1"
4✔
449
            elif len(args) == 2:
4✔
450
                start = self._parse_expr(args[0])
4✔
451
                end = self._parse_expr(args[1])
4✔
452
                step = "1"
4✔
453
            elif len(args) == 3:
4✔
454
                start = self._parse_expr(args[0])
4✔
455
                end = self._parse_expr(args[1])
4✔
456

457
                # Special handling for step to avoid creating tasklets for constants
458
                step_node = args[2]
4✔
459
                if isinstance(step_node, ast.Constant):
4✔
460
                    step = str(step_node.value)
4✔
461
                elif (
4✔
462
                    isinstance(step_node, ast.UnaryOp)
463
                    and isinstance(step_node.op, ast.USub)
464
                    and isinstance(step_node.operand, ast.Constant)
465
                ):
466
                    step = f"-{step_node.operand.value}"
4✔
467
                else:
NEW
468
                    step = self._parse_expr(step_node)
×
469
            else:
NEW
470
                raise ValueError("Invalid range arguments")
×
471

472
            if not self.builder.exists(var):
4✔
473
                self.builder.add_container(var, Scalar(PrimitiveType.Int64), False)
4✔
474
                self.symbol_table[var] = Scalar(PrimitiveType.Int64)
4✔
475

476
            self.builder.begin_for(var, start, end, step, debug_info)
4✔
477

478
            for stmt in node.body:
4✔
479
                self.visit(stmt)
4✔
480

481
            self.builder.end_for()
4✔
482
            return
4✔
483

484
        # Check if iterating over an ndarray (for x in array)
NEW
485
        if isinstance(node.iter, ast.Name):
×
NEW
486
            iter_name = node.iter.id
×
NEW
487
            if iter_name in self.array_info:
×
NEW
488
                arr_info = self.array_info[iter_name]
×
NEW
489
                if arr_info["ndim"] < 1:
×
NEW
490
                    raise NotImplementedError("Cannot iterate over 0-dimensional array")
×
491

492
                # Get the size of the first dimension
NEW
493
                arr_size = arr_info["shapes"][0]
×
494

495
                # Create a hidden index variable for the loop
NEW
496
                idx_var = f"_iter_idx_{self._get_unique_id()}"
×
NEW
497
                if not self.builder.exists(idx_var):
×
NEW
498
                    self.builder.add_container(
×
499
                        idx_var, Scalar(PrimitiveType.Int64), False
500
                    )
NEW
501
                    self.symbol_table[idx_var] = Scalar(PrimitiveType.Int64)
×
502

503
                # Determine the type of the loop variable (element type)
504
                # For a 1D array, it's a scalar; for ND array, it's a view of N-1 dimensions
NEW
505
                if arr_info["ndim"] == 1:
×
506
                    # Element is a scalar - get the element type from the array's type
NEW
507
                    arr_type = self.symbol_table.get(iter_name)
×
NEW
508
                    if isinstance(arr_type, Pointer):
×
NEW
509
                        elem_type = arr_type.pointee_type
×
510
                    else:
NEW
511
                        elem_type = Scalar(PrimitiveType.Double)  # Default fallback
×
512

NEW
513
                    if not self.builder.exists(var):
×
NEW
514
                        self.builder.add_container(var, elem_type, False)
×
NEW
515
                        self.symbol_table[var] = elem_type
×
516
                else:
517
                    # For multi-dimensional arrays, create a view/slice
518
                    # The loop variable becomes a pointer to the sub-array
NEW
519
                    inner_shapes = arr_info["shapes"][1:]
×
NEW
520
                    inner_ndim = arr_info["ndim"] - 1
×
521

NEW
522
                    arr_type = self.symbol_table.get(iter_name)
×
NEW
523
                    if isinstance(arr_type, Pointer):
×
NEW
524
                        elem_type = arr_type  # Keep as pointer type for views
×
525
                    else:
NEW
526
                        elem_type = Pointer(Scalar(PrimitiveType.Double))
×
527

NEW
528
                    if not self.builder.exists(var):
×
NEW
529
                        self.builder.add_container(var, elem_type, False)
×
NEW
530
                        self.symbol_table[var] = elem_type
×
531

532
                    # Register the view in array_info
NEW
533
                    self.array_info[var] = {"ndim": inner_ndim, "shapes": inner_shapes}
×
534

535
                # Begin the for loop
NEW
536
                self.builder.begin_for(idx_var, "0", str(arr_size), "1", debug_info)
×
537

538
                # Generate the assignment: var = array[idx_var]
539
                # Create an AST node for the assignment and visit it
NEW
540
                assign_node = ast.Assign(
×
541
                    targets=[ast.Name(id=var, ctx=ast.Store())],
542
                    value=ast.Subscript(
543
                        value=ast.Name(id=iter_name, ctx=ast.Load()),
544
                        slice=ast.Name(id=idx_var, ctx=ast.Load()),
545
                        ctx=ast.Load(),
546
                    ),
547
                )
NEW
548
                ast.copy_location(assign_node, node)
×
NEW
549
                self.visit_Assign(assign_node)
×
550

551
                # Visit the loop body
NEW
552
                for stmt in node.body:
×
NEW
553
                    self.visit(stmt)
×
554

NEW
555
                self.builder.end_for()
×
NEW
556
                return
×
557

NEW
558
        raise NotImplementedError(
×
559
            f"Only range() loops and iteration over ndarrays supported, got: {ast.dump(node.iter)}"
560
        )
561

562
    def _get_max_array_ndim_in_expr(self, node):
4✔
563
        """Get the maximum array dimensionality in an expression."""
564
        max_ndim = 0
4✔
565

566
        class NdimVisitor(ast.NodeVisitor):
4✔
567
            def __init__(self, array_info):
4✔
568
                self.array_info = array_info
4✔
569
                self.max_ndim = 0
4✔
570

571
            def visit_Name(self, node):
4✔
572
                if node.id in self.array_info:
4✔
573
                    ndim = self.array_info[node.id].get("ndim", 0)
4✔
574
                    self.max_ndim = max(self.max_ndim, ndim)
4✔
575
                return self.generic_visit(node)
4✔
576

577
        visitor = NdimVisitor(self.array_info)
4✔
578
        visitor.visit(node)
4✔
579
        return visitor.max_ndim
4✔
580

581
    def _handle_broadcast_slice_assignment(
4✔
582
        self, target, value, target_name, indices, target_ndim, value_ndim, debug_info
583
    ):
584
        """Handle slice assignment with broadcasting (e.g., 2D -= 1D)."""
585
        # Number of broadcast dimensions (outer loops)
586
        broadcast_dims = target_ndim - value_ndim
×
587

588
        shapes = self.array_info[target_name].get("shapes", [])
×
589

590
        # Create outer loops for broadcast dimensions
591
        outer_loop_vars = []
×
592
        for i in range(broadcast_dims):
×
593
            loop_var = f"_bcast_iter_{i}_{self._get_unique_id()}"
×
594
            outer_loop_vars.append(loop_var)
×
595

596
            if not self.builder.exists(loop_var):
×
597
                self.builder.add_container(loop_var, Scalar(PrimitiveType.Int64), False)
×
598
                self.symbol_table[loop_var] = Scalar(PrimitiveType.Int64)
×
599

600
            dim_size = shapes[i] if i < len(shapes) else f"_{target_name}_shape_{i}"
×
601
            self.builder.begin_for(loop_var, "0", dim_size, "1", debug_info)
×
602

603
        # Create a row view (reference) for the inner dimensions
604
        row_view_name = f"_row_view_{self._get_unique_id()}"
×
605

606
        # Get inner shape for the row view
607
        inner_shapes = shapes[broadcast_dims:] if len(shapes) > broadcast_dims else []
×
608

609
        # Determine element type from the target
610
        target_type = self.symbol_table.get(target_name)
×
611
        if isinstance(target_type, Pointer) and target_type.has_pointee_type():
×
612
            element_type = target_type.pointee_type
×
613
        else:
614
            element_type = Scalar(PrimitiveType.Double)
×
615

616
        # Create pointer type for row view
617
        row_type = Pointer(element_type)
×
618
        self.builder.add_container(row_view_name, row_type, False)
×
619
        self.symbol_table[row_view_name] = row_type
×
620

621
        # Register row view in array_info
622
        self.array_info[row_view_name] = {"ndim": value_ndim, "shapes": inner_shapes}
×
623

624
        # Create reference memlet: row_view = &target[i, 0, 0, ...]
625
        # The index is: outer_loop_vars joined, then zeros for inner dims
626
        ref_index_parts = outer_loop_vars[:]
×
627
        for _ in range(value_ndim):
×
628
            ref_index_parts.append("0")
×
629

630
        # Compute linearized index for reference
631
        # For target[i, j] with shape (n, m), linear index for row i is i * m
632
        linear_idx = outer_loop_vars[0] if outer_loop_vars else "0"
×
633
        for dim_idx in range(1, broadcast_dims):
×
634
            dim_size = (
×
635
                shapes[dim_idx]
636
                if dim_idx < len(shapes)
637
                else f"_{target_name}_shape_{dim_idx}"
638
            )
639
            linear_idx = f"({linear_idx}) * ({dim_size}) + {outer_loop_vars[dim_idx]}"
×
640

641
        # Multiply by inner dimension sizes to get the start of the row
642
        for dim_idx in range(broadcast_dims, target_ndim):
×
643
            dim_size = (
×
644
                shapes[dim_idx]
645
                if dim_idx < len(shapes)
646
                else f"_{target_name}_shape_{dim_idx}"
647
            )
648
            linear_idx = f"({linear_idx}) * ({dim_size})"
×
649

650
        # Create the reference memlet block
651
        block = self.builder.add_block(debug_info)
×
652
        t_src = self.builder.add_access(block, target_name, debug_info)
×
653
        t_dst = self.builder.add_access(block, row_view_name, debug_info)
×
654
        self.builder.add_reference_memlet(
×
655
            block, t_src, t_dst, linear_idx, row_type, debug_info
656
        )
657

658
        # Now handle the inner slice assignment with the row view
659
        # Create inner indices (all slices for the inner dimensions)
660
        inner_indices = [
×
661
            ast.Slice(lower=None, upper=None, step=None) for _ in range(value_ndim)
662
        ]
663

664
        # Create new target using row view
665
        new_target = ast.Subscript(
×
666
            value=ast.Name(id=row_view_name, ctx=ast.Load()),
667
            slice=(
668
                ast.Tuple(elts=inner_indices, ctx=ast.Load())
669
                if len(inner_indices) > 1
670
                else inner_indices[0]
671
            ),
672
            ctx=ast.Store(),
673
        )
674

675
        # Recursively handle the inner assignment (now same-dimension)
676
        self._handle_slice_assignment(
×
677
            new_target, value, row_view_name, inner_indices, debug_info
678
        )
679

680
        # Close outer loops
681
        for _ in outer_loop_vars:
×
682
            self.builder.end_for()
×
683

684
    def _handle_slice_assignment(
4✔
685
        self, target, value, target_name, indices, debug_info=None
686
    ):
687
        if debug_info is None:
4✔
688
            debug_info = DebugInfo()
×
689

690
        if target_name in self.array_info:
4✔
691
            ndim = self.array_info[target_name]["ndim"]
4✔
692
            if len(indices) < ndim:
4✔
693
                indices = list(indices)
4✔
694
                for _ in range(ndim - len(indices)):
4✔
695
                    indices.append(ast.Slice(lower=None, upper=None, step=None))
4✔
696

697
        # Check if the RHS contains a ufunc outer operation
698
        # If so, we handle it specially to avoid the loop transformation
699
        # which would destroy the slice shape information
700
        has_outer, ufunc_name, outer_node = contains_ufunc_outer(value)
4✔
701
        if has_outer:
4✔
702
            self._handle_ufunc_outer_slice_assignment(
4✔
703
                target, value, target_name, indices, debug_info
704
            )
705
            return
4✔
706

707
        # Count slice dimensions to determine effective target dimensionality
708
        # (slice indices produce array dimensions, point indices collapse them)
709
        target_slice_ndim = sum(1 for idx in indices if isinstance(idx, ast.Slice))
4✔
710
        value_max_ndim = self._get_max_array_ndim_in_expr(value)
4✔
711

712
        if (
4✔
713
            target_slice_ndim > 0
714
            and value_max_ndim > 0
715
            and target_slice_ndim > value_max_ndim
716
        ):
717
            # Broadcasting case: use row-by-row approach with reference memlets
718
            self._handle_broadcast_slice_assignment(
×
719
                target,
720
                value,
721
                target_name,
722
                indices,
723
                target_slice_ndim,
724
                value_max_ndim,
725
                debug_info,
726
            )
727
            return
×
728

729
        loop_vars = []
4✔
730
        new_target_indices = []
4✔
731

732
        for i, idx in enumerate(indices):
4✔
733
            if isinstance(idx, ast.Slice):
4✔
734
                loop_var = f"_slice_iter_{len(loop_vars)}_{self._get_unique_id()}"
4✔
735
                loop_vars.append(loop_var)
4✔
736

737
                if not self.builder.exists(loop_var):
4✔
738
                    self.builder.add_container(
4✔
739
                        loop_var, Scalar(PrimitiveType.Int64), False
740
                    )
741
                    self.symbol_table[loop_var] = Scalar(PrimitiveType.Int64)
4✔
742

743
                start_str = "0"
4✔
744
                if idx.lower:
4✔
745
                    start_str = self._parse_expr(idx.lower)
4✔
746
                    if start_str.startswith("-"):
4✔
747
                        shapes = self.array_info[target_name].get("shapes", [])
×
748
                        dim_size = (
×
749
                            str(shapes[i])
750
                            if i < len(shapes)
751
                            else f"_{target_name}_shape_{i}"
752
                        )
753
                        start_str = f"({dim_size} {start_str})"
×
754

755
                stop_str = ""
4✔
756
                if idx.upper and not (
4✔
757
                    isinstance(idx.upper, ast.Constant) and idx.upper.value is None
758
                ):
759
                    stop_str = self._parse_expr(idx.upper)
4✔
760
                    if stop_str.startswith("-") or stop_str.startswith("(-"):
4✔
761
                        shapes = self.array_info[target_name].get("shapes", [])
4✔
762
                        dim_size = (
4✔
763
                            str(shapes[i])
764
                            if i < len(shapes)
765
                            else f"_{target_name}_shape_{i}"
766
                        )
767
                        stop_str = f"({dim_size} {stop_str})"
4✔
768
                else:
769
                    shapes = self.array_info[target_name].get("shapes", [])
4✔
770
                    stop_str = (
4✔
771
                        str(shapes[i])
772
                        if i < len(shapes)
773
                        else f"_{target_name}_shape_{i}"
774
                    )
775

776
                step_str = "1"
4✔
777
                if idx.step:
4✔
778
                    step_str = self._parse_expr(idx.step)
×
779

780
                count_str = f"({stop_str} - {start_str})"
4✔
781

782
                self.builder.begin_for(loop_var, "0", count_str, "1", debug_info)
4✔
783
                self.symbol_table[loop_var] = Scalar(PrimitiveType.Int64)
4✔
784

785
                new_target_indices.append(
4✔
786
                    ast.Name(
787
                        id=f"{start_str} + {loop_var} * {step_str}", ctx=ast.Load()
788
                    )
789
                )
790
            else:
791
                # Handle non-slice indices - need to normalize negative indices
792
                shapes = self.array_info[target_name].get("shapes", [])
4✔
793
                dim_size = shapes[i] if i < len(shapes) else f"_{target_name}_shape_{i}"
4✔
794
                normalized_idx = normalize_negative_index(idx, dim_size)
4✔
795
                new_target_indices.append(normalized_idx)
4✔
796

797
        rewriter = SliceRewriter(loop_vars, self.array_info, self.expr_visitor)
4✔
798
        new_value = rewriter.visit(copy.deepcopy(value))
4✔
799

800
        new_target = copy.deepcopy(target)
4✔
801
        if len(new_target_indices) == 1:
4✔
802
            new_target.slice = new_target_indices[0]
4✔
803
        else:
804
            new_target.slice = ast.Tuple(elts=new_target_indices, ctx=ast.Load())
4✔
805

806
        target_str = self._parse_expr(new_target)
4✔
807
        value_str = self._parse_expr(new_value)
4✔
808
        self.builder.add_assignment(target_str, value_str, debug_info)
4✔
809

810
        for _ in loop_vars:
4✔
811
            self.builder.end_for()
4✔
812

813
    def _handle_ufunc_outer_slice_assignment(
4✔
814
        self, target, value, target_name, indices, debug_info=None
815
    ):
816
        """Handle slice assignment where RHS contains a ufunc outer operation.
817

818
        Example: path[:] = np.minimum(path[:], np.add.outer(path[:, k], path[k, :]))
819

820
        The strategy is:
821
        1. Evaluate the entire RHS expression, which will create a temporary array
822
           containing the result of the ufunc outer (potentially wrapped in other ops)
823
        2. Copy the temporary result to the target slice
824

825
        This avoids the loop transformation that would destroy slice shape info.
826
        """
827
        if debug_info is None:
4✔
828
            from docc.sdfg import DebugInfo
×
829

830
            debug_info = DebugInfo()
×
831

832
        # Evaluate the full RHS expression
833
        # This will:
834
        # - Create temp arrays for ufunc outer results
835
        # - Apply any wrapping operations (np.minimum, etc.)
836
        # - Return the name of the final result array
837
        result_name = self._parse_expr(value)
4✔
838

839
        # Now we need to copy result to target slice
840
        # Count slice dimensions to determine if we need loops
841
        target_slice_ndim = sum(1 for idx in indices if isinstance(idx, ast.Slice))
4✔
842

843
        if target_slice_ndim == 0:
4✔
844
            # No slices on target - just simple assignment
845
            target_str = self._parse_expr(target)
×
846
            block = self.builder.add_block(debug_info)
×
847
            t_src, src_sub = self.expr_visitor._add_read(block, result_name, debug_info)
×
848
            t_dst = self.builder.add_access(block, target_str, debug_info)
×
849
            t_task = self.builder.add_tasklet(
×
850
                block, TaskletCode.assign, ["_in"], ["_out"], debug_info
851
            )
852
            self.builder.add_memlet(
×
853
                block, t_src, "void", t_task, "_in", src_sub, None, debug_info
854
            )
855
            self.builder.add_memlet(
×
856
                block, t_task, "_out", t_dst, "void", "", None, debug_info
857
            )
858
            return
×
859

860
        # We have slices on the target - need to create loops for copying
861
        # Get target array info
862
        target_info = self.array_info.get(target_name, {})
4✔
863
        target_shapes = target_info.get("shapes", [])
4✔
864

865
        loop_vars = []
4✔
866
        new_target_indices = []
4✔
867

868
        for i, idx in enumerate(indices):
4✔
869
            if isinstance(idx, ast.Slice):
4✔
870
                loop_var = f"_copy_iter_{len(loop_vars)}_{self._get_unique_id()}"
4✔
871
                loop_vars.append(loop_var)
4✔
872

873
                if not self.builder.exists(loop_var):
4✔
874
                    self.builder.add_container(
4✔
875
                        loop_var, Scalar(PrimitiveType.Int64), False
876
                    )
877
                    self.symbol_table[loop_var] = Scalar(PrimitiveType.Int64)
4✔
878

879
                start_str = "0"
4✔
880
                if idx.lower:
4✔
881
                    start_str = self._parse_expr(idx.lower)
×
882

883
                stop_str = ""
4✔
884
                if idx.upper and not (
4✔
885
                    isinstance(idx.upper, ast.Constant) and idx.upper.value is None
886
                ):
887
                    stop_str = self._parse_expr(idx.upper)
×
888
                else:
889
                    stop_str = (
4✔
890
                        target_shapes[i]
891
                        if i < len(target_shapes)
892
                        else f"_{target_name}_shape_{i}"
893
                    )
894

895
                step_str = "1"
4✔
896
                if idx.step:
4✔
897
                    step_str = self._parse_expr(idx.step)
×
898

899
                count_str = f"({stop_str} - {start_str})"
4✔
900

901
                self.builder.begin_for(loop_var, "0", count_str, "1", debug_info)
4✔
902
                self.symbol_table[loop_var] = Scalar(PrimitiveType.Int64)
4✔
903

904
                new_target_indices.append(
4✔
905
                    ast.Name(
906
                        id=f"{start_str} + {loop_var} * {step_str}", ctx=ast.Load()
907
                    )
908
                )
909
            else:
910
                # Handle non-slice indices - need to normalize negative indices
911
                dim_size = (
×
912
                    target_shapes[i]
913
                    if i < len(target_shapes)
914
                    else f"_{target_name}_shape_{i}"
915
                )
916
                normalized_idx = normalize_negative_index(idx, dim_size)
×
917
                new_target_indices.append(normalized_idx)
×
918

919
        # Create assignment block: target[i,j,...] = result[i,j,...]
920
        block = self.builder.add_block(debug_info)
4✔
921

922
        # Access nodes
923
        t_src = self.builder.add_access(block, result_name, debug_info)
4✔
924
        t_dst = self.builder.add_access(block, target_name, debug_info)
4✔
925
        t_task = self.builder.add_tasklet(
4✔
926
            block, TaskletCode.assign, ["_in"], ["_out"], debug_info
927
        )
928

929
        # Source index - just use loop vars for flat array from ufunc outer
930
        # The ufunc outer result is a flat array of size M*N
931
        if len(loop_vars) == 2:
4✔
932
            # 2D case: result is indexed as i * N + j
933
            # Get the second dimension size from target shapes
934
            n_dim = (
4✔
935
                target_shapes[1]
936
                if len(target_shapes) > 1
937
                else f"_{target_name}_shape_1"
938
            )
939
            src_index = f"(({loop_vars[0]}) * ({n_dim}) + ({loop_vars[1]}))"
4✔
940
        elif len(loop_vars) == 1:
×
941
            src_index = loop_vars[0]
×
942
        else:
943
            # General case - compute linear index
944
            src_terms = []
×
945
            stride = "1"
×
946
            for i in range(len(loop_vars) - 1, -1, -1):
×
947
                if stride == "1":
×
948
                    src_terms.insert(0, loop_vars[i])
×
949
                else:
950
                    src_terms.insert(0, f"({loop_vars[i]} * {stride})")
×
951
                if i > 0:
×
952
                    dim_size = (
×
953
                        target_shapes[i]
954
                        if i < len(target_shapes)
955
                        else f"_{target_name}_shape_{i}"
956
                    )
957
                    stride = (
×
958
                        f"({stride} * {dim_size})" if stride != "1" else str(dim_size)
959
                    )
960
            src_index = " + ".join(src_terms) if src_terms else "0"
×
961

962
        # Target index - compute linear index (row-major order)
963
        # For 2D array with shape (M, N): linear_index = i * N + j
964
        target_index_parts = []
4✔
965
        for idx in new_target_indices:
4✔
966
            if isinstance(idx, ast.Name):
4✔
967
                target_index_parts.append(idx.id)
4✔
968
            else:
969
                target_index_parts.append(self._parse_expr(idx))
×
970

971
        # Convert to linear index
972
        if len(target_index_parts) == 2:
4✔
973
            # 2D case
974
            n_dim = (
4✔
975
                target_shapes[1]
976
                if len(target_shapes) > 1
977
                else f"_{target_name}_shape_1"
978
            )
979
            target_index = (
4✔
980
                f"(({target_index_parts[0]}) * ({n_dim}) + ({target_index_parts[1]}))"
981
            )
982
        elif len(target_index_parts) == 1:
×
983
            target_index = target_index_parts[0]
×
984
        else:
985
            # General case - compute linear index with strides
986
            stride = "1"
×
987
            target_index = "0"
×
988
            for i in range(len(target_index_parts) - 1, -1, -1):
×
989
                idx_part = target_index_parts[i]
×
990
                if stride == "1":
×
991
                    term = idx_part
×
992
                else:
993
                    term = f"(({idx_part}) * ({stride}))"
×
994

995
                if target_index == "0":
×
996
                    target_index = term
×
997
                else:
998
                    target_index = f"({term} + {target_index})"
×
999

1000
                if i > 0:
×
1001
                    dim_size = (
×
1002
                        target_shapes[i]
1003
                        if i < len(target_shapes)
1004
                        else f"_{target_name}_shape_{i}"
1005
                    )
1006
                    stride = (
×
1007
                        f"({stride} * {dim_size})" if stride != "1" else str(dim_size)
1008
                    )
1009

1010
        # Connect memlets
1011
        self.builder.add_memlet(
4✔
1012
            block, t_src, "void", t_task, "_in", src_index, None, debug_info
1013
        )
1014
        self.builder.add_memlet(
4✔
1015
            block, t_task, "_out", t_dst, "void", target_index, None, debug_info
1016
        )
1017

1018
        # End loops
1019
        for _ in loop_vars:
4✔
1020
            self.builder.end_for()
4✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc