• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

daisytuner / docc / 28685903378

03 Jul 2026 10:24PM UTC coverage: 62.417% (+0.3%) from 62.147%
28685903378

Pull #832

github

web-flow
Merge 7ab5e18a3 into 3726be1d9
Pull Request #832: activates numpy tests

99 of 112 new or added lines in 2 files covered. (88.39%)

22 existing lines in 2 files now uncovered.

39691 of 63590 relevant lines covered (62.42%)

978.44 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

78.91
/python/docc/python/ast_parser.py
1
import ast
4✔
2
import copy
4✔
3
import inspect
4✔
4
import textwrap
4✔
5
from docc.sdfg import (
4✔
6
    Scalar,
7
    PrimitiveType,
8
    Pointer,
9
    TaskletCode,
10
    DebugInfo,
11
    Structure,
12
    CMathFunction,
13
    Tensor,
14
)
15
from docc.python.ast_utils import (
4✔
16
    SliceRewriter,
17
    get_debug_info,
18
    contains_ufunc_outer,
19
    normalize_negative_index,
20
)
21
from docc.python.types import (
4✔
22
    sdfg_type_from_type,
23
    element_type_from_sdfg_type,
24
)
25
from docc.python.functions.numpy import NumPyHandler
4✔
26
from docc.python.functions.math import MathHandler
4✔
27
from docc.python.functions.python import PythonHandler
4✔
28
from docc.python.memory import ManagedMemoryHandler
4✔
29

30

31
class ASTParser(ast.NodeVisitor):
4✔
32
    def __init__(
4✔
33
        self,
34
        builder,
35
        tensor_table,
36
        container_table,
37
        filename="",
38
        function_name="",
39
        infer_return_type=False,
40
        globals_dict=None,
41
        unique_counter_ref=None,
42
        structure_member_info=None,
43
        memory_handler=None,
44
    ):
45
        self.builder = builder
4✔
46

47
        # Lookup tables for variables
48
        self.tensor_table = tensor_table
4✔
49
        self.container_table = container_table
4✔
50

51
        # Debug info
52
        self.filename = filename
4✔
53
        self.function_name = function_name
4✔
54

55
        # Context
56
        self.infer_return_type = infer_return_type
4✔
57
        self.globals_dict = globals_dict if globals_dict is not None else {}
4✔
58
        self._unique_counter_ref = (
4✔
59
            unique_counter_ref if unique_counter_ref is not None else [0]
60
        )
61
        self._access_cache = {}
4✔
62
        self.structure_member_info = (
4✔
63
            structure_member_info if structure_member_info is not None else {}
64
        )
65
        self.captured_return_shapes = {}  # Map param name to shape string list
4✔
66
        self.captured_return_strides = {}  # Map param name to stride string list
4✔
67
        self.shapes_runtime_info = (
4✔
68
            {}
69
        )  # Map array name to runtime shapes (separate from Tensor)
70

71
        # Memory manager for hoisted allocations (shared with inline parsers)
72
        self.memory_handler = (
4✔
73
            memory_handler
74
            if memory_handler is not None
75
            else ManagedMemoryHandler(builder)
76
        )
77

78
        # Initialize handlers - they receive 'self' to access expression visitor methods
79
        self.numpy_visitor = NumPyHandler(self)
4✔
80
        self.math_handler = MathHandler(self)
4✔
81
        self.python_handler = PythonHandler(self)
4✔
82

83
    def visit_Constant(self, node):
4✔
84
        if isinstance(node.value, bool):
4✔
85
            return "true" if node.value else "false"
4✔
86
        return str(node.value)
4✔
87

88
    def visit_Name(self, node):
4✔
89
        name = node.id
4✔
90
        if name not in self.container_table and self.globals_dict is not None:
4✔
91
            if name in self.globals_dict:
4✔
92
                val = self.globals_dict[name]
4✔
93
                if isinstance(val, (int, float)):
4✔
94
                    return str(val)
4✔
95
        return name
4✔
96

97
    def visit_Add(self, node):
4✔
98
        return "+"
4✔
99

100
    def visit_Sub(self, node):
4✔
101
        return "-"
4✔
102

103
    def visit_Mult(self, node):
4✔
104
        return "*"
4✔
105

106
    def visit_Div(self, node):
4✔
107
        return "/"
4✔
108

109
    def visit_FloorDiv(self, node):
4✔
110
        return "//"
4✔
111

112
    def visit_Mod(self, node):
4✔
113
        return "%"
4✔
114

115
    def visit_Pow(self, node):
4✔
116
        return "**"
4✔
117

118
    def visit_Eq(self, node):
4✔
119
        return "=="
4✔
120

121
    def visit_NotEq(self, node):
4✔
122
        return "!="
×
123

124
    def visit_Lt(self, node):
4✔
125
        return "<"
4✔
126

127
    def visit_LtE(self, node):
4✔
128
        return "<="
×
129

130
    def visit_Gt(self, node):
4✔
131
        return ">"
4✔
132

133
    def visit_GtE(self, node):
4✔
134
        return ">="
4✔
135

136
    def visit_And(self, node):
4✔
137
        return "&"
4✔
138

139
    def visit_Or(self, node):
4✔
140
        return "|"
4✔
141

142
    def visit_BitAnd(self, node):
4✔
143
        return "&"
4✔
144

145
    def visit_BitOr(self, node):
4✔
146
        return "|"
4✔
147

148
    def visit_BitXor(self, node):
4✔
149
        return "^"
4✔
150

151
    def visit_LShift(self, node):
4✔
152
        return "<<"
×
153

154
    def visit_RShift(self, node):
4✔
155
        return ">>"
×
156

157
    def visit_Not(self, node):
4✔
158
        return "!"
4✔
159

160
    def visit_USub(self, node):
4✔
161
        return "-"
4✔
162

163
    def visit_UAdd(self, node):
4✔
164
        return "+"
×
165

166
    def visit_Invert(self, node):
4✔
167
        return "~"
4✔
168

169
    def visit_BoolOp(self, node):
4✔
170
        op = self.visit(node.op)
4✔
171
        values = [f"({self.visit(v)} != 0)" for v in node.values]
4✔
172
        expr_str = f"{f' {op} '.join(values)}"
4✔
173

174
        tmp_name = self.builder.find_new_name()
4✔
175
        dtype = Scalar(PrimitiveType.Bool)
4✔
176
        self.builder.add_container(tmp_name, dtype, False)
4✔
177

178
        self.builder.begin_if(expr_str)
4✔
179
        self._add_assign_constant(tmp_name, "true", dtype)
4✔
180
        self.builder.begin_else()
4✔
181
        self._add_assign_constant(tmp_name, "false", dtype)
4✔
182
        self.builder.end_if()
4✔
183

184
        self.container_table[tmp_name] = dtype
4✔
185
        return tmp_name
4✔
186

187
    def visit_UnaryOp(self, node):
4✔
188
        if (
4✔
189
            isinstance(node.op, ast.USub)
190
            and isinstance(node.operand, ast.Constant)
191
            and isinstance(node.operand.value, (int, float))
192
        ):
193
            return f"-{node.operand.value}"
4✔
194

195
        op = self.visit(node.op)
4✔
196
        operand = self.visit(node.operand)
4✔
197

198
        if operand in self.tensor_table and op == "-":
4✔
199
            return self.numpy_visitor.handle_array_negate(operand)
4✔
200

201
        assert operand in self.container_table, f"Undefined variable: {operand}"
4✔
202
        tmp_name = self.builder.find_new_name()
4✔
203
        dtype = self.container_table[operand]
4✔
204
        self.builder.add_container(tmp_name, dtype, False)
4✔
205
        self.container_table[tmp_name] = dtype
4✔
206

207
        block = self.builder.add_block()
4✔
208
        t_src, src_sub = self._add_read(block, operand)
4✔
209
        t_dst = self.builder.add_access(block, tmp_name)
4✔
210

211
        if isinstance(node.op, ast.Not):
4✔
212
            t_const = self.builder.add_constant(
4✔
213
                block, "true", Scalar(PrimitiveType.Bool)
214
            )
215
            t_task = self.builder.add_tasklet(
4✔
216
                block, TaskletCode.int_xor, ["_in1", "_in2"], ["_out"]
217
            )
218
            self.builder.add_memlet(block, t_src, "void", t_task, "_in1", src_sub)
4✔
219
            self.builder.add_memlet(block, t_const, "void", t_task, "_in2", "")
4✔
220
            self.builder.add_memlet(block, t_task, "_out", t_dst, "void", "")
4✔
221
        elif op == "-":
4✔
222
            if dtype.primitive_type == PrimitiveType.Int64:
4✔
223
                t_const = self.builder.add_constant(block, "0", dtype)
4✔
224
                t_task = self.builder.add_tasklet(
4✔
225
                    block, TaskletCode.int_sub, ["_in1", "_in2"], ["_out"]
226
                )
227
                self.builder.add_memlet(block, t_const, "void", t_task, "_in1", "")
4✔
228
                self.builder.add_memlet(block, t_src, "void", t_task, "_in2", src_sub)
4✔
229
                self.builder.add_memlet(block, t_task, "_out", t_dst, "void", "")
4✔
230
            else:
231
                t_task = self.builder.add_tasklet(
4✔
232
                    block, TaskletCode.fp_neg, ["_in"], ["_out"]
233
                )
234
                self.builder.add_memlet(block, t_src, "void", t_task, "_in", src_sub)
4✔
235
                self.builder.add_memlet(block, t_task, "_out", t_dst, "void", "")
4✔
236
        elif op == "~":
4✔
237
            t_const = self.builder.add_constant(
4✔
238
                block, "-1", Scalar(PrimitiveType.Int64)
239
            )
240
            t_task = self.builder.add_tasklet(
4✔
241
                block, TaskletCode.int_xor, ["_in1", "_in2"], ["_out"]
242
            )
243
            self.builder.add_memlet(block, t_src, "void", t_task, "_in1", src_sub)
4✔
244
            self.builder.add_memlet(block, t_const, "void", t_task, "_in2", "")
4✔
245
            self.builder.add_memlet(block, t_task, "_out", t_dst, "void", "")
4✔
246
        else:
247
            t_task = self.builder.add_tasklet(
×
248
                block, TaskletCode.assign, ["_in"], ["_out"]
249
            )
250
            self.builder.add_memlet(block, t_src, "void", t_task, "_in", src_sub)
×
251
            self.builder.add_memlet(block, t_task, "_out", t_dst, "void", "")
×
252

253
        return tmp_name
4✔
254

255
    def visit_BinOp(self, node):
4✔
256
        if isinstance(node.op, ast.MatMult):
4✔
257
            return self.numpy_visitor.handle_numpy_matmul_op(node.left, node.right)
4✔
258

259
        left = self.visit(node.left)
4✔
260
        op = self.visit(node.op)
4✔
261
        right = self.visit(node.right)
4✔
262

263
        left_is_array = left in self.tensor_table
4✔
264
        right_is_array = right in self.tensor_table
4✔
265

266
        if left_is_array or right_is_array:
4✔
267
            op_map = {"+": "add", "-": "sub", "*": "mul", "/": "div", "**": "pow"}
4✔
268
            if op in op_map:
4✔
269
                return self.numpy_visitor.handle_array_binary_op(
4✔
270
                    op_map[op], left, right
271
                )
272
            else:
273
                raise NotImplementedError(f"Array operation {op} not supported")
×
274

275
        tmp_name = self.builder.find_new_name()
4✔
276

277
        left_is_int = self._is_int(left)
4✔
278
        right_is_int = self._is_int(right)
4✔
279
        dtype = Scalar(PrimitiveType.Double)
4✔
280
        if left_is_int and right_is_int and op not in ["/", "**"]:
4✔
281
            dtype = Scalar(PrimitiveType.Int64)
4✔
282

283
        if not self.builder.exists(tmp_name):
4✔
284
            self.builder.add_container(tmp_name, dtype, False)
4✔
285
            self.container_table[tmp_name] = dtype
4✔
286

287
        real_left = left
4✔
288
        real_right = right
4✔
289
        if dtype.primitive_type == PrimitiveType.Double:
4✔
290
            if left_is_int:
4✔
291
                left_cast = self.builder.find_new_name()
4✔
292
                self.builder.add_container(
4✔
293
                    left_cast, Scalar(PrimitiveType.Double), False
294
                )
295
                self.container_table[left_cast] = Scalar(PrimitiveType.Double)
4✔
296

297
                c_block = self.builder.add_block()
4✔
298
                t_src, src_sub = self._add_read(c_block, left)
4✔
299
                t_dst = self.builder.add_access(c_block, left_cast)
4✔
300
                t_task = self.builder.add_tasklet(
4✔
301
                    c_block, TaskletCode.assign, ["_in"], ["_out"]
302
                )
303
                self.builder.add_memlet(c_block, t_src, "void", t_task, "_in", src_sub)
4✔
304
                self.builder.add_memlet(c_block, t_task, "_out", t_dst, "void", "")
4✔
305

306
                real_left = left_cast
4✔
307

308
            if right_is_int:
4✔
309
                right_cast = self.builder.find_new_name()
4✔
310
                self.builder.add_container(
4✔
311
                    right_cast, Scalar(PrimitiveType.Double), False
312
                )
313
                self.container_table[right_cast] = Scalar(PrimitiveType.Double)
4✔
314

315
                c_block = self.builder.add_block()
4✔
316
                t_src, src_sub = self._add_read(c_block, right)
4✔
317
                t_dst = self.builder.add_access(c_block, right_cast)
4✔
318
                t_task = self.builder.add_tasklet(
4✔
319
                    c_block, TaskletCode.assign, ["_in"], ["_out"]
320
                )
321
                self.builder.add_memlet(c_block, t_src, "void", t_task, "_in", src_sub)
4✔
322
                self.builder.add_memlet(c_block, t_task, "_out", t_dst, "void", "")
4✔
323

324
                real_right = right_cast
4✔
325

326
        if op == "**":
4✔
327
            block = self.builder.add_block()
4✔
328
            t_left, left_sub = self._add_read(block, real_left)
4✔
329
            t_right, right_sub = self._add_read(block, real_right)
4✔
330
            t_out = self.builder.add_access(block, tmp_name)
4✔
331

332
            t_task = self.builder.add_cmath(
4✔
333
                block, CMathFunction.pow, dtype.primitive_type
334
            )
335
            self.builder.add_memlet(block, t_left, "void", t_task, "_in1", left_sub)
4✔
336
            self.builder.add_memlet(block, t_right, "void", t_task, "_in2", right_sub)
4✔
337
            self.builder.add_memlet(block, t_task, "_out", t_out, "void", "")
4✔
338

339
            return tmp_name
4✔
340
        elif op == "%":
4✔
341
            block = self.builder.add_block()
4✔
342
            t_left, left_sub = self._add_read(block, real_left)
4✔
343
            t_right, right_sub = self._add_read(block, real_right)
4✔
344
            t_out = self.builder.add_access(block, tmp_name)
4✔
345

346
            if dtype.primitive_type == PrimitiveType.Int64:
4✔
347
                t_rem1 = self.builder.add_tasklet(
4✔
348
                    block, TaskletCode.int_srem, ["_in1", "_in2"], ["_out"]
349
                )
350
                self.builder.add_memlet(block, t_left, "void", t_rem1, "_in1", left_sub)
4✔
351
                self.builder.add_memlet(
4✔
352
                    block, t_right, "void", t_rem1, "_in2", right_sub
353
                )
354

355
                rem1_name = self.builder.find_new_name()
4✔
356
                self.builder.add_container(rem1_name, dtype, False)
4✔
357
                t_rem1_out = self.builder.add_access(block, rem1_name)
4✔
358
                self.builder.add_memlet(block, t_rem1, "_out", t_rem1_out, "void", "")
4✔
359

360
                t_add = self.builder.add_tasklet(
4✔
361
                    block, TaskletCode.int_add, ["_in1", "_in2"], ["_out"]
362
                )
363
                self.builder.add_memlet(block, t_rem1_out, "void", t_add, "_in1", "")
4✔
364
                self.builder.add_memlet(
4✔
365
                    block, t_right, "void", t_add, "_in2", right_sub
366
                )
367

368
                add_name = self.builder.find_new_name()
4✔
369
                self.builder.add_container(add_name, dtype, False)
4✔
370
                t_add_out = self.builder.add_access(block, add_name)
4✔
371
                self.builder.add_memlet(block, t_add, "_out", t_add_out, "void", "")
4✔
372

373
                t_rem2 = self.builder.add_tasklet(
4✔
374
                    block, TaskletCode.int_srem, ["_in1", "_in2"], ["_out"]
375
                )
376
                self.builder.add_memlet(block, t_add_out, "void", t_rem2, "_in1", "")
4✔
377
                self.builder.add_memlet(
4✔
378
                    block, t_right, "void", t_rem2, "_in2", right_sub
379
                )
380
                self.builder.add_memlet(block, t_rem2, "_out", t_out, "void", "")
4✔
381

382
                return tmp_name
4✔
383
            else:
384
                t_rem1 = self.builder.add_tasklet(
4✔
385
                    block, TaskletCode.fp_rem, ["_in1", "_in2"], ["_out"]
386
                )
387
                self.builder.add_memlet(block, t_left, "void", t_rem1, "_in1", left_sub)
4✔
388
                self.builder.add_memlet(
4✔
389
                    block, t_right, "void", t_rem1, "_in2", right_sub
390
                )
391

392
                rem1_name = self.builder.find_new_name()
4✔
393
                self.builder.add_container(rem1_name, dtype, False)
4✔
394
                t_rem1_out = self.builder.add_access(block, rem1_name)
4✔
395
                self.builder.add_memlet(block, t_rem1, "_out", t_rem1_out, "void", "")
4✔
396

397
                t_add = self.builder.add_tasklet(
4✔
398
                    block, TaskletCode.fp_add, ["_in1", "_in2"], ["_out"]
399
                )
400
                self.builder.add_memlet(block, t_rem1_out, "void", t_add, "_in1", "")
4✔
401
                self.builder.add_memlet(
4✔
402
                    block, t_right, "void", t_add, "_in2", right_sub
403
                )
404

405
                add_name = self.builder.find_new_name()
4✔
406
                self.builder.add_container(add_name, dtype, False)
4✔
407
                t_add_out = self.builder.add_access(block, add_name)
4✔
408
                self.builder.add_memlet(block, t_add, "_out", t_add_out, "void", "")
4✔
409

410
                t_rem2 = self.builder.add_tasklet(
4✔
411
                    block, TaskletCode.fp_rem, ["_in1", "_in2"], ["_out"]
412
                )
413
                self.builder.add_memlet(block, t_add_out, "void", t_rem2, "_in1", "")
4✔
414
                self.builder.add_memlet(
4✔
415
                    block, t_right, "void", t_rem2, "_in2", right_sub
416
                )
417
                self.builder.add_memlet(block, t_rem2, "_out", t_out, "void", "")
4✔
418

419
                return tmp_name
4✔
420

421
        tasklet_code = None
4✔
422
        if dtype.primitive_type == PrimitiveType.Int64:
4✔
423
            if op == "+":
4✔
424
                tasklet_code = TaskletCode.int_add
4✔
425
            elif op == "-":
4✔
426
                tasklet_code = TaskletCode.int_sub
4✔
427
            elif op == "*":
4✔
428
                tasklet_code = TaskletCode.int_mul
4✔
429
            elif op == "/":
4✔
430
                tasklet_code = TaskletCode.int_sdiv
×
431
            elif op == "//":
4✔
432
                tasklet_code = TaskletCode.int_sdiv
4✔
433
            elif op == "&":
4✔
434
                tasklet_code = TaskletCode.int_and
4✔
435
            elif op == "|":
4✔
436
                tasklet_code = TaskletCode.int_or
4✔
437
            elif op == "^":
4✔
438
                tasklet_code = TaskletCode.int_xor
4✔
439
            elif op == "<<":
×
440
                tasklet_code = TaskletCode.int_shl
×
441
            elif op == ">>":
×
442
                tasklet_code = TaskletCode.int_lshr
×
443
        else:
444
            if op == "+":
4✔
445
                tasklet_code = TaskletCode.fp_add
4✔
446
            elif op == "-":
4✔
447
                tasklet_code = TaskletCode.fp_sub
4✔
448
            elif op == "*":
4✔
449
                tasklet_code = TaskletCode.fp_mul
4✔
450
            elif op == "/":
4✔
451
                tasklet_code = TaskletCode.fp_div
4✔
452
            elif op == "//":
×
453
                tasklet_code = TaskletCode.fp_div
×
454
            else:
455
                raise NotImplementedError(f"Operation {op} not supported for floats")
×
456

457
        block = self.builder.add_block()
4✔
458
        t_left, left_sub = self._add_read(block, real_left)
4✔
459
        t_right, right_sub = self._add_read(block, real_right)
4✔
460
        t_out = self.builder.add_access(block, tmp_name)
4✔
461

462
        t_task = self.builder.add_tasklet(
4✔
463
            block, tasklet_code, ["_in1", "_in2"], ["_out"]
464
        )
465

466
        # For indexed array accesses like "arr(i,j)", we need to pass the Tensor type
467
        # to ensure correct type inference during validation
468
        left_type = self._get_memlet_type_for_access(real_left, left_sub)
4✔
469
        right_type = self._get_memlet_type_for_access(real_right, right_sub)
4✔
470

471
        self.builder.add_memlet(
4✔
472
            block, t_left, "void", t_task, "_in1", left_sub, left_type
473
        )
474
        self.builder.add_memlet(
4✔
475
            block, t_right, "void", t_task, "_in2", right_sub, right_type
476
        )
477
        self.builder.add_memlet(block, t_task, "_out", t_out, "void", "")
4✔
478

479
        return tmp_name
4✔
480

481
    def visit_Attribute(self, node):
4✔
482
        if node.attr == "shape":
4✔
483
            if isinstance(node.value, ast.Name) and node.value.id in self.tensor_table:
4✔
484
                return f"_shape_proxy_{node.value.id}"
4✔
485

486
        if node.attr == "T":
4✔
487
            value_name = None
4✔
488
            if isinstance(node.value, ast.Name):
4✔
489
                value_name = node.value.id
4✔
490
            elif isinstance(node.value, ast.Subscript):
×
491
                value_name = self.visit(node.value)
×
492

493
            if value_name and value_name in self.tensor_table:
4✔
494
                return self.numpy_visitor.handle_transpose_expr(node)
4✔
495

496
        if isinstance(node.value, ast.Name) and node.value.id == "math":
4✔
497
            val = ""
4✔
498
            if node.attr == "pi":
4✔
499
                val = "M_PI"
4✔
500
            elif node.attr == "e":
4✔
501
                val = "M_E"
4✔
502

503
            if val:
4✔
504
                tmp_name = self.builder.find_new_name()
4✔
505
                dtype = Scalar(PrimitiveType.Double)
4✔
506
                self.builder.add_container(tmp_name, dtype, False)
4✔
507
                self.container_table[tmp_name] = dtype
4✔
508
                self._add_assign_constant(tmp_name, val, dtype)
4✔
509
                return tmp_name
4✔
510

511
        if isinstance(node.value, ast.Name):
4✔
512
            obj_name = node.value.id
4✔
513
            attr_name = node.attr
4✔
514

515
            if obj_name in self.container_table:
4✔
516
                obj_type = self.container_table[obj_name]
4✔
517
                if isinstance(obj_type, Pointer) and obj_type.has_pointee_type():
4✔
518
                    pointee_type = obj_type.pointee_type
4✔
519
                    if isinstance(pointee_type, Structure):
4✔
520
                        struct_name = pointee_type.name
4✔
521

522
                        if (
4✔
523
                            struct_name in self.structure_member_info
524
                            and attr_name in self.structure_member_info[struct_name]
525
                        ):
526
                            member_index, member_type = self.structure_member_info[
4✔
527
                                struct_name
528
                            ][attr_name]
529
                        else:
530
                            raise RuntimeError(
×
531
                                f"Member '{attr_name}' not found in structure '{struct_name}'. "
532
                                f"Available members: {list(self.structure_member_info.get(struct_name, {}).keys())}"
533
                            )
534

535
                        tmp_name = self.builder.find_new_name()
4✔
536

537
                        self.builder.add_container(tmp_name, member_type, False)
4✔
538
                        self.container_table[tmp_name] = member_type
4✔
539

540
                        block = self.builder.add_block()
4✔
541
                        obj_access = self.builder.add_access(block, obj_name)
4✔
542
                        tmp_access = self.builder.add_access(block, tmp_name)
4✔
543

544
                        tasklet = self.builder.add_tasklet(
4✔
545
                            block, TaskletCode.assign, ["_in"], ["_out"]
546
                        )
547

548
                        subset = "0," + str(member_index)
4✔
549
                        self.builder.add_memlet(
4✔
550
                            block, obj_access, "", tasklet, "_in", subset
551
                        )
552
                        self.builder.add_memlet(block, tasklet, "_out", tmp_access, "")
4✔
553

554
                        return tmp_name
4✔
555

556
        raise NotImplementedError(f"Attribute access {node.attr} not supported")
×
557

558
    def visit_Compare(self, node):
4✔
559
        left = self.visit(node.left)
4✔
560
        if len(node.ops) > 1:
4✔
561
            raise NotImplementedError("Chained comparisons not supported yet")
×
562

563
        op = self.visit(node.ops[0])
4✔
564
        right = self.visit(node.comparators[0])
4✔
565

566
        left_is_array = left in self.tensor_table
4✔
567
        right_is_array = right in self.tensor_table
4✔
568

569
        if left_is_array or right_is_array:
4✔
570
            return self.numpy_visitor.handle_array_compare(
4✔
571
                left, op, right, left_is_array, right_is_array
572
            )
573

574
        expr_str = f"{left} {op} {right}"
4✔
575

576
        tmp_name = self.builder.find_new_name()
4✔
577
        dtype = Scalar(PrimitiveType.Bool)
4✔
578
        self.builder.add_container(tmp_name, dtype, False)
4✔
579

580
        self.builder.begin_if(expr_str)
4✔
581
        self.builder.add_transition(tmp_name, "true")
4✔
582
        self.builder.begin_else()
4✔
583
        self.builder.add_transition(tmp_name, "false")
4✔
584
        self.builder.end_if()
4✔
585

586
        self.container_table[tmp_name] = dtype
4✔
587
        return tmp_name
4✔
588

589
    def visit_Subscript(self, node):
4✔
590
        value_str = self.visit(node.value)
4✔
591

592
        if value_str.startswith("_shape_proxy_"):
4✔
593
            array_name = value_str[len("_shape_proxy_") :]
4✔
594
            if isinstance(node.slice, ast.Constant):
4✔
595
                idx = node.slice.value
4✔
596
            elif isinstance(node.slice, ast.Index):
×
597
                idx = node.slice.value.value
×
598
            else:
599
                try:
×
600
                    idx = int(self.visit(node.slice))
×
601
                except:
×
602
                    raise NotImplementedError(
×
603
                        "Dynamic shape indexing not fully supported yet"
604
                    )
605

606
            if array_name in self.tensor_table:
4✔
607
                return self.tensor_table[array_name].shape[idx]
4✔
608

609
            return f"_{array_name}_shape_{idx}"
×
610

611
        if value_str in self.tensor_table:
4✔
612
            tensor = self.tensor_table[value_str]
4✔
613
            ndim = len(tensor.shape)
4✔
614
            shapes = tensor.shape
4✔
615

616
            if isinstance(node.slice, ast.Tuple):
4✔
617
                indices_nodes = node.slice.elts
4✔
618
            else:
619
                indices_nodes = [node.slice]
4✔
620

621
            all_full_slices = True
4✔
622
            for idx in indices_nodes:
4✔
623
                if isinstance(idx, ast.Slice):
4✔
624
                    if idx.lower is not None or idx.upper is not None:
4✔
625
                        all_full_slices = False
4✔
626
                        break
4✔
627
                    # Also check for non-trivial step (step != None and step != 1)
628
                    if idx.step is not None:
4✔
629
                        # Check if step is a constant 1; if not, it's not a full slice
630
                        if isinstance(idx.step, ast.Constant) and idx.step.value == 1:
4✔
631
                            pass  # step=1 is equivalent to no step
×
632
                        else:
633
                            all_full_slices = False
4✔
634
                            break
4✔
635
                else:
636
                    all_full_slices = False
4✔
637
                    break
4✔
638

639
            if all_full_slices:
4✔
640
                return value_str
4✔
641

642
            has_slices = any(isinstance(idx, ast.Slice) for idx in indices_nodes)
4✔
643
            if has_slices:
4✔
644
                return self._handle_expression_slicing(
4✔
645
                    node, value_str, indices_nodes, shapes, ndim
646
                )
647

648
            if len(indices_nodes) == 1 and self._is_array_index(indices_nodes[0]):
4✔
649
                if self.builder:
4✔
650
                    return self._handle_gather(value_str, indices_nodes[0])
4✔
651

652
            if isinstance(node.slice, ast.Tuple):
4✔
653
                indices = [self.visit(elt) for elt in node.slice.elts]
4✔
654
            else:
655
                indices = [self.visit(node.slice)]
4✔
656

657
            if len(indices) != ndim:
4✔
658
                raise ValueError(
×
659
                    f"Array {value_str} has {ndim} dimensions, but accessed with {len(indices)} indices"
660
                )
661

662
            normalized_indices = []
4✔
663
            for i, idx_str in enumerate(indices):
4✔
664
                shape_val = shapes[i]
4✔
665
                if isinstance(idx_str, str) and (
4✔
666
                    idx_str.startswith("-") or idx_str.startswith("(-")
667
                ):
668
                    normalized_indices.append(f"({shape_val} + {idx_str})")
×
669
                else:
670
                    normalized_indices.append(idx_str)
4✔
671

672
            subscript_str = ",".join(normalized_indices)
4✔
673
            access_str = f"{value_str}({subscript_str})"
4✔
674

675
            if isinstance(node.ctx, ast.Load):
4✔
676
                tmp_name = self.builder.find_new_name()
4✔
677
                self.builder.add_container(tmp_name, tensor.element_type, False)
4✔
678
                self.container_table[tmp_name] = tensor.element_type
4✔
679

680
                block = self.builder.add_block()
4✔
681
                t_src = self.builder.add_access(block, value_str)
4✔
682
                t_dst = self.builder.add_access(block, tmp_name)
4✔
683
                t_task = self.builder.add_tasklet(
4✔
684
                    block, TaskletCode.assign, ["_in"], ["_out"]
685
                )
686
                self.builder.add_memlet(
4✔
687
                    block, t_src, "void", t_task, "_in", subscript_str, tensor
688
                )
689
                self.builder.add_memlet(
4✔
690
                    block, t_task, "_out", t_dst, "void", "", tensor.element_type
691
                )
692

693
                return tmp_name
4✔
694

695
            return access_str
4✔
696

697
        slice_val = self.visit(node.slice)
×
698
        access_str = f"{value_str}({slice_val})"
×
699
        return access_str
×
700

701
    def visit_AugAssign(self, node):
4✔
702
        if isinstance(node.target, ast.Name) and node.target.id in self.tensor_table:
4✔
703
            # Convert to slice assignment: target[:] = target op value
704
            ndim = len(self.tensor_table[node.target.id].shape)
4✔
705

706
            slices = []
4✔
707
            for _ in range(ndim):
4✔
708
                slices.append(ast.Slice(lower=None, upper=None, step=None))
4✔
709

710
            if ndim == 1:
4✔
711
                slice_arg = slices[0]
×
712
            else:
713
                slice_arg = ast.Tuple(elts=slices, ctx=ast.Load())
4✔
714

715
            slice_node = ast.Subscript(
4✔
716
                value=node.target, slice=slice_arg, ctx=ast.Store()
717
            )
718

719
            new_node = ast.Assign(
4✔
720
                targets=[slice_node],
721
                value=ast.BinOp(left=node.target, op=node.op, right=node.value),
722
            )
723
            self.visit_Assign(new_node)
4✔
724
        else:
725
            new_node = ast.Assign(
4✔
726
                targets=[node.target],
727
                value=ast.BinOp(left=node.target, op=node.op, right=node.value),
728
            )
729
            self.visit_Assign(new_node)
4✔
730

731
    def visit_Assign(self, node):
4✔
732
        # Handle multiple targets: a = b = c or a, b = expr
733
        if len(node.targets) > 1:
4✔
734
            tmp_name = self.builder.find_new_name()
4✔
735
            # Assign value to temporary
736
            val_assign = ast.Assign(
4✔
737
                targets=[ast.Name(id=tmp_name, ctx=ast.Store())], value=node.value
738
            )
739
            ast.copy_location(val_assign, node)
4✔
740
            self.visit_Assign(val_assign)
4✔
741

742
            # Assign temporary to targets
743
            for target in node.targets:
4✔
744
                assign = ast.Assign(
4✔
745
                    targets=[target], value=ast.Name(id=tmp_name, ctx=ast.Load())
746
                )
747
                ast.copy_location(assign, node)
4✔
748
                self.visit_Assign(assign)
4✔
749
            return
4✔
750
        target = node.targets[0]
4✔
751

752
        # Handle tuple unpacking: I, J, K = expr1, expr2, expr3
753
        if isinstance(target, ast.Tuple):
4✔
754
            if isinstance(node.value, ast.Tuple):
4✔
755
                if len(target.elts) != len(node.value.elts):
4✔
756
                    raise ValueError("Tuple unpacking size mismatch")
×
757
                for tgt, val in zip(target.elts, node.value.elts):
4✔
758
                    assign = ast.Assign(targets=[tgt], value=val)
4✔
759
                    ast.copy_location(assign, node)
4✔
760
                    self.visit_Assign(assign)
4✔
761
                return
4✔
762
            else:
763
                raise NotImplementedError(
×
764
                    "Tuple unpacking from non-tuple values not supported"
765
                )
766

767
        # Special cases, where rhs is not just a simple expression but requires special handling
768
        if self.numpy_visitor.is_gemm(node.value):
4✔
769
            if self.numpy_visitor.handle_gemm(target, node.value):
4✔
770
                return
4✔
771
            if self.numpy_visitor.handle_dot(target, node.value):
4✔
772
                return
×
773
        if self.numpy_visitor.is_outer(node.value):
4✔
774
            if self.numpy_visitor.handle_outer(target, node.value):
4✔
775
                return
4✔
776
        if self.numpy_visitor.is_transpose(node.value):
4✔
777
            if self.numpy_visitor.handle_transpose(target, node.value):
4✔
778
                return
4✔
779

780
        # Handle subscript assignments: a[i] = val or a[i, j] = val
781
        if isinstance(target, ast.Subscript):
4✔
782
            debug_info = get_debug_info(node, self.filename, self.function_name)
4✔
783

784
            target_name = self.visit(target.value)
4✔
785
            indices = []
4✔
786
            if isinstance(target.slice, ast.Tuple):
4✔
787
                indices = target.slice.elts
4✔
788
            else:
789
                indices = [target.slice]
4✔
790

791
            # Handle slice assignment separately
792
            has_slice = False
4✔
793
            for idx in indices:
4✔
794
                if isinstance(idx, ast.Slice):
4✔
795
                    has_slice = True
4✔
796
                    break
4✔
797

798
            if has_slice:
4✔
799
                self._handle_slice_assignment(
4✔
800
                    target, node.value, target_name, indices, debug_info
801
                )
802
                return
4✔
803

804
            # Handle rhs and store in scalar tmp
805
            rhs_tmp = self.visit(node.value)
4✔
806

807
            # Evaluate the LHS (index) expression before creating the store
808
            # block/tasklet.
809
            lhs_expr = self.visit(target)
4✔
810

811
            block = self.builder.add_block(debug_info)
4✔
812
            t_task = self.builder.add_tasklet(
4✔
813
                block, TaskletCode.assign, ["_in"], ["_out"], debug_info
814
            )
815

816
            t_src, src_sub = self._add_read(block, rhs_tmp, debug_info)
4✔
817
            self.builder.add_memlet(
4✔
818
                block, t_src, "void", t_task, "_in", src_sub, None, debug_info
819
            )
820

821
            if "(" in lhs_expr and lhs_expr.endswith(")"):
4✔
822
                subset = lhs_expr[lhs_expr.find("(") + 1 : -1]
4✔
823
                tensor_dst = self.tensor_table[target_name]
4✔
824

825
                t_dst = self.builder.add_access(block, target_name, debug_info)
4✔
826
                self.builder.add_memlet(
4✔
827
                    block, t_task, "_out", t_dst, "void", subset, tensor_dst, debug_info
828
                )
829
            else:
830
                t_dst = self.builder.add_access(block, target_name, debug_info)
×
831
                self.builder.add_memlet(
×
832
                    block, t_task, "_out", t_dst, "void", "", None, debug_info
833
                )
834
            return
4✔
835

836
        # Fallback: lhs is a simple scalar assignments
837
        if not isinstance(target, ast.Name):
4✔
838
            raise NotImplementedError("Only assignment to variables supported")
×
839

840
        target_name = target.id
4✔
841
        rhs_tmp = self.visit(node.value)
4✔
842
        debug_info = get_debug_info(node, self.filename, self.function_name)
4✔
843

844
        if not self.builder.exists(target_name):
4✔
845
            if isinstance(node.value, ast.Constant):
4✔
846
                val = node.value.value
4✔
847
                if isinstance(val, int):
4✔
848
                    dtype = Scalar(PrimitiveType.Int64)
4✔
849
                elif isinstance(val, float):
4✔
850
                    dtype = Scalar(PrimitiveType.Double)
4✔
851
                elif isinstance(val, bool):
×
852
                    dtype = Scalar(PrimitiveType.Bool)
×
853
                else:
854
                    raise NotImplementedError(f"Cannot infer type for {val}")
×
855

856
                self.builder.add_container(target_name, dtype, False)
4✔
857
                self.container_table[target_name] = dtype
4✔
858
            else:
859
                self.builder.add_container(
4✔
860
                    target_name, self.container_table[rhs_tmp], False
861
                )
862
                self.container_table[target_name] = self.container_table[rhs_tmp]
4✔
863

864
        if rhs_tmp in self.tensor_table:
4✔
865
            self.tensor_table[target_name] = self.tensor_table[rhs_tmp]
4✔
866

867
        # Also copy shapes_runtime_info if available
868
        if rhs_tmp in self.shapes_runtime_info:
4✔
869
            self.shapes_runtime_info[target_name] = self.shapes_runtime_info[rhs_tmp]
4✔
870

871
        # Distinguish assignments: scalar -> tasklet, pointer -> reference_memlet
872
        src_type = self.container_table.get(rhs_tmp)
4✔
873
        dst_type = self.container_table[target_name]
4✔
874
        if src_type and isinstance(src_type, Pointer) and isinstance(dst_type, Pointer):
4✔
875
            block = self.builder.add_block(debug_info)
4✔
876
            t_src = self.builder.add_access(block, rhs_tmp, debug_info)
4✔
877
            t_dst = self.builder.add_access(block, target_name, debug_info)
4✔
878
            self.builder.add_reference_memlet(
4✔
879
                block, t_src, t_dst, "0", src_type, debug_info
880
            )
881
        elif (src_type and isinstance(src_type, Scalar)) or isinstance(
4✔
882
            dst_type, Scalar
883
        ):
884
            block = self.builder.add_block(debug_info)
4✔
885
            t_dst = self.builder.add_access(block, target_name, debug_info)
4✔
886
            t_task = self.builder.add_tasklet(
4✔
887
                block, TaskletCode.assign, ["_in"], ["_out"], debug_info
888
            )
889

890
            if src_type:
4✔
891
                t_src = self.builder.add_access(block, rhs_tmp, debug_info)
4✔
892
            else:
893
                t_src = self.builder.add_constant(block, rhs_tmp, dst_type, debug_info)
4✔
894

895
            self.builder.add_memlet(
4✔
896
                block, t_src, "void", t_task, "_in", "", None, debug_info
897
            )
898
            self.builder.add_memlet(
4✔
899
                block, t_task, "_out", t_dst, "void", "", None, debug_info
900
            )
901

902
    def visit_Expr(self, node):
4✔
903
        self.visit(node.value)
×
904

905
    def visit_If(self, node):
4✔
906
        cond = self.visit(node.test)
4✔
907
        debug_info = get_debug_info(node, self.filename, self.function_name)
4✔
908
        self.builder.begin_if(f"{cond} != false", debug_info)
4✔
909

910
        for stmt in node.body:
4✔
911
            self.visit(stmt)
4✔
912

913
        if node.orelse:
4✔
914
            self.builder.begin_else(debug_info)
4✔
915
            for stmt in node.orelse:
4✔
916
                self.visit(stmt)
4✔
917

918
        self.builder.end_if()
4✔
919

920
    def visit_While(self, node):
4✔
921
        if node.orelse:
4✔
922
            raise NotImplementedError("while-else is not supported")
×
923

924
        debug_info = get_debug_info(node, self.filename, self.function_name)
4✔
925
        self.builder.begin_while(debug_info)
4✔
926

927
        # Evaluate condition inside the loop so it's re-evaluated each iteration
928
        cond = self.visit(node.test)
4✔
929

930
        # Create if-break pattern: if condition is false, break
931
        self.builder.begin_if(f"{cond} == false", debug_info)
4✔
932
        self.builder.add_break(debug_info)
4✔
933
        self.builder.end_if()
4✔
934

935
        for stmt in node.body:
4✔
936
            self.visit(stmt)
4✔
937

938
        self.builder.end_while()
4✔
939

940
    def visit_Break(self, node):
4✔
941
        debug_info = get_debug_info(node, self.filename, self.function_name)
4✔
942
        self.builder.add_break(debug_info)
4✔
943

944
    def visit_Continue(self, node):
4✔
945
        debug_info = get_debug_info(node, self.filename, self.function_name)
4✔
946
        self.builder.add_continue(debug_info)
4✔
947

948
    def visit_For(self, node):
4✔
949
        if node.orelse:
4✔
950
            raise NotImplementedError("while-else is not supported")
×
951
        if not isinstance(node.target, ast.Name):
4✔
952
            raise NotImplementedError("Only simple for loops supported")
×
953

954
        var = node.target.id
4✔
955
        debug_info = get_debug_info(node, self.filename, self.function_name)
4✔
956

957
        # Check if iterating over a range() call
958
        if (
4✔
959
            isinstance(node.iter, ast.Call)
960
            and isinstance(node.iter.func, ast.Name)
961
            and node.iter.func.id == "range"
962
        ):
963
            args = node.iter.args
4✔
964
            if len(args) == 1:
4✔
965
                start = "0"
4✔
966
                end = self.visit(args[0])
4✔
967
                step = "1"
4✔
968
            elif len(args) == 2:
4✔
969
                start = self.visit(args[0])
4✔
970
                end = self.visit(args[1])
4✔
971
                step = "1"
4✔
972
            elif len(args) == 3:
4✔
973
                start = self.visit(args[0])
4✔
974
                end = self.visit(args[1])
4✔
975

976
                # Special handling for step to avoid creating tasklets for constants
977
                step_node = args[2]
4✔
978
                if isinstance(step_node, ast.Constant):
4✔
979
                    step = str(step_node.value)
4✔
980
                elif (
4✔
981
                    isinstance(step_node, ast.UnaryOp)
982
                    and isinstance(step_node.op, ast.USub)
983
                    and isinstance(step_node.operand, ast.Constant)
984
                ):
985
                    step = f"-{step_node.operand.value}"
4✔
986
                else:
987
                    step = self.visit(step_node)
×
988
            else:
989
                raise ValueError("Invalid range arguments")
×
990

991
            if not self.builder.exists(var):
4✔
992
                self.builder.add_container(var, Scalar(PrimitiveType.Int64), False)
4✔
993
                self.container_table[var] = Scalar(PrimitiveType.Int64)
4✔
994

995
            self.builder.begin_for(var, start, end, step, debug_info)
4✔
996

997
            for stmt in node.body:
4✔
998
                self.visit(stmt)
4✔
999

1000
            self.builder.end_for()
4✔
1001
            return
4✔
1002

1003
        # Check if iterating over an ndarray (for x in array)
1004
        if isinstance(node.iter, ast.Name):
×
1005
            iter_name = node.iter.id
×
1006
            if iter_name in self.tensor_table:
×
1007
                arr_info = self.tensor_table[iter_name]
×
1008
                if len(arr_info.shape) == 0:
×
1009
                    raise NotImplementedError("Cannot iterate over 0-dimensional array")
×
1010

1011
                # Get the size of the first dimension
1012
                arr_size = arr_info.shape[0]
×
1013

1014
                # Create a hidden index variable for the loop
1015
                idx_var = self.builder.find_new_name()
×
1016
                if not self.builder.exists(idx_var):
×
1017
                    self.builder.add_container(
×
1018
                        idx_var, Scalar(PrimitiveType.Int64), False
1019
                    )
1020
                    self.container_table[idx_var] = Scalar(PrimitiveType.Int64)
×
1021

1022
                # Determine the type of the loop variable (element type)
1023
                # For a 1D array, it's a scalar; for ND array, it's a view of N-1 dimensions
1024
                if len(arr_info.shape) == 1:
×
1025
                    # Element is a scalar - get the element type from the array's type
1026
                    arr_type = self.container_table.get(iter_name)
×
1027
                    if isinstance(arr_type, Pointer):
×
1028
                        elem_type = arr_type.pointee_type
×
1029
                    else:
1030
                        elem_type = Scalar(PrimitiveType.Double)  # Default fallback
×
1031

1032
                    if not self.builder.exists(var):
×
1033
                        self.builder.add_container(var, elem_type, False)
×
1034
                        self.container_table[var] = elem_type
×
1035
                else:
1036
                    # For multi-dimensional arrays, create a view/slice
1037
                    # The loop variable becomes a pointer to the sub-array
1038
                    inner_shapes = arr_info.shape[1:]
×
1039
                    inner_ndim = len(arr_info.shape) - 1
×
1040

1041
                    arr_type = self.container_table.get(iter_name)
×
1042
                    if isinstance(arr_type, Pointer):
×
1043
                        elem_type = arr_type  # Keep as pointer type for views
×
1044
                    else:
1045
                        elem_type = Pointer(Scalar(PrimitiveType.Double))
×
1046

1047
                    if not self.builder.exists(var):
×
1048
                        self.builder.add_container(var, elem_type, False)
×
1049
                        self.container_table[var] = elem_type
×
1050

1051
                    # Register the view in tensor_table
1052
                    self.tensor_table[var] = Tensor(
×
1053
                        element_type_from_sdfg_type(elem_type), inner_shapes
1054
                    )
1055

1056
                # Begin the for loop
1057
                self.builder.begin_for(idx_var, "0", str(arr_size), "1", debug_info)
×
1058

1059
                # Generate the assignment: var = array[idx_var]
1060
                # Create an AST node for the assignment and visit it
1061
                assign_node = ast.Assign(
×
1062
                    targets=[ast.Name(id=var, ctx=ast.Store())],
1063
                    value=ast.Subscript(
1064
                        value=ast.Name(id=iter_name, ctx=ast.Load()),
1065
                        slice=ast.Name(id=idx_var, ctx=ast.Load()),
1066
                        ctx=ast.Load(),
1067
                    ),
1068
                )
1069
                ast.copy_location(assign_node, node)
×
1070
                self.visit_Assign(assign_node)
×
1071

1072
                # Visit the loop body
1073
                for stmt in node.body:
×
1074
                    self.visit(stmt)
×
1075

1076
                self.builder.end_for()
×
1077
                return
×
1078

1079
        raise NotImplementedError(
×
1080
            f"Only range() loops and iteration over ndarrays supported, got: {ast.dump(node.iter)}"
1081
        )
1082

1083
    def visit_Return(self, node):
4✔
1084
        if node.value is None:
4✔
1085
            debug_info = get_debug_info(node, self.filename, self.function_name)
×
1086
            # Emit frees for all deferred allocations before returning
1087
            if self.memory_handler.has_allocations():
×
1088
                self.memory_handler.emit_frees()
×
1089
            self.builder.add_return("", debug_info)
×
1090
            return
×
1091

1092
        if isinstance(node.value, ast.Tuple):
4✔
1093
            values = node.value.elts
4✔
1094
        else:
1095
            values = [node.value]
4✔
1096

1097
        parsed_values = [self.visit(v) for v in values]
4✔
1098
        debug_info = get_debug_info(node, self.filename, self.function_name)
4✔
1099

1100
        if self.infer_return_type:
4✔
1101
            for i, res in enumerate(parsed_values):
4✔
1102
                ret_name = f"_docc_ret_{i}"
4✔
1103
                if not self.builder.exists(ret_name):
4✔
1104
                    dtype = Scalar(PrimitiveType.Double)
4✔
1105
                    if res in self.container_table:
4✔
1106
                        dtype = self.container_table[res]
4✔
1107
                    elif isinstance(values[i], ast.Constant):
×
1108
                        val = values[i].value
×
1109
                        if isinstance(val, int):
×
1110
                            dtype = Scalar(PrimitiveType.Int64)
×
1111
                        elif isinstance(val, float):
×
1112
                            dtype = Scalar(PrimitiveType.Double)
×
1113
                        elif isinstance(val, bool):
×
1114
                            dtype = Scalar(PrimitiveType.Bool)
×
1115

1116
                    # Wrap Scalar in Pointer. Keep Arrays/Pointers as is.
1117
                    arg_type = dtype
4✔
1118
                    if isinstance(dtype, Scalar):
4✔
1119
                        arg_type = Pointer(dtype)
4✔
1120

1121
                    self.builder.add_container(ret_name, arg_type, is_argument=True)
4✔
1122
                    self.container_table[ret_name] = arg_type
4✔
1123

1124
                    if res in self.tensor_table:
4✔
1125
                        self.tensor_table[ret_name] = self.tensor_table[res]
4✔
1126

1127
            self.infer_return_type = False
4✔
1128

1129
        for i, res in enumerate(parsed_values):
4✔
1130
            ret_name = f"_docc_ret_{i}"
4✔
1131
            typ = self.container_table.get(ret_name)
4✔
1132

1133
            is_array_return = False
4✔
1134
            if res in self.tensor_table:
4✔
1135
                # Only treat as array return if it has dimensions
1136
                # 0-d arrays (scalars) should be handled by scalar assignment
1137
                if len(self.tensor_table[res].shape) > 0:
4✔
1138
                    is_array_return = True
4✔
1139
            elif res in self.container_table:
4✔
1140
                if isinstance(self.container_table[res], Pointer):
4✔
1141
                    is_array_return = True
×
1142

1143
            # Simple Scalar Assignment
1144
            if not is_array_return:
4✔
1145
                block = self.builder.add_block(debug_info)
4✔
1146
                t_dst = self.builder.add_access(block, ret_name, debug_info)
4✔
1147

1148
                t_src, src_sub = self._add_read(block, res, debug_info)
4✔
1149

1150
                t_task = self.builder.add_tasklet(
4✔
1151
                    block, TaskletCode.assign, ["_in"], ["_out"], debug_info
1152
                )
1153
                self.builder.add_memlet(
4✔
1154
                    block, t_src, "void", t_task, "_in", src_sub, None, debug_info
1155
                )
1156
                self.builder.add_memlet(
4✔
1157
                    block, t_task, "_out", t_dst, "void", "0", None, debug_info
1158
                )
1159

1160
            # Array Assignment (Copy)
1161
            else:
1162
                # Record shape for metadata
1163
                if res in self.tensor_table:
4✔
1164
                    # Prefer runtime shapes if available (for indirect access patterns)
1165
                    # Fall back to regular shapes otherwise
1166
                    res_info = self.tensor_table[res]
4✔
1167
                    if res in self.shapes_runtime_info:
4✔
1168
                        shape = self.shapes_runtime_info[res]
4✔
1169
                    else:
1170
                        shape = res_info.shape
4✔
1171
                    # Convert to string expressions
1172
                    self.captured_return_shapes[ret_name] = [str(s) for s in shape]
4✔
1173

1174
                    # Return arrays are always contiguous - compute fresh strides
1175
                    contiguous_strides = self.numpy_visitor._compute_strides(shape, "C")
4✔
1176
                    self.captured_return_strides[ret_name] = [
4✔
1177
                        str(s) for s in contiguous_strides
1178
                    ]
1179

1180
                    # Always overwrite tensor_table for return arrays with contiguous strides
1181
                    # (source tensor may have non-standard strides from views/flip)
1182
                    self.tensor_table[ret_name] = Tensor(
4✔
1183
                        res_info.element_type, shape, contiguous_strides
1184
                    )
1185

1186
                # Copy Logic using visit_Assign
1187
                ndim = 1
4✔
1188
                if ret_name in self.tensor_table:
4✔
1189
                    ndim = len(self.tensor_table[ret_name].shape)
4✔
1190

1191
                slice_node = ast.Slice(lower=None, upper=None, step=None)
4✔
1192
                if ndim > 1:
4✔
1193
                    target_slice = ast.Tuple(elts=[slice_node] * ndim, ctx=ast.Load())
4✔
1194
                else:
1195
                    target_slice = slice_node
4✔
1196

1197
                target_sub = ast.Subscript(
4✔
1198
                    value=ast.Name(id=ret_name, ctx=ast.Load()),
1199
                    slice=target_slice,
1200
                    ctx=ast.Store(),
1201
                )
1202

1203
                # Value node reconstruction
1204
                if isinstance(values[i], ast.Name):
4✔
1205
                    val_node = values[i]
4✔
1206
                else:
1207
                    val_node = ast.Name(id=res, ctx=ast.Load())
4✔
1208

1209
                assign_node = ast.Assign(targets=[target_sub], value=val_node)
4✔
1210
                self.visit_Assign(assign_node)
4✔
1211

1212
        # Emit frees for all deferred allocations before returning
1213
        if self.memory_handler.has_allocations():
4✔
1214
            self.memory_handler.emit_frees()
4✔
1215

1216
        # Add control flow return to exit the function/path
1217
        self.builder.add_return("", debug_info)
4✔
1218

1219
    def visit_Call(self, node):
4✔
1220
        func_name = ""
4✔
1221
        module_name = ""
4✔
1222
        submodule_name = ""
4✔
1223
        if isinstance(node.func, ast.Attribute):
4✔
1224
            if isinstance(node.func.value, ast.Name):
4✔
1225
                if node.func.value.id == "math":
4✔
1226
                    module_name = "math"
4✔
1227
                    func_name = node.func.attr
4✔
1228
                elif node.func.value.id in ["numpy", "np"]:
4✔
1229
                    module_name = "numpy"
4✔
1230
                    func_name = node.func.attr
4✔
1231
                else:
1232
                    array_name = node.func.value.id
4✔
1233
                    method_name = node.func.attr
4✔
1234
                    if array_name in self.tensor_table and method_name == "astype":
4✔
1235
                        return self.numpy_visitor.handle_numpy_astype(node, array_name)
4✔
1236
                    elif array_name in self.tensor_table and method_name == "copy":
4✔
1237
                        return self.numpy_visitor.handle_numpy_copy(node, array_name)
4✔
1238
            elif isinstance(node.func.value, ast.Attribute):
4✔
1239
                if (
4✔
1240
                    isinstance(node.func.value.value, ast.Name)
1241
                    and node.func.value.value.id in ["numpy", "np"]
1242
                    and node.func.attr == "outer"
1243
                ):
1244
                    ufunc_name = node.func.value.attr
4✔
1245
                    return self.numpy_visitor.handle_ufunc_outer(node, ufunc_name)
4✔
1246

1247
        elif isinstance(node.func, ast.Name):
4✔
1248
            func_name = node.func.id
4✔
1249

1250
        if module_name == "numpy":
4✔
1251
            if self.numpy_visitor.has_handler(func_name):
4✔
1252
                return self.numpy_visitor.handle_numpy_call(node, func_name)
4✔
1253

1254
        if module_name == "math":
4✔
1255
            if self.math_handler.has_handler(func_name):
4✔
1256
                return self.math_handler.handle_math_call(node, func_name)
4✔
1257

1258
        if self.python_handler.has_handler(func_name):
4✔
1259
            return self.python_handler.handle_python_call(node, func_name)
4✔
1260

1261
        if func_name in self.globals_dict:
4✔
1262
            obj = self.globals_dict[func_name]
4✔
1263
            if inspect.isfunction(obj):
4✔
1264
                return self._handle_inline_call(node, obj)
4✔
1265

1266
        raise NotImplementedError(f"Function call {func_name} not supported")
×
1267

1268
    def _handle_inline_call(self, node, func_obj):
4✔
1269
        try:
4✔
1270
            source_lines, start_line = inspect.getsourcelines(func_obj)
4✔
1271
            source = textwrap.dedent("".join(source_lines))
4✔
1272
            tree = ast.parse(source)
4✔
1273
            func_def = tree.body[0]
4✔
1274
        except Exception as e:
×
1275
            raise NotImplementedError(
×
1276
                f"Could not parse function {func_obj.__name__}: {e}"
1277
            )
1278

1279
        arg_vars = [self.visit(arg) for arg in node.args]
4✔
1280

1281
        if len(arg_vars) != len(func_def.args.args):
4✔
1282
            raise NotImplementedError(
×
1283
                f"Argument count mismatch for {func_obj.__name__}"
1284
            )
1285

1286
        suffix = f"_{func_obj.__name__}_{self._get_unique_id()}"
4✔
1287
        res_name = f"_res{suffix}"
4✔
1288

1289
        # Combine globals with closure variables of the inlined function
1290
        combined_globals = dict(self.globals_dict)
4✔
1291
        closure_constants = {}  # name -> value for numeric closure vars
4✔
1292
        if func_obj.__closure__ is not None and func_obj.__code__.co_freevars:
4✔
1293
            for name, cell in zip(func_obj.__code__.co_freevars, func_obj.__closure__):
4✔
1294
                val = cell.cell_contents
4✔
1295
                combined_globals[name] = val
4✔
1296
                # Track numeric constants for injection
1297
                if isinstance(val, (int, float)) and not isinstance(val, bool):
4✔
1298
                    closure_constants[name] = val
4✔
1299

1300
        class VariableRenamer(ast.NodeTransformer):
4✔
1301
            BUILTINS = {
4✔
1302
                "range",
1303
                "len",
1304
                "int",
1305
                "float",
1306
                "bool",
1307
                "str",
1308
                "list",
1309
                "dict",
1310
                "tuple",
1311
                "set",
1312
                "print",
1313
                "abs",
1314
                "min",
1315
                "max",
1316
                "sum",
1317
                "enumerate",
1318
                "zip",
1319
                "map",
1320
                "filter",
1321
                "sorted",
1322
                "reversed",
1323
                "True",
1324
                "False",
1325
                "None",
1326
            }
1327

1328
            def __init__(self, suffix, globals_dict):
4✔
1329
                self.suffix = suffix
4✔
1330
                self.globals_dict = globals_dict
4✔
1331

1332
            def visit_Name(self, node):
4✔
1333
                if node.id in self.globals_dict or node.id in self.BUILTINS:
4✔
1334
                    return node
4✔
1335
                return ast.Name(id=f"{node.id}{self.suffix}", ctx=node.ctx)
4✔
1336

1337
            def visit_Return(self, node):
4✔
1338
                if node.value:
4✔
1339
                    val = self.visit(node.value)
4✔
1340
                    return ast.Assign(
4✔
1341
                        targets=[ast.Name(id=res_name, ctx=ast.Store())],
1342
                        value=val,
1343
                    )
1344
                return node
×
1345

1346
        renamer = VariableRenamer(suffix, combined_globals)
4✔
1347
        new_body = [renamer.visit(stmt) for stmt in func_def.body]
4✔
1348

1349
        param_assignments = []
4✔
1350

1351
        # Inject closure constants as assignments
1352
        for name, val in closure_constants.items():
4✔
1353
            if isinstance(val, int):
4✔
1354
                self.container_table[name] = Scalar(PrimitiveType.Int64)
4✔
1355
                self.builder.add_container(name, Scalar(PrimitiveType.Int64), False)
4✔
1356
                val_node = ast.Constant(value=val)
4✔
1357
            else:
1358
                self.container_table[name] = Scalar(PrimitiveType.Double)
×
1359
                self.builder.add_container(name, Scalar(PrimitiveType.Double), False)
×
1360
                val_node = ast.Constant(value=val)
×
1361
            assign = ast.Assign(
4✔
1362
                targets=[ast.Name(id=name, ctx=ast.Store())], value=val_node
1363
            )
1364
            param_assignments.append(assign)
4✔
1365

1366
        for arg_def, arg_val in zip(func_def.args.args, arg_vars):
4✔
1367
            param_name = f"{arg_def.arg}{suffix}"
4✔
1368

1369
            if arg_val in self.container_table:
4✔
1370
                self.container_table[param_name] = self.container_table[arg_val]
4✔
1371
                self.builder.add_container(
4✔
1372
                    param_name, self.container_table[arg_val], False
1373
                )
1374
                val_node = ast.Name(id=arg_val, ctx=ast.Load())
4✔
1375
            elif self._is_int(arg_val):
×
1376
                self.container_table[param_name] = Scalar(PrimitiveType.Int64)
×
1377
                self.builder.add_container(
×
1378
                    param_name, Scalar(PrimitiveType.Int64), False
1379
                )
1380
                val_node = ast.Constant(value=int(arg_val))
×
1381
            else:
1382
                try:
×
1383
                    val = float(arg_val)
×
1384
                    self.container_table[param_name] = Scalar(PrimitiveType.Double)
×
1385
                    self.builder.add_container(
×
1386
                        param_name, Scalar(PrimitiveType.Double), False
1387
                    )
1388
                    val_node = ast.Constant(value=val)
×
1389
                except ValueError:
×
1390
                    val_node = ast.Name(id=arg_val, ctx=ast.Load())
×
1391

1392
            assign = ast.Assign(
4✔
1393
                targets=[ast.Name(id=param_name, ctx=ast.Store())], value=val_node
1394
            )
1395
            param_assignments.append(assign)
4✔
1396

1397
        final_body = param_assignments + new_body
4✔
1398

1399
        # Create a new parser instance for the inlined function
1400
        # Share memory_handler so hoisted allocations go to main function entry
1401
        parser = ASTParser(
4✔
1402
            self.builder,
1403
            self.tensor_table,
1404
            self.container_table,
1405
            globals_dict=combined_globals,
1406
            unique_counter_ref=self._unique_counter_ref,
1407
            memory_handler=self.memory_handler,
1408
        )
1409

1410
        for stmt in final_body:
4✔
1411
            parser.visit(stmt)
4✔
1412

1413
        return res_name
4✔
1414

1415
    def _add_assign_constant(self, target_name, value_str, dtype):
4✔
1416
        block = self.builder.add_block()
4✔
1417
        t_const = self.builder.add_constant(block, value_str, dtype)
4✔
1418
        t_dst = self.builder.add_access(block, target_name)
4✔
1419
        t_task = self.builder.add_tasklet(block, TaskletCode.assign, ["_in"], ["_out"])
4✔
1420
        self.builder.add_memlet(block, t_const, "void", t_task, "_in", "")
4✔
1421
        self.builder.add_memlet(block, t_task, "_out", t_dst, "void", "")
4✔
1422

1423
    def _handle_expression_slicing(self, node, value_str, indices_nodes, shapes, ndim):
4✔
1424
        """Handle slicing in expressions (e.g., arr[1:, :, k+1]).
1425

1426
        Uses a zero-copy view when possible (positive step, no indirect access).
1427
        Falls back to copy-based approach for complex cases.
1428
        """
1429
        if not self.builder:
4✔
1430
            raise ValueError("Builder required for expression slicing")
×
1431

1432
        # Try view-based approach first (zero-copy)
1433
        if self._can_use_slice_view(indices_nodes):
4✔
1434
            return self._create_slice_view(value_str, indices_nodes, shapes, ndim)
4✔
1435

1436
        # Fall back to copy-based approach for complex cases
1437
        return self._handle_expression_slicing_copy(
4✔
1438
            node, value_str, indices_nodes, shapes, ndim
1439
        )
1440

1441
    def _can_use_slice_view(self, indices_nodes):
4✔
1442
        """Check if slicing can be expressed as a zero-copy view.
1443

1444
        Views can be used when:
1445
        - All steps are non-zero constants (positive or negative)
1446
        - No indirect array access in slice parameters
1447

1448
        Returns True if a view can be used, False if a copy is required.
1449
        """
1450
        for idx in indices_nodes:
4✔
1451
            if isinstance(idx, ast.Slice):
4✔
1452
                # Check for zero step (invalid)
1453
                if idx.step is not None:
4✔
1454
                    if isinstance(idx.step, ast.Constant):
4✔
1455
                        if idx.step.value == 0:
4✔
1456
                            return False  # Zero step is invalid
×
1457
                    elif isinstance(idx.step, ast.UnaryOp) and isinstance(
4✔
1458
                        idx.step.op, ast.USub
1459
                    ):
1460
                        # Negative step like -2 is OK
1461
                        pass
4✔
1462
                    elif self._contains_indirect_access(idx.step):
×
1463
                        return False  # Dynamic step requires copy
×
1464

1465
                # Check for indirect access in slice bounds
1466
                if idx.lower is not None and self._contains_indirect_access(idx.lower):
4✔
1467
                    return False
4✔
1468
                if idx.upper is not None and self._contains_indirect_access(idx.upper):
4✔
1469
                    return False
×
1470
            else:
1471
                # Fixed index: check for indirect access
1472
                if self._contains_indirect_access(idx):
4✔
1473
                    return False
×
1474
        return True
4✔
1475

1476
    def _create_slice_view(self, value_str, indices_nodes, shapes, ndim):
4✔
1477
        """Create a zero-copy view for array slicing.
1478

1479
        This creates a new tensor that shares data with the source but has
1480
        adjusted shape, strides, and offset to represent the sliced region.
1481

1482
        For positive step A[start:stop:step, ...] on dimension i:
1483
        - new_shape[i] = ceil((stop - start) / step)
1484
        - new_stride[i] = old_stride[i] * step
1485
        - offset contribution = start * old_stride[i]
1486

1487
        For negative step A[start:stop:step, ...] (e.g., ::-1):
1488
        - Default start = shape - 1 (last element)
1489
        - Default stop = -1 (before first element)
1490
        - new_shape[i] = ceil((start - stop) / abs(step))
1491
        - new_stride[i] = old_stride[i] * step (negative)
1492
        - offset contribution = start * old_stride[i] (points to last element)
1493

1494
        For a fixed index A[k, ...] on dimension i (dimension reduction):
1495
        - offset contribution = k * old_stride[i]
1496
        - dimension is removed from output
1497
        """
1498
        in_tensor = self.tensor_table[value_str]
4✔
1499
        in_shape = in_tensor.shape
4✔
1500
        dtype = in_tensor.element_type
4✔
1501

1502
        # Get input strides (compute if not available)
1503
        in_strides = (
4✔
1504
            in_tensor.strides
1505
            if hasattr(in_tensor, "strides") and in_tensor.strides
1506
            else None
1507
        )
1508
        if in_strides is None:
4✔
1509
            in_strides = self.numpy_visitor._compute_strides(in_shape, "C")
×
1510

1511
        # Get base offset from input tensor
1512
        in_offset = getattr(in_tensor, "offset", "0") or "0"
4✔
1513

1514
        # Build output shape, strides, and compute offset
1515
        out_shape = []
4✔
1516
        out_strides = []
4✔
1517
        offset_terms = []
4✔
1518
        if in_offset != "0":
4✔
1519
            offset_terms.append(str(in_offset))
4✔
1520

1521
        for i, idx in enumerate(indices_nodes):
4✔
1522
            shape_val = shapes[i] if i < len(shapes) else f"_{value_str}_shape_{i}"
4✔
1523
            stride_val = in_strides[i] if i < len(in_strides) else "1"
4✔
1524

1525
            if isinstance(idx, ast.Slice):
4✔
1526
                # Determine step value and sign
1527
                step_str = "1"
4✔
1528
                step_is_negative = False
4✔
1529
                step_value = 1
4✔
1530

1531
                if idx.step is not None:
4✔
1532
                    if isinstance(idx.step, ast.Constant):
4✔
1533
                        step_value = idx.step.value
4✔
1534
                        step_str = str(step_value)
4✔
1535
                        step_is_negative = step_value < 0
4✔
1536
                    elif isinstance(idx.step, ast.UnaryOp) and isinstance(
4✔
1537
                        idx.step.op, ast.USub
1538
                    ):
1539
                        # Handle -N syntax
1540
                        if isinstance(idx.step.operand, ast.Constant):
4✔
1541
                            step_value = -idx.step.operand.value
4✔
1542
                            step_str = str(step_value)
4✔
1543
                            step_is_negative = True
4✔
1544
                        else:
1545
                            step_str = self.visit(idx.step)
×
1546
                    else:
1547
                        step_str = self.visit(idx.step)
×
1548

1549
                if step_is_negative:
4✔
1550
                    # Negative step: iterate from end to start
1551
                    # Default start = shape - 1, default stop = -1 (before 0)
1552
                    if idx.lower is not None:
4✔
1553
                        start_str = self.visit(idx.lower)
×
1554
                        if isinstance(start_str, str) and (
×
1555
                            start_str.startswith("-") or start_str.startswith("(-")
1556
                        ):
1557
                            start_str = f"({shape_val} + {start_str})"
×
1558
                    else:
1559
                        start_str = f"({shape_val} - 1)"
4✔
1560

1561
                    if idx.upper is not None:
4✔
1562
                        stop_str = self.visit(idx.upper)
×
1563
                        if isinstance(stop_str, str) and (
×
1564
                            stop_str.startswith("-") or stop_str.startswith("(-")
1565
                        ):
1566
                            stop_str = f"({shape_val} + {stop_str})"
×
1567
                    else:
1568
                        stop_str = "-1"
4✔
1569

1570
                    # Shape for negative step: ceil((start - stop) / abs(step))
1571
                    abs_step = abs(step_value)
4✔
1572
                    if abs_step == 1:
4✔
1573
                        dim_size = f"({start_str} - {stop_str})"
4✔
1574
                    else:
1575
                        dim_size = f"(({start_str} - {stop_str} + {abs_step} - 1) / {abs_step})"
4✔
1576
                    out_shape.append(dim_size)
4✔
1577

1578
                    # Stride for negative step: old_stride * step (negative)
1579
                    out_strides.append(f"({stride_val} * {step_str})")
4✔
1580

1581
                    # Offset: start * old_stride (points to first element to access)
1582
                    offset_terms.append(f"({start_str} * {stride_val})")
4✔
1583
                else:
1584
                    # Positive step (original logic)
1585
                    start_str = "0"
4✔
1586
                    if idx.lower is not None:
4✔
1587
                        start_str = self.visit(idx.lower)
4✔
1588
                        if isinstance(start_str, str) and (
4✔
1589
                            start_str.startswith("-") or start_str.startswith("(-")
1590
                        ):
1591
                            start_str = f"({shape_val} + {start_str})"
×
1592

1593
                    stop_str = str(shape_val)
4✔
1594
                    if idx.upper is not None:
4✔
1595
                        stop_str = self.visit(idx.upper)
4✔
1596
                        if isinstance(stop_str, str) and (
4✔
1597
                            stop_str.startswith("-") or stop_str.startswith("(-")
1598
                        ):
1599
                            stop_str = f"({shape_val} + {stop_str})"
4✔
1600

1601
                    # Compute new shape: ceil((stop - start) / step)
1602
                    if step_str == "1":
4✔
1603
                        dim_size = f"({stop_str} - {start_str})"
4✔
1604
                    else:
1605
                        dim_size = f"idiv({stop_str} - {start_str} + {step_str} - 1, {step_str})"
4✔
1606
                    out_shape.append(dim_size)
4✔
1607

1608
                    # Compute new stride: old_stride * step
1609
                    if step_str == "1":
4✔
1610
                        out_strides.append(stride_val)
4✔
1611
                    else:
1612
                        out_strides.append(f"({stride_val} * {step_str})")
4✔
1613

1614
                    # Add offset contribution: start * stride
1615
                    if start_str != "0":
4✔
1616
                        offset_terms.append(f"({start_str} * {stride_val})")
4✔
1617
            else:
1618
                # Fixed index: dimension is removed, just add offset
1619
                index_str = self.visit(idx)
4✔
1620
                if isinstance(index_str, str) and (
4✔
1621
                    index_str.startswith("-") or index_str.startswith("(-")
1622
                ):
1623
                    index_str = f"({shape_val} + {index_str})"
4✔
1624
                offset_terms.append(f"({index_str} * {stride_val})")
4✔
1625

1626
        # Combine offset terms
1627
        if not offset_terms:
4✔
1628
            out_offset = "0"
4✔
1629
        elif len(offset_terms) == 1:
4✔
1630
            out_offset = offset_terms[0]
4✔
1631
        else:
1632
            out_offset = " + ".join(offset_terms)
4✔
1633

1634
        # Create new pointer container
1635
        tmp_name = self.builder.find_new_name("_slice_view_")
4✔
1636
        ptr_type = Pointer(dtype)
4✔
1637
        self.builder.add_container(tmp_name, ptr_type, False)
4✔
1638
        self.container_table[tmp_name] = ptr_type
4✔
1639

1640
        # Create output tensor with new shape, strides, and offset
1641
        # Offset is stored in the Tensor (like Tensor.flip() does)
1642
        # Reference memlet just creates the pointer alias with "0" offset
1643
        if out_shape:
4✔
1644
            out_tensor = Tensor(dtype, out_shape, out_strides, out_offset)
4✔
1645
            self.tensor_table[tmp_name] = out_tensor
4✔
1646
        else:
1647
            # Scalar result (all indices were fixed)
1648
            self.builder.add_container(tmp_name, dtype, False)
×
1649
            self.container_table[tmp_name] = dtype
×
1650

1651
        # Create reference memlet (offset is handled by tensor's offset property)
1652
        debug_info = DebugInfo()
4✔
1653
        block = self.builder.add_block(debug_info)
4✔
1654
        t_src = self.builder.add_access(block, value_str, debug_info)
4✔
1655
        t_dst = self.builder.add_access(block, tmp_name, debug_info)
4✔
1656
        self.builder.add_reference_memlet(block, t_src, t_dst, "0", ptr_type)
4✔
1657

1658
        return tmp_name
4✔
1659

1660
    def _handle_expression_slicing_copy(
4✔
1661
        self, node, value_str, indices_nodes, shapes, ndim
1662
    ):
1663
        """Copy-based slicing for cases that cannot use views.
1664

1665
        This allocates a new array and copies elements using nested loops.
1666
        Used for negative steps or indirect access patterns.
1667
        """
1668
        dtype = Scalar(PrimitiveType.Double)
4✔
1669
        if value_str in self.container_table:
4✔
1670
            t = self.container_table[value_str]
4✔
1671
            if isinstance(t, Pointer) and t.has_pointee_type():
4✔
1672
                dtype = t.pointee_type
4✔
1673

1674
        result_shapes = []
4✔
1675
        result_shapes_runtime = []
4✔
1676
        slice_info = []
4✔
1677
        index_info = []
4✔
1678

1679
        for i, idx in enumerate(indices_nodes):
4✔
1680
            shape_val = shapes[i] if i < len(shapes) else f"_{value_str}_shape_{i}"
4✔
1681

1682
            if isinstance(idx, ast.Slice):
4✔
1683
                start_str = "0"
4✔
1684
                start_str_runtime = "0"
4✔
1685
                if idx.lower is not None:
4✔
1686
                    if self._contains_indirect_access(idx.lower):
4✔
1687
                        start_str, start_str_runtime = (
4✔
1688
                            self._materialize_indirect_access(
1689
                                idx.lower, return_original_expr=True
1690
                            )
1691
                        )
1692
                    else:
1693
                        start_str = self.visit(idx.lower)
×
1694
                        start_str_runtime = start_str
×
1695
                    if isinstance(start_str, str) and (
4✔
1696
                        start_str.startswith("-") or start_str.startswith("(-")
1697
                    ):
1698
                        start_str = f"({shape_val} + {start_str})"
×
1699
                        start_str_runtime = f"({shape_val} + {start_str_runtime})"
×
1700

1701
                stop_str = str(shape_val)
4✔
1702
                stop_str_runtime = str(shape_val)
4✔
1703
                if idx.upper is not None:
4✔
1704
                    if self._contains_indirect_access(idx.upper):
4✔
1705
                        stop_str, stop_str_runtime = self._materialize_indirect_access(
4✔
1706
                            idx.upper, return_original_expr=True
1707
                        )
1708
                    else:
1709
                        stop_str = self.visit(idx.upper)
×
1710
                        stop_str_runtime = stop_str
×
1711
                    if isinstance(stop_str, str) and (
4✔
1712
                        stop_str.startswith("-") or stop_str.startswith("(-")
1713
                    ):
1714
                        stop_str = f"({shape_val} + {stop_str})"
×
1715
                        stop_str_runtime = f"({shape_val} + {stop_str_runtime})"
×
1716

1717
                step_str = "1"
4✔
1718
                if idx.step is not None:
4✔
1719
                    step_str = self.visit(idx.step)
×
1720

1721
                # Compute dimension size accounting for step: ceil((stop - start) / step)
1722
                # For symbolic expressions, use integer ceiling formula: idiv(n + d - 1, d)
1723
                if step_str == "1":
4✔
1724
                    dim_size = f"({stop_str} - {start_str})"
4✔
1725
                    dim_size_runtime = f"({stop_str_runtime} - {start_str_runtime})"
4✔
1726
                else:
1727
                    dim_size = (
×
1728
                        f"idiv({stop_str} - {start_str} + {step_str} - 1, {step_str})"
1729
                    )
1730
                    dim_size_runtime = f"idiv({stop_str_runtime} - {start_str_runtime} + {step_str} - 1, {step_str})"
×
1731
                result_shapes.append(dim_size)
4✔
1732
                result_shapes_runtime.append(dim_size_runtime)
4✔
1733
                slice_info.append((i, start_str, stop_str, step_str))
4✔
1734
            else:
1735
                if self._contains_indirect_access(idx):
×
1736
                    index_str = self._materialize_indirect_access(idx)
×
1737
                else:
1738
                    index_str = self.visit(idx)
×
1739
                if isinstance(index_str, str) and (
×
1740
                    index_str.startswith("-") or index_str.startswith("(-")
1741
                ):
1742
                    index_str = f"({shape_val} + {index_str})"
×
1743
                index_info.append((i, index_str))
×
1744

1745
        tmp_name = self.builder.find_new_name("_slice_tmp_")
4✔
1746
        result_ndim = len(result_shapes)
4✔
1747

1748
        if result_ndim == 0:
4✔
1749
            self.builder.add_container(tmp_name, dtype, False)
×
1750
            self.container_table[tmp_name] = dtype
×
1751
        else:
1752
            size_str = "1"
4✔
1753
            for dim in result_shapes:
4✔
1754
                size_str = f"({size_str} * {dim})"
4✔
1755

1756
            element_size = self.builder.get_sizeof(dtype)
4✔
1757
            total_size = f"({size_str} * {element_size})"
4✔
1758

1759
            ptr_type = Pointer(dtype)
4✔
1760
            self.builder.add_container(tmp_name, ptr_type, False)
4✔
1761
            self.container_table[tmp_name] = ptr_type
4✔
1762
            tensor_info = Tensor(dtype, result_shapes)
4✔
1763
            self.shapes_runtime_info[tmp_name] = (
4✔
1764
                result_shapes_runtime  # Store runtime shapes separately
1765
            )
1766
            self.tensor_table[tmp_name] = tensor_info
4✔
1767

1768
            debug_info = DebugInfo()
4✔
1769
            block_alloc = self.builder.add_block(debug_info)
4✔
1770
            t_malloc = self.builder.add_malloc(block_alloc, total_size)
4✔
1771
            t_ptr = self.builder.add_access(block_alloc, tmp_name, debug_info)
4✔
1772
            self.builder.add_memlet(
4✔
1773
                block_alloc, t_malloc, "_ret", t_ptr, "void", "", ptr_type, debug_info
1774
            )
1775

1776
        loop_vars = []
4✔
1777
        debug_info = DebugInfo()
4✔
1778

1779
        for dim_idx, (orig_dim, start_str, stop_str, step_str) in enumerate(slice_info):
4✔
1780
            loop_var = self.builder.find_new_name(f"_slice_loop_{dim_idx}_")
4✔
1781
            loop_vars.append((loop_var, orig_dim, start_str, step_str))
4✔
1782

1783
            if not self.builder.exists(loop_var):
4✔
1784
                self.builder.add_container(loop_var, Scalar(PrimitiveType.Int64), False)
4✔
1785
                self.container_table[loop_var] = Scalar(PrimitiveType.Int64)
4✔
1786

1787
            # Account for step in loop count: ceil((stop - start) / step)
1788
            if step_str == "1":
4✔
1789
                count_str = f"({stop_str} - {start_str})"
4✔
1790
            else:
1791
                count_str = (
×
1792
                    f"idiv({stop_str} - {start_str} + {step_str} - 1, {step_str})"
1793
                )
1794
            self.builder.begin_for(loop_var, "0", count_str, "1", debug_info)
4✔
1795

1796
        src_indices = [""] * ndim
4✔
1797
        dst_indices = []
4✔
1798

1799
        for orig_dim, index_str in index_info:
4✔
1800
            src_indices[orig_dim] = index_str
×
1801

1802
        for loop_var, orig_dim, start_str, step_str in loop_vars:
4✔
1803
            if step_str == "1":
4✔
1804
                src_indices[orig_dim] = f"({start_str} + {loop_var})"
4✔
1805
            else:
1806
                src_indices[orig_dim] = f"({start_str} + {loop_var} * {step_str})"
×
1807
            dst_indices.append(loop_var)
4✔
1808

1809
        src_linear = self._compute_linear_index(src_indices, shapes, value_str, ndim)
4✔
1810
        if result_ndim > 0:
4✔
1811
            dst_linear = self._compute_linear_index(
4✔
1812
                dst_indices, result_shapes, tmp_name, result_ndim
1813
            )
1814
        else:
1815
            dst_linear = "0"
×
1816

1817
        block = self.builder.add_block(debug_info)
4✔
1818
        t_src = self.builder.add_access(block, value_str, debug_info)
4✔
1819
        t_dst = self.builder.add_access(block, tmp_name, debug_info)
4✔
1820
        t_task = self.builder.add_tasklet(
4✔
1821
            block, TaskletCode.assign, ["_in"], ["_out"], debug_info
1822
        )
1823

1824
        self.builder.add_memlet(
4✔
1825
            block, t_src, "void", t_task, "_in", src_linear, None, debug_info
1826
        )
1827
        self.builder.add_memlet(
4✔
1828
            block, t_task, "_out", t_dst, "void", dst_linear, None, debug_info
1829
        )
1830

1831
        for _ in loop_vars:
4✔
1832
            self.builder.end_for()
4✔
1833

1834
        return tmp_name
4✔
1835

1836
    def _compute_linear_index(self, indices, shapes, array_name, ndim):
4✔
1837
        """Compute linear index from multi-dimensional indices."""
1838
        if ndim == 0:
4✔
1839
            return "0"
×
1840

1841
        linear_index = ""
4✔
1842
        for i in range(ndim):
4✔
1843
            term = str(indices[i])
4✔
1844
            for j in range(i + 1, ndim):
4✔
1845
                shape_val = shapes[j] if j < len(shapes) else f"_{array_name}_shape_{j}"
×
1846
                term = f"(({term}) * {shape_val})"
×
1847

1848
            if i == 0:
4✔
1849
                linear_index = term
4✔
1850
            else:
1851
                linear_index = f"({linear_index} + {term})"
×
1852

1853
        return linear_index
4✔
1854

1855
    def _is_array_index(self, node):
4✔
1856
        """Check if a node represents an array that could be used as an index (gather)."""
1857
        if isinstance(node, ast.Name):
4✔
1858
            return node.id in self.tensor_table
4✔
1859
        return False
4✔
1860

1861
    def _handle_gather(self, value_str, index_node, debug_info=None):
4✔
1862
        """Handle gather operation: x[indices] where indices is an array."""
1863
        if debug_info is None:
4✔
1864
            debug_info = DebugInfo()
4✔
1865

1866
        if isinstance(index_node, ast.Name):
4✔
1867
            idx_array_name = index_node.id
4✔
1868
        else:
1869
            idx_array_name = self.visit(index_node)
×
1870

1871
        if idx_array_name not in self.tensor_table:
4✔
NEW
1872
            raise ValueError(f"Gather index must be an array, got {idx_array_name}")
×
1873

1874
        idx_shapes = self.tensor_table[idx_array_name].shape
4✔
1875
        idx_ndim = len(idx_shapes)
4✔
1876

1877
        if idx_ndim != 1:
4✔
NEW
1878
            raise NotImplementedError("Only 1D index arrays supported for gather")
×
1879

1880
        result_shape = idx_shapes[0] if idx_shapes else f"_{idx_array_name}_shape_0"
4✔
1881

1882
        # For runtime evaluation, prefer shapes_runtime_info if available
1883
        # This ensures we use expressions that can be evaluated at runtime
1884
        if idx_array_name in self.shapes_runtime_info:
4✔
1885
            runtime_shapes = self.shapes_runtime_info[idx_array_name]
4✔
1886
            result_shape_runtime = runtime_shapes[0] if runtime_shapes else result_shape
4✔
1887
        else:
NEW
1888
            result_shape_runtime = result_shape
×
1889

1890
        dtype = Scalar(PrimitiveType.Double)
4✔
1891
        if value_str in self.container_table:
4✔
1892
            t = self.container_table[value_str]
4✔
1893
            if isinstance(t, Pointer) and t.has_pointee_type():
4✔
1894
                dtype = t.pointee_type
4✔
1895

1896
        idx_dtype = Scalar(PrimitiveType.Int64)
4✔
1897
        if idx_array_name in self.container_table:
4✔
1898
            t = self.container_table[idx_array_name]
4✔
1899
            if isinstance(t, Pointer) and t.has_pointee_type():
4✔
1900
                idx_dtype = t.pointee_type
4✔
1901

1902
        tmp_name = self.builder.find_new_name("_gather_")
4✔
1903

1904
        element_size = self.builder.get_sizeof(dtype)
4✔
1905
        total_size = f"({result_shape} * {element_size})"
4✔
1906

1907
        ptr_type = Pointer(dtype)
4✔
1908
        self.builder.add_container(tmp_name, ptr_type, False)
4✔
1909
        self.container_table[tmp_name] = ptr_type
4✔
1910
        self.tensor_table[tmp_name] = Tensor(dtype, [result_shape])
4✔
1911
        # Store runtime evaluable shape for this gather result
1912
        self.shapes_runtime_info[tmp_name] = [result_shape_runtime]
4✔
1913

1914
        block_alloc = self.builder.add_block(debug_info)
4✔
1915
        t_malloc = self.builder.add_malloc(block_alloc, total_size)
4✔
1916
        t_ptr = self.builder.add_access(block_alloc, tmp_name, debug_info)
4✔
1917
        self.builder.add_memlet(
4✔
1918
            block_alloc, t_malloc, "_ret", t_ptr, "void", "", ptr_type, debug_info
1919
        )
1920

1921
        loop_var = self.builder.find_new_name("_gather_i_")
4✔
1922
        self.builder.add_container(loop_var, Scalar(PrimitiveType.Int64), False)
4✔
1923
        self.container_table[loop_var] = Scalar(PrimitiveType.Int64)
4✔
1924

1925
        idx_var = self.builder.find_new_name("_gather_idx_")
4✔
1926
        self.builder.add_container(idx_var, idx_dtype, False)
4✔
1927
        self.container_table[idx_var] = idx_dtype
4✔
1928

1929
        self.builder.begin_for(loop_var, "0", str(result_shape), "1", debug_info)
4✔
1930

1931
        block_load_idx = self.builder.add_block(debug_info)
4✔
1932
        idx_arr_access = self.builder.add_access(
4✔
1933
            block_load_idx, idx_array_name, debug_info
1934
        )
1935
        idx_var_access = self.builder.add_access(block_load_idx, idx_var, debug_info)
4✔
1936
        tasklet_load = self.builder.add_tasklet(
4✔
1937
            block_load_idx, TaskletCode.assign, ["_in"], ["_out"], debug_info
1938
        )
1939
        self.builder.add_memlet(
4✔
1940
            block_load_idx,
1941
            idx_arr_access,
1942
            "void",
1943
            tasklet_load,
1944
            "_in",
1945
            loop_var,
1946
            None,
1947
            debug_info,
1948
        )
1949
        self.builder.add_memlet(
4✔
1950
            block_load_idx,
1951
            tasklet_load,
1952
            "_out",
1953
            idx_var_access,
1954
            "void",
1955
            "",
1956
            None,
1957
            debug_info,
1958
        )
1959

1960
        block_gather = self.builder.add_block(debug_info)
4✔
1961
        src_access = self.builder.add_access(block_gather, value_str, debug_info)
4✔
1962
        dst_access = self.builder.add_access(block_gather, tmp_name, debug_info)
4✔
1963
        tasklet_gather = self.builder.add_tasklet(
4✔
1964
            block_gather, TaskletCode.assign, ["_in"], ["_out"], debug_info
1965
        )
1966

1967
        self.builder.add_memlet(
4✔
1968
            block_gather,
1969
            src_access,
1970
            "void",
1971
            tasklet_gather,
1972
            "_in",
1973
            idx_var,
1974
            None,
1975
            debug_info,
1976
        )
1977
        self.builder.add_memlet(
4✔
1978
            block_gather,
1979
            tasklet_gather,
1980
            "_out",
1981
            dst_access,
1982
            "void",
1983
            loop_var,
1984
            None,
1985
            debug_info,
1986
        )
1987

1988
        self.builder.end_for()
4✔
1989

1990
        return tmp_name
4✔
1991

1992
    def _get_max_array_ndim_in_expr(self, node):
4✔
1993
        """Get the maximum array dimensionality in an expression."""
1994
        max_ndim = 0
4✔
1995

1996
        class NdimVisitor(ast.NodeVisitor):
4✔
1997
            def __init__(self, tensor_table):
4✔
1998
                self.tensor_table = tensor_table
4✔
1999
                self.max_ndim = 0
4✔
2000

2001
            def visit_Name(self, node):
4✔
2002
                if node.id in self.tensor_table:
4✔
2003
                    ndim = len(self.tensor_table[node.id].shape)
4✔
2004
                    self.max_ndim = max(self.max_ndim, ndim)
4✔
2005
                return self.generic_visit(node)
4✔
2006

2007
        visitor = NdimVisitor(self.tensor_table)
4✔
2008
        visitor.visit(node)
4✔
2009
        return visitor.max_ndim
4✔
2010

2011
    def _handle_broadcast_slice_assignment(
4✔
2012
        self,
2013
        target,
2014
        materialized_rhs,
2015
        target_name,
2016
        indices,
2017
        target_ndim,
2018
        value_ndim,
2019
        debug_info,
2020
    ):
2021
        """Handle slice assignment with broadcasting (e.g., 2D[:,:] = 1D[:]).
2022

2023
        materialized_rhs is the already-evaluated RHS array name (not AST node).
2024
        """
2025
        broadcast_dims = target_ndim - value_ndim
×
2026
        shapes = self.tensor_table[target_name].shape
×
2027
        rhs_tensor = self.tensor_table.get(materialized_rhs)
×
2028
        rhs_shapes = rhs_tensor.shape if rhs_tensor else []
×
2029

2030
        # Create outer loops for broadcast dimensions
2031
        outer_loop_vars = []
×
2032
        for i in range(broadcast_dims):
×
2033
            loop_var = self.builder.find_new_name(f"_bcast_iter_{i}_")
×
2034
            outer_loop_vars.append(loop_var)
×
2035

2036
            if not self.builder.exists(loop_var):
×
2037
                self.builder.add_container(loop_var, Scalar(PrimitiveType.Int64), False)
×
2038
                self.container_table[loop_var] = Scalar(PrimitiveType.Int64)
×
2039

2040
            dim_size = shapes[i] if i < len(shapes) else f"_{target_name}_shape_{i}"
×
2041
            self.builder.begin_for(loop_var, "0", str(dim_size), "1", debug_info)
×
2042

2043
        # Create inner loops for value dimensions
2044
        inner_loop_vars = []
×
2045
        for i in range(value_ndim):
×
2046
            loop_var = self.builder.find_new_name(f"_inner_iter_{i}_")
×
2047
            inner_loop_vars.append(loop_var)
×
2048

2049
            if not self.builder.exists(loop_var):
×
2050
                self.builder.add_container(loop_var, Scalar(PrimitiveType.Int64), False)
×
2051
                self.container_table[loop_var] = Scalar(PrimitiveType.Int64)
×
2052

2053
            # Use RHS shape for inner dimension bounds
2054
            dim_size = (
×
2055
                rhs_shapes[i] if i < len(rhs_shapes) else shapes[broadcast_dims + i]
2056
            )
2057
            self.builder.begin_for(loop_var, "0", str(dim_size), "1", debug_info)
×
2058

2059
        # Create assignment block: target[outer_vars, inner_vars] = rhs[inner_vars]
2060
        block = self.builder.add_block(debug_info)
×
2061
        t_src = self.builder.add_access(block, materialized_rhs, debug_info)
×
2062
        t_dst = self.builder.add_access(block, target_name, debug_info)
×
2063
        t_task = self.builder.add_tasklet(
×
2064
            block, TaskletCode.assign, ["_in"], ["_out"], debug_info
2065
        )
2066

2067
        # Source index: just inner loop vars
2068
        src_index = ",".join(inner_loop_vars) if inner_loop_vars else "0"
×
2069

2070
        # Target index: outer_vars + inner_vars combined
2071
        all_target_vars = outer_loop_vars + inner_loop_vars
×
2072
        target_index = ",".join(all_target_vars) if all_target_vars else "0"
×
2073

2074
        self.builder.add_memlet(
×
2075
            block, t_src, "void", t_task, "_in", src_index, rhs_tensor, debug_info
2076
        )
2077

2078
        tensor_dst = self.tensor_table[target_name]
×
2079
        self.builder.add_memlet(
×
2080
            block, t_task, "_out", t_dst, "void", target_index, tensor_dst, debug_info
2081
        )
2082

2083
        # Close all loops (inner first, then outer)
2084
        for _ in inner_loop_vars:
×
2085
            self.builder.end_for()
×
2086
        for _ in outer_loop_vars:
×
2087
            self.builder.end_for()
×
2088

2089
    def _handle_slice_assignment(
4✔
2090
        self, target, value, target_name, indices, debug_info=None
2091
    ):
2092
        if debug_info is None:
4✔
2093
            debug_info = DebugInfo()
×
2094

2095
        # Add missing dimensions
2096
        tensor_info = self.tensor_table[target_name]
4✔
2097
        ndim = len(tensor_info.shape)
4✔
2098
        if len(indices) < ndim:
4✔
2099
            indices = list(indices)
4✔
2100
            for _ in range(ndim - len(indices)):
4✔
2101
                indices.append(ast.Slice(lower=None, upper=None, step=None))
4✔
2102

2103
        # Handle ufunc outer case separately to preserve slice shape info
2104
        has_outer, ufunc_name, outer_node = contains_ufunc_outer(value)
4✔
2105
        if has_outer:
4✔
2106
            self._handle_ufunc_outer_slice_assignment(
4✔
2107
                target, value, target_name, indices, debug_info
2108
            )
2109
            return
4✔
2110

2111
        # Count slice dimensions to determine effective target dimensionality
2112
        target_slice_ndim = sum(1 for idx in indices if isinstance(idx, ast.Slice))
4✔
2113
        value_max_ndim = self._get_max_array_ndim_in_expr(value)
4✔
2114

2115
        # ALWAYS evaluate RHS first (NumPy semantics) - before any loops
2116
        materialized_rhs = self.visit(value)
4✔
2117

2118
        if (
4✔
2119
            target_slice_ndim > 0
2120
            and value_max_ndim > 0
2121
            and target_slice_ndim > value_max_ndim
2122
        ):
2123
            # Broadcasting case: use row-by-row approach with reference memlets
2124
            self._handle_broadcast_slice_assignment(
×
2125
                target,
2126
                materialized_rhs,
2127
                target_name,
2128
                indices,
2129
                target_slice_ndim,
2130
                value_max_ndim,
2131
                debug_info,
2132
            )
2133
            return
×
2134

2135
        loop_vars = []
4✔
2136
        new_target_indices = []
4✔
2137

2138
        for i, idx in enumerate(indices):
4✔
2139
            if isinstance(idx, ast.Slice):
4✔
2140
                loop_var = self.builder.find_new_name(f"_slice_iter_{len(loop_vars)}_")
4✔
2141
                loop_vars.append(loop_var)
4✔
2142

2143
                if not self.builder.exists(loop_var):
4✔
2144
                    self.builder.add_container(
4✔
2145
                        loop_var, Scalar(PrimitiveType.Int64), False
2146
                    )
2147
                    self.container_table[loop_var] = Scalar(PrimitiveType.Int64)
4✔
2148

2149
                start_str = "0"
4✔
2150
                if idx.lower:
4✔
2151
                    start_str = self.visit(idx.lower)
4✔
2152
                    if start_str.startswith("-"):
4✔
2153
                        dim_size = (
×
2154
                            str(tensor_info.shape[i])
2155
                            if i < len(tensor_info.shape)
2156
                            else f"_{target_name}_shape_{i}"
2157
                        )
2158
                        start_str = f"({dim_size} {start_str})"
×
2159

2160
                stop_str = ""
4✔
2161
                if idx.upper and not (
4✔
2162
                    isinstance(idx.upper, ast.Constant) and idx.upper.value is None
2163
                ):
2164
                    stop_str = self.visit(idx.upper)
4✔
2165
                    if stop_str.startswith("-") or stop_str.startswith("(-"):
4✔
2166
                        dim_size = (
×
2167
                            str(tensor_info.shape[i])
2168
                            if i < len(tensor_info.shape)
2169
                            else f"_{target_name}_shape_{i}"
2170
                        )
2171
                        stop_str = f"({dim_size} {stop_str})"
×
2172
                else:
2173
                    stop_str = (
4✔
2174
                        str(tensor_info.shape[i])
2175
                        if i < len(tensor_info.shape)
2176
                        else f"_{target_name}_shape_{i}"
2177
                    )
2178

2179
                step_str = "1"
4✔
2180
                if idx.step:
4✔
2181
                    step_str = self.visit(idx.step)
×
2182

2183
                count_str = f"({stop_str} - {start_str})"
4✔
2184

2185
                self.builder.begin_for(loop_var, "0", count_str, "1", debug_info)
4✔
2186
                self.container_table[loop_var] = Scalar(PrimitiveType.Int64)
4✔
2187
                new_target_indices.append(
4✔
2188
                    ast.Name(
2189
                        id=f"{start_str} + {loop_var} * {step_str}", ctx=ast.Load()
2190
                    )
2191
                )
2192
            else:
2193
                dim_size = (
4✔
2194
                    tensor_info.shape[i]
2195
                    if i < len(tensor_info.shape)
2196
                    else f"_{target_name}_shape_{i}"
2197
                )
2198
                normalized_idx = normalize_negative_index(idx, dim_size)
4✔
2199
                # intermediate computations are placed outside the loops
2200
                idx_str = self.visit(normalized_idx)
4✔
2201
                new_target_indices.append(ast.Name(id=idx_str, ctx=ast.Load()))
4✔
2202

2203
        rewriter = SliceRewriter(loop_vars, self.tensor_table, self)
4✔
2204
        new_value = rewriter.visit(copy.deepcopy(value))
4✔
2205

2206
        new_target = copy.deepcopy(target)
4✔
2207
        if len(new_target_indices) == 1:
4✔
2208
            new_target.slice = new_target_indices[0]
4✔
2209
        else:
2210
            new_target.slice = ast.Tuple(elts=new_target_indices, ctx=ast.Load())
4✔
2211

2212
        rhs_memlet_type = None
4✔
2213
        rhs_indexed_subset = ""
4✔
2214
        if materialized_rhs in self.tensor_table:
4✔
2215
            rhs_tensor = self.tensor_table[materialized_rhs]
4✔
2216
            rhs_ndim = len(rhs_tensor.shape)
4✔
2217
            if rhs_ndim > 0 and rhs_ndim == len(loop_vars):
4✔
2218
                # RHS is an array matching the slice dimensions - index it with loop vars
2219
                rhs_indexed_subset = ",".join(loop_vars)
4✔
2220
                rhs_memlet_type = rhs_tensor
4✔
2221

2222
        block = self.builder.add_block(debug_info)
4✔
2223
        t_task = self.builder.add_tasklet(
4✔
2224
            block, TaskletCode.assign, ["_in"], ["_out"], debug_info
2225
        )
2226

2227
        t_src, src_sub = self._add_read(block, materialized_rhs, debug_info)
4✔
2228
        # Use indexed subset if RHS is an array that needs indexing
2229
        actual_src_sub = rhs_indexed_subset if rhs_indexed_subset else src_sub
4✔
2230
        self.builder.add_memlet(
4✔
2231
            block,
2232
            t_src,
2233
            "void",
2234
            t_task,
2235
            "_in",
2236
            actual_src_sub,
2237
            rhs_memlet_type,
2238
            debug_info,
2239
        )
2240

2241
        lhs_expr = self.visit(new_target)
4✔
2242
        if "(" in lhs_expr and lhs_expr.endswith(")"):
4✔
2243
            subset = lhs_expr[lhs_expr.find("(") + 1 : -1]
4✔
2244
            tensor_dst = self.tensor_table[target_name]
4✔
2245

2246
            t_dst = self.builder.add_access(block, target_name, debug_info)
4✔
2247
            self.builder.add_memlet(
4✔
2248
                block, t_task, "_out", t_dst, "void", subset, tensor_dst, debug_info
2249
            )
2250
        else:
2251
            t_dst = self.builder.add_access(block, target_name, debug_info)
×
2252
            self.builder.add_memlet(
×
2253
                block, t_task, "_out", t_dst, "void", "", None, debug_info
2254
            )
2255

2256
        for _ in loop_vars:
4✔
2257
            self.builder.end_for()
4✔
2258

2259
    def _handle_ufunc_outer_slice_assignment(
4✔
2260
        self, target, value, target_name, indices, debug_info=None
2261
    ):
2262
        """Handle slice assignment where RHS contains a ufunc outer operation.
2263

2264
        Example: path[:] = np.minimum(path[:], np.add.outer(path[:, k], path[k, :]))
2265

2266
        The strategy is:
2267
        1. Evaluate the entire RHS expression, which will create a temporary array
2268
           containing the result of the ufunc outer (potentially wrapped in other ops)
2269
        2. Copy the temporary result to the target slice
2270

2271
        This avoids the loop transformation that would destroy slice shape info.
2272
        """
2273
        if debug_info is None:
4✔
2274
            from docc.sdfg import DebugInfo
×
2275

2276
            debug_info = DebugInfo()
×
2277

2278
        # Evaluate the full RHS expression
2279
        # This will:
2280
        # - Create temp arrays for ufunc outer results
2281
        # - Apply any wrapping operations (np.minimum, etc.)
2282
        # - Return the name of the final result array
2283
        result_name = self.visit(value)
4✔
2284

2285
        # Now we need to copy result to target slice
2286
        # Count slice dimensions to determine if we need loops
2287
        target_slice_ndim = sum(1 for idx in indices if isinstance(idx, ast.Slice))
4✔
2288

2289
        if target_slice_ndim == 0:
4✔
2290
            # No slices on target - just simple assignment
2291
            target_str = self.visit(target)
×
2292
            block = self.builder.add_block(debug_info)
×
2293
            t_src, src_sub = self._add_read(block, result_name, debug_info)
×
2294
            t_dst = self.builder.add_access(block, target_str, debug_info)
×
2295
            t_task = self.builder.add_tasklet(
×
2296
                block, TaskletCode.assign, ["_in"], ["_out"], debug_info
2297
            )
2298
            self.builder.add_memlet(
×
2299
                block, t_src, "void", t_task, "_in", src_sub, None, debug_info
2300
            )
2301
            self.builder.add_memlet(
×
2302
                block, t_task, "_out", t_dst, "void", "", None, debug_info
2303
            )
2304
            return
×
2305

2306
        # We have slices on the target - need to create loops for copying
2307
        # Get target array info
2308
        target_shapes = self.tensor_table[target_name].shape
4✔
2309

2310
        loop_vars = []
4✔
2311
        new_target_indices = []
4✔
2312

2313
        for i, idx in enumerate(indices):
4✔
2314
            if isinstance(idx, ast.Slice):
4✔
2315
                loop_var = self.builder.find_new_name(f"_copy_iter_{len(loop_vars)}_")
4✔
2316
                loop_vars.append(loop_var)
4✔
2317

2318
                if not self.builder.exists(loop_var):
4✔
2319
                    self.builder.add_container(
4✔
2320
                        loop_var, Scalar(PrimitiveType.Int64), False
2321
                    )
2322
                    self.container_table[loop_var] = Scalar(PrimitiveType.Int64)
4✔
2323

2324
                start_str = "0"
4✔
2325
                if idx.lower:
4✔
2326
                    start_str = self.visit(idx.lower)
×
2327

2328
                stop_str = ""
4✔
2329
                if idx.upper and not (
4✔
2330
                    isinstance(idx.upper, ast.Constant) and idx.upper.value is None
2331
                ):
2332
                    stop_str = self.visit(idx.upper)
×
2333
                else:
2334
                    stop_str = (
4✔
2335
                        target_shapes[i]
2336
                        if i < len(target_shapes)
2337
                        else f"_{target_name}_shape_{i}"
2338
                    )
2339

2340
                step_str = "1"
4✔
2341
                if idx.step:
4✔
2342
                    step_str = self.visit(idx.step)
×
2343

2344
                count_str = f"({stop_str} - {start_str})"
4✔
2345

2346
                self.builder.begin_for(loop_var, "0", count_str, "1", debug_info)
4✔
2347
                self.container_table[loop_var] = Scalar(PrimitiveType.Int64)
4✔
2348

2349
                new_target_indices.append(
4✔
2350
                    ast.Name(
2351
                        id=f"{start_str} + {loop_var} * {step_str}", ctx=ast.Load()
2352
                    )
2353
                )
2354
            else:
2355
                # Handle non-slice indices - need to normalize negative indices
2356
                dim_size = (
×
2357
                    target_shapes[i]
2358
                    if i < len(target_shapes)
2359
                    else f"_{target_name}_shape_{i}"
2360
                )
2361
                normalized_idx = normalize_negative_index(idx, dim_size)
×
2362
                # Visit the index NOW before any loops are opened to ensure
2363
                # intermediate computations are placed outside the loops
2364
                idx_str = self.visit(normalized_idx)
×
2365
                new_target_indices.append(ast.Name(id=idx_str, ctx=ast.Load()))
×
2366

2367
        # Create assignment block: target[i,j,...] = result[i,j,...]
2368
        block = self.builder.add_block(debug_info)
4✔
2369

2370
        # Access nodes
2371
        t_src = self.builder.add_access(block, result_name, debug_info)
4✔
2372
        t_dst = self.builder.add_access(block, target_name, debug_info)
4✔
2373
        t_task = self.builder.add_tasklet(
4✔
2374
            block, TaskletCode.assign, ["_in"], ["_out"], debug_info
2375
        )
2376

2377
        # Source index - just use loop vars for flat array from ufunc outer
2378
        # The ufunc outer result is a flat array of size M*N
2379
        if len(loop_vars) == 2:
4✔
2380
            # 2D case: result is indexed as i * N + j
2381
            # Get the second dimension size from target shapes
2382
            n_dim = (
4✔
2383
                target_shapes[1]
2384
                if len(target_shapes) > 1
2385
                else f"_{target_name}_shape_1"
2386
            )
2387
            src_index = f"(({loop_vars[0]}) * ({n_dim}) + ({loop_vars[1]}))"
4✔
2388
        elif len(loop_vars) == 1:
×
2389
            src_index = loop_vars[0]
×
2390
        else:
2391
            # General case - compute linear index
2392
            src_terms = []
×
2393
            stride = "1"
×
2394
            for i in range(len(loop_vars) - 1, -1, -1):
×
2395
                if stride == "1":
×
2396
                    src_terms.insert(0, loop_vars[i])
×
2397
                else:
2398
                    src_terms.insert(0, f"({loop_vars[i]} * {stride})")
×
2399
                if i > 0:
×
2400
                    dim_size = (
×
2401
                        target_shapes[i]
2402
                        if i < len(target_shapes)
2403
                        else f"_{target_name}_shape_{i}"
2404
                    )
2405
                    stride = (
×
2406
                        f"({stride} * {dim_size})" if stride != "1" else str(dim_size)
2407
                    )
2408
            src_index = " + ".join(src_terms) if src_terms else "0"
×
2409

2410
        # Target index - compute linear index (row-major order)
2411
        # For 2D array with shape (M, N): linear_index = i * N + j
2412
        target_index_parts = []
4✔
2413
        for idx in new_target_indices:
4✔
2414
            if isinstance(idx, ast.Name):
4✔
2415
                target_index_parts.append(idx.id)
4✔
2416
            else:
2417
                target_index_parts.append(self.visit(idx))
×
2418

2419
        # Convert to linear index
2420
        if len(target_index_parts) == 2:
4✔
2421
            # 2D case
2422
            n_dim = (
4✔
2423
                target_shapes[1]
2424
                if len(target_shapes) > 1
2425
                else f"_{target_name}_shape_1"
2426
            )
2427
            target_index = (
4✔
2428
                f"(({target_index_parts[0]}) * ({n_dim}) + ({target_index_parts[1]}))"
2429
            )
2430
        elif len(target_index_parts) == 1:
×
2431
            target_index = target_index_parts[0]
×
2432
        else:
2433
            # General case - compute linear index with strides
2434
            stride = "1"
×
2435
            target_index = "0"
×
2436
            for i in range(len(target_index_parts) - 1, -1, -1):
×
2437
                idx_part = target_index_parts[i]
×
2438
                if stride == "1":
×
2439
                    term = idx_part
×
2440
                else:
2441
                    term = f"(({idx_part}) * ({stride}))"
×
2442

2443
                if target_index == "0":
×
2444
                    target_index = term
×
2445
                else:
2446
                    target_index = f"({term} + {target_index})"
×
2447

2448
                if i > 0:
×
2449
                    dim_size = (
×
2450
                        target_shapes[i]
2451
                        if i < len(target_shapes)
2452
                        else f"_{target_name}_shape_{i}"
2453
                    )
2454
                    stride = (
×
2455
                        f"({stride} * {dim_size})" if stride != "1" else str(dim_size)
2456
                    )
2457

2458
        # Connect memlets
2459
        self.builder.add_memlet(
4✔
2460
            block, t_src, "void", t_task, "_in", src_index, None, debug_info
2461
        )
2462
        self.builder.add_memlet(
4✔
2463
            block, t_task, "_out", t_dst, "void", target_index, None, debug_info
2464
        )
2465

2466
        # End loops
2467
        for _ in loop_vars:
4✔
2468
            self.builder.end_for()
4✔
2469

2470
    def _contains_indirect_access(self, node):
4✔
2471
        """Check if an AST node contains any indirect array access."""
2472
        if isinstance(node, ast.Subscript):
4✔
2473
            if isinstance(node.value, ast.Name):
4✔
2474
                arr_name = node.value.id
4✔
2475
                if arr_name in self.tensor_table:
4✔
2476
                    return True
4✔
2477
        elif isinstance(node, ast.BinOp):
4✔
2478
            return self._contains_indirect_access(
4✔
2479
                node.left
2480
            ) or self._contains_indirect_access(node.right)
2481
        elif isinstance(node, ast.UnaryOp):
4✔
2482
            return self._contains_indirect_access(node.operand)
4✔
2483
        return False
4✔
2484

2485
    def _materialize_indirect_access(
4✔
2486
        self, node, debug_info=None, return_original_expr=False
2487
    ):
2488
        """Materialize an array access into a scalar variable using tasklet+memlets."""
2489
        if not self.builder:
4✔
2490
            expr = self.visit(node)
×
2491
            return (expr, expr) if return_original_expr else expr
×
2492

2493
        if debug_info is None:
4✔
2494
            debug_info = DebugInfo()
4✔
2495

2496
        if not isinstance(node, ast.Subscript):
4✔
2497
            expr = self.visit(node)
×
2498
            return (expr, expr) if return_original_expr else expr
×
2499

2500
        if not isinstance(node.value, ast.Name):
4✔
2501
            expr = self.visit(node)
×
2502
            return (expr, expr) if return_original_expr else expr
×
2503

2504
        arr_name = node.value.id
4✔
2505
        if arr_name not in self.tensor_table:
4✔
2506
            expr = self.visit(node)
×
2507
            return (expr, expr) if return_original_expr else expr
×
2508

2509
        dtype = Scalar(PrimitiveType.Int64)
4✔
2510
        if arr_name in self.container_table:
4✔
2511
            t = self.container_table[arr_name]
4✔
2512
            if isinstance(t, Pointer) and t.has_pointee_type():
4✔
2513
                dtype = t.pointee_type
4✔
2514

2515
        tmp_name = self.builder.find_new_name("_idx_")
4✔
2516
        self.builder.add_container(tmp_name, dtype, False)
4✔
2517
        self.container_table[tmp_name] = dtype
4✔
2518

2519
        ndim = len(self.tensor_table[arr_name].shape)
4✔
2520
        shapes = self.tensor_table[arr_name].shape
4✔
2521

2522
        if isinstance(node.slice, ast.Tuple):
4✔
2523
            indices = [self.visit(elt) for elt in node.slice.elts]
×
2524
        else:
2525
            indices = [self.visit(node.slice)]
4✔
2526

2527
        materialized_indices = []
4✔
2528
        for idx_str in indices:
4✔
2529
            if "(" in idx_str and idx_str.endswith(")"):
4✔
2530
                materialized_indices.append(idx_str)
×
2531
            else:
2532
                materialized_indices.append(idx_str)
4✔
2533

2534
        linear_index = self._compute_linear_index(
4✔
2535
            materialized_indices, shapes, arr_name, ndim
2536
        )
2537

2538
        block = self.builder.add_block(debug_info)
4✔
2539
        t_src = self.builder.add_access(block, arr_name, debug_info)
4✔
2540
        t_dst = self.builder.add_access(block, tmp_name, debug_info)
4✔
2541
        t_task = self.builder.add_tasklet(
4✔
2542
            block, TaskletCode.assign, ["_in"], ["_out"], debug_info
2543
        )
2544

2545
        self.builder.add_memlet(
4✔
2546
            block, t_src, "void", t_task, "_in", linear_index, None, debug_info
2547
        )
2548
        self.builder.add_memlet(
4✔
2549
            block, t_task, "_out", t_dst, "void", "", None, debug_info
2550
        )
2551

2552
        if return_original_expr:
4✔
2553
            original_expr = f"{arr_name}({linear_index})"
4✔
2554
            return (tmp_name, original_expr)
4✔
2555

2556
        return tmp_name
×
2557

2558
    def _get_unique_id(self):
4✔
2559
        self._unique_counter_ref[0] += 1
4✔
2560
        return self._unique_counter_ref[0]
4✔
2561

2562
    def _get_memlet_type_for_access(self, expr_str, subset):
4✔
2563
        """Get the Tensor type for an indexed array access expression.
2564

2565
        When accessing an array like "arr(i,j)" with a multi-dimensional subset,
2566
        we need to pass the Tensor type to add_memlet for correct type inference.
2567
        If the expression is a simple scalar variable or constant, returns None.
2568
        """
2569
        if not subset:
4✔
2570
            return None
4✔
2571

2572
        # Check if expr_str is an indexed array access like "arr(i,j)"
2573
        if "(" in expr_str and expr_str.endswith(")"):
×
2574
            name = expr_str.split("(")[0]
×
2575
            if name in self.tensor_table:
×
2576
                return self.tensor_table[name]
×
2577

2578
        # Check if expr_str is a simple array name with a non-empty subset from _add_read
2579
        if expr_str in self.tensor_table:
×
2580
            return self.tensor_table[expr_str]
×
2581

2582
        return None
×
2583

2584
    def _element_type(self, name):
4✔
2585
        if name in self.container_table:
4✔
2586
            return element_type_from_sdfg_type(self.container_table[name])
4✔
2587
        else:  # Constant
2588
            if self._is_int(name):
4✔
2589
                return Scalar(PrimitiveType.Int64)
4✔
2590
            else:
2591
                return Scalar(PrimitiveType.Double)
4✔
2592

2593
    def _is_int(self, operand):
4✔
2594
        try:
4✔
2595
            if operand.lstrip("-").isdigit():
4✔
2596
                return True
4✔
2597
        except ValueError:
×
2598
            pass
×
2599

2600
        name = operand
4✔
2601
        if "(" in operand and operand.endswith(")"):
4✔
2602
            name = operand.split("(")[0]
×
2603

2604
        if name in self.container_table:
4✔
2605
            t = self.container_table[name]
4✔
2606

2607
            def is_int_ptype(pt):
4✔
2608
                return pt in [
4✔
2609
                    PrimitiveType.Int64,
2610
                    PrimitiveType.Int32,
2611
                    PrimitiveType.Int8,
2612
                    PrimitiveType.Int16,
2613
                    PrimitiveType.UInt64,
2614
                    PrimitiveType.UInt32,
2615
                    PrimitiveType.UInt8,
2616
                    PrimitiveType.UInt16,
2617
                ]
2618

2619
            if isinstance(t, Scalar):
4✔
2620
                return is_int_ptype(t.primitive_type)
4✔
2621

2622
            if type(t).__name__ == "Array" and hasattr(t, "element_type"):
×
2623
                et = t.element_type
×
2624
                if callable(et):
×
2625
                    et = et()
×
2626
                if isinstance(et, Scalar):
×
2627
                    return is_int_ptype(et.primitive_type)
×
2628

2629
            if type(t).__name__ == "Pointer":
×
2630
                if hasattr(t, "pointee_type"):
×
2631
                    et = t.pointee_type
×
2632
                    if callable(et):
×
2633
                        et = et()
×
2634
                    if isinstance(et, Scalar):
×
2635
                        return is_int_ptype(et.primitive_type)
×
2636
                if hasattr(t, "element_type"):
×
2637
                    et = t.element_type
×
2638
                    if callable(et):
×
2639
                        et = et()
×
2640
                    if isinstance(et, Scalar):
×
2641
                        return is_int_ptype(et.primitive_type)
×
2642

2643
        return False
4✔
2644

2645
    def _add_read(self, block, expr_str, debug_info=None):
4✔
2646
        try:
4✔
2647
            if (block, expr_str) in self._access_cache:
4✔
2648
                return self._access_cache[(block, expr_str)]
×
2649
        except TypeError:
×
2650
            pass
×
2651

2652
        if debug_info is None:
4✔
2653
            debug_info = DebugInfo()
4✔
2654

2655
        if "(" in expr_str and expr_str.endswith(")"):
4✔
2656
            name = expr_str.split("(")[0]
×
2657
            subset = expr_str[expr_str.find("(") + 1 : -1]
×
2658
            access = self.builder.add_access(block, name, debug_info)
×
2659
            try:
×
2660
                self._access_cache[(block, expr_str)] = (access, subset)
×
2661
            except TypeError:
×
2662
                pass
×
2663
            return access, subset
×
2664

2665
        if self.builder.exists(expr_str):
4✔
2666
            access = self.builder.add_access(block, expr_str, debug_info)
4✔
2667
            subset = ""
4✔
2668
            if expr_str in self.container_table:
4✔
2669
                sym_type = self.container_table[expr_str]
4✔
2670
                if isinstance(sym_type, Pointer):
4✔
2671
                    if expr_str in self.tensor_table:
4✔
2672
                        ndim = len(self.tensor_table[expr_str].shape)
4✔
2673
                        if ndim == 0:
4✔
2674
                            subset = "0"
×
2675
                    else:
2676
                        subset = "0"
×
2677
            try:
4✔
2678
                self._access_cache[(block, expr_str)] = (access, subset)
4✔
2679
            except TypeError:
×
2680
                pass
×
2681
            return access, subset
4✔
2682

2683
        dtype = Scalar(PrimitiveType.Double)
4✔
2684
        if self._is_int(expr_str):
4✔
2685
            dtype = Scalar(PrimitiveType.Int64)
4✔
2686
        elif expr_str == "true" or expr_str == "false":
4✔
2687
            dtype = Scalar(PrimitiveType.Bool)
×
2688

2689
        const_node = self.builder.add_constant(block, expr_str, dtype, debug_info)
4✔
2690
        try:
4✔
2691
            self._access_cache[(block, expr_str)] = (const_node, "")
4✔
2692
        except TypeError:
×
2693
            pass
×
2694
        return const_node, ""
4✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc