• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

daisytuner / docc / 22298043375

23 Feb 2026 08:19AM UTC coverage: 64.63% (-0.1%) from 64.743%
22298043375

push

github

web-flow
Merge pull request #537 from daisytuner/map-optimizations

Adds memory management to frontend

48 of 51 new or added lines in 4 files covered. (94.12%)

62 existing lines in 4 files now uncovered.

23672 of 36627 relevant lines covered (64.63%)

357.5 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

67.82
/python/docc/python/ast_parser.py
1
import ast
4✔
2
import copy
4✔
3
import inspect
4✔
4
import textwrap
4✔
5
from docc.sdfg import (
4✔
6
    Scalar,
7
    PrimitiveType,
8
    Pointer,
9
    TaskletCode,
10
    DebugInfo,
11
    Structure,
12
    CMathFunction,
13
    Tensor,
14
)
15
from docc.python.ast_utils import (
4✔
16
    SliceRewriter,
17
    get_debug_info,
18
    contains_ufunc_outer,
19
    normalize_negative_index,
20
)
21
from docc.python.types import (
4✔
22
    sdfg_type_from_type,
23
    element_type_from_sdfg_type,
24
)
25
from docc.python.functions.numpy import NumPyHandler
4✔
26
from docc.python.functions.math import MathHandler
4✔
27
from docc.python.functions.python import PythonHandler
4✔
28
from docc.python.functions.scipy import SciPyHandler
4✔
29
from docc.python.memory import ManagedMemoryHandler
4✔
30

31

32
class ASTParser(ast.NodeVisitor):
4✔
33
    def __init__(
4✔
34
        self,
35
        builder,
36
        tensor_table,
37
        container_table,
38
        filename="",
39
        function_name="",
40
        infer_return_type=False,
41
        globals_dict=None,
42
        unique_counter_ref=None,
43
        structure_member_info=None,
44
        memory_handler=None,
45
    ):
46
        self.builder = builder
4✔
47

48
        # Lookup tables for variables
49
        self.tensor_table = tensor_table
4✔
50
        self.container_table = container_table
4✔
51

52
        # Debug info
53
        self.filename = filename
4✔
54
        self.function_name = function_name
4✔
55

56
        # Context
57
        self.infer_return_type = infer_return_type
4✔
58
        self.globals_dict = globals_dict if globals_dict is not None else {}
4✔
59
        self._unique_counter_ref = (
4✔
60
            unique_counter_ref if unique_counter_ref is not None else [0]
61
        )
62
        self._access_cache = {}
4✔
63
        self.structure_member_info = (
4✔
64
            structure_member_info if structure_member_info is not None else {}
65
        )
66
        self.captured_return_shapes = {}  # Map param name to shape string list
4✔
67
        self.captured_return_strides = {}  # Map param name to stride string list
4✔
68
        self.shapes_runtime_info = (
4✔
69
            {}
70
        )  # Map array name to runtime shapes (separate from Tensor)
71

72
        # Memory manager for hoisted allocations (shared with inline parsers)
73
        self.memory_handler = (
4✔
74
            memory_handler
75
            if memory_handler is not None
76
            else ManagedMemoryHandler(builder)
77
        )
78

79
        # Initialize handlers - they receive 'self' to access expression visitor methods
80
        self.numpy_visitor = NumPyHandler(self)
4✔
81
        self.math_handler = MathHandler(self)
4✔
82
        self.python_handler = PythonHandler(self)
4✔
83
        self.scipy_handler = SciPyHandler(self)
4✔
84

85
    def visit_Constant(self, node):
4✔
86
        if isinstance(node.value, bool):
4✔
87
            return "true" if node.value else "false"
4✔
88
        return str(node.value)
4✔
89

90
    def visit_Name(self, node):
4✔
91
        name = node.id
4✔
92
        if name not in self.container_table and self.globals_dict is not None:
4✔
93
            if name in self.globals_dict:
4✔
94
                val = self.globals_dict[name]
4✔
95
                if isinstance(val, (int, float)):
4✔
96
                    return str(val)
4✔
97
        return name
4✔
98

99
    def visit_Add(self, node):
4✔
100
        return "+"
4✔
101

102
    def visit_Sub(self, node):
4✔
103
        return "-"
4✔
104

105
    def visit_Mult(self, node):
4✔
106
        return "*"
4✔
107

108
    def visit_Div(self, node):
4✔
109
        return "/"
4✔
110

111
    def visit_FloorDiv(self, node):
4✔
112
        return "//"
4✔
113

114
    def visit_Mod(self, node):
4✔
115
        return "%"
4✔
116

117
    def visit_Pow(self, node):
4✔
118
        return "**"
4✔
119

120
    def visit_Eq(self, node):
4✔
121
        return "=="
4✔
122

123
    def visit_NotEq(self, node):
4✔
124
        return "!="
×
125

126
    def visit_Lt(self, node):
4✔
127
        return "<"
4✔
128

129
    def visit_LtE(self, node):
4✔
130
        return "<="
×
131

132
    def visit_Gt(self, node):
4✔
133
        return ">"
4✔
134

135
    def visit_GtE(self, node):
4✔
136
        return ">="
4✔
137

138
    def visit_And(self, node):
4✔
139
        return "&"
4✔
140

141
    def visit_Or(self, node):
4✔
142
        return "|"
4✔
143

144
    def visit_BitAnd(self, node):
4✔
145
        return "&"
4✔
146

147
    def visit_BitOr(self, node):
4✔
148
        return "|"
4✔
149

150
    def visit_BitXor(self, node):
4✔
151
        return "^"
4✔
152

153
    def visit_LShift(self, node):
4✔
154
        return "<<"
×
155

156
    def visit_RShift(self, node):
4✔
157
        return ">>"
×
158

159
    def visit_Not(self, node):
4✔
160
        return "!"
4✔
161

162
    def visit_USub(self, node):
4✔
163
        return "-"
4✔
164

165
    def visit_UAdd(self, node):
4✔
166
        return "+"
×
167

168
    def visit_Invert(self, node):
4✔
169
        return "~"
4✔
170

171
    def visit_BoolOp(self, node):
4✔
172
        op = self.visit(node.op)
4✔
173
        values = [f"({self.visit(v)} != 0)" for v in node.values]
4✔
174
        expr_str = f"{f' {op} '.join(values)}"
4✔
175

176
        tmp_name = self.builder.find_new_name()
4✔
177
        dtype = Scalar(PrimitiveType.Bool)
4✔
178
        self.builder.add_container(tmp_name, dtype, False)
4✔
179

180
        self.builder.begin_if(expr_str)
4✔
181
        self._add_assign_constant(tmp_name, "true", dtype)
4✔
182
        self.builder.begin_else()
4✔
183
        self._add_assign_constant(tmp_name, "false", dtype)
4✔
184
        self.builder.end_if()
4✔
185

186
        self.container_table[tmp_name] = dtype
4✔
187
        return tmp_name
4✔
188

189
    def visit_UnaryOp(self, node):
4✔
190
        if (
4✔
191
            isinstance(node.op, ast.USub)
192
            and isinstance(node.operand, ast.Constant)
193
            and isinstance(node.operand.value, (int, float))
194
        ):
195
            return f"-{node.operand.value}"
4✔
196

197
        op = self.visit(node.op)
4✔
198
        operand = self.visit(node.operand)
4✔
199

200
        if operand in self.tensor_table and op == "-":
4✔
201
            return self.numpy_visitor.handle_array_negate(operand)
4✔
202

203
        assert operand in self.container_table, f"Undefined variable: {operand}"
4✔
204
        tmp_name = self.builder.find_new_name()
4✔
205
        dtype = self.container_table[operand]
4✔
206
        self.builder.add_container(tmp_name, dtype, False)
4✔
207
        self.container_table[tmp_name] = dtype
4✔
208

209
        block = self.builder.add_block()
4✔
210
        t_src, src_sub = self._add_read(block, operand)
4✔
211
        t_dst = self.builder.add_access(block, tmp_name)
4✔
212

213
        if isinstance(node.op, ast.Not):
4✔
214
            t_const = self.builder.add_constant(
4✔
215
                block, "true", Scalar(PrimitiveType.Bool)
216
            )
217
            t_task = self.builder.add_tasklet(
4✔
218
                block, TaskletCode.int_xor, ["_in1", "_in2"], ["_out"]
219
            )
220
            self.builder.add_memlet(block, t_src, "void", t_task, "_in1", src_sub)
4✔
221
            self.builder.add_memlet(block, t_const, "void", t_task, "_in2", "")
4✔
222
            self.builder.add_memlet(block, t_task, "_out", t_dst, "void", "")
4✔
223
        elif op == "-":
4✔
224
            if dtype.primitive_type == PrimitiveType.Int64:
4✔
225
                t_const = self.builder.add_constant(block, "0", dtype)
4✔
226
                t_task = self.builder.add_tasklet(
4✔
227
                    block, TaskletCode.int_sub, ["_in1", "_in2"], ["_out"]
228
                )
229
                self.builder.add_memlet(block, t_const, "void", t_task, "_in1", "")
4✔
230
                self.builder.add_memlet(block, t_src, "void", t_task, "_in2", src_sub)
4✔
231
                self.builder.add_memlet(block, t_task, "_out", t_dst, "void", "")
4✔
232
            else:
233
                t_task = self.builder.add_tasklet(
4✔
234
                    block, TaskletCode.fp_neg, ["_in"], ["_out"]
235
                )
236
                self.builder.add_memlet(block, t_src, "void", t_task, "_in", src_sub)
4✔
237
                self.builder.add_memlet(block, t_task, "_out", t_dst, "void", "")
4✔
238
        elif op == "~":
4✔
239
            t_const = self.builder.add_constant(
4✔
240
                block, "-1", Scalar(PrimitiveType.Int64)
241
            )
242
            t_task = self.builder.add_tasklet(
4✔
243
                block, TaskletCode.int_xor, ["_in1", "_in2"], ["_out"]
244
            )
245
            self.builder.add_memlet(block, t_src, "void", t_task, "_in1", src_sub)
4✔
246
            self.builder.add_memlet(block, t_const, "void", t_task, "_in2", "")
4✔
247
            self.builder.add_memlet(block, t_task, "_out", t_dst, "void", "")
4✔
248
        else:
249
            t_task = self.builder.add_tasklet(
×
250
                block, TaskletCode.assign, ["_in"], ["_out"]
251
            )
252
            self.builder.add_memlet(block, t_src, "void", t_task, "_in", src_sub)
×
253
            self.builder.add_memlet(block, t_task, "_out", t_dst, "void", "")
×
254

255
        return tmp_name
4✔
256

257
    def visit_BinOp(self, node):
4✔
258
        if isinstance(node.op, ast.MatMult):
4✔
259
            return self.numpy_visitor.handle_numpy_matmul_op(node.left, node.right)
4✔
260

261
        left = self.visit(node.left)
4✔
262
        op = self.visit(node.op)
4✔
263
        right = self.visit(node.right)
4✔
264

265
        left_is_array = left in self.tensor_table
4✔
266
        right_is_array = right in self.tensor_table
4✔
267

268
        if left_is_array or right_is_array:
4✔
269
            op_map = {"+": "add", "-": "sub", "*": "mul", "/": "div", "**": "pow"}
4✔
270
            if op in op_map:
4✔
271
                return self.numpy_visitor.handle_array_binary_op(
4✔
272
                    op_map[op], left, right
273
                )
274
            else:
275
                raise NotImplementedError(f"Array operation {op} not supported")
×
276

277
        tmp_name = self.builder.find_new_name()
4✔
278

279
        left_is_int = self._is_int(left)
4✔
280
        right_is_int = self._is_int(right)
4✔
281
        dtype = Scalar(PrimitiveType.Double)
4✔
282
        if left_is_int and right_is_int and op not in ["/", "**"]:
4✔
283
            dtype = Scalar(PrimitiveType.Int64)
4✔
284

285
        if not self.builder.exists(tmp_name):
4✔
286
            self.builder.add_container(tmp_name, dtype, False)
4✔
287
            self.container_table[tmp_name] = dtype
4✔
288

289
        real_left = left
4✔
290
        real_right = right
4✔
291
        if dtype.primitive_type == PrimitiveType.Double:
4✔
292
            if left_is_int:
4✔
293
                left_cast = self.builder.find_new_name()
4✔
294
                self.builder.add_container(
4✔
295
                    left_cast, Scalar(PrimitiveType.Double), False
296
                )
297
                self.container_table[left_cast] = Scalar(PrimitiveType.Double)
4✔
298

299
                c_block = self.builder.add_block()
4✔
300
                t_src, src_sub = self._add_read(c_block, left)
4✔
301
                t_dst = self.builder.add_access(c_block, left_cast)
4✔
302
                t_task = self.builder.add_tasklet(
4✔
303
                    c_block, TaskletCode.assign, ["_in"], ["_out"]
304
                )
305
                self.builder.add_memlet(c_block, t_src, "void", t_task, "_in", src_sub)
4✔
306
                self.builder.add_memlet(c_block, t_task, "_out", t_dst, "void", "")
4✔
307

308
                real_left = left_cast
4✔
309

310
            if right_is_int:
4✔
311
                right_cast = self.builder.find_new_name()
4✔
312
                self.builder.add_container(
4✔
313
                    right_cast, Scalar(PrimitiveType.Double), False
314
                )
315
                self.container_table[right_cast] = Scalar(PrimitiveType.Double)
4✔
316

317
                c_block = self.builder.add_block()
4✔
318
                t_src, src_sub = self._add_read(c_block, right)
4✔
319
                t_dst = self.builder.add_access(c_block, right_cast)
4✔
320
                t_task = self.builder.add_tasklet(
4✔
321
                    c_block, TaskletCode.assign, ["_in"], ["_out"]
322
                )
323
                self.builder.add_memlet(c_block, t_src, "void", t_task, "_in", src_sub)
4✔
324
                self.builder.add_memlet(c_block, t_task, "_out", t_dst, "void", "")
4✔
325

326
                real_right = right_cast
4✔
327

328
        if op == "**":
4✔
329
            block = self.builder.add_block()
4✔
330
            t_left, left_sub = self._add_read(block, real_left)
4✔
331
            t_right, right_sub = self._add_read(block, real_right)
4✔
332
            t_out = self.builder.add_access(block, tmp_name)
4✔
333

334
            t_task = self.builder.add_cmath(block, CMathFunction.pow)
4✔
335
            self.builder.add_memlet(block, t_left, "void", t_task, "_in1", left_sub)
4✔
336
            self.builder.add_memlet(block, t_right, "void", t_task, "_in2", right_sub)
4✔
337
            self.builder.add_memlet(block, t_task, "_out", t_out, "void", "")
4✔
338

339
            return tmp_name
4✔
340
        elif op == "%":
4✔
341
            block = self.builder.add_block()
4✔
342
            t_left, left_sub = self._add_read(block, real_left)
4✔
343
            t_right, right_sub = self._add_read(block, real_right)
4✔
344
            t_out = self.builder.add_access(block, tmp_name)
4✔
345

346
            if dtype.primitive_type == PrimitiveType.Int64:
4✔
347
                t_rem1 = self.builder.add_tasklet(
4✔
348
                    block, TaskletCode.int_srem, ["_in1", "_in2"], ["_out"]
349
                )
350
                self.builder.add_memlet(block, t_left, "void", t_rem1, "_in1", left_sub)
4✔
351
                self.builder.add_memlet(
4✔
352
                    block, t_right, "void", t_rem1, "_in2", right_sub
353
                )
354

355
                rem1_name = self.builder.find_new_name()
4✔
356
                self.builder.add_container(rem1_name, dtype, False)
4✔
357
                t_rem1_out = self.builder.add_access(block, rem1_name)
4✔
358
                self.builder.add_memlet(block, t_rem1, "_out", t_rem1_out, "void", "")
4✔
359

360
                t_add = self.builder.add_tasklet(
4✔
361
                    block, TaskletCode.int_add, ["_in1", "_in2"], ["_out"]
362
                )
363
                self.builder.add_memlet(block, t_rem1_out, "void", t_add, "_in1", "")
4✔
364
                self.builder.add_memlet(
4✔
365
                    block, t_right, "void", t_add, "_in2", right_sub
366
                )
367

368
                add_name = self.builder.find_new_name()
4✔
369
                self.builder.add_container(add_name, dtype, False)
4✔
370
                t_add_out = self.builder.add_access(block, add_name)
4✔
371
                self.builder.add_memlet(block, t_add, "_out", t_add_out, "void", "")
4✔
372

373
                t_rem2 = self.builder.add_tasklet(
4✔
374
                    block, TaskletCode.int_srem, ["_in1", "_in2"], ["_out"]
375
                )
376
                self.builder.add_memlet(block, t_add_out, "void", t_rem2, "_in1", "")
4✔
377
                self.builder.add_memlet(
4✔
378
                    block, t_right, "void", t_rem2, "_in2", right_sub
379
                )
380
                self.builder.add_memlet(block, t_rem2, "_out", t_out, "void", "")
4✔
381

382
                return tmp_name
4✔
383
            else:
384
                t_rem1 = self.builder.add_tasklet(
4✔
385
                    block, TaskletCode.fp_rem, ["_in1", "_in2"], ["_out"]
386
                )
387
                self.builder.add_memlet(block, t_left, "void", t_rem1, "_in1", left_sub)
4✔
388
                self.builder.add_memlet(
4✔
389
                    block, t_right, "void", t_rem1, "_in2", right_sub
390
                )
391

392
                rem1_name = self.builder.find_new_name()
4✔
393
                self.builder.add_container(rem1_name, dtype, False)
4✔
394
                t_rem1_out = self.builder.add_access(block, rem1_name)
4✔
395
                self.builder.add_memlet(block, t_rem1, "_out", t_rem1_out, "void", "")
4✔
396

397
                t_add = self.builder.add_tasklet(
4✔
398
                    block, TaskletCode.fp_add, ["_in1", "_in2"], ["_out"]
399
                )
400
                self.builder.add_memlet(block, t_rem1_out, "void", t_add, "_in1", "")
4✔
401
                self.builder.add_memlet(
4✔
402
                    block, t_right, "void", t_add, "_in2", right_sub
403
                )
404

405
                add_name = self.builder.find_new_name()
4✔
406
                self.builder.add_container(add_name, dtype, False)
4✔
407
                t_add_out = self.builder.add_access(block, add_name)
4✔
408
                self.builder.add_memlet(block, t_add, "_out", t_add_out, "void", "")
4✔
409

410
                t_rem2 = self.builder.add_tasklet(
4✔
411
                    block, TaskletCode.fp_rem, ["_in1", "_in2"], ["_out"]
412
                )
413
                self.builder.add_memlet(block, t_add_out, "void", t_rem2, "_in1", "")
4✔
414
                self.builder.add_memlet(
4✔
415
                    block, t_right, "void", t_rem2, "_in2", right_sub
416
                )
417
                self.builder.add_memlet(block, t_rem2, "_out", t_out, "void", "")
4✔
418

419
                return tmp_name
4✔
420

421
        tasklet_code = None
4✔
422
        if dtype.primitive_type == PrimitiveType.Int64:
4✔
423
            if op == "+":
4✔
424
                tasklet_code = TaskletCode.int_add
4✔
425
            elif op == "-":
4✔
426
                tasklet_code = TaskletCode.int_sub
4✔
427
            elif op == "*":
4✔
428
                tasklet_code = TaskletCode.int_mul
4✔
429
            elif op == "/":
4✔
430
                tasklet_code = TaskletCode.int_sdiv
×
431
            elif op == "//":
4✔
432
                tasklet_code = TaskletCode.int_sdiv
4✔
433
            elif op == "&":
4✔
434
                tasklet_code = TaskletCode.int_and
4✔
435
            elif op == "|":
4✔
436
                tasklet_code = TaskletCode.int_or
4✔
437
            elif op == "^":
4✔
438
                tasklet_code = TaskletCode.int_xor
4✔
439
            elif op == "<<":
×
440
                tasklet_code = TaskletCode.int_shl
×
441
            elif op == ">>":
×
442
                tasklet_code = TaskletCode.int_lshr
×
443
        else:
444
            if op == "+":
4✔
445
                tasklet_code = TaskletCode.fp_add
4✔
446
            elif op == "-":
4✔
447
                tasklet_code = TaskletCode.fp_sub
4✔
448
            elif op == "*":
4✔
449
                tasklet_code = TaskletCode.fp_mul
4✔
450
            elif op == "/":
4✔
451
                tasklet_code = TaskletCode.fp_div
4✔
452
            elif op == "//":
×
453
                tasklet_code = TaskletCode.fp_div
×
454
            else:
455
                raise NotImplementedError(f"Operation {op} not supported for floats")
×
456

457
        block = self.builder.add_block()
4✔
458
        t_left, left_sub = self._add_read(block, real_left)
4✔
459
        t_right, right_sub = self._add_read(block, real_right)
4✔
460
        t_out = self.builder.add_access(block, tmp_name)
4✔
461

462
        t_task = self.builder.add_tasklet(
4✔
463
            block, tasklet_code, ["_in1", "_in2"], ["_out"]
464
        )
465

466
        # For indexed array accesses like "arr(i,j)", we need to pass the Tensor type
467
        # to ensure correct type inference during validation
468
        left_type = self._get_memlet_type_for_access(real_left, left_sub)
4✔
469
        right_type = self._get_memlet_type_for_access(real_right, right_sub)
4✔
470

471
        self.builder.add_memlet(
4✔
472
            block, t_left, "void", t_task, "_in1", left_sub, left_type
473
        )
474
        self.builder.add_memlet(
4✔
475
            block, t_right, "void", t_task, "_in2", right_sub, right_type
476
        )
477
        self.builder.add_memlet(block, t_task, "_out", t_out, "void", "")
4✔
478

479
        return tmp_name
4✔
480

481
    def visit_Attribute(self, node):
4✔
482
        if node.attr == "shape":
4✔
483
            if isinstance(node.value, ast.Name) and node.value.id in self.tensor_table:
4✔
484
                return f"_shape_proxy_{node.value.id}"
4✔
485

486
        if node.attr == "T":
4✔
487
            value_name = None
4✔
488
            if isinstance(node.value, ast.Name):
4✔
489
                value_name = node.value.id
4✔
490
            elif isinstance(node.value, ast.Subscript):
×
491
                value_name = self.visit(node.value)
×
492

493
            if value_name and value_name in self.tensor_table:
4✔
494
                return self.numpy_visitor.handle_transpose_expr(node)
4✔
495

496
        if isinstance(node.value, ast.Name) and node.value.id == "math":
4✔
497
            val = ""
4✔
498
            if node.attr == "pi":
4✔
499
                val = "M_PI"
4✔
500
            elif node.attr == "e":
4✔
501
                val = "M_E"
4✔
502

503
            if val:
4✔
504
                tmp_name = self.builder.find_new_name()
4✔
505
                dtype = Scalar(PrimitiveType.Double)
4✔
506
                self.builder.add_container(tmp_name, dtype, False)
4✔
507
                self.container_table[tmp_name] = dtype
4✔
508
                self._add_assign_constant(tmp_name, val, dtype)
4✔
509
                return tmp_name
4✔
510

511
        if isinstance(node.value, ast.Name):
4✔
512
            obj_name = node.value.id
4✔
513
            attr_name = node.attr
4✔
514

515
            if obj_name in self.container_table:
4✔
516
                obj_type = self.container_table[obj_name]
4✔
517
                if isinstance(obj_type, Pointer) and obj_type.has_pointee_type():
4✔
518
                    pointee_type = obj_type.pointee_type
4✔
519
                    if isinstance(pointee_type, Structure):
4✔
520
                        struct_name = pointee_type.name
4✔
521

522
                        if (
4✔
523
                            struct_name in self.structure_member_info
524
                            and attr_name in self.structure_member_info[struct_name]
525
                        ):
526
                            member_index, member_type = self.structure_member_info[
4✔
527
                                struct_name
528
                            ][attr_name]
529
                        else:
530
                            raise RuntimeError(
×
531
                                f"Member '{attr_name}' not found in structure '{struct_name}'. "
532
                                f"Available members: {list(self.structure_member_info.get(struct_name, {}).keys())}"
533
                            )
534

535
                        tmp_name = self.builder.find_new_name()
4✔
536

537
                        self.builder.add_container(tmp_name, member_type, False)
4✔
538
                        self.container_table[tmp_name] = member_type
4✔
539

540
                        block = self.builder.add_block()
4✔
541
                        obj_access = self.builder.add_access(block, obj_name)
4✔
542
                        tmp_access = self.builder.add_access(block, tmp_name)
4✔
543

544
                        tasklet = self.builder.add_tasklet(
4✔
545
                            block, TaskletCode.assign, ["_in"], ["_out"]
546
                        )
547

548
                        subset = "0," + str(member_index)
4✔
549
                        self.builder.add_memlet(
4✔
550
                            block, obj_access, "", tasklet, "_in", subset
551
                        )
552
                        self.builder.add_memlet(block, tasklet, "_out", tmp_access, "")
4✔
553

554
                        return tmp_name
4✔
555

556
        raise NotImplementedError(f"Attribute access {node.attr} not supported")
×
557

558
    def visit_Compare(self, node):
4✔
559
        left = self.visit(node.left)
4✔
560
        if len(node.ops) > 1:
4✔
561
            raise NotImplementedError("Chained comparisons not supported yet")
×
562

563
        op = self.visit(node.ops[0])
4✔
564
        right = self.visit(node.comparators[0])
4✔
565

566
        left_is_array = left in self.tensor_table
4✔
567
        right_is_array = right in self.tensor_table
4✔
568

569
        if left_is_array or right_is_array:
4✔
570
            return self.numpy_visitor.handle_array_compare(
4✔
571
                left, op, right, left_is_array, right_is_array
572
            )
573

574
        expr_str = f"{left} {op} {right}"
4✔
575

576
        tmp_name = self.builder.find_new_name()
4✔
577
        dtype = Scalar(PrimitiveType.Bool)
4✔
578
        self.builder.add_container(tmp_name, dtype, False)
4✔
579

580
        self.builder.begin_if(expr_str)
4✔
581
        self.builder.add_transition(tmp_name, "true")
4✔
582
        self.builder.begin_else()
4✔
583
        self.builder.add_transition(tmp_name, "false")
4✔
584
        self.builder.end_if()
4✔
585

586
        self.container_table[tmp_name] = dtype
4✔
587
        return tmp_name
4✔
588

589
    def visit_Subscript(self, node):
4✔
590
        value_str = self.visit(node.value)
4✔
591

592
        if value_str.startswith("_shape_proxy_"):
4✔
593
            array_name = value_str[len("_shape_proxy_") :]
4✔
594
            if isinstance(node.slice, ast.Constant):
4✔
595
                idx = node.slice.value
4✔
596
            elif isinstance(node.slice, ast.Index):
×
597
                idx = node.slice.value.value
×
598
            else:
599
                try:
×
600
                    idx = int(self.visit(node.slice))
×
601
                except:
×
602
                    raise NotImplementedError(
×
603
                        "Dynamic shape indexing not fully supported yet"
604
                    )
605

606
            if array_name in self.tensor_table:
4✔
607
                return self.tensor_table[array_name].shape[idx]
4✔
608

609
            return f"_{array_name}_shape_{idx}"
×
610

611
        if value_str in self.tensor_table:
4✔
612
            tensor = self.tensor_table[value_str]
4✔
613
            ndim = len(tensor.shape)
4✔
614
            shapes = tensor.shape
4✔
615

616
            if isinstance(node.slice, ast.Tuple):
4✔
617
                indices_nodes = node.slice.elts
4✔
618
            else:
619
                indices_nodes = [node.slice]
4✔
620

621
            all_full_slices = True
4✔
622
            for idx in indices_nodes:
4✔
623
                if isinstance(idx, ast.Slice):
4✔
624
                    if idx.lower is not None or idx.upper is not None:
4✔
625
                        all_full_slices = False
4✔
626
                        break
4✔
627
                    # Also check for non-trivial step (step != None and step != 1)
628
                    if idx.step is not None:
4✔
629
                        # Check if step is a constant 1; if not, it's not a full slice
630
                        if isinstance(idx.step, ast.Constant) and idx.step.value == 1:
4✔
631
                            pass  # step=1 is equivalent to no step
×
632
                        else:
633
                            all_full_slices = False
4✔
634
                            break
4✔
635
                else:
636
                    all_full_slices = False
4✔
637
                    break
4✔
638

639
            if all_full_slices:
4✔
640
                return value_str
4✔
641

642
            has_slices = any(isinstance(idx, ast.Slice) for idx in indices_nodes)
4✔
643
            if has_slices:
4✔
644
                return self._handle_expression_slicing(
4✔
645
                    node, value_str, indices_nodes, shapes, ndim
646
                )
647

648
            if len(indices_nodes) == 1 and self._is_array_index(indices_nodes[0]):
4✔
649
                if self.builder:
×
650
                    return self._handle_gather(value_str, indices_nodes[0])
×
651

652
            if isinstance(node.slice, ast.Tuple):
4✔
653
                indices = [self.visit(elt) for elt in node.slice.elts]
4✔
654
            else:
655
                indices = [self.visit(node.slice)]
4✔
656

657
            if len(indices) != ndim:
4✔
658
                raise ValueError(
×
659
                    f"Array {value_str} has {ndim} dimensions, but accessed with {len(indices)} indices"
660
                )
661

662
            normalized_indices = []
4✔
663
            for i, idx_str in enumerate(indices):
4✔
664
                shape_val = shapes[i]
4✔
665
                if isinstance(idx_str, str) and (
4✔
666
                    idx_str.startswith("-") or idx_str.startswith("(-")
667
                ):
668
                    normalized_indices.append(f"({shape_val} + {idx_str})")
×
669
                else:
670
                    normalized_indices.append(idx_str)
4✔
671

672
            subscript_str = ",".join(normalized_indices)
4✔
673
            access_str = f"{value_str}({subscript_str})"
4✔
674

675
            if isinstance(node.ctx, ast.Load):
4✔
676
                tmp_name = self.builder.find_new_name()
4✔
677
                self.builder.add_container(tmp_name, tensor.element_type, False)
4✔
678
                self.container_table[tmp_name] = tensor.element_type
4✔
679

680
                block = self.builder.add_block()
4✔
681
                t_src = self.builder.add_access(block, value_str)
4✔
682
                t_dst = self.builder.add_access(block, tmp_name)
4✔
683
                t_task = self.builder.add_tasklet(
4✔
684
                    block, TaskletCode.assign, ["_in"], ["_out"]
685
                )
686
                self.builder.add_memlet(
4✔
687
                    block, t_src, "void", t_task, "_in", subscript_str, tensor
688
                )
689
                self.builder.add_memlet(
4✔
690
                    block, t_task, "_out", t_dst, "void", "", tensor.element_type
691
                )
692

693
                return tmp_name
4✔
694

695
            return access_str
4✔
696

697
        slice_val = self.visit(node.slice)
×
698
        access_str = f"{value_str}({slice_val})"
×
699
        return access_str
×
700

701
    def visit_AugAssign(self, node):
4✔
702
        if isinstance(node.target, ast.Name) and node.target.id in self.tensor_table:
4✔
703
            # Convert to slice assignment: target[:] = target op value
704
            ndim = len(self.tensor_table[node.target.id].shape)
4✔
705

706
            slices = []
4✔
707
            for _ in range(ndim):
4✔
708
                slices.append(ast.Slice(lower=None, upper=None, step=None))
4✔
709

710
            if ndim == 1:
4✔
711
                slice_arg = slices[0]
×
712
            else:
713
                slice_arg = ast.Tuple(elts=slices, ctx=ast.Load())
4✔
714

715
            slice_node = ast.Subscript(
4✔
716
                value=node.target, slice=slice_arg, ctx=ast.Store()
717
            )
718

719
            new_node = ast.Assign(
4✔
720
                targets=[slice_node],
721
                value=ast.BinOp(left=node.target, op=node.op, right=node.value),
722
            )
723
            self.visit_Assign(new_node)
4✔
724
        else:
725
            new_node = ast.Assign(
4✔
726
                targets=[node.target],
727
                value=ast.BinOp(left=node.target, op=node.op, right=node.value),
728
            )
729
            self.visit_Assign(new_node)
4✔
730

731
    def visit_Assign(self, node):
4✔
732
        # Handle multiple targets: a = b = c or a, b = expr
733
        if len(node.targets) > 1:
4✔
734
            tmp_name = self.builder.find_new_name()
4✔
735
            # Assign value to temporary
736
            val_assign = ast.Assign(
4✔
737
                targets=[ast.Name(id=tmp_name, ctx=ast.Store())], value=node.value
738
            )
739
            ast.copy_location(val_assign, node)
4✔
740
            self.visit_Assign(val_assign)
4✔
741

742
            # Assign temporary to targets
743
            for target in node.targets:
4✔
744
                assign = ast.Assign(
4✔
745
                    targets=[target], value=ast.Name(id=tmp_name, ctx=ast.Load())
746
                )
747
                ast.copy_location(assign, node)
4✔
748
                self.visit_Assign(assign)
4✔
749
            return
4✔
750
        target = node.targets[0]
4✔
751

752
        # Handle tuple unpacking: I, J, K = expr1, expr2, expr3
753
        if isinstance(target, ast.Tuple):
4✔
754
            if isinstance(node.value, ast.Tuple):
4✔
755
                if len(target.elts) != len(node.value.elts):
4✔
756
                    raise ValueError("Tuple unpacking size mismatch")
×
757
                for tgt, val in zip(target.elts, node.value.elts):
4✔
758
                    assign = ast.Assign(targets=[tgt], value=val)
4✔
759
                    ast.copy_location(assign, node)
4✔
760
                    self.visit_Assign(assign)
4✔
761
                return
4✔
762
            else:
763
                raise NotImplementedError(
×
764
                    "Tuple unpacking from non-tuple values not supported"
765
                )
766

767
        # Special cases, where rhs is not just a simple expression but requires special handling
768
        if self.numpy_visitor.is_gemm(node.value):
4✔
769
            if self.numpy_visitor.handle_gemm(target, node.value):
4✔
770
                return
×
771
            if self.numpy_visitor.handle_dot(target, node.value):
4✔
772
                return
×
773
        if self.numpy_visitor.is_outer(node.value):
4✔
774
            if self.numpy_visitor.handle_outer(target, node.value):
4✔
775
                return
4✔
776
        if self.scipy_handler.is_correlate2d(node.value):
4✔
777
            if self.scipy_handler.handle_correlate2d(target, node.value):
×
778
                return
×
779
        if self.numpy_visitor.is_transpose(node.value):
4✔
780
            if self.numpy_visitor.handle_transpose(target, node.value):
4✔
781
                return
4✔
782

783
        # Handle subscript assignments: a[i] = val or a[i, j] = val
784
        if isinstance(target, ast.Subscript):
4✔
785
            debug_info = get_debug_info(node, self.filename, self.function_name)
4✔
786

787
            target_name = self.visit(target.value)
4✔
788
            indices = []
4✔
789
            if isinstance(target.slice, ast.Tuple):
4✔
790
                indices = target.slice.elts
4✔
791
            else:
792
                indices = [target.slice]
4✔
793

794
            # Handle slice assignment separately
795
            has_slice = False
4✔
796
            for idx in indices:
4✔
797
                if isinstance(idx, ast.Slice):
4✔
798
                    has_slice = True
4✔
799
                    break
4✔
800

801
            if has_slice:
4✔
802
                self._handle_slice_assignment(
4✔
803
                    target, node.value, target_name, indices, debug_info
804
                )
805
                return
4✔
806

807
            # Handle rhs and store in scalar tmp
808
            rhs_tmp = self.visit(node.value)
4✔
809

810
            block = self.builder.add_block(debug_info)
4✔
811
            t_task = self.builder.add_tasklet(
4✔
812
                block, TaskletCode.assign, ["_in"], ["_out"], debug_info
813
            )
814

815
            t_src, src_sub = self._add_read(block, rhs_tmp, debug_info)
4✔
816
            self.builder.add_memlet(
4✔
817
                block, t_src, "void", t_task, "_in", src_sub, None, debug_info
818
            )
819

820
            lhs_expr = self.visit(target)
4✔
821
            if "(" in lhs_expr and lhs_expr.endswith(")"):
4✔
822
                subset = lhs_expr[lhs_expr.find("(") + 1 : -1]
4✔
823
                tensor_dst = self.tensor_table[target_name]
4✔
824

825
                t_dst = self.builder.add_access(block, target_name, debug_info)
4✔
826
                self.builder.add_memlet(
4✔
827
                    block, t_task, "_out", t_dst, "void", subset, tensor_dst, debug_info
828
                )
829
            else:
830
                t_dst = self.builder.add_access(block, target_name, debug_info)
×
831
                self.builder.add_memlet(
×
832
                    block, t_task, "_out", t_dst, "void", "", None, debug_info
833
                )
834
            return
4✔
835

836
        # Fallback: lhs is a simple scalar assignments
837
        if not isinstance(target, ast.Name):
4✔
838
            raise NotImplementedError("Only assignment to variables supported")
×
839

840
        target_name = target.id
4✔
841
        rhs_tmp = self.visit(node.value)
4✔
842
        debug_info = get_debug_info(node, self.filename, self.function_name)
4✔
843

844
        if not self.builder.exists(target_name):
4✔
845
            if isinstance(node.value, ast.Constant):
4✔
846
                val = node.value.value
4✔
847
                if isinstance(val, int):
4✔
848
                    dtype = Scalar(PrimitiveType.Int64)
4✔
849
                elif isinstance(val, float):
×
850
                    dtype = Scalar(PrimitiveType.Double)
×
851
                elif isinstance(val, bool):
×
852
                    dtype = Scalar(PrimitiveType.Bool)
×
853
                else:
854
                    raise NotImplementedError(f"Cannot infer type for {val}")
×
855

856
                self.builder.add_container(target_name, dtype, False)
4✔
857
                self.container_table[target_name] = dtype
4✔
858
            else:
859
                self.builder.add_container(
4✔
860
                    target_name, self.container_table[rhs_tmp], False
861
                )
862
                self.container_table[target_name] = self.container_table[rhs_tmp]
4✔
863

864
        if rhs_tmp in self.tensor_table:
4✔
865
            self.tensor_table[target_name] = self.tensor_table[rhs_tmp]
4✔
866

867
        # Also copy shapes_runtime_info if available
868
        if rhs_tmp in self.shapes_runtime_info:
4✔
869
            self.shapes_runtime_info[target_name] = self.shapes_runtime_info[rhs_tmp]
4✔
870

871
        # Distinguish assignments: scalar -> tasklet, pointer -> reference_memlet
872
        src_type = self.container_table.get(rhs_tmp)
4✔
873
        dst_type = self.container_table[target_name]
4✔
874
        if src_type and isinstance(src_type, Pointer) and isinstance(dst_type, Pointer):
4✔
875
            block = self.builder.add_block(debug_info)
4✔
876
            t_src = self.builder.add_access(block, rhs_tmp, debug_info)
4✔
877
            t_dst = self.builder.add_access(block, target_name, debug_info)
4✔
878
            self.builder.add_reference_memlet(
4✔
879
                block, t_src, t_dst, "0", src_type, debug_info
880
            )
881
        elif (src_type and isinstance(src_type, Scalar)) or isinstance(
4✔
882
            dst_type, Scalar
883
        ):
884
            block = self.builder.add_block(debug_info)
4✔
885
            t_dst = self.builder.add_access(block, target_name, debug_info)
4✔
886
            t_task = self.builder.add_tasklet(
4✔
887
                block, TaskletCode.assign, ["_in"], ["_out"], debug_info
888
            )
889

890
            if src_type:
4✔
891
                t_src = self.builder.add_access(block, rhs_tmp, debug_info)
4✔
892
            else:
893
                t_src = self.builder.add_constant(block, rhs_tmp, dst_type, debug_info)
4✔
894

895
            self.builder.add_memlet(
4✔
896
                block, t_src, "void", t_task, "_in", "", None, debug_info
897
            )
898
            self.builder.add_memlet(
4✔
899
                block, t_task, "_out", t_dst, "void", "", None, debug_info
900
            )
901

902
    def visit_Expr(self, node):
4✔
903
        self.visit(node.value)
×
904

905
    def visit_If(self, node):
4✔
906
        cond = self.visit(node.test)
4✔
907
        debug_info = get_debug_info(node, self.filename, self.function_name)
4✔
908
        self.builder.begin_if(f"{cond} != false", debug_info)
4✔
909

910
        for stmt in node.body:
4✔
911
            self.visit(stmt)
4✔
912

913
        if node.orelse:
4✔
914
            self.builder.begin_else(debug_info)
4✔
915
            for stmt in node.orelse:
4✔
916
                self.visit(stmt)
4✔
917

918
        self.builder.end_if()
4✔
919

920
    def visit_While(self, node):
4✔
921
        if node.orelse:
4✔
922
            raise NotImplementedError("while-else is not supported")
×
923

924
        debug_info = get_debug_info(node, self.filename, self.function_name)
4✔
925
        self.builder.begin_while(debug_info)
4✔
926

927
        # Evaluate condition inside the loop so it's re-evaluated each iteration
928
        cond = self.visit(node.test)
4✔
929

930
        # Create if-break pattern: if condition is false, break
931
        self.builder.begin_if(f"{cond} == false", debug_info)
4✔
932
        self.builder.add_break(debug_info)
4✔
933
        self.builder.end_if()
4✔
934

935
        for stmt in node.body:
4✔
936
            self.visit(stmt)
4✔
937

938
        self.builder.end_while()
4✔
939

940
    def visit_Break(self, node):
4✔
941
        debug_info = get_debug_info(node, self.filename, self.function_name)
4✔
942
        self.builder.add_break(debug_info)
4✔
943

944
    def visit_Continue(self, node):
4✔
945
        debug_info = get_debug_info(node, self.filename, self.function_name)
4✔
946
        self.builder.add_continue(debug_info)
4✔
947

948
    def visit_For(self, node):
4✔
949
        if node.orelse:
4✔
950
            raise NotImplementedError("while-else is not supported")
×
951
        if not isinstance(node.target, ast.Name):
4✔
952
            raise NotImplementedError("Only simple for loops supported")
×
953

954
        var = node.target.id
4✔
955
        debug_info = get_debug_info(node, self.filename, self.function_name)
4✔
956

957
        # Check if iterating over a range() call
958
        if (
4✔
959
            isinstance(node.iter, ast.Call)
960
            and isinstance(node.iter.func, ast.Name)
961
            and node.iter.func.id == "range"
962
        ):
963
            args = node.iter.args
4✔
964
            if len(args) == 1:
4✔
965
                start = "0"
4✔
966
                end = self.visit(args[0])
4✔
967
                step = "1"
4✔
968
            elif len(args) == 2:
4✔
969
                start = self.visit(args[0])
4✔
970
                end = self.visit(args[1])
4✔
971
                step = "1"
4✔
972
            elif len(args) == 3:
4✔
973
                start = self.visit(args[0])
4✔
974
                end = self.visit(args[1])
4✔
975

976
                # Special handling for step to avoid creating tasklets for constants
977
                step_node = args[2]
4✔
978
                if isinstance(step_node, ast.Constant):
4✔
979
                    step = str(step_node.value)
4✔
980
                elif (
4✔
981
                    isinstance(step_node, ast.UnaryOp)
982
                    and isinstance(step_node.op, ast.USub)
983
                    and isinstance(step_node.operand, ast.Constant)
984
                ):
985
                    step = f"-{step_node.operand.value}"
4✔
986
                else:
987
                    step = self.visit(step_node)
×
988
            else:
989
                raise ValueError("Invalid range arguments")
×
990

991
            if not self.builder.exists(var):
4✔
992
                self.builder.add_container(var, Scalar(PrimitiveType.Int64), False)
4✔
993
                self.container_table[var] = Scalar(PrimitiveType.Int64)
4✔
994

995
            self.builder.begin_for(var, start, end, step, debug_info)
4✔
996

997
            for stmt in node.body:
4✔
998
                self.visit(stmt)
4✔
999

1000
            self.builder.end_for()
4✔
1001
            return
4✔
1002

1003
        # Check if iterating over an ndarray (for x in array)
1004
        if isinstance(node.iter, ast.Name):
×
1005
            iter_name = node.iter.id
×
1006
            if iter_name in self.tensor_table:
×
1007
                arr_info = self.tensor_table[iter_name]
×
1008
                if len(arr_info.shape) == 0:
×
1009
                    raise NotImplementedError("Cannot iterate over 0-dimensional array")
×
1010

1011
                # Get the size of the first dimension
1012
                arr_size = arr_info.shape[0]
×
1013

1014
                # Create a hidden index variable for the loop
1015
                idx_var = self.builder.find_new_name()
×
1016
                if not self.builder.exists(idx_var):
×
1017
                    self.builder.add_container(
×
1018
                        idx_var, Scalar(PrimitiveType.Int64), False
1019
                    )
1020
                    self.container_table[idx_var] = Scalar(PrimitiveType.Int64)
×
1021

1022
                # Determine the type of the loop variable (element type)
1023
                # For a 1D array, it's a scalar; for ND array, it's a view of N-1 dimensions
1024
                if len(arr_info.shape) == 1:
×
1025
                    # Element is a scalar - get the element type from the array's type
1026
                    arr_type = self.container_table.get(iter_name)
×
1027
                    if isinstance(arr_type, Pointer):
×
1028
                        elem_type = arr_type.pointee_type
×
1029
                    else:
1030
                        elem_type = Scalar(PrimitiveType.Double)  # Default fallback
×
1031

1032
                    if not self.builder.exists(var):
×
1033
                        self.builder.add_container(var, elem_type, False)
×
1034
                        self.container_table[var] = elem_type
×
1035
                else:
1036
                    # For multi-dimensional arrays, create a view/slice
1037
                    # The loop variable becomes a pointer to the sub-array
1038
                    inner_shapes = arr_info.shape[1:]
×
1039
                    inner_ndim = len(arr_info.shape) - 1
×
1040

1041
                    arr_type = self.container_table.get(iter_name)
×
1042
                    if isinstance(arr_type, Pointer):
×
1043
                        elem_type = arr_type  # Keep as pointer type for views
×
1044
                    else:
1045
                        elem_type = Pointer(Scalar(PrimitiveType.Double))
×
1046

1047
                    if not self.builder.exists(var):
×
1048
                        self.builder.add_container(var, elem_type, False)
×
1049
                        self.container_table[var] = elem_type
×
1050

1051
                    # Register the view in tensor_table
1052
                    self.tensor_table[var] = Tensor(
×
1053
                        element_type_from_sdfg_type(elem_type), inner_shapes
1054
                    )
1055

1056
                # Begin the for loop
1057
                self.builder.begin_for(idx_var, "0", str(arr_size), "1", debug_info)
×
1058

1059
                # Generate the assignment: var = array[idx_var]
1060
                # Create an AST node for the assignment and visit it
1061
                assign_node = ast.Assign(
×
1062
                    targets=[ast.Name(id=var, ctx=ast.Store())],
1063
                    value=ast.Subscript(
1064
                        value=ast.Name(id=iter_name, ctx=ast.Load()),
1065
                        slice=ast.Name(id=idx_var, ctx=ast.Load()),
1066
                        ctx=ast.Load(),
1067
                    ),
1068
                )
1069
                ast.copy_location(assign_node, node)
×
1070
                self.visit_Assign(assign_node)
×
1071

1072
                # Visit the loop body
1073
                for stmt in node.body:
×
1074
                    self.visit(stmt)
×
1075

1076
                self.builder.end_for()
×
1077
                return
×
1078

1079
        raise NotImplementedError(
×
1080
            f"Only range() loops and iteration over ndarrays supported, got: {ast.dump(node.iter)}"
1081
        )
1082

1083
    def visit_Return(self, node):
4✔
1084
        if node.value is None:
4✔
1085
            debug_info = get_debug_info(node, self.filename, self.function_name)
×
1086
            # Emit frees for all deferred allocations before returning
NEW
1087
            if self.memory_handler.has_allocations():
×
NEW
1088
                self.memory_handler.emit_frees()
×
1089
            self.builder.add_return("", debug_info)
×
1090
            return
×
1091

1092
        if isinstance(node.value, ast.Tuple):
4✔
1093
            values = node.value.elts
4✔
1094
        else:
1095
            values = [node.value]
4✔
1096

1097
        parsed_values = [self.visit(v) for v in values]
4✔
1098
        debug_info = get_debug_info(node, self.filename, self.function_name)
4✔
1099

1100
        if self.infer_return_type:
4✔
1101
            for i, res in enumerate(parsed_values):
4✔
1102
                ret_name = f"_docc_ret_{i}"
4✔
1103
                if not self.builder.exists(ret_name):
4✔
1104
                    dtype = Scalar(PrimitiveType.Double)
4✔
1105
                    if res in self.container_table:
4✔
1106
                        dtype = self.container_table[res]
4✔
1107
                    elif isinstance(values[i], ast.Constant):
×
1108
                        val = values[i].value
×
1109
                        if isinstance(val, int):
×
1110
                            dtype = Scalar(PrimitiveType.Int64)
×
1111
                        elif isinstance(val, float):
×
1112
                            dtype = Scalar(PrimitiveType.Double)
×
1113
                        elif isinstance(val, bool):
×
1114
                            dtype = Scalar(PrimitiveType.Bool)
×
1115

1116
                    # Wrap Scalar in Pointer. Keep Arrays/Pointers as is.
1117
                    arg_type = dtype
4✔
1118
                    if isinstance(dtype, Scalar):
4✔
1119
                        arg_type = Pointer(dtype)
4✔
1120

1121
                    self.builder.add_container(ret_name, arg_type, is_argument=True)
4✔
1122
                    self.container_table[ret_name] = arg_type
4✔
1123

1124
                    if res in self.tensor_table:
4✔
1125
                        self.tensor_table[ret_name] = self.tensor_table[res]
4✔
1126

1127
            self.infer_return_type = False
4✔
1128

1129
        for i, res in enumerate(parsed_values):
4✔
1130
            ret_name = f"_docc_ret_{i}"
4✔
1131
            typ = self.container_table.get(ret_name)
4✔
1132

1133
            is_array_return = False
4✔
1134
            if res in self.tensor_table:
4✔
1135
                # Only treat as array return if it has dimensions
1136
                # 0-d arrays (scalars) should be handled by scalar assignment
1137
                if len(self.tensor_table[res].shape) > 0:
4✔
1138
                    is_array_return = True
4✔
1139
            elif res in self.container_table:
4✔
1140
                if isinstance(self.container_table[res], Pointer):
4✔
1141
                    is_array_return = True
×
1142

1143
            # Simple Scalar Assignment
1144
            if not is_array_return:
4✔
1145
                block = self.builder.add_block(debug_info)
4✔
1146
                t_dst = self.builder.add_access(block, ret_name, debug_info)
4✔
1147

1148
                t_src, src_sub = self._add_read(block, res, debug_info)
4✔
1149

1150
                t_task = self.builder.add_tasklet(
4✔
1151
                    block, TaskletCode.assign, ["_in"], ["_out"], debug_info
1152
                )
1153
                self.builder.add_memlet(
4✔
1154
                    block, t_src, "void", t_task, "_in", src_sub, None, debug_info
1155
                )
1156
                self.builder.add_memlet(
4✔
1157
                    block, t_task, "_out", t_dst, "void", "0", None, debug_info
1158
                )
1159

1160
            # Array Assignment (Copy)
1161
            else:
1162
                # Record shape for metadata
1163
                if res in self.tensor_table:
4✔
1164
                    # Prefer runtime shapes if available (for indirect access patterns)
1165
                    # Fall back to regular shapes otherwise
1166
                    res_info = self.tensor_table[res]
4✔
1167
                    if res in self.shapes_runtime_info:
4✔
1168
                        shape = self.shapes_runtime_info[res]
4✔
1169
                    else:
1170
                        shape = res_info.shape
4✔
1171
                    # Convert to string expressions
1172
                    self.captured_return_shapes[ret_name] = [str(s) for s in shape]
4✔
1173

1174
                    # Return arrays are always contiguous - compute fresh strides
1175
                    contiguous_strides = self.numpy_visitor._compute_strides(shape, "C")
4✔
1176
                    self.captured_return_strides[ret_name] = [
4✔
1177
                        str(s) for s in contiguous_strides
1178
                    ]
1179

1180
                    # Always overwrite tensor_table for return arrays with contiguous strides
1181
                    # (source tensor may have non-standard strides from views/flip)
1182
                    self.tensor_table[ret_name] = Tensor(
4✔
1183
                        res_info.element_type, shape, contiguous_strides
1184
                    )
1185

1186
                # Copy Logic using visit_Assign
1187
                ndim = 1
4✔
1188
                if ret_name in self.tensor_table:
4✔
1189
                    ndim = len(self.tensor_table[ret_name].shape)
4✔
1190

1191
                slice_node = ast.Slice(lower=None, upper=None, step=None)
4✔
1192
                if ndim > 1:
4✔
1193
                    target_slice = ast.Tuple(elts=[slice_node] * ndim, ctx=ast.Load())
4✔
1194
                else:
1195
                    target_slice = slice_node
4✔
1196

1197
                target_sub = ast.Subscript(
4✔
1198
                    value=ast.Name(id=ret_name, ctx=ast.Load()),
1199
                    slice=target_slice,
1200
                    ctx=ast.Store(),
1201
                )
1202

1203
                # Value node reconstruction
1204
                if isinstance(values[i], ast.Name):
4✔
1205
                    val_node = values[i]
4✔
1206
                else:
1207
                    val_node = ast.Name(id=res, ctx=ast.Load())
4✔
1208

1209
                assign_node = ast.Assign(targets=[target_sub], value=val_node)
4✔
1210
                self.visit_Assign(assign_node)
4✔
1211

1212
        # Emit frees for all deferred allocations before returning
1213
        if self.memory_handler.has_allocations():
4✔
1214
            self.memory_handler.emit_frees()
4✔
1215

1216
        # Add control flow return to exit the function/path
1217
        self.builder.add_return("", debug_info)
4✔
1218

1219
    def visit_Call(self, node):
4✔
1220
        func_name = ""
4✔
1221
        module_name = ""
4✔
1222
        submodule_name = ""
4✔
1223
        if isinstance(node.func, ast.Attribute):
4✔
1224
            if isinstance(node.func.value, ast.Name):
4✔
1225
                if node.func.value.id == "math":
4✔
1226
                    module_name = "math"
4✔
1227
                    func_name = node.func.attr
4✔
1228
                elif node.func.value.id in ["numpy", "np"]:
4✔
1229
                    module_name = "numpy"
4✔
1230
                    func_name = node.func.attr
4✔
1231
                else:
1232
                    array_name = node.func.value.id
4✔
1233
                    method_name = node.func.attr
4✔
1234
                    if array_name in self.tensor_table and method_name == "astype":
4✔
1235
                        return self.numpy_visitor.handle_numpy_astype(node, array_name)
4✔
1236
                    elif array_name in self.tensor_table and method_name == "copy":
4✔
1237
                        return self.numpy_visitor.handle_numpy_copy(node, array_name)
4✔
1238
            elif isinstance(node.func.value, ast.Attribute):
4✔
1239
                if (
4✔
1240
                    isinstance(node.func.value.value, ast.Name)
1241
                    and node.func.value.value.id == "scipy"
1242
                ):
1243
                    module_name = "scipy"
4✔
1244
                    submodule_name = node.func.value.attr
4✔
1245
                    func_name = node.func.attr
4✔
1246
                elif (
4✔
1247
                    isinstance(node.func.value.value, ast.Name)
1248
                    and node.func.value.value.id in ["numpy", "np"]
1249
                    and node.func.attr == "outer"
1250
                ):
1251
                    ufunc_name = node.func.value.attr
4✔
1252
                    return self.numpy_visitor.handle_ufunc_outer(node, ufunc_name)
4✔
1253

1254
        elif isinstance(node.func, ast.Name):
4✔
1255
            func_name = node.func.id
4✔
1256

1257
        if module_name == "numpy":
4✔
1258
            if self.numpy_visitor.has_handler(func_name):
4✔
1259
                return self.numpy_visitor.handle_numpy_call(node, func_name)
4✔
1260

1261
        if module_name == "math":
4✔
1262
            if self.math_handler.has_handler(func_name):
4✔
1263
                return self.math_handler.handle_math_call(node, func_name)
4✔
1264

1265
        if module_name == "scipy":
4✔
1266
            if self.scipy_handler.has_handler(submodule_name, func_name):
4✔
1267
                return self.scipy_handler.handle_scipy_call(
4✔
1268
                    node, submodule_name, func_name
1269
                )
1270

1271
        if self.python_handler.has_handler(func_name):
4✔
1272
            return self.python_handler.handle_python_call(node, func_name)
4✔
1273

1274
        if func_name in self.globals_dict:
4✔
1275
            obj = self.globals_dict[func_name]
4✔
1276
            if inspect.isfunction(obj):
4✔
1277
                return self._handle_inline_call(node, obj)
4✔
1278

1279
        raise NotImplementedError(f"Function call {func_name} not supported")
×
1280

1281
    def _handle_inline_call(self, node, func_obj):
4✔
1282
        try:
4✔
1283
            source_lines, start_line = inspect.getsourcelines(func_obj)
4✔
1284
            source = textwrap.dedent("".join(source_lines))
4✔
1285
            tree = ast.parse(source)
4✔
1286
            func_def = tree.body[0]
4✔
1287
        except Exception as e:
×
1288
            raise NotImplementedError(
×
1289
                f"Could not parse function {func_obj.__name__}: {e}"
1290
            )
1291

1292
        arg_vars = [self.visit(arg) for arg in node.args]
4✔
1293

1294
        if len(arg_vars) != len(func_def.args.args):
4✔
1295
            raise NotImplementedError(
×
1296
                f"Argument count mismatch for {func_obj.__name__}"
1297
            )
1298

1299
        suffix = f"_{func_obj.__name__}_{self._get_unique_id()}"
4✔
1300
        res_name = f"_res{suffix}"
4✔
1301

1302
        # Combine globals with closure variables of the inlined function
1303
        combined_globals = dict(self.globals_dict)
4✔
1304
        closure_constants = {}  # name -> value for numeric closure vars
4✔
1305
        if func_obj.__closure__ is not None and func_obj.__code__.co_freevars:
4✔
1306
            for name, cell in zip(func_obj.__code__.co_freevars, func_obj.__closure__):
4✔
1307
                val = cell.cell_contents
4✔
1308
                combined_globals[name] = val
4✔
1309
                # Track numeric constants for injection
1310
                if isinstance(val, (int, float)) and not isinstance(val, bool):
4✔
1311
                    closure_constants[name] = val
4✔
1312

1313
        class VariableRenamer(ast.NodeTransformer):
4✔
1314
            BUILTINS = {
4✔
1315
                "range",
1316
                "len",
1317
                "int",
1318
                "float",
1319
                "bool",
1320
                "str",
1321
                "list",
1322
                "dict",
1323
                "tuple",
1324
                "set",
1325
                "print",
1326
                "abs",
1327
                "min",
1328
                "max",
1329
                "sum",
1330
                "enumerate",
1331
                "zip",
1332
                "map",
1333
                "filter",
1334
                "sorted",
1335
                "reversed",
1336
                "True",
1337
                "False",
1338
                "None",
1339
            }
1340

1341
            def __init__(self, suffix, globals_dict):
4✔
1342
                self.suffix = suffix
4✔
1343
                self.globals_dict = globals_dict
4✔
1344

1345
            def visit_Name(self, node):
4✔
1346
                if node.id in self.globals_dict or node.id in self.BUILTINS:
4✔
1347
                    return node
4✔
1348
                return ast.Name(id=f"{node.id}{self.suffix}", ctx=node.ctx)
4✔
1349

1350
            def visit_Return(self, node):
4✔
1351
                if node.value:
4✔
1352
                    val = self.visit(node.value)
4✔
1353
                    return ast.Assign(
4✔
1354
                        targets=[ast.Name(id=res_name, ctx=ast.Store())],
1355
                        value=val,
1356
                    )
1357
                return node
×
1358

1359
        renamer = VariableRenamer(suffix, combined_globals)
4✔
1360
        new_body = [renamer.visit(stmt) for stmt in func_def.body]
4✔
1361

1362
        param_assignments = []
4✔
1363

1364
        # Inject closure constants as assignments
1365
        for name, val in closure_constants.items():
4✔
1366
            if isinstance(val, int):
4✔
1367
                self.container_table[name] = Scalar(PrimitiveType.Int64)
4✔
1368
                self.builder.add_container(name, Scalar(PrimitiveType.Int64), False)
4✔
1369
                val_node = ast.Constant(value=val)
4✔
1370
            else:
1371
                self.container_table[name] = Scalar(PrimitiveType.Double)
×
1372
                self.builder.add_container(name, Scalar(PrimitiveType.Double), False)
×
1373
                val_node = ast.Constant(value=val)
×
1374
            assign = ast.Assign(
4✔
1375
                targets=[ast.Name(id=name, ctx=ast.Store())], value=val_node
1376
            )
1377
            param_assignments.append(assign)
4✔
1378

1379
        for arg_def, arg_val in zip(func_def.args.args, arg_vars):
4✔
1380
            param_name = f"{arg_def.arg}{suffix}"
4✔
1381

1382
            if arg_val in self.container_table:
4✔
1383
                self.container_table[param_name] = self.container_table[arg_val]
4✔
1384
                self.builder.add_container(
4✔
1385
                    param_name, self.container_table[arg_val], False
1386
                )
1387
                val_node = ast.Name(id=arg_val, ctx=ast.Load())
4✔
1388
            elif self._is_int(arg_val):
×
1389
                self.container_table[param_name] = Scalar(PrimitiveType.Int64)
×
1390
                self.builder.add_container(
×
1391
                    param_name, Scalar(PrimitiveType.Int64), False
1392
                )
1393
                val_node = ast.Constant(value=int(arg_val))
×
1394
            else:
1395
                try:
×
1396
                    val = float(arg_val)
×
1397
                    self.container_table[param_name] = Scalar(PrimitiveType.Double)
×
1398
                    self.builder.add_container(
×
1399
                        param_name, Scalar(PrimitiveType.Double), False
1400
                    )
1401
                    val_node = ast.Constant(value=val)
×
1402
                except ValueError:
×
1403
                    val_node = ast.Name(id=arg_val, ctx=ast.Load())
×
1404

1405
            assign = ast.Assign(
4✔
1406
                targets=[ast.Name(id=param_name, ctx=ast.Store())], value=val_node
1407
            )
1408
            param_assignments.append(assign)
4✔
1409

1410
        final_body = param_assignments + new_body
4✔
1411

1412
        # Create a new parser instance for the inlined function
1413
        # Share memory_handler so hoisted allocations go to main function entry
1414
        parser = ASTParser(
4✔
1415
            self.builder,
1416
            self.tensor_table,
1417
            self.container_table,
1418
            globals_dict=combined_globals,
1419
            unique_counter_ref=self._unique_counter_ref,
1420
            memory_handler=self.memory_handler,
1421
        )
1422

1423
        for stmt in final_body:
4✔
1424
            parser.visit(stmt)
4✔
1425

1426
        return res_name
4✔
1427

1428
    def _add_assign_constant(self, target_name, value_str, dtype):
4✔
1429
        block = self.builder.add_block()
4✔
1430
        t_const = self.builder.add_constant(block, value_str, dtype)
4✔
1431
        t_dst = self.builder.add_access(block, target_name)
4✔
1432
        t_task = self.builder.add_tasklet(block, TaskletCode.assign, ["_in"], ["_out"])
4✔
1433
        self.builder.add_memlet(block, t_const, "void", t_task, "_in", "")
4✔
1434
        self.builder.add_memlet(block, t_task, "_out", t_dst, "void", "")
4✔
1435

1436
    def _handle_expression_slicing(self, node, value_str, indices_nodes, shapes, ndim):
4✔
1437
        """Handle slicing in expressions (e.g., arr[1:, :, k+1]).
1438

1439
        Uses a zero-copy view when possible (positive step, no indirect access).
1440
        Falls back to copy-based approach for complex cases.
1441
        """
1442
        if not self.builder:
4✔
1443
            raise ValueError("Builder required for expression slicing")
×
1444

1445
        # Try view-based approach first (zero-copy)
1446
        if self._can_use_slice_view(indices_nodes):
4✔
1447
            return self._create_slice_view(value_str, indices_nodes, shapes, ndim)
4✔
1448

1449
        # Fall back to copy-based approach for complex cases
1450
        return self._handle_expression_slicing_copy(
×
1451
            node, value_str, indices_nodes, shapes, ndim
1452
        )
1453

1454
    def _can_use_slice_view(self, indices_nodes):
4✔
1455
        """Check if slicing can be expressed as a zero-copy view.
1456

1457
        Views can be used when:
1458
        - All steps are non-zero constants (positive or negative)
1459
        - No indirect array access in slice parameters
1460

1461
        Returns True if a view can be used, False if a copy is required.
1462
        """
1463
        for idx in indices_nodes:
4✔
1464
            if isinstance(idx, ast.Slice):
4✔
1465
                # Check for zero step (invalid)
1466
                if idx.step is not None:
4✔
1467
                    if isinstance(idx.step, ast.Constant):
4✔
1468
                        if idx.step.value == 0:
4✔
1469
                            return False  # Zero step is invalid
×
1470
                    elif isinstance(idx.step, ast.UnaryOp) and isinstance(
4✔
1471
                        idx.step.op, ast.USub
1472
                    ):
1473
                        # Negative step like -2 is OK
1474
                        pass
4✔
1475
                    elif self._contains_indirect_access(idx.step):
×
1476
                        return False  # Dynamic step requires copy
×
1477

1478
                # Check for indirect access in slice bounds
1479
                if idx.lower is not None and self._contains_indirect_access(idx.lower):
4✔
1480
                    return False
×
1481
                if idx.upper is not None and self._contains_indirect_access(idx.upper):
4✔
1482
                    return False
×
1483
            else:
1484
                # Fixed index: check for indirect access
1485
                if self._contains_indirect_access(idx):
4✔
1486
                    return False
×
1487
        return True
4✔
1488

1489
    def _create_slice_view(self, value_str, indices_nodes, shapes, ndim):
4✔
1490
        """Create a zero-copy view for array slicing.
1491

1492
        This creates a new tensor that shares data with the source but has
1493
        adjusted shape, strides, and offset to represent the sliced region.
1494

1495
        For positive step A[start:stop:step, ...] on dimension i:
1496
        - new_shape[i] = ceil((stop - start) / step)
1497
        - new_stride[i] = old_stride[i] * step
1498
        - offset contribution = start * old_stride[i]
1499

1500
        For negative step A[start:stop:step, ...] (e.g., ::-1):
1501
        - Default start = shape - 1 (last element)
1502
        - Default stop = -1 (before first element)
1503
        - new_shape[i] = ceil((start - stop) / abs(step))
1504
        - new_stride[i] = old_stride[i] * step (negative)
1505
        - offset contribution = start * old_stride[i] (points to last element)
1506

1507
        For a fixed index A[k, ...] on dimension i (dimension reduction):
1508
        - offset contribution = k * old_stride[i]
1509
        - dimension is removed from output
1510
        """
1511
        in_tensor = self.tensor_table[value_str]
4✔
1512
        in_shape = in_tensor.shape
4✔
1513
        dtype = in_tensor.element_type
4✔
1514

1515
        # Get input strides (compute if not available)
1516
        in_strides = (
4✔
1517
            in_tensor.strides
1518
            if hasattr(in_tensor, "strides") and in_tensor.strides
1519
            else None
1520
        )
1521
        if in_strides is None:
4✔
1522
            in_strides = self.numpy_visitor._compute_strides(in_shape, "C")
×
1523

1524
        # Get base offset from input tensor
1525
        in_offset = getattr(in_tensor, "offset", "0") or "0"
4✔
1526

1527
        # Build output shape, strides, and compute offset
1528
        out_shape = []
4✔
1529
        out_strides = []
4✔
1530
        offset_terms = []
4✔
1531
        if in_offset != "0":
4✔
1532
            offset_terms.append(str(in_offset))
4✔
1533

1534
        for i, idx in enumerate(indices_nodes):
4✔
1535
            shape_val = shapes[i] if i < len(shapes) else f"_{value_str}_shape_{i}"
4✔
1536
            stride_val = in_strides[i] if i < len(in_strides) else "1"
4✔
1537

1538
            if isinstance(idx, ast.Slice):
4✔
1539
                # Determine step value and sign
1540
                step_str = "1"
4✔
1541
                step_is_negative = False
4✔
1542
                step_value = 1
4✔
1543

1544
                if idx.step is not None:
4✔
1545
                    if isinstance(idx.step, ast.Constant):
4✔
1546
                        step_value = idx.step.value
4✔
1547
                        step_str = str(step_value)
4✔
1548
                        step_is_negative = step_value < 0
4✔
1549
                    elif isinstance(idx.step, ast.UnaryOp) and isinstance(
4✔
1550
                        idx.step.op, ast.USub
1551
                    ):
1552
                        # Handle -N syntax
1553
                        if isinstance(idx.step.operand, ast.Constant):
4✔
1554
                            step_value = -idx.step.operand.value
4✔
1555
                            step_str = str(step_value)
4✔
1556
                            step_is_negative = True
4✔
1557
                        else:
1558
                            step_str = self.visit(idx.step)
×
1559
                    else:
1560
                        step_str = self.visit(idx.step)
×
1561

1562
                if step_is_negative:
4✔
1563
                    # Negative step: iterate from end to start
1564
                    # Default start = shape - 1, default stop = -1 (before 0)
1565
                    if idx.lower is not None:
4✔
1566
                        start_str = self.visit(idx.lower)
×
1567
                        if isinstance(start_str, str) and (
×
1568
                            start_str.startswith("-") or start_str.startswith("(-")
1569
                        ):
1570
                            start_str = f"({shape_val} + {start_str})"
×
1571
                    else:
1572
                        start_str = f"({shape_val} - 1)"
4✔
1573

1574
                    if idx.upper is not None:
4✔
1575
                        stop_str = self.visit(idx.upper)
×
1576
                        if isinstance(stop_str, str) and (
×
1577
                            stop_str.startswith("-") or stop_str.startswith("(-")
1578
                        ):
1579
                            stop_str = f"({shape_val} + {stop_str})"
×
1580
                    else:
1581
                        stop_str = "-1"
4✔
1582

1583
                    # Shape for negative step: ceil((start - stop) / abs(step))
1584
                    abs_step = abs(step_value)
4✔
1585
                    if abs_step == 1:
4✔
1586
                        dim_size = f"({start_str} - {stop_str})"
4✔
1587
                    else:
1588
                        dim_size = f"(({start_str} - {stop_str} + {abs_step} - 1) / {abs_step})"
4✔
1589
                    out_shape.append(dim_size)
4✔
1590

1591
                    # Stride for negative step: old_stride * step (negative)
1592
                    out_strides.append(f"({stride_val} * {step_str})")
4✔
1593

1594
                    # Offset: start * old_stride (points to first element to access)
1595
                    offset_terms.append(f"({start_str} * {stride_val})")
4✔
1596
                else:
1597
                    # Positive step (original logic)
1598
                    start_str = "0"
4✔
1599
                    if idx.lower is not None:
4✔
1600
                        start_str = self.visit(idx.lower)
4✔
1601
                        if isinstance(start_str, str) and (
4✔
1602
                            start_str.startswith("-") or start_str.startswith("(-")
1603
                        ):
1604
                            start_str = f"({shape_val} + {start_str})"
×
1605

1606
                    stop_str = str(shape_val)
4✔
1607
                    if idx.upper is not None:
4✔
1608
                        stop_str = self.visit(idx.upper)
4✔
1609
                        if isinstance(stop_str, str) and (
4✔
1610
                            stop_str.startswith("-") or stop_str.startswith("(-")
1611
                        ):
1612
                            stop_str = f"({shape_val} + {stop_str})"
4✔
1613

1614
                    # Compute new shape: ceil((stop - start) / step)
1615
                    if step_str == "1":
4✔
1616
                        dim_size = f"({stop_str} - {start_str})"
4✔
1617
                    else:
1618
                        dim_size = f"idiv({stop_str} - {start_str} + {step_str} - 1, {step_str})"
4✔
1619
                    out_shape.append(dim_size)
4✔
1620

1621
                    # Compute new stride: old_stride * step
1622
                    if step_str == "1":
4✔
1623
                        out_strides.append(stride_val)
4✔
1624
                    else:
1625
                        out_strides.append(f"({stride_val} * {step_str})")
4✔
1626

1627
                    # Add offset contribution: start * stride
1628
                    if start_str != "0":
4✔
1629
                        offset_terms.append(f"({start_str} * {stride_val})")
4✔
1630
            else:
1631
                # Fixed index: dimension is removed, just add offset
1632
                index_str = self.visit(idx)
4✔
1633
                if isinstance(index_str, str) and (
4✔
1634
                    index_str.startswith("-") or index_str.startswith("(-")
1635
                ):
1636
                    index_str = f"({shape_val} + {index_str})"
4✔
1637
                offset_terms.append(f"({index_str} * {stride_val})")
4✔
1638

1639
        # Combine offset terms
1640
        if not offset_terms:
4✔
1641
            out_offset = "0"
4✔
1642
        elif len(offset_terms) == 1:
4✔
1643
            out_offset = offset_terms[0]
4✔
1644
        else:
1645
            out_offset = " + ".join(offset_terms)
4✔
1646

1647
        # Create new pointer container
1648
        tmp_name = self.builder.find_new_name("_slice_view_")
4✔
1649
        ptr_type = Pointer(dtype)
4✔
1650
        self.builder.add_container(tmp_name, ptr_type, False)
4✔
1651
        self.container_table[tmp_name] = ptr_type
4✔
1652

1653
        # Create output tensor with new shape, strides, and offset
1654
        # Offset is stored in the Tensor (like Tensor.flip() does)
1655
        # Reference memlet just creates the pointer alias with "0" offset
1656
        if out_shape:
4✔
1657
            out_tensor = Tensor(dtype, out_shape, out_strides, out_offset)
4✔
1658
            self.tensor_table[tmp_name] = out_tensor
4✔
1659
        else:
1660
            # Scalar result (all indices were fixed)
1661
            self.builder.add_container(tmp_name, dtype, False)
×
1662
            self.container_table[tmp_name] = dtype
×
1663

1664
        # Create reference memlet (offset is handled by tensor's offset property)
1665
        debug_info = DebugInfo()
4✔
1666
        block = self.builder.add_block(debug_info)
4✔
1667
        t_src = self.builder.add_access(block, value_str, debug_info)
4✔
1668
        t_dst = self.builder.add_access(block, tmp_name, debug_info)
4✔
1669
        self.builder.add_reference_memlet(block, t_src, t_dst, "0", ptr_type)
4✔
1670

1671
        return tmp_name
4✔
1672

1673
    def _handle_expression_slicing_copy(
4✔
1674
        self, node, value_str, indices_nodes, shapes, ndim
1675
    ):
1676
        """Copy-based slicing for cases that cannot use views.
1677

1678
        This allocates a new array and copies elements using nested loops.
1679
        Used for negative steps or indirect access patterns.
1680
        """
1681
        dtype = Scalar(PrimitiveType.Double)
×
1682
        if value_str in self.container_table:
×
1683
            t = self.container_table[value_str]
×
1684
            if isinstance(t, Pointer) and t.has_pointee_type():
×
1685
                dtype = t.pointee_type
×
1686

1687
        result_shapes = []
×
1688
        result_shapes_runtime = []
×
1689
        slice_info = []
×
1690
        index_info = []
×
1691

1692
        for i, idx in enumerate(indices_nodes):
×
1693
            shape_val = shapes[i] if i < len(shapes) else f"_{value_str}_shape_{i}"
×
1694

1695
            if isinstance(idx, ast.Slice):
×
1696
                start_str = "0"
×
1697
                start_str_runtime = "0"
×
1698
                if idx.lower is not None:
×
1699
                    if self._contains_indirect_access(idx.lower):
×
1700
                        start_str, start_str_runtime = (
×
1701
                            self._materialize_indirect_access(
1702
                                idx.lower, return_original_expr=True
1703
                            )
1704
                        )
1705
                    else:
1706
                        start_str = self.visit(idx.lower)
×
1707
                        start_str_runtime = start_str
×
1708
                    if isinstance(start_str, str) and (
×
1709
                        start_str.startswith("-") or start_str.startswith("(-")
1710
                    ):
1711
                        start_str = f"({shape_val} + {start_str})"
×
1712
                        start_str_runtime = f"({shape_val} + {start_str_runtime})"
×
1713

1714
                stop_str = str(shape_val)
×
1715
                stop_str_runtime = str(shape_val)
×
1716
                if idx.upper is not None:
×
1717
                    if self._contains_indirect_access(idx.upper):
×
1718
                        stop_str, stop_str_runtime = self._materialize_indirect_access(
×
1719
                            idx.upper, return_original_expr=True
1720
                        )
1721
                    else:
1722
                        stop_str = self.visit(idx.upper)
×
1723
                        stop_str_runtime = stop_str
×
1724
                    if isinstance(stop_str, str) and (
×
1725
                        stop_str.startswith("-") or stop_str.startswith("(-")
1726
                    ):
1727
                        stop_str = f"({shape_val} + {stop_str})"
×
1728
                        stop_str_runtime = f"({shape_val} + {stop_str_runtime})"
×
1729

1730
                step_str = "1"
×
1731
                if idx.step is not None:
×
1732
                    step_str = self.visit(idx.step)
×
1733

1734
                # Compute dimension size accounting for step: ceil((stop - start) / step)
1735
                # For symbolic expressions, use integer ceiling formula: idiv(n + d - 1, d)
1736
                if step_str == "1":
×
1737
                    dim_size = f"({stop_str} - {start_str})"
×
1738
                    dim_size_runtime = f"({stop_str_runtime} - {start_str_runtime})"
×
1739
                else:
1740
                    dim_size = (
×
1741
                        f"idiv({stop_str} - {start_str} + {step_str} - 1, {step_str})"
1742
                    )
1743
                    dim_size_runtime = f"idiv({stop_str_runtime} - {start_str_runtime} + {step_str} - 1, {step_str})"
×
1744
                result_shapes.append(dim_size)
×
1745
                result_shapes_runtime.append(dim_size_runtime)
×
1746
                slice_info.append((i, start_str, stop_str, step_str))
×
1747
            else:
1748
                if self._contains_indirect_access(idx):
×
1749
                    index_str = self._materialize_indirect_access(idx)
×
1750
                else:
1751
                    index_str = self.visit(idx)
×
1752
                if isinstance(index_str, str) and (
×
1753
                    index_str.startswith("-") or index_str.startswith("(-")
1754
                ):
1755
                    index_str = f"({shape_val} + {index_str})"
×
1756
                index_info.append((i, index_str))
×
1757

1758
        tmp_name = self.builder.find_new_name("_slice_tmp_")
×
1759
        result_ndim = len(result_shapes)
×
1760

1761
        if result_ndim == 0:
×
1762
            self.builder.add_container(tmp_name, dtype, False)
×
1763
            self.container_table[tmp_name] = dtype
×
1764
        else:
1765
            size_str = "1"
×
1766
            for dim in result_shapes:
×
1767
                size_str = f"({size_str} * {dim})"
×
1768

1769
            element_size = self.builder.get_sizeof(dtype)
×
1770
            total_size = f"({size_str} * {element_size})"
×
1771

1772
            ptr_type = Pointer(dtype)
×
1773
            self.builder.add_container(tmp_name, ptr_type, False)
×
1774
            self.container_table[tmp_name] = ptr_type
×
1775
            tensor_info = Tensor(dtype, result_shapes)
×
1776
            self.shapes_runtime_info[tmp_name] = (
×
1777
                result_shapes_runtime  # Store runtime shapes separately
1778
            )
1779
            self.tensor_table[tmp_name] = tensor_info
×
1780

1781
            debug_info = DebugInfo()
×
1782
            block_alloc = self.builder.add_block(debug_info)
×
1783
            t_malloc = self.builder.add_malloc(block_alloc, total_size)
×
1784
            t_ptr = self.builder.add_access(block_alloc, tmp_name, debug_info)
×
1785
            self.builder.add_memlet(
×
1786
                block_alloc, t_malloc, "_ret", t_ptr, "void", "", ptr_type, debug_info
1787
            )
1788

1789
        loop_vars = []
×
1790
        debug_info = DebugInfo()
×
1791

1792
        for dim_idx, (orig_dim, start_str, stop_str, step_str) in enumerate(slice_info):
×
1793
            loop_var = self.builder.find_new_name(f"_slice_loop_{dim_idx}_")
×
1794
            loop_vars.append((loop_var, orig_dim, start_str, step_str))
×
1795

1796
            if not self.builder.exists(loop_var):
×
1797
                self.builder.add_container(loop_var, Scalar(PrimitiveType.Int64), False)
×
1798
                self.container_table[loop_var] = Scalar(PrimitiveType.Int64)
×
1799

1800
            # Account for step in loop count: ceil((stop - start) / step)
1801
            if step_str == "1":
×
1802
                count_str = f"({stop_str} - {start_str})"
×
1803
            else:
1804
                count_str = (
×
1805
                    f"idiv({stop_str} - {start_str} + {step_str} - 1, {step_str})"
1806
                )
1807
            self.builder.begin_for(loop_var, "0", count_str, "1", debug_info)
×
1808

1809
        src_indices = [""] * ndim
×
1810
        dst_indices = []
×
1811

1812
        for orig_dim, index_str in index_info:
×
1813
            src_indices[orig_dim] = index_str
×
1814

1815
        for loop_var, orig_dim, start_str, step_str in loop_vars:
×
1816
            if step_str == "1":
×
1817
                src_indices[orig_dim] = f"({start_str} + {loop_var})"
×
1818
            else:
1819
                src_indices[orig_dim] = f"({start_str} + {loop_var} * {step_str})"
×
1820
            dst_indices.append(loop_var)
×
1821

1822
        src_linear = self._compute_linear_index(src_indices, shapes, value_str, ndim)
×
1823
        if result_ndim > 0:
×
1824
            dst_linear = self._compute_linear_index(
×
1825
                dst_indices, result_shapes, tmp_name, result_ndim
1826
            )
1827
        else:
1828
            dst_linear = "0"
×
1829

1830
        block = self.builder.add_block(debug_info)
×
1831
        t_src = self.builder.add_access(block, value_str, debug_info)
×
1832
        t_dst = self.builder.add_access(block, tmp_name, debug_info)
×
1833
        t_task = self.builder.add_tasklet(
×
1834
            block, TaskletCode.assign, ["_in"], ["_out"], debug_info
1835
        )
1836

1837
        self.builder.add_memlet(
×
1838
            block, t_src, "void", t_task, "_in", src_linear, None, debug_info
1839
        )
1840
        self.builder.add_memlet(
×
1841
            block, t_task, "_out", t_dst, "void", dst_linear, None, debug_info
1842
        )
1843

1844
        for _ in loop_vars:
×
1845
            self.builder.end_for()
×
1846

1847
        return tmp_name
×
1848

1849
    def _compute_linear_index(self, indices, shapes, array_name, ndim):
4✔
1850
        """Compute linear index from multi-dimensional indices."""
1851
        if ndim == 0:
×
1852
            return "0"
×
1853

1854
        linear_index = ""
×
1855
        for i in range(ndim):
×
1856
            term = str(indices[i])
×
1857
            for j in range(i + 1, ndim):
×
1858
                shape_val = shapes[j] if j < len(shapes) else f"_{array_name}_shape_{j}"
×
1859
                term = f"(({term}) * {shape_val})"
×
1860

1861
            if i == 0:
×
1862
                linear_index = term
×
1863
            else:
1864
                linear_index = f"({linear_index} + {term})"
×
1865

1866
        return linear_index
×
1867

1868
    def _is_array_index(self, node):
4✔
1869
        """Check if a node represents an array that could be used as an index (gather)."""
1870
        if isinstance(node, ast.Name):
4✔
1871
            return node.id in self.tensor_table
4✔
1872
        return False
4✔
1873

1874
    def _handle_gather(self, value_str, index_node, debug_info=None):
4✔
1875
        """Handle gather operation: x[indices] where indices is an array."""
1876
        if debug_info is None:
×
1877
            debug_info = DebugInfo()
×
1878

1879
        if isinstance(index_node, ast.Name):
×
1880
            idx_array_name = index_node.id
×
1881
        else:
1882
            idx_array_name = self.visit(index_node)
×
1883

1884
        if idx_array_name not in self.tensor_table:
×
1885
            raise ValueError(f"Gather index must be an array, got {idx_array_name}")
×
1886

1887
        idx_shapes = self.tensor_table[idx_array_name].shape
×
1888
        idx_ndim = len(idx_shapes)
×
1889

1890
        if idx_ndim != 1:
×
1891
            raise NotImplementedError("Only 1D index arrays supported for gather")
×
1892

1893
        result_shape = idx_shapes[0] if idx_shapes else f"_{idx_array_name}_shape_0"
×
1894

1895
        # For runtime evaluation, prefer shapes_runtime_info if available
1896
        # This ensures we use expressions that can be evaluated at runtime
1897
        if idx_array_name in self.shapes_runtime_info:
×
1898
            runtime_shapes = self.shapes_runtime_info[idx_array_name]
×
1899
            result_shape_runtime = runtime_shapes[0] if runtime_shapes else result_shape
×
1900
        else:
1901
            result_shape_runtime = result_shape
×
1902

1903
        dtype = Scalar(PrimitiveType.Double)
×
1904
        if value_str in self.container_table:
×
1905
            t = self.container_table[value_str]
×
1906
            if isinstance(t, Pointer) and t.has_pointee_type():
×
1907
                dtype = t.pointee_type
×
1908

1909
        idx_dtype = Scalar(PrimitiveType.Int64)
×
1910
        if idx_array_name in self.container_table:
×
1911
            t = self.container_table[idx_array_name]
×
1912
            if isinstance(t, Pointer) and t.has_pointee_type():
×
1913
                idx_dtype = t.pointee_type
×
1914

1915
        tmp_name = self.builder.find_new_name("_gather_")
×
1916

1917
        element_size = self.builder.get_sizeof(dtype)
×
1918
        total_size = f"({result_shape} * {element_size})"
×
1919

1920
        ptr_type = Pointer(dtype)
×
1921
        self.builder.add_container(tmp_name, ptr_type, False)
×
1922
        self.container_table[tmp_name] = ptr_type
×
1923
        self.tensor_table[tmp_name] = Tensor(dtype, [result_shape])
×
1924
        # Store runtime evaluable shape for this gather result
1925
        self.shapes_runtime_info[tmp_name] = [result_shape_runtime]
×
1926

1927
        block_alloc = self.builder.add_block(debug_info)
×
1928
        t_malloc = self.builder.add_malloc(block_alloc, total_size)
×
1929
        t_ptr = self.builder.add_access(block_alloc, tmp_name, debug_info)
×
1930
        self.builder.add_memlet(
×
1931
            block_alloc, t_malloc, "_ret", t_ptr, "void", "", ptr_type, debug_info
1932
        )
1933

1934
        loop_var = self.builder.find_new_name("_gather_i_")
×
1935
        self.builder.add_container(loop_var, Scalar(PrimitiveType.Int64), False)
×
1936
        self.container_table[loop_var] = Scalar(PrimitiveType.Int64)
×
1937

1938
        idx_var = self.builder.find_new_name("_gather_idx_")
×
1939
        self.builder.add_container(idx_var, idx_dtype, False)
×
1940
        self.container_table[idx_var] = idx_dtype
×
1941

1942
        self.builder.begin_for(loop_var, "0", str(result_shape), "1", debug_info)
×
1943

1944
        block_load_idx = self.builder.add_block(debug_info)
×
1945
        idx_arr_access = self.builder.add_access(
×
1946
            block_load_idx, idx_array_name, debug_info
1947
        )
1948
        idx_var_access = self.builder.add_access(block_load_idx, idx_var, debug_info)
×
1949
        tasklet_load = self.builder.add_tasklet(
×
1950
            block_load_idx, TaskletCode.assign, ["_in"], ["_out"], debug_info
1951
        )
1952
        self.builder.add_memlet(
×
1953
            block_load_idx,
1954
            idx_arr_access,
1955
            "void",
1956
            tasklet_load,
1957
            "_in",
1958
            loop_var,
1959
            None,
1960
            debug_info,
1961
        )
1962
        self.builder.add_memlet(
×
1963
            block_load_idx,
1964
            tasklet_load,
1965
            "_out",
1966
            idx_var_access,
1967
            "void",
1968
            "",
1969
            None,
1970
            debug_info,
1971
        )
1972

1973
        block_gather = self.builder.add_block(debug_info)
×
1974
        src_access = self.builder.add_access(block_gather, value_str, debug_info)
×
1975
        dst_access = self.builder.add_access(block_gather, tmp_name, debug_info)
×
1976
        tasklet_gather = self.builder.add_tasklet(
×
1977
            block_gather, TaskletCode.assign, ["_in"], ["_out"], debug_info
1978
        )
1979

1980
        self.builder.add_memlet(
×
1981
            block_gather,
1982
            src_access,
1983
            "void",
1984
            tasklet_gather,
1985
            "_in",
1986
            idx_var,
1987
            None,
1988
            debug_info,
1989
        )
1990
        self.builder.add_memlet(
×
1991
            block_gather,
1992
            tasklet_gather,
1993
            "_out",
1994
            dst_access,
1995
            "void",
1996
            loop_var,
1997
            None,
1998
            debug_info,
1999
        )
2000

2001
        self.builder.end_for()
×
2002

2003
        return tmp_name
×
2004

2005
    def _get_max_array_ndim_in_expr(self, node):
4✔
2006
        """Get the maximum array dimensionality in an expression."""
2007
        max_ndim = 0
4✔
2008

2009
        class NdimVisitor(ast.NodeVisitor):
4✔
2010
            def __init__(self, tensor_table):
4✔
2011
                self.tensor_table = tensor_table
4✔
2012
                self.max_ndim = 0
4✔
2013

2014
            def visit_Name(self, node):
4✔
2015
                if node.id in self.tensor_table:
4✔
2016
                    ndim = len(self.tensor_table[node.id].shape)
4✔
2017
                    self.max_ndim = max(self.max_ndim, ndim)
4✔
2018
                return self.generic_visit(node)
4✔
2019

2020
        visitor = NdimVisitor(self.tensor_table)
4✔
2021
        visitor.visit(node)
4✔
2022
        return visitor.max_ndim
4✔
2023

2024
    def _handle_broadcast_slice_assignment(
4✔
2025
        self,
2026
        target,
2027
        materialized_rhs,
2028
        target_name,
2029
        indices,
2030
        target_ndim,
2031
        value_ndim,
2032
        debug_info,
2033
    ):
2034
        """Handle slice assignment with broadcasting (e.g., 2D[:,:] = 1D[:]).
2035

2036
        materialized_rhs is the already-evaluated RHS array name (not AST node).
2037
        """
2038
        broadcast_dims = target_ndim - value_ndim
×
2039
        shapes = self.tensor_table[target_name].shape
×
2040
        rhs_tensor = self.tensor_table.get(materialized_rhs)
×
2041
        rhs_shapes = rhs_tensor.shape if rhs_tensor else []
×
2042

2043
        # Create outer loops for broadcast dimensions
2044
        outer_loop_vars = []
×
2045
        for i in range(broadcast_dims):
×
2046
            loop_var = self.builder.find_new_name(f"_bcast_iter_{i}_")
×
2047
            outer_loop_vars.append(loop_var)
×
2048

2049
            if not self.builder.exists(loop_var):
×
2050
                self.builder.add_container(loop_var, Scalar(PrimitiveType.Int64), False)
×
2051
                self.container_table[loop_var] = Scalar(PrimitiveType.Int64)
×
2052

2053
            dim_size = shapes[i] if i < len(shapes) else f"_{target_name}_shape_{i}"
×
2054
            self.builder.begin_for(loop_var, "0", str(dim_size), "1", debug_info)
×
2055

2056
        # Create inner loops for value dimensions
2057
        inner_loop_vars = []
×
2058
        for i in range(value_ndim):
×
2059
            loop_var = self.builder.find_new_name(f"_inner_iter_{i}_")
×
2060
            inner_loop_vars.append(loop_var)
×
2061

2062
            if not self.builder.exists(loop_var):
×
2063
                self.builder.add_container(loop_var, Scalar(PrimitiveType.Int64), False)
×
2064
                self.container_table[loop_var] = Scalar(PrimitiveType.Int64)
×
2065

2066
            # Use RHS shape for inner dimension bounds
2067
            dim_size = (
×
2068
                rhs_shapes[i] if i < len(rhs_shapes) else shapes[broadcast_dims + i]
2069
            )
2070
            self.builder.begin_for(loop_var, "0", str(dim_size), "1", debug_info)
×
2071

2072
        # Create assignment block: target[outer_vars, inner_vars] = rhs[inner_vars]
2073
        block = self.builder.add_block(debug_info)
×
2074
        t_src = self.builder.add_access(block, materialized_rhs, debug_info)
×
2075
        t_dst = self.builder.add_access(block, target_name, debug_info)
×
2076
        t_task = self.builder.add_tasklet(
×
2077
            block, TaskletCode.assign, ["_in"], ["_out"], debug_info
2078
        )
2079

2080
        # Source index: just inner loop vars
2081
        src_index = ",".join(inner_loop_vars) if inner_loop_vars else "0"
×
2082

2083
        # Target index: outer_vars + inner_vars combined
2084
        all_target_vars = outer_loop_vars + inner_loop_vars
×
2085
        target_index = ",".join(all_target_vars) if all_target_vars else "0"
×
2086

2087
        self.builder.add_memlet(
×
2088
            block, t_src, "void", t_task, "_in", src_index, rhs_tensor, debug_info
2089
        )
2090

2091
        tensor_dst = self.tensor_table[target_name]
×
2092
        self.builder.add_memlet(
×
2093
            block, t_task, "_out", t_dst, "void", target_index, tensor_dst, debug_info
2094
        )
2095

2096
        # Close all loops (inner first, then outer)
2097
        for _ in inner_loop_vars:
×
2098
            self.builder.end_for()
×
2099
        for _ in outer_loop_vars:
×
2100
            self.builder.end_for()
×
2101

2102
    def _handle_slice_assignment(
4✔
2103
        self, target, value, target_name, indices, debug_info=None
2104
    ):
2105
        if debug_info is None:
4✔
2106
            debug_info = DebugInfo()
×
2107

2108
        # Add missing dimensions
2109
        tensor_info = self.tensor_table[target_name]
4✔
2110
        ndim = len(tensor_info.shape)
4✔
2111
        if len(indices) < ndim:
4✔
2112
            indices = list(indices)
4✔
2113
            for _ in range(ndim - len(indices)):
4✔
2114
                indices.append(ast.Slice(lower=None, upper=None, step=None))
4✔
2115

2116
        # Handle ufunc outer case separately to preserve slice shape info
2117
        has_outer, ufunc_name, outer_node = contains_ufunc_outer(value)
4✔
2118
        if has_outer:
4✔
2119
            self._handle_ufunc_outer_slice_assignment(
4✔
2120
                target, value, target_name, indices, debug_info
2121
            )
2122
            return
4✔
2123

2124
        # Count slice dimensions to determine effective target dimensionality
2125
        target_slice_ndim = sum(1 for idx in indices if isinstance(idx, ast.Slice))
4✔
2126
        value_max_ndim = self._get_max_array_ndim_in_expr(value)
4✔
2127

2128
        # ALWAYS evaluate RHS first (NumPy semantics) - before any loops
2129
        materialized_rhs = self.visit(value)
4✔
2130

2131
        if (
4✔
2132
            target_slice_ndim > 0
2133
            and value_max_ndim > 0
2134
            and target_slice_ndim > value_max_ndim
2135
        ):
2136
            # Broadcasting case: use row-by-row approach with reference memlets
2137
            self._handle_broadcast_slice_assignment(
×
2138
                target,
2139
                materialized_rhs,
2140
                target_name,
2141
                indices,
2142
                target_slice_ndim,
2143
                value_max_ndim,
2144
                debug_info,
2145
            )
2146
            return
×
2147

2148
        loop_vars = []
4✔
2149
        new_target_indices = []
4✔
2150

2151
        for i, idx in enumerate(indices):
4✔
2152
            if isinstance(idx, ast.Slice):
4✔
2153
                loop_var = self.builder.find_new_name(f"_slice_iter_{len(loop_vars)}_")
4✔
2154
                loop_vars.append(loop_var)
4✔
2155

2156
                if not self.builder.exists(loop_var):
4✔
2157
                    self.builder.add_container(
4✔
2158
                        loop_var, Scalar(PrimitiveType.Int64), False
2159
                    )
2160
                    self.container_table[loop_var] = Scalar(PrimitiveType.Int64)
4✔
2161

2162
                start_str = "0"
4✔
2163
                if idx.lower:
4✔
2164
                    start_str = self.visit(idx.lower)
4✔
2165
                    if start_str.startswith("-"):
4✔
2166
                        dim_size = (
×
2167
                            str(tensor_info.shape[i])
2168
                            if i < len(tensor_info.shape)
2169
                            else f"_{target_name}_shape_{i}"
2170
                        )
2171
                        start_str = f"({dim_size} {start_str})"
×
2172

2173
                stop_str = ""
4✔
2174
                if idx.upper and not (
4✔
2175
                    isinstance(idx.upper, ast.Constant) and idx.upper.value is None
2176
                ):
2177
                    stop_str = self.visit(idx.upper)
4✔
2178
                    if stop_str.startswith("-") or stop_str.startswith("(-"):
4✔
2179
                        dim_size = (
×
2180
                            str(tensor_info.shape[i])
2181
                            if i < len(tensor_info.shape)
2182
                            else f"_{target_name}_shape_{i}"
2183
                        )
2184
                        stop_str = f"({dim_size} {stop_str})"
×
2185
                else:
2186
                    stop_str = (
4✔
2187
                        str(tensor_info.shape[i])
2188
                        if i < len(tensor_info.shape)
2189
                        else f"_{target_name}_shape_{i}"
2190
                    )
2191

2192
                step_str = "1"
4✔
2193
                if idx.step:
4✔
2194
                    step_str = self.visit(idx.step)
×
2195

2196
                count_str = f"({stop_str} - {start_str})"
4✔
2197

2198
                self.builder.begin_for(loop_var, "0", count_str, "1", debug_info)
4✔
2199
                self.container_table[loop_var] = Scalar(PrimitiveType.Int64)
4✔
2200
                new_target_indices.append(
4✔
2201
                    ast.Name(
2202
                        id=f"{start_str} + {loop_var} * {step_str}", ctx=ast.Load()
2203
                    )
2204
                )
2205
            else:
2206
                dim_size = (
4✔
2207
                    tensor_info.shape[i]
2208
                    if i < len(tensor_info.shape)
2209
                    else f"_{target_name}_shape_{i}"
2210
                )
2211
                normalized_idx = normalize_negative_index(idx, dim_size)
4✔
2212
                # intermediate computations are placed outside the loops
2213
                idx_str = self.visit(normalized_idx)
4✔
2214
                new_target_indices.append(ast.Name(id=idx_str, ctx=ast.Load()))
4✔
2215

2216
        rewriter = SliceRewriter(loop_vars, self.tensor_table, self)
4✔
2217
        new_value = rewriter.visit(copy.deepcopy(value))
4✔
2218

2219
        new_target = copy.deepcopy(target)
4✔
2220
        if len(new_target_indices) == 1:
4✔
2221
            new_target.slice = new_target_indices[0]
4✔
2222
        else:
2223
            new_target.slice = ast.Tuple(elts=new_target_indices, ctx=ast.Load())
4✔
2224

2225
        rhs_memlet_type = None
4✔
2226
        rhs_indexed_subset = ""
4✔
2227
        if materialized_rhs in self.tensor_table:
4✔
2228
            rhs_tensor = self.tensor_table[materialized_rhs]
4✔
2229
            rhs_ndim = len(rhs_tensor.shape)
4✔
2230
            if rhs_ndim > 0 and rhs_ndim == len(loop_vars):
4✔
2231
                # RHS is an array matching the slice dimensions - index it with loop vars
2232
                rhs_indexed_subset = ",".join(loop_vars)
4✔
2233
                rhs_memlet_type = rhs_tensor
4✔
2234

2235
        block = self.builder.add_block(debug_info)
4✔
2236
        t_task = self.builder.add_tasklet(
4✔
2237
            block, TaskletCode.assign, ["_in"], ["_out"], debug_info
2238
        )
2239

2240
        t_src, src_sub = self._add_read(block, materialized_rhs, debug_info)
4✔
2241
        # Use indexed subset if RHS is an array that needs indexing
2242
        actual_src_sub = rhs_indexed_subset if rhs_indexed_subset else src_sub
4✔
2243
        self.builder.add_memlet(
4✔
2244
            block,
2245
            t_src,
2246
            "void",
2247
            t_task,
2248
            "_in",
2249
            actual_src_sub,
2250
            rhs_memlet_type,
2251
            debug_info,
2252
        )
2253

2254
        lhs_expr = self.visit(new_target)
4✔
2255
        if "(" in lhs_expr and lhs_expr.endswith(")"):
4✔
2256
            subset = lhs_expr[lhs_expr.find("(") + 1 : -1]
4✔
2257
            tensor_dst = self.tensor_table[target_name]
4✔
2258

2259
            t_dst = self.builder.add_access(block, target_name, debug_info)
4✔
2260
            self.builder.add_memlet(
4✔
2261
                block, t_task, "_out", t_dst, "void", subset, tensor_dst, debug_info
2262
            )
2263
        else:
2264
            t_dst = self.builder.add_access(block, target_name, debug_info)
×
2265
            self.builder.add_memlet(
×
2266
                block, t_task, "_out", t_dst, "void", "", None, debug_info
2267
            )
2268

2269
        for _ in loop_vars:
4✔
2270
            self.builder.end_for()
4✔
2271

2272
    def _handle_ufunc_outer_slice_assignment(
4✔
2273
        self, target, value, target_name, indices, debug_info=None
2274
    ):
2275
        """Handle slice assignment where RHS contains a ufunc outer operation.
2276

2277
        Example: path[:] = np.minimum(path[:], np.add.outer(path[:, k], path[k, :]))
2278

2279
        The strategy is:
2280
        1. Evaluate the entire RHS expression, which will create a temporary array
2281
           containing the result of the ufunc outer (potentially wrapped in other ops)
2282
        2. Copy the temporary result to the target slice
2283

2284
        This avoids the loop transformation that would destroy slice shape info.
2285
        """
2286
        if debug_info is None:
4✔
2287
            from docc.sdfg import DebugInfo
×
2288

2289
            debug_info = DebugInfo()
×
2290

2291
        # Evaluate the full RHS expression
2292
        # This will:
2293
        # - Create temp arrays for ufunc outer results
2294
        # - Apply any wrapping operations (np.minimum, etc.)
2295
        # - Return the name of the final result array
2296
        result_name = self.visit(value)
4✔
2297

2298
        # Now we need to copy result to target slice
2299
        # Count slice dimensions to determine if we need loops
2300
        target_slice_ndim = sum(1 for idx in indices if isinstance(idx, ast.Slice))
4✔
2301

2302
        if target_slice_ndim == 0:
4✔
2303
            # No slices on target - just simple assignment
2304
            target_str = self.visit(target)
×
2305
            block = self.builder.add_block(debug_info)
×
2306
            t_src, src_sub = self._add_read(block, result_name, debug_info)
×
2307
            t_dst = self.builder.add_access(block, target_str, debug_info)
×
2308
            t_task = self.builder.add_tasklet(
×
2309
                block, TaskletCode.assign, ["_in"], ["_out"], debug_info
2310
            )
2311
            self.builder.add_memlet(
×
2312
                block, t_src, "void", t_task, "_in", src_sub, None, debug_info
2313
            )
2314
            self.builder.add_memlet(
×
2315
                block, t_task, "_out", t_dst, "void", "", None, debug_info
2316
            )
2317
            return
×
2318

2319
        # We have slices on the target - need to create loops for copying
2320
        # Get target array info
2321
        target_shapes = self.tensor_table[target_name].shape
4✔
2322

2323
        loop_vars = []
4✔
2324
        new_target_indices = []
4✔
2325

2326
        for i, idx in enumerate(indices):
4✔
2327
            if isinstance(idx, ast.Slice):
4✔
2328
                loop_var = self.builder.find_new_name(f"_copy_iter_{len(loop_vars)}_")
4✔
2329
                loop_vars.append(loop_var)
4✔
2330

2331
                if not self.builder.exists(loop_var):
4✔
2332
                    self.builder.add_container(
4✔
2333
                        loop_var, Scalar(PrimitiveType.Int64), False
2334
                    )
2335
                    self.container_table[loop_var] = Scalar(PrimitiveType.Int64)
4✔
2336

2337
                start_str = "0"
4✔
2338
                if idx.lower:
4✔
2339
                    start_str = self.visit(idx.lower)
×
2340

2341
                stop_str = ""
4✔
2342
                if idx.upper and not (
4✔
2343
                    isinstance(idx.upper, ast.Constant) and idx.upper.value is None
2344
                ):
2345
                    stop_str = self.visit(idx.upper)
×
2346
                else:
2347
                    stop_str = (
4✔
2348
                        target_shapes[i]
2349
                        if i < len(target_shapes)
2350
                        else f"_{target_name}_shape_{i}"
2351
                    )
2352

2353
                step_str = "1"
4✔
2354
                if idx.step:
4✔
2355
                    step_str = self.visit(idx.step)
×
2356

2357
                count_str = f"({stop_str} - {start_str})"
4✔
2358

2359
                self.builder.begin_for(loop_var, "0", count_str, "1", debug_info)
4✔
2360
                self.container_table[loop_var] = Scalar(PrimitiveType.Int64)
4✔
2361

2362
                new_target_indices.append(
4✔
2363
                    ast.Name(
2364
                        id=f"{start_str} + {loop_var} * {step_str}", ctx=ast.Load()
2365
                    )
2366
                )
2367
            else:
2368
                # Handle non-slice indices - need to normalize negative indices
2369
                dim_size = (
×
2370
                    target_shapes[i]
2371
                    if i < len(target_shapes)
2372
                    else f"_{target_name}_shape_{i}"
2373
                )
2374
                normalized_idx = normalize_negative_index(idx, dim_size)
×
2375
                # Visit the index NOW before any loops are opened to ensure
2376
                # intermediate computations are placed outside the loops
2377
                idx_str = self.visit(normalized_idx)
×
2378
                new_target_indices.append(ast.Name(id=idx_str, ctx=ast.Load()))
×
2379

2380
        # Create assignment block: target[i,j,...] = result[i,j,...]
2381
        block = self.builder.add_block(debug_info)
4✔
2382

2383
        # Access nodes
2384
        t_src = self.builder.add_access(block, result_name, debug_info)
4✔
2385
        t_dst = self.builder.add_access(block, target_name, debug_info)
4✔
2386
        t_task = self.builder.add_tasklet(
4✔
2387
            block, TaskletCode.assign, ["_in"], ["_out"], debug_info
2388
        )
2389

2390
        # Source index - just use loop vars for flat array from ufunc outer
2391
        # The ufunc outer result is a flat array of size M*N
2392
        if len(loop_vars) == 2:
4✔
2393
            # 2D case: result is indexed as i * N + j
2394
            # Get the second dimension size from target shapes
2395
            n_dim = (
4✔
2396
                target_shapes[1]
2397
                if len(target_shapes) > 1
2398
                else f"_{target_name}_shape_1"
2399
            )
2400
            src_index = f"(({loop_vars[0]}) * ({n_dim}) + ({loop_vars[1]}))"
4✔
2401
        elif len(loop_vars) == 1:
×
2402
            src_index = loop_vars[0]
×
2403
        else:
2404
            # General case - compute linear index
2405
            src_terms = []
×
2406
            stride = "1"
×
2407
            for i in range(len(loop_vars) - 1, -1, -1):
×
2408
                if stride == "1":
×
2409
                    src_terms.insert(0, loop_vars[i])
×
2410
                else:
2411
                    src_terms.insert(0, f"({loop_vars[i]} * {stride})")
×
2412
                if i > 0:
×
2413
                    dim_size = (
×
2414
                        target_shapes[i]
2415
                        if i < len(target_shapes)
2416
                        else f"_{target_name}_shape_{i}"
2417
                    )
2418
                    stride = (
×
2419
                        f"({stride} * {dim_size})" if stride != "1" else str(dim_size)
2420
                    )
2421
            src_index = " + ".join(src_terms) if src_terms else "0"
×
2422

2423
        # Target index - compute linear index (row-major order)
2424
        # For 2D array with shape (M, N): linear_index = i * N + j
2425
        target_index_parts = []
4✔
2426
        for idx in new_target_indices:
4✔
2427
            if isinstance(idx, ast.Name):
4✔
2428
                target_index_parts.append(idx.id)
4✔
2429
            else:
2430
                target_index_parts.append(self.visit(idx))
×
2431

2432
        # Convert to linear index
2433
        if len(target_index_parts) == 2:
4✔
2434
            # 2D case
2435
            n_dim = (
4✔
2436
                target_shapes[1]
2437
                if len(target_shapes) > 1
2438
                else f"_{target_name}_shape_1"
2439
            )
2440
            target_index = (
4✔
2441
                f"(({target_index_parts[0]}) * ({n_dim}) + ({target_index_parts[1]}))"
2442
            )
2443
        elif len(target_index_parts) == 1:
×
2444
            target_index = target_index_parts[0]
×
2445
        else:
2446
            # General case - compute linear index with strides
2447
            stride = "1"
×
2448
            target_index = "0"
×
2449
            for i in range(len(target_index_parts) - 1, -1, -1):
×
2450
                idx_part = target_index_parts[i]
×
2451
                if stride == "1":
×
2452
                    term = idx_part
×
2453
                else:
2454
                    term = f"(({idx_part}) * ({stride}))"
×
2455

2456
                if target_index == "0":
×
2457
                    target_index = term
×
2458
                else:
2459
                    target_index = f"({term} + {target_index})"
×
2460

2461
                if i > 0:
×
2462
                    dim_size = (
×
2463
                        target_shapes[i]
2464
                        if i < len(target_shapes)
2465
                        else f"_{target_name}_shape_{i}"
2466
                    )
2467
                    stride = (
×
2468
                        f"({stride} * {dim_size})" if stride != "1" else str(dim_size)
2469
                    )
2470

2471
        # Connect memlets
2472
        self.builder.add_memlet(
4✔
2473
            block, t_src, "void", t_task, "_in", src_index, None, debug_info
2474
        )
2475
        self.builder.add_memlet(
4✔
2476
            block, t_task, "_out", t_dst, "void", target_index, None, debug_info
2477
        )
2478

2479
        # End loops
2480
        for _ in loop_vars:
4✔
2481
            self.builder.end_for()
4✔
2482

2483
    def _contains_indirect_access(self, node):
4✔
2484
        """Check if an AST node contains any indirect array access."""
2485
        if isinstance(node, ast.Subscript):
4✔
2486
            if isinstance(node.value, ast.Name):
×
2487
                arr_name = node.value.id
×
2488
                if arr_name in self.tensor_table:
×
2489
                    return True
×
2490
        elif isinstance(node, ast.BinOp):
4✔
2491
            return self._contains_indirect_access(
4✔
2492
                node.left
2493
            ) or self._contains_indirect_access(node.right)
2494
        elif isinstance(node, ast.UnaryOp):
4✔
2495
            return self._contains_indirect_access(node.operand)
4✔
2496
        return False
4✔
2497

2498
    def _materialize_indirect_access(
4✔
2499
        self, node, debug_info=None, return_original_expr=False
2500
    ):
2501
        """Materialize an array access into a scalar variable using tasklet+memlets."""
2502
        if not self.builder:
×
2503
            expr = self.visit(node)
×
2504
            return (expr, expr) if return_original_expr else expr
×
2505

2506
        if debug_info is None:
×
2507
            debug_info = DebugInfo()
×
2508

2509
        if not isinstance(node, ast.Subscript):
×
2510
            expr = self.visit(node)
×
2511
            return (expr, expr) if return_original_expr else expr
×
2512

2513
        if not isinstance(node.value, ast.Name):
×
2514
            expr = self.visit(node)
×
2515
            return (expr, expr) if return_original_expr else expr
×
2516

2517
        arr_name = node.value.id
×
2518
        if arr_name not in self.tensor_table:
×
2519
            expr = self.visit(node)
×
2520
            return (expr, expr) if return_original_expr else expr
×
2521

2522
        dtype = Scalar(PrimitiveType.Int64)
×
2523
        if arr_name in self.container_table:
×
2524
            t = self.container_table[arr_name]
×
2525
            if isinstance(t, Pointer) and t.has_pointee_type():
×
2526
                dtype = t.pointee_type
×
2527

2528
        tmp_name = self.builder.find_new_name("_idx_")
×
2529
        self.builder.add_container(tmp_name, dtype, False)
×
2530
        self.container_table[tmp_name] = dtype
×
2531

2532
        ndim = len(self.tensor_table[arr_name].shape)
×
2533
        shapes = self.tensor_table[arr_name].shape
×
2534

2535
        if isinstance(node.slice, ast.Tuple):
×
2536
            indices = [self.visit(elt) for elt in node.slice.elts]
×
2537
        else:
2538
            indices = [self.visit(node.slice)]
×
2539

2540
        materialized_indices = []
×
2541
        for idx_str in indices:
×
2542
            if "(" in idx_str and idx_str.endswith(")"):
×
2543
                materialized_indices.append(idx_str)
×
2544
            else:
2545
                materialized_indices.append(idx_str)
×
2546

2547
        linear_index = self._compute_linear_index(
×
2548
            materialized_indices, shapes, arr_name, ndim
2549
        )
2550

2551
        block = self.builder.add_block(debug_info)
×
2552
        t_src = self.builder.add_access(block, arr_name, debug_info)
×
2553
        t_dst = self.builder.add_access(block, tmp_name, debug_info)
×
2554
        t_task = self.builder.add_tasklet(
×
2555
            block, TaskletCode.assign, ["_in"], ["_out"], debug_info
2556
        )
2557

2558
        self.builder.add_memlet(
×
2559
            block, t_src, "void", t_task, "_in", linear_index, None, debug_info
2560
        )
2561
        self.builder.add_memlet(
×
2562
            block, t_task, "_out", t_dst, "void", "", None, debug_info
2563
        )
2564

2565
        if return_original_expr:
×
2566
            original_expr = f"{arr_name}({linear_index})"
×
2567
            return (tmp_name, original_expr)
×
2568

2569
        return tmp_name
×
2570

2571
    def _get_unique_id(self):
4✔
2572
        self._unique_counter_ref[0] += 1
4✔
2573
        return self._unique_counter_ref[0]
4✔
2574

2575
    def _get_memlet_type_for_access(self, expr_str, subset):
4✔
2576
        """Get the Tensor type for an indexed array access expression.
2577

2578
        When accessing an array like "arr(i,j)" with a multi-dimensional subset,
2579
        we need to pass the Tensor type to add_memlet for correct type inference.
2580
        If the expression is a simple scalar variable or constant, returns None.
2581
        """
2582
        if not subset:
4✔
2583
            return None
4✔
2584

2585
        # Check if expr_str is an indexed array access like "arr(i,j)"
2586
        if "(" in expr_str and expr_str.endswith(")"):
4✔
2587
            name = expr_str.split("(")[0]
4✔
2588
            if name in self.tensor_table:
4✔
2589
                return self.tensor_table[name]
4✔
2590

2591
        # Check if expr_str is a simple array name with a non-empty subset from _add_read
2592
        if expr_str in self.tensor_table:
×
2593
            return self.tensor_table[expr_str]
×
2594

2595
        return None
×
2596

2597
    def _element_type(self, name):
4✔
2598
        if name in self.container_table:
4✔
2599
            return element_type_from_sdfg_type(self.container_table[name])
4✔
2600
        else:  # Constant
2601
            if self._is_int(name):
4✔
2602
                return Scalar(PrimitiveType.Int64)
4✔
2603
            else:
2604
                return Scalar(PrimitiveType.Double)
4✔
2605

2606
    def _is_int(self, operand):
4✔
2607
        try:
4✔
2608
            if operand.lstrip("-").isdigit():
4✔
2609
                return True
4✔
2610
        except ValueError:
×
2611
            pass
×
2612

2613
        name = operand
4✔
2614
        if "(" in operand and operand.endswith(")"):
4✔
2615
            name = operand.split("(")[0]
4✔
2616

2617
        if name in self.container_table:
4✔
2618
            t = self.container_table[name]
4✔
2619

2620
            def is_int_ptype(pt):
4✔
2621
                return pt in [
4✔
2622
                    PrimitiveType.Int64,
2623
                    PrimitiveType.Int32,
2624
                    PrimitiveType.Int8,
2625
                    PrimitiveType.Int16,
2626
                    PrimitiveType.UInt64,
2627
                    PrimitiveType.UInt32,
2628
                    PrimitiveType.UInt8,
2629
                    PrimitiveType.UInt16,
2630
                ]
2631

2632
            if isinstance(t, Scalar):
4✔
2633
                return is_int_ptype(t.primitive_type)
4✔
2634

2635
            if type(t).__name__ == "Array" and hasattr(t, "element_type"):
4✔
2636
                et = t.element_type
×
2637
                if callable(et):
×
2638
                    et = et()
×
2639
                if isinstance(et, Scalar):
×
2640
                    return is_int_ptype(et.primitive_type)
×
2641

2642
            if type(t).__name__ == "Pointer":
4✔
2643
                if hasattr(t, "pointee_type"):
4✔
2644
                    et = t.pointee_type
4✔
2645
                    if callable(et):
4✔
2646
                        et = et()
×
2647
                    if isinstance(et, Scalar):
4✔
2648
                        return is_int_ptype(et.primitive_type)
4✔
2649
                if hasattr(t, "element_type"):
×
2650
                    et = t.element_type
×
2651
                    if callable(et):
×
2652
                        et = et()
×
2653
                    if isinstance(et, Scalar):
×
2654
                        return is_int_ptype(et.primitive_type)
×
2655

2656
        return False
4✔
2657

2658
    def _add_read(self, block, expr_str, debug_info=None):
4✔
2659
        try:
4✔
2660
            if (block, expr_str) in self._access_cache:
4✔
2661
                return self._access_cache[(block, expr_str)]
×
2662
        except TypeError:
×
2663
            pass
×
2664

2665
        if debug_info is None:
4✔
2666
            debug_info = DebugInfo()
4✔
2667

2668
        if "(" in expr_str and expr_str.endswith(")"):
4✔
2669
            name = expr_str.split("(")[0]
4✔
2670
            subset = expr_str[expr_str.find("(") + 1 : -1]
4✔
2671
            access = self.builder.add_access(block, name, debug_info)
4✔
2672
            try:
4✔
2673
                self._access_cache[(block, expr_str)] = (access, subset)
4✔
2674
            except TypeError:
×
2675
                pass
×
2676
            return access, subset
4✔
2677

2678
        if self.builder.exists(expr_str):
4✔
2679
            access = self.builder.add_access(block, expr_str, debug_info)
4✔
2680
            subset = ""
4✔
2681
            if expr_str in self.container_table:
4✔
2682
                sym_type = self.container_table[expr_str]
4✔
2683
                if isinstance(sym_type, Pointer):
4✔
2684
                    if expr_str in self.tensor_table:
4✔
2685
                        ndim = len(self.tensor_table[expr_str].shape)
4✔
2686
                        if ndim == 0:
4✔
2687
                            subset = "0"
×
2688
                    else:
2689
                        subset = "0"
×
2690
            try:
4✔
2691
                self._access_cache[(block, expr_str)] = (access, subset)
4✔
2692
            except TypeError:
×
2693
                pass
×
2694
            return access, subset
4✔
2695

2696
        dtype = Scalar(PrimitiveType.Double)
4✔
2697
        if self._is_int(expr_str):
4✔
2698
            dtype = Scalar(PrimitiveType.Int64)
4✔
2699
        elif expr_str == "true" or expr_str == "false":
4✔
2700
            dtype = Scalar(PrimitiveType.Bool)
×
2701

2702
        const_node = self.builder.add_constant(block, expr_str, dtype, debug_info)
4✔
2703
        try:
4✔
2704
            self._access_cache[(block, expr_str)] = (const_node, "")
4✔
2705
        except TypeError:
×
2706
            pass
×
2707
        return const_node, ""
4✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc