• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

daisytuner / sdfglib / 21478974613

29 Jan 2026 12:55PM UTC coverage: 65.778% (-0.07%) from 65.843%
21478974613

push

github

web-flow
Merge pull request #485 from daisytuner/npbench-cavity-flow

Adds support for npbench's cavity_flow

59 of 130 new or added lines in 6 files covered. (45.38%)

1 existing line in 1 file now uncovered.

22446 of 34124 relevant lines covered (65.78%)

382.63 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

75.05
/python/docc/ast_parser.py
1
import ast
4✔
2
import copy
4✔
3
from ._sdfg import Scalar, PrimitiveType, Pointer, TaskletCode
4✔
4
from .ast_utils import (
4✔
5
    SliceRewriter,
6
    get_debug_info,
7
    contains_ufunc_outer,
8
    normalize_negative_index,
9
)
10
from .expression_visitor import ExpressionVisitor
4✔
11
from .linear_algebra import LinearAlgebraHandler
4✔
12
from .convolution import ConvolutionHandler
4✔
13
from .onnx_ops import ONNXHandler
4✔
14

15

16
class ASTParser(ast.NodeVisitor):
4✔
17
    def __init__(
4✔
18
        self,
19
        builder,
20
        array_info=None,
21
        symbol_table=None,
22
        filename="",
23
        function_name="",
24
        infer_return_type=False,
25
        globals_dict=None,
26
        unique_counter_ref=None,
27
        structure_member_info=None,
28
    ):
29
        self.builder = builder
4✔
30
        self.array_info = array_info if array_info is not None else {}
4✔
31
        self.symbol_table = symbol_table if symbol_table is not None else {}
4✔
32
        self.filename = filename
4✔
33
        self.function_name = function_name
4✔
34
        self.infer_return_type = infer_return_type
4✔
35
        self.globals_dict = globals_dict
4✔
36
        self._unique_counter_ref = (
4✔
37
            unique_counter_ref if unique_counter_ref is not None else [0]
38
        )
39
        self.expr_visitor = ExpressionVisitor(
4✔
40
            self.array_info,
41
            self.builder,
42
            self.symbol_table,
43
            self.globals_dict,
44
            unique_counter_ref=self._unique_counter_ref,
45
            structure_member_info=structure_member_info,
46
        )
47
        self.la_handler = LinearAlgebraHandler(
4✔
48
            self.builder, self.array_info, self.symbol_table, self.expr_visitor
49
        )
50
        self.conv_handler = ConvolutionHandler(
4✔
51
            self.builder, self.array_info, self.symbol_table, self.expr_visitor
52
        )
53
        self.onnx_handler = ONNXHandler(
4✔
54
            self.builder, self.array_info, self.symbol_table, self.expr_visitor
55
        )
56
        self.expr_visitor.la_handler = self.la_handler
4✔
57
        self.captured_return_shapes = {}  # Map param name to shape string list
4✔
58

59
    def _get_unique_id(self):
4✔
60
        self._unique_counter_ref[0] += 1
4✔
61
        return self._unique_counter_ref[0]
4✔
62

63
    def _parse_expr(self, node):
4✔
64
        return self.expr_visitor.visit(node)
4✔
65

66
    def visit_Return(self, node):
4✔
67
        if node.value is None:
4✔
68
            debug_info = get_debug_info(node, self.filename, self.function_name)
×
69
            self.builder.add_return("", debug_info)
×
70
            return
×
71

72
        if isinstance(node.value, ast.Tuple):
4✔
73
            values = node.value.elts
4✔
74
        else:
75
            values = [node.value]
4✔
76

77
        parsed_values = [self._parse_expr(v) for v in values]
4✔
78
        debug_info = get_debug_info(node, self.filename, self.function_name)
4✔
79

80
        if self.infer_return_type:
4✔
81
            for i, res in enumerate(parsed_values):
4✔
82
                ret_name = f"_docc_ret_{i}"
4✔
83
                if not self.builder.exists(ret_name):
4✔
84
                    dtype = Scalar(PrimitiveType.Double)
4✔
85
                    if res in self.symbol_table:
4✔
86
                        dtype = self.symbol_table[res]
4✔
87
                    elif isinstance(values[i], ast.Constant):
×
88
                        val = values[i].value
×
89
                        if isinstance(val, int):
×
90
                            dtype = Scalar(PrimitiveType.Int64)
×
91
                        elif isinstance(val, float):
×
92
                            dtype = Scalar(PrimitiveType.Double)
×
93
                        elif isinstance(val, bool):
×
94
                            dtype = Scalar(PrimitiveType.Bool)
×
95

96
                    # Wrap Scalar in Pointer. Keep Arrays/Pointers as is.
97
                    arg_type = dtype
4✔
98
                    if isinstance(dtype, Scalar):
4✔
99
                        arg_type = Pointer(dtype)
4✔
100

101
                    self.builder.add_container(ret_name, arg_type, is_argument=True)
4✔
102
                    self.symbol_table[ret_name] = arg_type
4✔
103

104
                    if res in self.array_info:
4✔
105
                        self.array_info[ret_name] = self.array_info[res]
4✔
106

107
            self.infer_return_type = False
4✔
108

109
        for i, res in enumerate(parsed_values):
4✔
110
            ret_name = f"_docc_ret_{i}"
4✔
111
            typ = self.symbol_table.get(ret_name)
4✔
112

113
            is_array_return = False
4✔
114
            if res in self.array_info:
4✔
115
                # Only treat as array return if it has dimensions
116
                # 0-d arrays (scalars) should be handled by scalar assignment
117
                if self.array_info[res]["ndim"] > 0:
4✔
118
                    is_array_return = True
4✔
119
            elif res in self.symbol_table:
4✔
120
                if isinstance(self.symbol_table[res], Pointer):
4✔
121
                    is_array_return = True
×
122

123
            # Simple Scalar Assignment
124
            if not is_array_return:
4✔
125
                block = self.builder.add_block(debug_info)
4✔
126
                t_dst = self.builder.add_access(block, ret_name, debug_info)
4✔
127

128
                t_src, src_sub = self.expr_visitor._add_read(block, res, debug_info)
4✔
129

130
                t_task = self.builder.add_tasklet(
4✔
131
                    block, TaskletCode.assign, ["_in"], ["_out"], debug_info
132
                )
133
                self.builder.add_memlet(
4✔
134
                    block, t_src, "void", t_task, "_in", src_sub, None, debug_info
135
                )
136
                self.builder.add_memlet(
4✔
137
                    block, t_task, "_out", t_dst, "void", "0", None, debug_info
138
                )
139

140
            # Array Assignment (Copy)
141
            else:
142
                # Record shape for metadata
143
                if res in self.array_info:
4✔
144
                    # Prefer runtime shapes if available (for indirect access patterns)
145
                    # Fall back to regular shapes otherwise
146
                    if "shapes_runtime" in self.array_info[res]:
4✔
147
                        shape = self.array_info[res]["shapes_runtime"]
4✔
148
                    else:
149
                        shape = self.array_info[res]["shapes"]
4✔
150
                    # Convert to string expressions
151
                    self.captured_return_shapes[ret_name] = [str(s) for s in shape]
4✔
152

153
                    # Ensure destination array info exists
154
                    if ret_name not in self.array_info:
4✔
155
                        self.array_info[ret_name] = self.array_info[res]
4✔
156

157
                # Copy Logic using visit_Assign
158
                ndim = 1
4✔
159
                if ret_name in self.array_info:
4✔
160
                    ndim = self.array_info[ret_name]["ndim"]
4✔
161

162
                slice_node = ast.Slice(lower=None, upper=None, step=None)
4✔
163
                if ndim > 1:
4✔
164
                    target_slice = ast.Tuple(elts=[slice_node] * ndim, ctx=ast.Load())
4✔
165
                else:
166
                    target_slice = slice_node
4✔
167

168
                target_sub = ast.Subscript(
4✔
169
                    value=ast.Name(id=ret_name, ctx=ast.Load()),
170
                    slice=target_slice,
171
                    ctx=ast.Store(),
172
                )
173

174
                # Value node reconstruction
175
                if isinstance(values[i], ast.Name):
4✔
176
                    val_node = values[i]
4✔
177
                else:
178
                    val_node = ast.Name(id=res, ctx=ast.Load())
4✔
179

180
                assign_node = ast.Assign(targets=[target_sub], value=val_node)
4✔
181
                self.visit_Assign(assign_node)
4✔
182

183
        # Add control flow return to exit the function/path
184
        self.builder.add_return("", debug_info)
4✔
185

186
    def visit_Expr(self, node):
4✔
187
        """Handle expression statements (e.g., bare function calls)."""
188
        # Expression statements are typically function calls without assignment
189
        # Like: build_up_b(...) or pressure_poisson(...)
NEW
190
        if isinstance(node.value, ast.Call):
×
191
            # This is a function call statement - handle it through expression visitor
192
            # which will inline the function
NEW
193
            self._parse_expr(node.value)
×
194
        else:
195
            # For other expression statements, just evaluate them
NEW
196
            self._parse_expr(node.value)
×
197

198
    def visit_AugAssign(self, node):
4✔
199
        if isinstance(node.target, ast.Name) and node.target.id in self.array_info:
4✔
200
            # Convert to slice assignment: target[:] = target op value
201
            ndim = self.array_info[node.target.id]["ndim"]
4✔
202

203
            slices = []
4✔
204
            for _ in range(ndim):
4✔
205
                slices.append(ast.Slice(lower=None, upper=None, step=None))
4✔
206

207
            if ndim == 1:
4✔
208
                slice_arg = slices[0]
×
209
            else:
210
                slice_arg = ast.Tuple(elts=slices, ctx=ast.Load())
4✔
211

212
            slice_node = ast.Subscript(
4✔
213
                value=node.target, slice=slice_arg, ctx=ast.Store()
214
            )
215

216
            new_node = ast.Assign(
4✔
217
                targets=[slice_node],
218
                value=ast.BinOp(left=node.target, op=node.op, right=node.value),
219
            )
220
            self.visit_Assign(new_node)
4✔
221
        else:
222
            new_node = ast.Assign(
4✔
223
                targets=[node.target],
224
                value=ast.BinOp(left=node.target, op=node.op, right=node.value),
225
            )
226
            self.visit_Assign(new_node)
4✔
227

228
    def visit_Assign(self, node):
4✔
229
        if len(node.targets) > 1:
4✔
230
            tmp_name = f"_assign_tmp_{self._get_unique_id()}"
4✔
231
            # Assign value to temporary
232
            val_assign = ast.Assign(
4✔
233
                targets=[ast.Name(id=tmp_name, ctx=ast.Store())], value=node.value
234
            )
235
            ast.copy_location(val_assign, node)
4✔
236
            self.visit_Assign(val_assign)
4✔
237

238
            # Assign temporary to targets
239
            for target in node.targets:
4✔
240
                assign = ast.Assign(
4✔
241
                    targets=[target], value=ast.Name(id=tmp_name, ctx=ast.Load())
242
                )
243
                ast.copy_location(assign, node)
4✔
244
                self.visit_Assign(assign)
4✔
245
            return
4✔
246

247
        target = node.targets[0]
4✔
248

249
        # Handle tuple unpacking: I, J, K = expr1, expr2, expr3
250
        if isinstance(target, ast.Tuple):
4✔
251
            if isinstance(node.value, ast.Tuple):
4✔
252
                # Unpacking tuple to tuple: a, b, c = x, y, z
253
                if len(target.elts) != len(node.value.elts):
4✔
254
                    raise ValueError("Tuple unpacking size mismatch")
×
255
                for tgt, val in zip(target.elts, node.value.elts):
4✔
256
                    assign = ast.Assign(targets=[tgt], value=val)
4✔
257
                    ast.copy_location(assign, node)
4✔
258
                    self.visit_Assign(assign)
4✔
259
            else:
260
                raise NotImplementedError(
×
261
                    "Tuple unpacking from non-tuple values not supported"
262
                )
263
            return
4✔
264

265
        # Special case: linear algebra functions
266
        if self.la_handler.is_gemm(node.value):
4✔
267
            if self.la_handler.handle_gemm(target, node.value):
4✔
268
                return
×
269
            if self.la_handler.handle_dot(target, node.value):
4✔
270
                return
×
271

272
        # Special case: outer product
273
        if self.la_handler.is_outer(node.value):
4✔
274
            if self.la_handler.handle_outer(target, node.value):
4✔
275
                return
4✔
276

277
        # Special case: convolution
278
        if self.conv_handler.is_conv(node.value):
4✔
279
            if self.conv_handler.handle_conv(target, node.value):
4✔
280
                return
4✔
281

282
        # Special case: ONNX ops (Transpose)
283
        if self.onnx_handler.is_transpose(node.value):
4✔
284
            if self.onnx_handler.handle_transpose(target, node.value):
4✔
285
                return
4✔
286

287
        # Special case:
288
        if isinstance(target, ast.Subscript):
4✔
289
            target_name = self.expr_visitor.visit(target.value)
4✔
290

291
            indices = []
4✔
292
            if isinstance(target.slice, ast.Tuple):
4✔
293
                indices = target.slice.elts
4✔
294
            else:
295
                indices = [target.slice]
4✔
296

297
            has_slice = False
4✔
298
            for idx in indices:
4✔
299
                if isinstance(idx, ast.Slice):
4✔
300
                    has_slice = True
4✔
301
                    break
4✔
302

303
            if has_slice:
4✔
304
                debug_info = get_debug_info(node, self.filename, self.function_name)
4✔
305
                self._handle_slice_assignment(
4✔
306
                    target, node.value, target_name, indices, debug_info
307
                )
308
                return
4✔
309

310
            target_name_full = self._parse_expr(target)
4✔
311
            value_str = self._parse_expr(node.value)
4✔
312
            debug_info = get_debug_info(node, self.filename, self.function_name)
4✔
313

314
            block = self.builder.add_block(debug_info)
4✔
315
            t_src, src_sub = self.expr_visitor._add_read(block, value_str, debug_info)
4✔
316

317
            if "(" in target_name_full and target_name_full.endswith(")"):
4✔
318
                name = target_name_full.split("(")[0]
4✔
319
                subset = target_name_full[target_name_full.find("(") + 1 : -1]
4✔
320
                t_dst = self.builder.add_access(block, name, debug_info)
4✔
321
                dst_sub = subset
4✔
322
            else:
323
                t_dst = self.builder.add_access(block, target_name_full, debug_info)
×
324
                dst_sub = ""
×
325

326
            t_task = self.builder.add_tasklet(
4✔
327
                block, TaskletCode.assign, ["_in"], ["_out"], debug_info
328
            )
329

330
            self.builder.add_memlet(
4✔
331
                block, t_src, "void", t_task, "_in", src_sub, None, debug_info
332
            )
333
            self.builder.add_memlet(
4✔
334
                block, t_task, "_out", t_dst, "void", dst_sub, None, debug_info
335
            )
336
            return
4✔
337

338
        # Variable assignments
339
        if not isinstance(target, ast.Name):
4✔
340
            raise NotImplementedError("Only assignment to variables supported")
×
341

342
        target_name = target.id
4✔
343
        value_str = self._parse_expr(node.value)
4✔
344
        debug_info = get_debug_info(node, self.filename, self.function_name)
4✔
345

346
        if not self.builder.exists(target_name):
4✔
347
            if isinstance(node.value, ast.Constant):
4✔
348
                val = node.value.value
4✔
349
                if isinstance(val, int):
4✔
350
                    dtype = Scalar(PrimitiveType.Int64)
4✔
351
                elif isinstance(val, float):
4✔
352
                    dtype = Scalar(PrimitiveType.Double)
4✔
353
                elif isinstance(val, bool):
×
354
                    dtype = Scalar(PrimitiveType.Bool)
×
355
                else:
356
                    raise NotImplementedError(f"Cannot infer type for {val}")
×
357

358
                self.builder.add_container(target_name, dtype, False)
4✔
359
                self.symbol_table[target_name] = dtype
4✔
360
            else:
361
                assert value_str in self.symbol_table
4✔
362
                self.builder.add_container(
4✔
363
                    target_name, self.symbol_table[value_str], False
364
                )
365
                self.symbol_table[target_name] = self.symbol_table[value_str]
4✔
366

367
        if value_str in self.array_info:
4✔
368
            self.array_info[target_name] = self.array_info[value_str]
4✔
369

370
        # Distinguish assignments: scalar -> tasklet, pointer -> reference_memlet
371
        src_type = self.symbol_table.get(value_str)
4✔
372
        dst_type = self.symbol_table[target_name]
4✔
373
        if src_type and isinstance(src_type, Pointer) and isinstance(dst_type, Pointer):
4✔
374
            block = self.builder.add_block(debug_info)
4✔
375
            t_src = self.builder.add_access(block, value_str, debug_info)
4✔
376
            t_dst = self.builder.add_access(block, target_name, debug_info)
4✔
377
            self.builder.add_reference_memlet(
4✔
378
                block, t_src, t_dst, "0", src_type, debug_info
379
            )
380
            return
4✔
381
        elif (src_type and isinstance(src_type, Scalar)) or isinstance(
4✔
382
            dst_type, Scalar
383
        ):
384
            block = self.builder.add_block(debug_info)
4✔
385
            t_dst = self.builder.add_access(block, target_name, debug_info)
4✔
386
            t_task = self.builder.add_tasklet(
4✔
387
                block, TaskletCode.assign, ["_in"], ["_out"], debug_info
388
            )
389

390
            if src_type:
4✔
391
                t_src = self.builder.add_access(block, value_str, debug_info)
4✔
392
            else:
393
                t_src = self.builder.add_constant(
4✔
394
                    block, value_str, dst_type, debug_info
395
                )
396

397
            self.builder.add_memlet(
4✔
398
                block, t_src, "void", t_task, "_in", "", None, debug_info
399
            )
400
            self.builder.add_memlet(
4✔
401
                block, t_task, "_out", t_dst, "void", "", None, debug_info
402
            )
403

404
            return
4✔
405

406
    def visit_If(self, node):
4✔
407
        cond = self._parse_expr(node.test)
4✔
408
        debug_info = get_debug_info(node, self.filename, self.function_name)
4✔
409
        self.builder.begin_if(f"{cond} != false", debug_info)
4✔
410

411
        for stmt in node.body:
4✔
412
            self.visit(stmt)
4✔
413

414
        if node.orelse:
4✔
415
            self.builder.begin_else(debug_info)
4✔
416
            for stmt in node.orelse:
4✔
417
                self.visit(stmt)
4✔
418

419
        self.builder.end_if()
4✔
420

421
    def visit_While(self, node):
4✔
422
        cond = self._parse_expr(node.test)
4✔
423
        debug_info = get_debug_info(node, self.filename, self.function_name)
4✔
424
        self.builder.begin_while(f"{cond} != false", debug_info)
4✔
425

426
        for stmt in node.body:
4✔
427
            self.visit(stmt)
4✔
428

429
        self.builder.end_while()
4✔
430

431
    def visit_For(self, node):
4✔
432
        if not isinstance(node.target, ast.Name):
4✔
433
            raise NotImplementedError("Only simple for loops supported")
×
434

435
        var = node.target.id
4✔
436

437
        if not isinstance(node.iter, ast.Call) or node.iter.func.id != "range":
4✔
438
            raise NotImplementedError("Only range() loops supported")
×
439

440
        args = node.iter.args
4✔
441
        if len(args) == 1:
4✔
442
            start = "0"
4✔
443
            end = self._parse_expr(args[0])
4✔
444
            step = "1"
4✔
445
        elif len(args) == 2:
4✔
446
            start = self._parse_expr(args[0])
4✔
447
            end = self._parse_expr(args[1])
4✔
448
            step = "1"
4✔
449
        elif len(args) == 3:
4✔
450
            start = self._parse_expr(args[0])
4✔
451
            end = self._parse_expr(args[1])
4✔
452

453
            # Special handling for step to avoid creating tasklets for constants
454
            step_node = args[2]
4✔
455
            if isinstance(step_node, ast.Constant):
4✔
456
                step = str(step_node.value)
4✔
457
            elif (
4✔
458
                isinstance(step_node, ast.UnaryOp)
459
                and isinstance(step_node.op, ast.USub)
460
                and isinstance(step_node.operand, ast.Constant)
461
            ):
462
                step = f"-{step_node.operand.value}"
4✔
463
            else:
464
                step = self._parse_expr(step_node)
×
465
        else:
466
            raise ValueError("Invalid range arguments")
×
467

468
        if not self.builder.exists(var):
4✔
469
            self.builder.add_container(var, Scalar(PrimitiveType.Int64), False)
4✔
470
            self.symbol_table[var] = Scalar(PrimitiveType.Int64)
4✔
471

472
        debug_info = get_debug_info(node, self.filename, self.function_name)
4✔
473
        self.builder.begin_for(var, start, end, step, debug_info)
4✔
474

475
        for stmt in node.body:
4✔
476
            self.visit(stmt)
4✔
477

478
        self.builder.end_for()
4✔
479

480
    def _get_max_array_ndim_in_expr(self, node):
4✔
481
        """Get the maximum array dimensionality in an expression."""
482
        max_ndim = 0
4✔
483

484
        class NdimVisitor(ast.NodeVisitor):
4✔
485
            def __init__(self, array_info):
4✔
486
                self.array_info = array_info
4✔
487
                self.max_ndim = 0
4✔
488

489
            def visit_Name(self, node):
4✔
490
                if node.id in self.array_info:
4✔
491
                    ndim = self.array_info[node.id].get("ndim", 0)
4✔
492
                    self.max_ndim = max(self.max_ndim, ndim)
4✔
493
                return self.generic_visit(node)
4✔
494

495
        visitor = NdimVisitor(self.array_info)
4✔
496
        visitor.visit(node)
4✔
497
        return visitor.max_ndim
4✔
498

499
    def _handle_broadcast_slice_assignment(
4✔
500
        self, target, value, target_name, indices, target_ndim, value_ndim, debug_info
501
    ):
502
        """Handle slice assignment with broadcasting (e.g., 2D -= 1D)."""
503
        # Number of broadcast dimensions (outer loops)
504
        broadcast_dims = target_ndim - value_ndim
×
505

506
        shapes = self.array_info[target_name].get("shapes", [])
×
507

508
        # Create outer loops for broadcast dimensions
509
        outer_loop_vars = []
×
510
        for i in range(broadcast_dims):
×
511
            loop_var = f"_bcast_iter_{i}_{self._get_unique_id()}"
×
512
            outer_loop_vars.append(loop_var)
×
513

514
            if not self.builder.exists(loop_var):
×
515
                self.builder.add_container(loop_var, Scalar(PrimitiveType.Int64), False)
×
516
                self.symbol_table[loop_var] = Scalar(PrimitiveType.Int64)
×
517

518
            dim_size = shapes[i] if i < len(shapes) else f"_{target_name}_shape_{i}"
×
519
            self.builder.begin_for(loop_var, "0", dim_size, "1", debug_info)
×
520

521
        # Create a row view (reference) for the inner dimensions
522
        row_view_name = f"_row_view_{self._get_unique_id()}"
×
523

524
        # Get inner shape for the row view
525
        inner_shapes = shapes[broadcast_dims:] if len(shapes) > broadcast_dims else []
×
526

527
        # Determine element type from the target
528
        target_type = self.symbol_table.get(target_name)
×
529
        if isinstance(target_type, Pointer) and target_type.has_pointee_type():
×
530
            element_type = target_type.pointee_type
×
531
        else:
532
            element_type = Scalar(PrimitiveType.Double)
×
533

534
        # Create pointer type for row view
535
        row_type = Pointer(element_type)
×
536
        self.builder.add_container(row_view_name, row_type, False)
×
537
        self.symbol_table[row_view_name] = row_type
×
538

539
        # Register row view in array_info
540
        self.array_info[row_view_name] = {"ndim": value_ndim, "shapes": inner_shapes}
×
541

542
        # Create reference memlet: row_view = &target[i, 0, 0, ...]
543
        # The index is: outer_loop_vars joined, then zeros for inner dims
544
        ref_index_parts = outer_loop_vars[:]
×
545
        for _ in range(value_ndim):
×
546
            ref_index_parts.append("0")
×
547

548
        # Compute linearized index for reference
549
        # For target[i, j] with shape (n, m), linear index for row i is i * m
550
        linear_idx = outer_loop_vars[0] if outer_loop_vars else "0"
×
551
        for dim_idx in range(1, broadcast_dims):
×
552
            dim_size = (
×
553
                shapes[dim_idx]
554
                if dim_idx < len(shapes)
555
                else f"_{target_name}_shape_{dim_idx}"
556
            )
557
            linear_idx = f"({linear_idx}) * ({dim_size}) + {outer_loop_vars[dim_idx]}"
×
558

559
        # Multiply by inner dimension sizes to get the start of the row
560
        for dim_idx in range(broadcast_dims, target_ndim):
×
561
            dim_size = (
×
562
                shapes[dim_idx]
563
                if dim_idx < len(shapes)
564
                else f"_{target_name}_shape_{dim_idx}"
565
            )
566
            linear_idx = f"({linear_idx}) * ({dim_size})"
×
567

568
        # Create the reference memlet block
569
        block = self.builder.add_block(debug_info)
×
570
        t_src = self.builder.add_access(block, target_name, debug_info)
×
571
        t_dst = self.builder.add_access(block, row_view_name, debug_info)
×
572
        self.builder.add_reference_memlet(
×
573
            block, t_src, t_dst, linear_idx, row_type, debug_info
574
        )
575

576
        # Now handle the inner slice assignment with the row view
577
        # Create inner indices (all slices for the inner dimensions)
578
        inner_indices = [
×
579
            ast.Slice(lower=None, upper=None, step=None) for _ in range(value_ndim)
580
        ]
581

582
        # Create new target using row view
583
        new_target = ast.Subscript(
×
584
            value=ast.Name(id=row_view_name, ctx=ast.Load()),
585
            slice=(
586
                ast.Tuple(elts=inner_indices, ctx=ast.Load())
587
                if len(inner_indices) > 1
588
                else inner_indices[0]
589
            ),
590
            ctx=ast.Store(),
591
        )
592

593
        # Recursively handle the inner assignment (now same-dimension)
594
        self._handle_slice_assignment(
×
595
            new_target, value, row_view_name, inner_indices, debug_info
596
        )
597

598
        # Close outer loops
599
        for _ in outer_loop_vars:
×
600
            self.builder.end_for()
×
601

602
    def _handle_slice_assignment(
4✔
603
        self, target, value, target_name, indices, debug_info=None
604
    ):
605
        if debug_info is None:
4✔
606
            debug_info = DebugInfo()
×
607

608
        if target_name in self.array_info:
4✔
609
            ndim = self.array_info[target_name]["ndim"]
4✔
610
            if len(indices) < ndim:
4✔
611
                indices = list(indices)
4✔
612
                for _ in range(ndim - len(indices)):
4✔
613
                    indices.append(ast.Slice(lower=None, upper=None, step=None))
4✔
614

615
        # Check if the RHS contains a ufunc outer operation
616
        # If so, we handle it specially to avoid the loop transformation
617
        # which would destroy the slice shape information
618
        has_outer, ufunc_name, outer_node = contains_ufunc_outer(value)
4✔
619
        if has_outer:
4✔
620
            self._handle_ufunc_outer_slice_assignment(
4✔
621
                target, value, target_name, indices, debug_info
622
            )
623
            return
4✔
624

625
        # Count slice dimensions to determine effective target dimensionality
626
        # (slice indices produce array dimensions, point indices collapse them)
627
        target_slice_ndim = sum(1 for idx in indices if isinstance(idx, ast.Slice))
4✔
628
        value_max_ndim = self._get_max_array_ndim_in_expr(value)
4✔
629

630
        if (
4✔
631
            target_slice_ndim > 0
632
            and value_max_ndim > 0
633
            and target_slice_ndim > value_max_ndim
634
        ):
635
            # Broadcasting case: use row-by-row approach with reference memlets
636
            self._handle_broadcast_slice_assignment(
×
637
                target,
638
                value,
639
                target_name,
640
                indices,
641
                target_slice_ndim,
642
                value_max_ndim,
643
                debug_info,
644
            )
645
            return
×
646

647
        loop_vars = []
4✔
648
        new_target_indices = []
4✔
649

650
        for i, idx in enumerate(indices):
4✔
651
            if isinstance(idx, ast.Slice):
4✔
652
                loop_var = f"_slice_iter_{len(loop_vars)}_{self._get_unique_id()}"
4✔
653
                loop_vars.append(loop_var)
4✔
654

655
                if not self.builder.exists(loop_var):
4✔
656
                    self.builder.add_container(
4✔
657
                        loop_var, Scalar(PrimitiveType.Int64), False
658
                    )
659
                    self.symbol_table[loop_var] = Scalar(PrimitiveType.Int64)
4✔
660

661
                start_str = "0"
4✔
662
                if idx.lower:
4✔
663
                    start_str = self._parse_expr(idx.lower)
4✔
664
                    if start_str.startswith("-"):
4✔
665
                        shapes = self.array_info[target_name].get("shapes", [])
×
666
                        dim_size = (
×
667
                            str(shapes[i])
668
                            if i < len(shapes)
669
                            else f"_{target_name}_shape_{i}"
670
                        )
671
                        start_str = f"({dim_size} {start_str})"
×
672

673
                stop_str = ""
4✔
674
                if idx.upper and not (
4✔
675
                    isinstance(idx.upper, ast.Constant) and idx.upper.value is None
676
                ):
677
                    stop_str = self._parse_expr(idx.upper)
4✔
678
                    if stop_str.startswith("-") or stop_str.startswith("(-"):
4✔
679
                        shapes = self.array_info[target_name].get("shapes", [])
4✔
680
                        dim_size = (
4✔
681
                            str(shapes[i])
682
                            if i < len(shapes)
683
                            else f"_{target_name}_shape_{i}"
684
                        )
685
                        stop_str = f"({dim_size} {stop_str})"
4✔
686
                else:
687
                    shapes = self.array_info[target_name].get("shapes", [])
4✔
688
                    stop_str = (
4✔
689
                        str(shapes[i])
690
                        if i < len(shapes)
691
                        else f"_{target_name}_shape_{i}"
692
                    )
693

694
                step_str = "1"
4✔
695
                if idx.step:
4✔
696
                    step_str = self._parse_expr(idx.step)
×
697

698
                count_str = f"({stop_str} - {start_str})"
4✔
699

700
                self.builder.begin_for(loop_var, "0", count_str, "1", debug_info)
4✔
701
                self.symbol_table[loop_var] = Scalar(PrimitiveType.Int64)
4✔
702

703
                new_target_indices.append(
4✔
704
                    ast.Name(
705
                        id=f"{start_str} + {loop_var} * {step_str}", ctx=ast.Load()
706
                    )
707
                )
708
            else:
709
                # Handle non-slice indices - need to normalize negative indices
710
                shapes = self.array_info[target_name].get("shapes", [])
4✔
711
                dim_size = shapes[i] if i < len(shapes) else f"_{target_name}_shape_{i}"
4✔
712
                normalized_idx = normalize_negative_index(idx, dim_size)
4✔
713
                new_target_indices.append(normalized_idx)
4✔
714

715
        rewriter = SliceRewriter(loop_vars, self.array_info, self.expr_visitor)
4✔
716
        new_value = rewriter.visit(copy.deepcopy(value))
4✔
717

718
        new_target = copy.deepcopy(target)
4✔
719
        if len(new_target_indices) == 1:
4✔
720
            new_target.slice = new_target_indices[0]
4✔
721
        else:
722
            new_target.slice = ast.Tuple(elts=new_target_indices, ctx=ast.Load())
4✔
723

724
        target_str = self._parse_expr(new_target)
4✔
725
        value_str = self._parse_expr(new_value)
4✔
726
        self.builder.add_assignment(target_str, value_str, debug_info)
4✔
727

728
        for _ in loop_vars:
4✔
729
            self.builder.end_for()
4✔
730

731
    def _handle_ufunc_outer_slice_assignment(
4✔
732
        self, target, value, target_name, indices, debug_info=None
733
    ):
734
        """Handle slice assignment where RHS contains a ufunc outer operation.
735

736
        Example: path[:] = np.minimum(path[:], np.add.outer(path[:, k], path[k, :]))
737

738
        The strategy is:
739
        1. Evaluate the entire RHS expression, which will create a temporary array
740
           containing the result of the ufunc outer (potentially wrapped in other ops)
741
        2. Copy the temporary result to the target slice
742

743
        This avoids the loop transformation that would destroy slice shape info.
744
        """
745
        if debug_info is None:
4✔
746
            from ._sdfg import DebugInfo
×
747

748
            debug_info = DebugInfo()
×
749

750
        # Evaluate the full RHS expression
751
        # This will:
752
        # - Create temp arrays for ufunc outer results
753
        # - Apply any wrapping operations (np.minimum, etc.)
754
        # - Return the name of the final result array
755
        result_name = self._parse_expr(value)
4✔
756

757
        # Now we need to copy result to target slice
758
        # Count slice dimensions to determine if we need loops
759
        target_slice_ndim = sum(1 for idx in indices if isinstance(idx, ast.Slice))
4✔
760

761
        if target_slice_ndim == 0:
4✔
762
            # No slices on target - just simple assignment
763
            target_str = self._parse_expr(target)
×
764
            block = self.builder.add_block(debug_info)
×
765
            t_src, src_sub = self.expr_visitor._add_read(block, result_name, debug_info)
×
766
            t_dst = self.builder.add_access(block, target_str, debug_info)
×
767
            t_task = self.builder.add_tasklet(
×
768
                block, TaskletCode.assign, ["_in"], ["_out"], debug_info
769
            )
770
            self.builder.add_memlet(
×
771
                block, t_src, "void", t_task, "_in", src_sub, None, debug_info
772
            )
773
            self.builder.add_memlet(
×
774
                block, t_task, "_out", t_dst, "void", "", None, debug_info
775
            )
776
            return
×
777

778
        # We have slices on the target - need to create loops for copying
779
        # Get target array info
780
        target_info = self.array_info.get(target_name, {})
4✔
781
        target_shapes = target_info.get("shapes", [])
4✔
782

783
        loop_vars = []
4✔
784
        new_target_indices = []
4✔
785

786
        for i, idx in enumerate(indices):
4✔
787
            if isinstance(idx, ast.Slice):
4✔
788
                loop_var = f"_copy_iter_{len(loop_vars)}_{self._get_unique_id()}"
4✔
789
                loop_vars.append(loop_var)
4✔
790

791
                if not self.builder.exists(loop_var):
4✔
792
                    self.builder.add_container(
4✔
793
                        loop_var, Scalar(PrimitiveType.Int64), False
794
                    )
795
                    self.symbol_table[loop_var] = Scalar(PrimitiveType.Int64)
4✔
796

797
                start_str = "0"
4✔
798
                if idx.lower:
4✔
799
                    start_str = self._parse_expr(idx.lower)
×
800

801
                stop_str = ""
4✔
802
                if idx.upper and not (
4✔
803
                    isinstance(idx.upper, ast.Constant) and idx.upper.value is None
804
                ):
805
                    stop_str = self._parse_expr(idx.upper)
×
806
                else:
807
                    stop_str = (
4✔
808
                        target_shapes[i]
809
                        if i < len(target_shapes)
810
                        else f"_{target_name}_shape_{i}"
811
                    )
812

813
                step_str = "1"
4✔
814
                if idx.step:
4✔
815
                    step_str = self._parse_expr(idx.step)
×
816

817
                count_str = f"({stop_str} - {start_str})"
4✔
818

819
                self.builder.begin_for(loop_var, "0", count_str, "1", debug_info)
4✔
820
                self.symbol_table[loop_var] = Scalar(PrimitiveType.Int64)
4✔
821

822
                new_target_indices.append(
4✔
823
                    ast.Name(
824
                        id=f"{start_str} + {loop_var} * {step_str}", ctx=ast.Load()
825
                    )
826
                )
827
            else:
828
                # Handle non-slice indices - need to normalize negative indices
829
                dim_size = (
×
830
                    target_shapes[i]
831
                    if i < len(target_shapes)
832
                    else f"_{target_name}_shape_{i}"
833
                )
834
                normalized_idx = normalize_negative_index(idx, dim_size)
×
835
                new_target_indices.append(normalized_idx)
×
836

837
        # Create assignment block: target[i,j,...] = result[i,j,...]
838
        block = self.builder.add_block(debug_info)
4✔
839

840
        # Access nodes
841
        t_src = self.builder.add_access(block, result_name, debug_info)
4✔
842
        t_dst = self.builder.add_access(block, target_name, debug_info)
4✔
843
        t_task = self.builder.add_tasklet(
4✔
844
            block, TaskletCode.assign, ["_in"], ["_out"], debug_info
845
        )
846

847
        # Source index - just use loop vars for flat array from ufunc outer
848
        # The ufunc outer result is a flat array of size M*N
849
        if len(loop_vars) == 2:
4✔
850
            # 2D case: result is indexed as i * N + j
851
            # Get the second dimension size from target shapes
852
            n_dim = (
4✔
853
                target_shapes[1]
854
                if len(target_shapes) > 1
855
                else f"_{target_name}_shape_1"
856
            )
857
            src_index = f"(({loop_vars[0]}) * ({n_dim}) + ({loop_vars[1]}))"
4✔
858
        elif len(loop_vars) == 1:
×
859
            src_index = loop_vars[0]
×
860
        else:
861
            # General case - compute linear index
862
            src_terms = []
×
863
            stride = "1"
×
864
            for i in range(len(loop_vars) - 1, -1, -1):
×
865
                if stride == "1":
×
866
                    src_terms.insert(0, loop_vars[i])
×
867
                else:
868
                    src_terms.insert(0, f"({loop_vars[i]} * {stride})")
×
869
                if i > 0:
×
870
                    dim_size = (
×
871
                        target_shapes[i]
872
                        if i < len(target_shapes)
873
                        else f"_{target_name}_shape_{i}"
874
                    )
875
                    stride = (
×
876
                        f"({stride} * {dim_size})" if stride != "1" else str(dim_size)
877
                    )
878
            src_index = " + ".join(src_terms) if src_terms else "0"
×
879

880
        # Target index - compute linear index (row-major order)
881
        # For 2D array with shape (M, N): linear_index = i * N + j
882
        target_index_parts = []
4✔
883
        for idx in new_target_indices:
4✔
884
            if isinstance(idx, ast.Name):
4✔
885
                target_index_parts.append(idx.id)
4✔
886
            else:
887
                target_index_parts.append(self._parse_expr(idx))
×
888

889
        # Convert to linear index
890
        if len(target_index_parts) == 2:
4✔
891
            # 2D case
892
            n_dim = (
4✔
893
                target_shapes[1]
894
                if len(target_shapes) > 1
895
                else f"_{target_name}_shape_1"
896
            )
897
            target_index = (
4✔
898
                f"(({target_index_parts[0]}) * ({n_dim}) + ({target_index_parts[1]}))"
899
            )
900
        elif len(target_index_parts) == 1:
×
901
            target_index = target_index_parts[0]
×
902
        else:
903
            # General case - compute linear index with strides
904
            stride = "1"
×
905
            target_index = "0"
×
906
            for i in range(len(target_index_parts) - 1, -1, -1):
×
907
                idx_part = target_index_parts[i]
×
908
                if stride == "1":
×
909
                    term = idx_part
×
910
                else:
911
                    term = f"(({idx_part}) * ({stride}))"
×
912

913
                if target_index == "0":
×
914
                    target_index = term
×
915
                else:
916
                    target_index = f"({term} + {target_index})"
×
917

918
                if i > 0:
×
919
                    dim_size = (
×
920
                        target_shapes[i]
921
                        if i < len(target_shapes)
922
                        else f"_{target_name}_shape_{i}"
923
                    )
924
                    stride = (
×
925
                        f"({stride} * {dim_size})" if stride != "1" else str(dim_size)
926
                    )
927

928
        # Connect memlets
929
        self.builder.add_memlet(
4✔
930
            block, t_src, "void", t_task, "_in", src_index, None, debug_info
931
        )
932
        self.builder.add_memlet(
4✔
933
            block, t_task, "_out", t_dst, "void", target_index, None, debug_info
934
        )
935

936
        # End loops
937
        for _ in loop_vars:
4✔
938
            self.builder.end_for()
4✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc