• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

daisytuner / docc / 22023884668

14 Feb 2026 08:36PM UTC coverage: 64.903% (-1.4%) from 66.315%
22023884668

Pull #525

github

web-flow
Merge 1d47f8bf2 into 9d01cacd5
Pull Request #525: Step 3 (Native Tensor Support): Refactor Python Frontend

2522 of 3435 new or added lines in 32 files covered. (73.42%)

320 existing lines in 15 files now uncovered.

23204 of 35752 relevant lines covered (64.9%)

370.03 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

78.82
/python/docc/python/ast_parser.py
1
import ast
4✔
2
import copy
4✔
3
import inspect
4✔
4
import textwrap
4✔
5
from docc.sdfg import (
4✔
6
    Scalar,
7
    PrimitiveType,
8
    Pointer,
9
    TaskletCode,
10
    DebugInfo,
11
    Structure,
12
    CMathFunction,
13
)
14
from docc.python.ast_utils import (
4✔
15
    SliceRewriter,
16
    get_debug_info,
17
    contains_ufunc_outer,
18
    normalize_negative_index,
19
)
20
from docc.python.types import (
4✔
21
    sdfg_type_from_type,
22
    element_type_from_sdfg_type,
23
)
24
from docc.python.functions.numpy import NumPyHandler
4✔
25
from docc.python.functions.math import MathHandler
4✔
26
from docc.python.functions.python import PythonHandler
4✔
27
from docc.python.functions.scipy import SciPyHandler
4✔
28

29

30
class ASTParser(ast.NodeVisitor):
4✔
31
    def __init__(
4✔
32
        self,
33
        builder,
34
        array_info=None,
35
        symbol_table=None,
36
        filename="",
37
        function_name="",
38
        infer_return_type=False,
39
        globals_dict=None,
40
        unique_counter_ref=None,
41
        structure_member_info=None,
42
    ):
43
        self.builder = builder
4✔
44
        self.array_info = array_info if array_info is not None else {}
4✔
45
        self.symbol_table = symbol_table if symbol_table is not None else {}
4✔
46
        self.filename = filename
4✔
47
        self.function_name = function_name
4✔
48
        self.infer_return_type = infer_return_type
4✔
49
        self.globals_dict = globals_dict if globals_dict is not None else {}
4✔
50
        self._unique_counter_ref = (
4✔
51
            unique_counter_ref if unique_counter_ref is not None else [0]
52
        )
53
        self._access_cache = {}
4✔
54
        self.structure_member_info = (
4✔
55
            structure_member_info if structure_member_info is not None else {}
56
        )
57
        self.captured_return_shapes = {}  # Map param name to shape string list
4✔
58

59
        # Initialize handlers - they receive 'self' to access expression visitor methods
60
        self.numpy_visitor = NumPyHandler(self)
4✔
61
        self.math_handler = MathHandler(self)
4✔
62
        self.python_handler = PythonHandler(self)
4✔
63
        self.scipy_handler = SciPyHandler(self)
4✔
64

65
    def visit_Constant(self, node):
4✔
66
        if isinstance(node.value, bool):
4✔
67
            return "true" if node.value else "false"
4✔
68
        return str(node.value)
4✔
69

70
    def visit_Name(self, node):
4✔
71
        name = node.id
4✔
72
        if name not in self.symbol_table and self.globals_dict is not None:
4✔
73
            if name in self.globals_dict:
4✔
74
                val = self.globals_dict[name]
4✔
75
                if isinstance(val, (int, float)):
4✔
76
                    return str(val)
4✔
77
        return name
4✔
78

79
    def visit_Add(self, node):
4✔
80
        return "+"
4✔
81

82
    def visit_Sub(self, node):
4✔
83
        return "-"
4✔
84

85
    def visit_Mult(self, node):
4✔
86
        return "*"
4✔
87

88
    def visit_Div(self, node):
4✔
89
        return "/"
4✔
90

91
    def visit_FloorDiv(self, node):
4✔
92
        return "//"
4✔
93

94
    def visit_Mod(self, node):
4✔
95
        return "%"
4✔
96

97
    def visit_Pow(self, node):
4✔
98
        return "**"
4✔
99

100
    def visit_Eq(self, node):
4✔
101
        return "=="
4✔
102

103
    def visit_NotEq(self, node):
4✔
NEW
104
        return "!="
×
105

106
    def visit_Lt(self, node):
4✔
107
        return "<"
4✔
108

109
    def visit_LtE(self, node):
4✔
NEW
110
        return "<="
×
111

112
    def visit_Gt(self, node):
4✔
113
        return ">"
4✔
114

115
    def visit_GtE(self, node):
4✔
116
        return ">="
4✔
117

118
    def visit_And(self, node):
4✔
119
        return "&"
4✔
120

121
    def visit_Or(self, node):
4✔
122
        return "|"
4✔
123

124
    def visit_BitAnd(self, node):
4✔
125
        return "&"
4✔
126

127
    def visit_BitOr(self, node):
4✔
128
        return "|"
4✔
129

130
    def visit_BitXor(self, node):
4✔
131
        return "^"
4✔
132

133
    def visit_LShift(self, node):
4✔
NEW
134
        return "<<"
×
135

136
    def visit_RShift(self, node):
4✔
NEW
137
        return ">>"
×
138

139
    def visit_Not(self, node):
4✔
140
        return "!"
4✔
141

142
    def visit_USub(self, node):
4✔
143
        return "-"
4✔
144

145
    def visit_UAdd(self, node):
4✔
NEW
146
        return "+"
×
147

148
    def visit_Invert(self, node):
4✔
149
        return "~"
4✔
150

151
    def visit_BoolOp(self, node):
4✔
152
        op = self.visit(node.op)
4✔
153
        values = [f"({self.visit(v)} != 0)" for v in node.values]
4✔
154
        expr_str = f"{f' {op} '.join(values)}"
4✔
155

156
        tmp_name = self.builder.find_new_name()
4✔
157
        dtype = Scalar(PrimitiveType.Bool)
4✔
158
        self.builder.add_container(tmp_name, dtype, False)
4✔
159

160
        self.builder.begin_if(expr_str)
4✔
161
        self._add_assign_constant(tmp_name, "true", dtype)
4✔
162
        self.builder.begin_else()
4✔
163
        self._add_assign_constant(tmp_name, "false", dtype)
4✔
164
        self.builder.end_if()
4✔
165

166
        self.symbol_table[tmp_name] = dtype
4✔
167
        return tmp_name
4✔
168

169
    def visit_UnaryOp(self, node):
4✔
170
        if (
4✔
171
            isinstance(node.op, ast.USub)
172
            and isinstance(node.operand, ast.Constant)
173
            and isinstance(node.operand.value, (int, float))
174
        ):
175
            return f"-{node.operand.value}"
4✔
176

177
        op = self.visit(node.op)
4✔
178
        operand = self.visit(node.operand)
4✔
179

180
        if operand in self.array_info and op == "-":
4✔
181
            return self.numpy_visitor.handle_array_negate(operand)
4✔
182

183
        tmp_name = self.builder.find_new_name()
4✔
184
        dtype = Scalar(PrimitiveType.Double)
4✔
185
        if operand in self.symbol_table:
4✔
186
            dtype = self.symbol_table[operand]
4✔
187
            if isinstance(dtype, Pointer) and dtype.has_pointee_type():
4✔
NEW
188
                dtype = dtype.pointee_type
×
NEW
189
        elif self._is_int(operand):
×
NEW
190
            dtype = Scalar(PrimitiveType.Int64)
×
NEW
191
        elif isinstance(node.op, ast.Not):
×
NEW
192
            dtype = Scalar(PrimitiveType.Bool)
×
193

194
        self.builder.add_container(tmp_name, dtype, False)
4✔
195
        self.symbol_table[tmp_name] = dtype
4✔
196

197
        block = self.builder.add_block()
4✔
198
        t_src, src_sub = self._add_read(block, operand)
4✔
199
        t_dst = self.builder.add_access(block, tmp_name)
4✔
200

201
        if isinstance(node.op, ast.Not):
4✔
202
            t_const = self.builder.add_constant(
4✔
203
                block, "true", Scalar(PrimitiveType.Bool)
204
            )
205
            t_task = self.builder.add_tasklet(
4✔
206
                block, TaskletCode.int_xor, ["_in1", "_in2"], ["_out"]
207
            )
208
            self.builder.add_memlet(block, t_src, "void", t_task, "_in1", src_sub)
4✔
209
            self.builder.add_memlet(block, t_const, "void", t_task, "_in2", "")
4✔
210
            self.builder.add_memlet(block, t_task, "_out", t_dst, "void", "")
4✔
211

212
        elif op == "-":
4✔
213
            if dtype.primitive_type == PrimitiveType.Int64:
4✔
214
                t_const = self.builder.add_constant(block, "0", dtype)
4✔
215
                t_task = self.builder.add_tasklet(
4✔
216
                    block, TaskletCode.int_sub, ["_in1", "_in2"], ["_out"]
217
                )
218
                self.builder.add_memlet(block, t_const, "void", t_task, "_in1", "")
4✔
219
                self.builder.add_memlet(block, t_src, "void", t_task, "_in2", src_sub)
4✔
220
                self.builder.add_memlet(block, t_task, "_out", t_dst, "void", "")
4✔
221
            else:
222
                t_task = self.builder.add_tasklet(
4✔
223
                    block, TaskletCode.fp_neg, ["_in"], ["_out"]
224
                )
225
                self.builder.add_memlet(block, t_src, "void", t_task, "_in", src_sub)
4✔
226
                self.builder.add_memlet(block, t_task, "_out", t_dst, "void", "")
4✔
227

228
        elif op == "~":
4✔
229
            t_const = self.builder.add_constant(
4✔
230
                block, "-1", Scalar(PrimitiveType.Int64)
231
            )
232
            t_task = self.builder.add_tasklet(
4✔
233
                block, TaskletCode.int_xor, ["_in1", "_in2"], ["_out"]
234
            )
235
            self.builder.add_memlet(block, t_src, "void", t_task, "_in1", src_sub)
4✔
236
            self.builder.add_memlet(block, t_const, "void", t_task, "_in2", "")
4✔
237
            self.builder.add_memlet(block, t_task, "_out", t_dst, "void", "")
4✔
238

239
        else:
NEW
240
            t_task = self.builder.add_tasklet(
×
241
                block, TaskletCode.assign, ["_in"], ["_out"]
242
            )
NEW
243
            self.builder.add_memlet(block, t_src, "void", t_task, "_in", src_sub)
×
NEW
244
            self.builder.add_memlet(block, t_task, "_out", t_dst, "void", "")
×
245

246
        return tmp_name
4✔
247

248
    def visit_BinOp(self, node):
4✔
249
        if isinstance(node.op, ast.MatMult):
4✔
250
            return self.numpy_visitor.handle_numpy_matmul_op(node.left, node.right)
4✔
251

252
        left = self.visit(node.left)
4✔
253
        op = self.visit(node.op)
4✔
254
        right = self.visit(node.right)
4✔
255

256
        left_is_array = left in self.array_info
4✔
257
        right_is_array = right in self.array_info
4✔
258

259
        if left_is_array or right_is_array:
4✔
260
            op_map = {"+": "add", "-": "sub", "*": "mul", "/": "div", "**": "pow"}
4✔
261
            if op in op_map:
4✔
262
                return self.numpy_visitor.handle_array_binary_op(
4✔
263
                    op_map[op], left, right
264
                )
265
            else:
NEW
266
                raise NotImplementedError(f"Array operation {op} not supported")
×
267

268
        tmp_name = self.builder.find_new_name()
4✔
269

270
        dtype = Scalar(PrimitiveType.Double)
4✔
271

272
        left_is_int = self._is_int(left)
4✔
273
        right_is_int = self._is_int(right)
4✔
274

275
        if left_is_int and right_is_int and op not in ["/", "**"]:
4✔
276
            dtype = Scalar(PrimitiveType.Int64)
4✔
277

278
        if not self.builder.exists(tmp_name):
4✔
279
            self.builder.add_container(tmp_name, dtype, False)
4✔
280
            self.symbol_table[tmp_name] = dtype
4✔
281

282
        real_left = left
4✔
283
        real_right = right
4✔
284

285
        if dtype.primitive_type == PrimitiveType.Double:
4✔
286
            if left_is_int:
4✔
287
                left_cast = self.builder.find_new_name()
4✔
288
                self.builder.add_container(
4✔
289
                    left_cast, Scalar(PrimitiveType.Double), False
290
                )
291
                self.symbol_table[left_cast] = Scalar(PrimitiveType.Double)
4✔
292

293
                c_block = self.builder.add_block()
4✔
294
                t_src, src_sub = self._add_read(c_block, left)
4✔
295
                t_dst = self.builder.add_access(c_block, left_cast)
4✔
296
                t_task = self.builder.add_tasklet(
4✔
297
                    c_block, TaskletCode.assign, ["_in"], ["_out"]
298
                )
299
                self.builder.add_memlet(c_block, t_src, "void", t_task, "_in", src_sub)
4✔
300
                self.builder.add_memlet(c_block, t_task, "_out", t_dst, "void", "")
4✔
301

302
                real_left = left_cast
4✔
303

304
            if right_is_int:
4✔
305
                right_cast = self.builder.find_new_name()
4✔
306
                self.builder.add_container(
4✔
307
                    right_cast, Scalar(PrimitiveType.Double), False
308
                )
309
                self.symbol_table[right_cast] = Scalar(PrimitiveType.Double)
4✔
310

311
                c_block = self.builder.add_block()
4✔
312
                t_src, src_sub = self._add_read(c_block, right)
4✔
313
                t_dst = self.builder.add_access(c_block, right_cast)
4✔
314
                t_task = self.builder.add_tasklet(
4✔
315
                    c_block, TaskletCode.assign, ["_in"], ["_out"]
316
                )
317
                self.builder.add_memlet(c_block, t_src, "void", t_task, "_in", src_sub)
4✔
318
                self.builder.add_memlet(c_block, t_task, "_out", t_dst, "void", "")
4✔
319

320
                real_right = right_cast
4✔
321

322
        if op == "**":
4✔
323
            block = self.builder.add_block()
4✔
324
            t_left, left_sub = self._add_read(block, real_left)
4✔
325
            t_right, right_sub = self._add_read(block, real_right)
4✔
326
            t_out = self.builder.add_access(block, tmp_name)
4✔
327

328
            t_task = self.builder.add_cmath(block, CMathFunction.pow)
4✔
329
            self.builder.add_memlet(block, t_left, "void", t_task, "_in1", left_sub)
4✔
330
            self.builder.add_memlet(block, t_right, "void", t_task, "_in2", right_sub)
4✔
331
            self.builder.add_memlet(block, t_task, "_out", t_out, "void", "")
4✔
332

333
            return tmp_name
4✔
334
        elif op == "%":
4✔
335
            block = self.builder.add_block()
4✔
336
            t_left, left_sub = self._add_read(block, real_left)
4✔
337
            t_right, right_sub = self._add_read(block, real_right)
4✔
338
            t_out = self.builder.add_access(block, tmp_name)
4✔
339

340
            if dtype.primitive_type == PrimitiveType.Int64:
4✔
341
                t_rem1 = self.builder.add_tasklet(
4✔
342
                    block, TaskletCode.int_srem, ["_in1", "_in2"], ["_out"]
343
                )
344
                self.builder.add_memlet(block, t_left, "void", t_rem1, "_in1", left_sub)
4✔
345
                self.builder.add_memlet(
4✔
346
                    block, t_right, "void", t_rem1, "_in2", right_sub
347
                )
348

349
                rem1_name = self.builder.find_new_name()
4✔
350
                self.builder.add_container(rem1_name, dtype, False)
4✔
351
                t_rem1_out = self.builder.add_access(block, rem1_name)
4✔
352
                self.builder.add_memlet(block, t_rem1, "_out", t_rem1_out, "void", "")
4✔
353

354
                t_add = self.builder.add_tasklet(
4✔
355
                    block, TaskletCode.int_add, ["_in1", "_in2"], ["_out"]
356
                )
357
                self.builder.add_memlet(block, t_rem1_out, "void", t_add, "_in1", "")
4✔
358
                self.builder.add_memlet(
4✔
359
                    block, t_right, "void", t_add, "_in2", right_sub
360
                )
361

362
                add_name = self.builder.find_new_name()
4✔
363
                self.builder.add_container(add_name, dtype, False)
4✔
364
                t_add_out = self.builder.add_access(block, add_name)
4✔
365
                self.builder.add_memlet(block, t_add, "_out", t_add_out, "void", "")
4✔
366

367
                t_rem2 = self.builder.add_tasklet(
4✔
368
                    block, TaskletCode.int_srem, ["_in1", "_in2"], ["_out"]
369
                )
370
                self.builder.add_memlet(block, t_add_out, "void", t_rem2, "_in1", "")
4✔
371
                self.builder.add_memlet(
4✔
372
                    block, t_right, "void", t_rem2, "_in2", right_sub
373
                )
374
                self.builder.add_memlet(block, t_rem2, "_out", t_out, "void", "")
4✔
375

376
                return tmp_name
4✔
377
            else:
378
                t_rem1 = self.builder.add_tasklet(
4✔
379
                    block, TaskletCode.fp_rem, ["_in1", "_in2"], ["_out"]
380
                )
381
                self.builder.add_memlet(block, t_left, "void", t_rem1, "_in1", left_sub)
4✔
382
                self.builder.add_memlet(
4✔
383
                    block, t_right, "void", t_rem1, "_in2", right_sub
384
                )
385

386
                rem1_name = self.builder.find_new_name()
4✔
387
                self.builder.add_container(rem1_name, dtype, False)
4✔
388
                t_rem1_out = self.builder.add_access(block, rem1_name)
4✔
389
                self.builder.add_memlet(block, t_rem1, "_out", t_rem1_out, "void", "")
4✔
390

391
                t_add = self.builder.add_tasklet(
4✔
392
                    block, TaskletCode.fp_add, ["_in1", "_in2"], ["_out"]
393
                )
394
                self.builder.add_memlet(block, t_rem1_out, "void", t_add, "_in1", "")
4✔
395
                self.builder.add_memlet(
4✔
396
                    block, t_right, "void", t_add, "_in2", right_sub
397
                )
398

399
                add_name = self.builder.find_new_name()
4✔
400
                self.builder.add_container(add_name, dtype, False)
4✔
401
                t_add_out = self.builder.add_access(block, add_name)
4✔
402
                self.builder.add_memlet(block, t_add, "_out", t_add_out, "void", "")
4✔
403

404
                t_rem2 = self.builder.add_tasklet(
4✔
405
                    block, TaskletCode.fp_rem, ["_in1", "_in2"], ["_out"]
406
                )
407
                self.builder.add_memlet(block, t_add_out, "void", t_rem2, "_in1", "")
4✔
408
                self.builder.add_memlet(
4✔
409
                    block, t_right, "void", t_rem2, "_in2", right_sub
410
                )
411
                self.builder.add_memlet(block, t_rem2, "_out", t_out, "void", "")
4✔
412

413
                return tmp_name
4✔
414

415
        tasklet_code = None
4✔
416
        if dtype.primitive_type == PrimitiveType.Int64:
4✔
417
            if op == "+":
4✔
418
                tasklet_code = TaskletCode.int_add
4✔
419
            elif op == "-":
4✔
420
                tasklet_code = TaskletCode.int_sub
4✔
421
            elif op == "*":
4✔
422
                tasklet_code = TaskletCode.int_mul
4✔
423
            elif op == "/":
4✔
NEW
424
                tasklet_code = TaskletCode.int_sdiv
×
425
            elif op == "//":
4✔
426
                tasklet_code = TaskletCode.int_sdiv
4✔
427
            elif op == "&":
4✔
428
                tasklet_code = TaskletCode.int_and
4✔
429
            elif op == "|":
4✔
430
                tasklet_code = TaskletCode.int_or
4✔
431
            elif op == "^":
4✔
432
                tasklet_code = TaskletCode.int_xor
4✔
NEW
433
            elif op == "<<":
×
NEW
434
                tasklet_code = TaskletCode.int_shl
×
NEW
435
            elif op == ">>":
×
NEW
436
                tasklet_code = TaskletCode.int_lshr
×
437
        else:
438
            if op == "+":
4✔
439
                tasklet_code = TaskletCode.fp_add
4✔
440
            elif op == "-":
4✔
441
                tasklet_code = TaskletCode.fp_sub
4✔
442
            elif op == "*":
4✔
443
                tasklet_code = TaskletCode.fp_mul
4✔
444
            elif op == "/":
4✔
445
                tasklet_code = TaskletCode.fp_div
4✔
NEW
446
            elif op == "//":
×
NEW
447
                tasklet_code = TaskletCode.fp_div
×
448
            else:
NEW
449
                raise NotImplementedError(f"Operation {op} not supported for floats")
×
450

451
        block = self.builder.add_block()
4✔
452
        t_left, left_sub = self._add_read(block, real_left)
4✔
453
        t_right, right_sub = self._add_read(block, real_right)
4✔
454
        t_out = self.builder.add_access(block, tmp_name)
4✔
455

456
        t_task = self.builder.add_tasklet(
4✔
457
            block, tasklet_code, ["_in1", "_in2"], ["_out"]
458
        )
459

460
        self.builder.add_memlet(block, t_left, "void", t_task, "_in1", left_sub)
4✔
461
        self.builder.add_memlet(block, t_right, "void", t_task, "_in2", right_sub)
4✔
462
        self.builder.add_memlet(block, t_task, "_out", t_out, "void", "")
4✔
463

464
        return tmp_name
4✔
465

466
    def visit_Attribute(self, node):
4✔
467
        if node.attr == "shape":
4✔
468
            if isinstance(node.value, ast.Name) and node.value.id in self.array_info:
4✔
469
                return f"_shape_proxy_{node.value.id}"
4✔
470

471
        if node.attr == "T":
4✔
NEW
472
            value_name = None
×
NEW
473
            if isinstance(node.value, ast.Name):
×
NEW
474
                value_name = node.value.id
×
NEW
475
            elif isinstance(node.value, ast.Subscript):
×
NEW
476
                value_name = self.visit(node.value)
×
477

NEW
478
            if value_name and value_name in self.array_info:
×
NEW
479
                return self.numpy_visitor.handle_transpose_expr(node)
×
480

481
        if isinstance(node.value, ast.Name) and node.value.id == "math":
4✔
482
            val = ""
4✔
483
            if node.attr == "pi":
4✔
484
                val = "M_PI"
4✔
485
            elif node.attr == "e":
4✔
486
                val = "M_E"
4✔
487

488
            if val:
4✔
489
                tmp_name = self.builder.find_new_name()
4✔
490
                dtype = Scalar(PrimitiveType.Double)
4✔
491
                self.builder.add_container(tmp_name, dtype, False)
4✔
492
                self.symbol_table[tmp_name] = dtype
4✔
493
                self._add_assign_constant(tmp_name, val, dtype)
4✔
494
                return tmp_name
4✔
495

496
        if isinstance(node.value, ast.Name):
4✔
497
            obj_name = node.value.id
4✔
498
            attr_name = node.attr
4✔
499

500
            if obj_name in self.symbol_table:
4✔
501
                obj_type = self.symbol_table[obj_name]
4✔
502
                if isinstance(obj_type, Pointer) and obj_type.has_pointee_type():
4✔
503
                    pointee_type = obj_type.pointee_type
4✔
504
                    if isinstance(pointee_type, Structure):
4✔
505
                        struct_name = pointee_type.name
4✔
506

507
                        if (
4✔
508
                            struct_name in self.structure_member_info
509
                            and attr_name in self.structure_member_info[struct_name]
510
                        ):
511
                            member_index, member_type = self.structure_member_info[
4✔
512
                                struct_name
513
                            ][attr_name]
514
                        else:
NEW
515
                            raise RuntimeError(
×
516
                                f"Member '{attr_name}' not found in structure '{struct_name}'. "
517
                                f"Available members: {list(self.structure_member_info.get(struct_name, {}).keys())}"
518
                            )
519

520
                        tmp_name = self.builder.find_new_name()
4✔
521

522
                        self.builder.add_container(tmp_name, member_type, False)
4✔
523
                        self.symbol_table[tmp_name] = member_type
4✔
524

525
                        block = self.builder.add_block()
4✔
526
                        obj_access = self.builder.add_access(block, obj_name)
4✔
527
                        tmp_access = self.builder.add_access(block, tmp_name)
4✔
528

529
                        tasklet = self.builder.add_tasklet(
4✔
530
                            block, TaskletCode.assign, ["_in"], ["_out"]
531
                        )
532

533
                        subset = "0," + str(member_index)
4✔
534
                        self.builder.add_memlet(
4✔
535
                            block, obj_access, "", tasklet, "_in", subset
536
                        )
537
                        self.builder.add_memlet(block, tasklet, "_out", tmp_access, "")
4✔
538

539
                        return tmp_name
4✔
540

NEW
541
        raise NotImplementedError(f"Attribute access {node.attr} not supported")
×
542

543
    def visit_Compare(self, node):
4✔
544
        left = self.visit(node.left)
4✔
545
        if len(node.ops) > 1:
4✔
NEW
546
            raise NotImplementedError("Chained comparisons not supported yet")
×
547

548
        op = self.visit(node.ops[0])
4✔
549
        right = self.visit(node.comparators[0])
4✔
550

551
        left_is_array = left in self.array_info
4✔
552
        right_is_array = right in self.array_info
4✔
553

554
        if left_is_array or right_is_array:
4✔
555
            return self.numpy_visitor.handle_array_compare(
4✔
556
                left, op, right, left_is_array, right_is_array
557
            )
558

559
        expr_str = f"{left} {op} {right}"
4✔
560

561
        tmp_name = self.builder.find_new_name()
4✔
562
        dtype = Scalar(PrimitiveType.Bool)
4✔
563
        self.builder.add_container(tmp_name, dtype, False)
4✔
564

565
        self.builder.begin_if(expr_str)
4✔
566
        self.builder.add_transition(tmp_name, "true")
4✔
567
        self.builder.begin_else()
4✔
568
        self.builder.add_transition(tmp_name, "false")
4✔
569
        self.builder.end_if()
4✔
570

571
        self.symbol_table[tmp_name] = dtype
4✔
572
        return tmp_name
4✔
573

574
    def visit_Subscript(self, node):
4✔
575
        value_str = self.visit(node.value)
4✔
576

577
        if value_str.startswith("_shape_proxy_"):
4✔
578
            array_name = value_str[len("_shape_proxy_") :]
4✔
579
            if isinstance(node.slice, ast.Constant):
4✔
580
                idx = node.slice.value
4✔
NEW
581
            elif isinstance(node.slice, ast.Index):
×
NEW
582
                idx = node.slice.value.value
×
583
            else:
NEW
584
                try:
×
NEW
585
                    idx = int(self.visit(node.slice))
×
NEW
586
                except:
×
NEW
587
                    raise NotImplementedError(
×
588
                        "Dynamic shape indexing not fully supported yet"
589
                    )
590

591
            if (
4✔
592
                array_name in self.array_info
593
                and "shapes" in self.array_info[array_name]
594
            ):
595
                return self.array_info[array_name]["shapes"][idx]
4✔
596

NEW
597
            return f"_{array_name}_shape_{idx}"
×
598

599
        if value_str in self.array_info:
4✔
600
            ndim = self.array_info[value_str]["ndim"]
4✔
601
            shapes = self.array_info[value_str].get("shapes", [])
4✔
602

603
            if isinstance(node.slice, ast.Tuple):
4✔
604
                indices_nodes = node.slice.elts
4✔
605
            else:
606
                indices_nodes = [node.slice]
4✔
607

608
            all_full_slices = True
4✔
609
            for idx in indices_nodes:
4✔
610
                if isinstance(idx, ast.Slice):
4✔
611
                    if idx.lower is not None or idx.upper is not None:
4✔
612
                        all_full_slices = False
4✔
613
                        break
4✔
614
                else:
615
                    all_full_slices = False
4✔
616
                    break
4✔
617

618
            if all_full_slices:
4✔
619
                return value_str
4✔
620

621
            has_slices = any(isinstance(idx, ast.Slice) for idx in indices_nodes)
4✔
622
            if has_slices:
4✔
623
                return self._handle_expression_slicing(
4✔
624
                    node, value_str, indices_nodes, shapes, ndim
625
                )
626

627
            if len(indices_nodes) == 1 and self._is_array_index(indices_nodes[0]):
4✔
628
                if self.builder:
4✔
629
                    return self._handle_gather(value_str, indices_nodes[0])
4✔
630

631
            if isinstance(node.slice, ast.Tuple):
4✔
632
                indices = [self.visit(elt) for elt in node.slice.elts]
4✔
633
            else:
634
                indices = [self.visit(node.slice)]
4✔
635

636
            if len(indices) != ndim:
4✔
NEW
637
                raise ValueError(
×
638
                    f"Array {value_str} has {ndim} dimensions, but accessed with {len(indices)} indices"
639
                )
640

641
            normalized_indices = []
4✔
642
            for i, idx_str in enumerate(indices):
4✔
643
                shape_val = shapes[i] if i < len(shapes) else f"_{value_str}_shape_{i}"
4✔
644
                if isinstance(idx_str, str) and (
4✔
645
                    idx_str.startswith("-") or idx_str.startswith("(-")
646
                ):
NEW
647
                    normalized_indices.append(f"({shape_val} + {idx_str})")
×
648
                else:
649
                    normalized_indices.append(idx_str)
4✔
650

651
            linear_index = ""
4✔
652
            for i in range(ndim):
4✔
653
                term = normalized_indices[i]
4✔
654
                for j in range(i + 1, ndim):
4✔
655
                    shape_val = shapes[j] if j < len(shapes) else None
4✔
656
                    shape_sym = (
4✔
657
                        shape_val
658
                        if shape_val is not None
659
                        else f"_{value_str}_shape_{j}"
660
                    )
661
                    term = f"(({term}) * {shape_sym})"
4✔
662

663
                if i == 0:
4✔
664
                    linear_index = term
4✔
665
                else:
666
                    linear_index = f"({linear_index} + {term})"
4✔
667

668
            access_str = f"{value_str}({linear_index})"
4✔
669

670
            if self.builder and isinstance(node.ctx, ast.Load):
4✔
671
                dtype = Scalar(PrimitiveType.Double)
4✔
672
                if value_str in self.symbol_table:
4✔
673
                    t = self.symbol_table[value_str]
4✔
674
                    if type(t).__name__ == "Array" and hasattr(t, "element_type"):
4✔
NEW
675
                        et = t.element_type
×
NEW
676
                        if callable(et):
×
NEW
677
                            et = et()
×
NEW
678
                        dtype = et
×
679
                    elif type(t).__name__ == "Pointer" and hasattr(t, "pointee_type"):
4✔
680
                        et = t.pointee_type
4✔
681
                        if callable(et):
4✔
NEW
682
                            et = et()
×
683
                        dtype = et
4✔
684

685
                tmp_name = self.builder.find_new_name()
4✔
686
                self.builder.add_container(tmp_name, dtype, False)
4✔
687

688
                block = self.builder.add_block()
4✔
689
                t_src = self.builder.add_access(block, value_str)
4✔
690
                t_dst = self.builder.add_access(block, tmp_name)
4✔
691
                t_task = self.builder.add_tasklet(
4✔
692
                    block, TaskletCode.assign, ["_in"], ["_out"]
693
                )
694

695
                self.builder.add_memlet(
4✔
696
                    block, t_src, "void", t_task, "_in", linear_index
697
                )
698
                self.builder.add_memlet(block, t_task, "_out", t_dst, "void", "")
4✔
699

700
                self.symbol_table[tmp_name] = dtype
4✔
701
                return tmp_name
4✔
702

703
            return access_str
4✔
704

NEW
705
        slice_val = self.visit(node.slice)
×
NEW
706
        access_str = f"{value_str}({slice_val})"
×
707

NEW
708
        if (
×
709
            self.builder
710
            and isinstance(node.ctx, ast.Load)
711
            and value_str in self.array_info
712
        ):
NEW
713
            tmp_name = self.builder.find_new_name()
×
NEW
714
            self.builder.add_container(tmp_name, Scalar(PrimitiveType.Double), False)
×
NEW
715
            self.builder.add_assignment(tmp_name, access_str)
×
NEW
716
            self.symbol_table[tmp_name] = Scalar(PrimitiveType.Double)
×
NEW
717
            return tmp_name
×
718

NEW
719
        return access_str
×
720

721
    def visit_AugAssign(self, node):
4✔
722
        if isinstance(node.target, ast.Name) and node.target.id in self.array_info:
4✔
723
            # Convert to slice assignment: target[:] = target op value
724
            ndim = self.array_info[node.target.id]["ndim"]
4✔
725

726
            slices = []
4✔
727
            for _ in range(ndim):
4✔
728
                slices.append(ast.Slice(lower=None, upper=None, step=None))
4✔
729

730
            if ndim == 1:
4✔
731
                slice_arg = slices[0]
×
732
            else:
733
                slice_arg = ast.Tuple(elts=slices, ctx=ast.Load())
4✔
734

735
            slice_node = ast.Subscript(
4✔
736
                value=node.target, slice=slice_arg, ctx=ast.Store()
737
            )
738

739
            new_node = ast.Assign(
4✔
740
                targets=[slice_node],
741
                value=ast.BinOp(left=node.target, op=node.op, right=node.value),
742
            )
743
            self.visit_Assign(new_node)
4✔
744
        else:
745
            new_node = ast.Assign(
4✔
746
                targets=[node.target],
747
                value=ast.BinOp(left=node.target, op=node.op, right=node.value),
748
            )
749
            self.visit_Assign(new_node)
4✔
750

751
    def visit_Assign(self, node):
4✔
752
        if len(node.targets) > 1:
4✔
753
            tmp_name = self.builder.find_new_name()
4✔
754
            # Assign value to temporary
755
            val_assign = ast.Assign(
4✔
756
                targets=[ast.Name(id=tmp_name, ctx=ast.Store())], value=node.value
757
            )
758
            ast.copy_location(val_assign, node)
4✔
759
            self.visit_Assign(val_assign)
4✔
760

761
            # Assign temporary to targets
762
            for target in node.targets:
4✔
763
                assign = ast.Assign(
4✔
764
                    targets=[target], value=ast.Name(id=tmp_name, ctx=ast.Load())
765
                )
766
                ast.copy_location(assign, node)
4✔
767
                self.visit_Assign(assign)
4✔
768
            return
4✔
769

770
        target = node.targets[0]
4✔
771

772
        # Handle tuple unpacking: I, J, K = expr1, expr2, expr3
773
        if isinstance(target, ast.Tuple):
4✔
774
            if isinstance(node.value, ast.Tuple):
4✔
775
                # Unpacking tuple to tuple: a, b, c = x, y, z
776
                if len(target.elts) != len(node.value.elts):
4✔
777
                    raise ValueError("Tuple unpacking size mismatch")
×
778
                for tgt, val in zip(target.elts, node.value.elts):
4✔
779
                    assign = ast.Assign(targets=[tgt], value=val)
4✔
780
                    ast.copy_location(assign, node)
4✔
781
                    self.visit_Assign(assign)
4✔
782
            else:
783
                raise NotImplementedError(
×
784
                    "Tuple unpacking from non-tuple values not supported"
785
                )
786
            return
4✔
787

788
        # Special case: linear algebra functions (handled by NumPyHandler)
789
        if self.numpy_visitor.is_gemm(node.value):
4✔
790
            if self.numpy_visitor.handle_gemm(target, node.value):
4✔
UNCOV
791
                return
×
792
            if self.numpy_visitor.handle_dot(target, node.value):
4✔
793
                return
×
794

795
        # Special case: outer product (handled by NumPyHandler)
796
        if self.numpy_visitor.is_outer(node.value):
4✔
797
            if self.numpy_visitor.handle_outer(target, node.value):
4✔
798
                return
4✔
799

800
        # Special case: convolution (scipy.signal.correlate2d)
801
        if self.scipy_handler.is_correlate2d(node.value):
4✔
NEW
802
            if self.scipy_handler.handle_correlate2d(target, node.value):
×
UNCOV
803
                return
×
804

805
        # Special case: Transpose (handled by NumPyHandler)
806
        if self.numpy_visitor.is_transpose(node.value):
4✔
807
            if self.numpy_visitor.handle_transpose(target, node.value):
4✔
808
                return
4✔
809

810
        # Special case:
811
        if isinstance(target, ast.Subscript):
4✔
812
            target_name = self.visit(target.value)
4✔
813

814
            indices = []
4✔
815
            if isinstance(target.slice, ast.Tuple):
4✔
816
                indices = target.slice.elts
4✔
817
            else:
818
                indices = [target.slice]
4✔
819

820
            has_slice = False
4✔
821
            for idx in indices:
4✔
822
                if isinstance(idx, ast.Slice):
4✔
823
                    has_slice = True
4✔
824
                    break
4✔
825

826
            if has_slice:
4✔
827
                debug_info = get_debug_info(node, self.filename, self.function_name)
4✔
828
                self._handle_slice_assignment(
4✔
829
                    target, node.value, target_name, indices, debug_info
830
                )
831
                return
4✔
832

833
            target_name_full = self.visit(target)
4✔
834
            value_str = self.visit(node.value)
4✔
835
            debug_info = get_debug_info(node, self.filename, self.function_name)
4✔
836

837
            block = self.builder.add_block(debug_info)
4✔
838
            t_src, src_sub = self._add_read(block, value_str, debug_info)
4✔
839

840
            if "(" in target_name_full and target_name_full.endswith(")"):
4✔
841
                name = target_name_full.split("(")[0]
4✔
842
                subset = target_name_full[target_name_full.find("(") + 1 : -1]
4✔
843
                t_dst = self.builder.add_access(block, name, debug_info)
4✔
844
                dst_sub = subset
4✔
845
            else:
846
                t_dst = self.builder.add_access(block, target_name_full, debug_info)
×
847
                dst_sub = ""
×
848

849
            t_task = self.builder.add_tasklet(
4✔
850
                block, TaskletCode.assign, ["_in"], ["_out"], debug_info
851
            )
852

853
            self.builder.add_memlet(
4✔
854
                block, t_src, "void", t_task, "_in", src_sub, None, debug_info
855
            )
856
            self.builder.add_memlet(
4✔
857
                block, t_task, "_out", t_dst, "void", dst_sub, None, debug_info
858
            )
859
            return
4✔
860

861
        # Variable assignments
862
        if not isinstance(target, ast.Name):
4✔
NEW
863
            raise NotImplementedError("Only assignment to variables supported")
×
864

865
        target_name = target.id
4✔
866
        value_str = self.visit(node.value)
4✔
867
        debug_info = get_debug_info(node, self.filename, self.function_name)
4✔
868

869
        if not self.builder.exists(target_name):
4✔
870
            if isinstance(node.value, ast.Constant):
4✔
871
                val = node.value.value
4✔
872
                if isinstance(val, int):
4✔
873
                    dtype = Scalar(PrimitiveType.Int64)
4✔
874
                elif isinstance(val, float):
4✔
875
                    dtype = Scalar(PrimitiveType.Double)
4✔
NEW
876
                elif isinstance(val, bool):
×
NEW
877
                    dtype = Scalar(PrimitiveType.Bool)
×
878
                else:
NEW
879
                    raise NotImplementedError(f"Cannot infer type for {val}")
×
880

881
                self.builder.add_container(target_name, dtype, False)
4✔
882
                self.symbol_table[target_name] = dtype
4✔
883
            else:
884
                assert value_str in self.symbol_table
4✔
885
                self.builder.add_container(
4✔
886
                    target_name, self.symbol_table[value_str], False
887
                )
888
                self.symbol_table[target_name] = self.symbol_table[value_str]
4✔
889

890
        if value_str in self.array_info:
4✔
891
            self.array_info[target_name] = self.array_info[value_str]
4✔
892

893
        # Distinguish assignments: scalar -> tasklet, pointer -> reference_memlet
894
        src_type = self.symbol_table.get(value_str)
4✔
895
        dst_type = self.symbol_table[target_name]
4✔
896
        if src_type and isinstance(src_type, Pointer) and isinstance(dst_type, Pointer):
4✔
897
            block = self.builder.add_block(debug_info)
4✔
898
            t_src = self.builder.add_access(block, value_str, debug_info)
4✔
899
            t_dst = self.builder.add_access(block, target_name, debug_info)
4✔
900
            self.builder.add_reference_memlet(
4✔
901
                block, t_src, t_dst, "0", src_type, debug_info
902
            )
903
            return
4✔
904
        elif (src_type and isinstance(src_type, Scalar)) or isinstance(
4✔
905
            dst_type, Scalar
906
        ):
907
            block = self.builder.add_block(debug_info)
4✔
908
            t_dst = self.builder.add_access(block, target_name, debug_info)
4✔
909
            t_task = self.builder.add_tasklet(
4✔
910
                block, TaskletCode.assign, ["_in"], ["_out"], debug_info
911
            )
912

913
            if src_type:
4✔
914
                t_src = self.builder.add_access(block, value_str, debug_info)
4✔
915
            else:
916
                t_src = self.builder.add_constant(
4✔
917
                    block, value_str, dst_type, debug_info
918
                )
919

920
            self.builder.add_memlet(
4✔
921
                block, t_src, "void", t_task, "_in", "", None, debug_info
922
            )
923
            self.builder.add_memlet(
4✔
924
                block, t_task, "_out", t_dst, "void", "", None, debug_info
925
            )
926

927
            return
4✔
928

929
    def visit_Expr(self, node):
4✔
NEW
930
        self.visit(node.value)
×
931

932
    def visit_If(self, node):
4✔
933
        cond = self.visit(node.test)
4✔
934
        debug_info = get_debug_info(node, self.filename, self.function_name)
4✔
935
        self.builder.begin_if(f"{cond} != false", debug_info)
4✔
936

937
        for stmt in node.body:
4✔
938
            self.visit(stmt)
4✔
939

940
        if node.orelse:
4✔
941
            self.builder.begin_else(debug_info)
4✔
942
            for stmt in node.orelse:
4✔
943
                self.visit(stmt)
4✔
944

945
        self.builder.end_if()
4✔
946

947
    def visit_While(self, node):
4✔
948
        if node.orelse:
4✔
NEW
949
            raise NotImplementedError("while-else is not supported")
×
950

951
        debug_info = get_debug_info(node, self.filename, self.function_name)
4✔
952
        self.builder.begin_while(debug_info)
4✔
953

954
        # Evaluate condition inside the loop so it's re-evaluated each iteration
955
        cond = self.visit(node.test)
4✔
956

957
        # Create if-break pattern: if condition is false, break
958
        self.builder.begin_if(f"{cond} == false", debug_info)
4✔
959
        self.builder.add_break(debug_info)
4✔
960
        self.builder.end_if()
4✔
961

962
        for stmt in node.body:
4✔
963
            self.visit(stmt)
4✔
964

965
        self.builder.end_while()
4✔
966

967
    def visit_Break(self, node):
4✔
968
        debug_info = get_debug_info(node, self.filename, self.function_name)
4✔
969
        self.builder.add_break(debug_info)
4✔
970

971
    def visit_Continue(self, node):
4✔
972
        debug_info = get_debug_info(node, self.filename, self.function_name)
4✔
973
        self.builder.add_continue(debug_info)
4✔
974

975
    def visit_For(self, node):
4✔
976
        if node.orelse:
4✔
NEW
977
            raise NotImplementedError("while-else is not supported")
×
978
        if not isinstance(node.target, ast.Name):
4✔
NEW
979
            raise NotImplementedError("Only simple for loops supported")
×
980

981
        var = node.target.id
4✔
982
        debug_info = get_debug_info(node, self.filename, self.function_name)
4✔
983

984
        # Check if iterating over a range() call
985
        if (
4✔
986
            isinstance(node.iter, ast.Call)
987
            and isinstance(node.iter.func, ast.Name)
988
            and node.iter.func.id == "range"
989
        ):
990
            args = node.iter.args
4✔
991
            if len(args) == 1:
4✔
992
                start = "0"
4✔
993
                end = self.visit(args[0])
4✔
994
                step = "1"
4✔
995
            elif len(args) == 2:
4✔
996
                start = self.visit(args[0])
4✔
997
                end = self.visit(args[1])
4✔
998
                step = "1"
4✔
999
            elif len(args) == 3:
4✔
1000
                start = self.visit(args[0])
4✔
1001
                end = self.visit(args[1])
4✔
1002

1003
                # Special handling for step to avoid creating tasklets for constants
1004
                step_node = args[2]
4✔
1005
                if isinstance(step_node, ast.Constant):
4✔
1006
                    step = str(step_node.value)
4✔
1007
                elif (
4✔
1008
                    isinstance(step_node, ast.UnaryOp)
1009
                    and isinstance(step_node.op, ast.USub)
1010
                    and isinstance(step_node.operand, ast.Constant)
1011
                ):
1012
                    step = f"-{step_node.operand.value}"
4✔
1013
                else:
NEW
1014
                    step = self.visit(step_node)
×
1015
            else:
NEW
1016
                raise ValueError("Invalid range arguments")
×
1017

1018
            if not self.builder.exists(var):
4✔
1019
                self.builder.add_container(var, Scalar(PrimitiveType.Int64), False)
4✔
1020
                self.symbol_table[var] = Scalar(PrimitiveType.Int64)
4✔
1021

1022
            self.builder.begin_for(var, start, end, step, debug_info)
4✔
1023

1024
            for stmt in node.body:
4✔
1025
                self.visit(stmt)
4✔
1026

1027
            self.builder.end_for()
4✔
1028
            return
4✔
1029

1030
        # Check if iterating over an ndarray (for x in array)
NEW
1031
        if isinstance(node.iter, ast.Name):
×
NEW
1032
            iter_name = node.iter.id
×
NEW
1033
            if iter_name in self.array_info:
×
NEW
1034
                arr_info = self.array_info[iter_name]
×
NEW
1035
                if arr_info["ndim"] < 1:
×
NEW
1036
                    raise NotImplementedError("Cannot iterate over 0-dimensional array")
×
1037

1038
                # Get the size of the first dimension
NEW
1039
                arr_size = arr_info["shapes"][0]
×
1040

1041
                # Create a hidden index variable for the loop
NEW
1042
                idx_var = self.builder.find_new_name()
×
NEW
1043
                if not self.builder.exists(idx_var):
×
NEW
1044
                    self.builder.add_container(
×
1045
                        idx_var, Scalar(PrimitiveType.Int64), False
1046
                    )
NEW
1047
                    self.symbol_table[idx_var] = Scalar(PrimitiveType.Int64)
×
1048

1049
                # Determine the type of the loop variable (element type)
1050
                # For a 1D array, it's a scalar; for ND array, it's a view of N-1 dimensions
NEW
1051
                if arr_info["ndim"] == 1:
×
1052
                    # Element is a scalar - get the element type from the array's type
NEW
1053
                    arr_type = self.symbol_table.get(iter_name)
×
NEW
1054
                    if isinstance(arr_type, Pointer):
×
NEW
1055
                        elem_type = arr_type.pointee_type
×
1056
                    else:
NEW
1057
                        elem_type = Scalar(PrimitiveType.Double)  # Default fallback
×
1058

NEW
1059
                    if not self.builder.exists(var):
×
NEW
1060
                        self.builder.add_container(var, elem_type, False)
×
NEW
1061
                        self.symbol_table[var] = elem_type
×
1062
                else:
1063
                    # For multi-dimensional arrays, create a view/slice
1064
                    # The loop variable becomes a pointer to the sub-array
NEW
1065
                    inner_shapes = arr_info["shapes"][1:]
×
NEW
1066
                    inner_ndim = arr_info["ndim"] - 1
×
1067

NEW
1068
                    arr_type = self.symbol_table.get(iter_name)
×
NEW
1069
                    if isinstance(arr_type, Pointer):
×
NEW
1070
                        elem_type = arr_type  # Keep as pointer type for views
×
1071
                    else:
NEW
1072
                        elem_type = Pointer(Scalar(PrimitiveType.Double))
×
1073

NEW
1074
                    if not self.builder.exists(var):
×
NEW
1075
                        self.builder.add_container(var, elem_type, False)
×
NEW
1076
                        self.symbol_table[var] = elem_type
×
1077

1078
                    # Register the view in array_info
NEW
1079
                    self.array_info[var] = {"ndim": inner_ndim, "shapes": inner_shapes}
×
1080

1081
                # Begin the for loop
NEW
1082
                self.builder.begin_for(idx_var, "0", str(arr_size), "1", debug_info)
×
1083

1084
                # Generate the assignment: var = array[idx_var]
1085
                # Create an AST node for the assignment and visit it
NEW
1086
                assign_node = ast.Assign(
×
1087
                    targets=[ast.Name(id=var, ctx=ast.Store())],
1088
                    value=ast.Subscript(
1089
                        value=ast.Name(id=iter_name, ctx=ast.Load()),
1090
                        slice=ast.Name(id=idx_var, ctx=ast.Load()),
1091
                        ctx=ast.Load(),
1092
                    ),
1093
                )
NEW
1094
                ast.copy_location(assign_node, node)
×
NEW
1095
                self.visit_Assign(assign_node)
×
1096

1097
                # Visit the loop body
NEW
1098
                for stmt in node.body:
×
NEW
1099
                    self.visit(stmt)
×
1100

NEW
1101
                self.builder.end_for()
×
NEW
1102
                return
×
1103

NEW
1104
        raise NotImplementedError(
×
1105
            f"Only range() loops and iteration over ndarrays supported, got: {ast.dump(node.iter)}"
1106
        )
1107

1108
    def visit_Return(self, node):
4✔
1109
        if node.value is None:
4✔
NEW
1110
            debug_info = get_debug_info(node, self.filename, self.function_name)
×
NEW
1111
            self.builder.add_return("", debug_info)
×
NEW
1112
            return
×
1113

1114
        if isinstance(node.value, ast.Tuple):
4✔
1115
            values = node.value.elts
4✔
1116
        else:
1117
            values = [node.value]
4✔
1118

1119
        parsed_values = [self.visit(v) for v in values]
4✔
1120
        debug_info = get_debug_info(node, self.filename, self.function_name)
4✔
1121

1122
        if self.infer_return_type:
4✔
1123
            for i, res in enumerate(parsed_values):
4✔
1124
                ret_name = f"_docc_ret_{i}"
4✔
1125
                if not self.builder.exists(ret_name):
4✔
1126
                    dtype = Scalar(PrimitiveType.Double)
4✔
1127
                    if res in self.symbol_table:
4✔
1128
                        dtype = self.symbol_table[res]
4✔
NEW
1129
                    elif isinstance(values[i], ast.Constant):
×
NEW
1130
                        val = values[i].value
×
NEW
1131
                        if isinstance(val, int):
×
NEW
1132
                            dtype = Scalar(PrimitiveType.Int64)
×
NEW
1133
                        elif isinstance(val, float):
×
NEW
1134
                            dtype = Scalar(PrimitiveType.Double)
×
NEW
1135
                        elif isinstance(val, bool):
×
NEW
1136
                            dtype = Scalar(PrimitiveType.Bool)
×
1137

1138
                    # Wrap Scalar in Pointer. Keep Arrays/Pointers as is.
1139
                    arg_type = dtype
4✔
1140
                    if isinstance(dtype, Scalar):
4✔
1141
                        arg_type = Pointer(dtype)
4✔
1142

1143
                    self.builder.add_container(ret_name, arg_type, is_argument=True)
4✔
1144
                    self.symbol_table[ret_name] = arg_type
4✔
1145

1146
                    if res in self.array_info:
4✔
1147
                        self.array_info[ret_name] = self.array_info[res]
4✔
1148

1149
            self.infer_return_type = False
4✔
1150

1151
        for i, res in enumerate(parsed_values):
4✔
1152
            ret_name = f"_docc_ret_{i}"
4✔
1153
            typ = self.symbol_table.get(ret_name)
4✔
1154

1155
            is_array_return = False
4✔
1156
            if res in self.array_info:
4✔
1157
                # Only treat as array return if it has dimensions
1158
                # 0-d arrays (scalars) should be handled by scalar assignment
1159
                if self.array_info[res]["ndim"] > 0:
4✔
1160
                    is_array_return = True
4✔
1161
            elif res in self.symbol_table:
4✔
1162
                if isinstance(self.symbol_table[res], Pointer):
4✔
NEW
1163
                    is_array_return = True
×
1164

1165
            # Simple Scalar Assignment
1166
            if not is_array_return:
4✔
1167
                block = self.builder.add_block(debug_info)
4✔
1168
                t_dst = self.builder.add_access(block, ret_name, debug_info)
4✔
1169

1170
                t_src, src_sub = self._add_read(block, res, debug_info)
4✔
1171

1172
                t_task = self.builder.add_tasklet(
4✔
1173
                    block, TaskletCode.assign, ["_in"], ["_out"], debug_info
1174
                )
1175
                self.builder.add_memlet(
4✔
1176
                    block, t_src, "void", t_task, "_in", src_sub, None, debug_info
1177
                )
1178
                self.builder.add_memlet(
4✔
1179
                    block, t_task, "_out", t_dst, "void", "0", None, debug_info
1180
                )
1181

1182
            # Array Assignment (Copy)
1183
            else:
1184
                # Record shape for metadata
1185
                if res in self.array_info:
4✔
1186
                    # Prefer runtime shapes if available (for indirect access patterns)
1187
                    # Fall back to regular shapes otherwise
1188
                    if "shapes_runtime" in self.array_info[res]:
4✔
1189
                        shape = self.array_info[res]["shapes_runtime"]
4✔
1190
                    else:
1191
                        shape = self.array_info[res]["shapes"]
4✔
1192
                    # Convert to string expressions
1193
                    self.captured_return_shapes[ret_name] = [str(s) for s in shape]
4✔
1194

1195
                    # Ensure destination array info exists
1196
                    if ret_name not in self.array_info:
4✔
1197
                        self.array_info[ret_name] = self.array_info[res]
4✔
1198

1199
                # Copy Logic using visit_Assign
1200
                ndim = 1
4✔
1201
                if ret_name in self.array_info:
4✔
1202
                    ndim = self.array_info[ret_name]["ndim"]
4✔
1203

1204
                slice_node = ast.Slice(lower=None, upper=None, step=None)
4✔
1205
                if ndim > 1:
4✔
1206
                    target_slice = ast.Tuple(elts=[slice_node] * ndim, ctx=ast.Load())
4✔
1207
                else:
1208
                    target_slice = slice_node
4✔
1209

1210
                target_sub = ast.Subscript(
4✔
1211
                    value=ast.Name(id=ret_name, ctx=ast.Load()),
1212
                    slice=target_slice,
1213
                    ctx=ast.Store(),
1214
                )
1215

1216
                # Value node reconstruction
1217
                if isinstance(values[i], ast.Name):
4✔
1218
                    val_node = values[i]
4✔
1219
                else:
1220
                    val_node = ast.Name(id=res, ctx=ast.Load())
4✔
1221

1222
                assign_node = ast.Assign(targets=[target_sub], value=val_node)
4✔
1223
                self.visit_Assign(assign_node)
4✔
1224

1225
        # Add control flow return to exit the function/path
1226
        self.builder.add_return("", debug_info)
4✔
1227

1228
    def visit_Call(self, node):
4✔
1229
        func_name = ""
4✔
1230
        module_name = ""
4✔
1231
        submodule_name = ""
4✔
1232
        if isinstance(node.func, ast.Attribute):
4✔
1233
            if isinstance(node.func.value, ast.Name):
4✔
1234
                if node.func.value.id == "math":
4✔
1235
                    module_name = "math"
4✔
1236
                    func_name = node.func.attr
4✔
1237
                elif node.func.value.id in ["numpy", "np"]:
4✔
1238
                    module_name = "numpy"
4✔
1239
                    func_name = node.func.attr
4✔
1240
                else:
1241
                    array_name = node.func.value.id
4✔
1242
                    method_name = node.func.attr
4✔
1243
                    if array_name in self.array_info and method_name == "astype":
4✔
1244
                        return self.numpy_visitor.handle_numpy_astype(node, array_name)
4✔
1245
                    elif array_name in self.array_info and method_name == "copy":
4✔
1246
                        return self.numpy_visitor.handle_numpy_copy(node, array_name)
4✔
1247
            elif isinstance(node.func.value, ast.Attribute):
4✔
1248
                if (
4✔
1249
                    isinstance(node.func.value.value, ast.Name)
1250
                    and node.func.value.value.id == "scipy"
1251
                ):
1252
                    module_name = "scipy"
4✔
1253
                    submodule_name = node.func.value.attr
4✔
1254
                    func_name = node.func.attr
4✔
1255
                elif (
4✔
1256
                    isinstance(node.func.value.value, ast.Name)
1257
                    and node.func.value.value.id in ["numpy", "np"]
1258
                    and node.func.attr == "outer"
1259
                ):
1260
                    ufunc_name = node.func.value.attr
4✔
1261
                    return self.numpy_visitor.handle_ufunc_outer(node, ufunc_name)
4✔
1262

1263
        elif isinstance(node.func, ast.Name):
4✔
1264
            func_name = node.func.id
4✔
1265

1266
        if module_name == "numpy":
4✔
1267
            if self.numpy_visitor.has_handler(func_name):
4✔
1268
                return self.numpy_visitor.handle_numpy_call(node, func_name)
4✔
1269

1270
        if module_name == "math":
4✔
1271
            if self.math_handler.has_handler(func_name):
4✔
1272
                return self.math_handler.handle_math_call(node, func_name)
4✔
1273

1274
        if module_name == "scipy":
4✔
1275
            if self.scipy_handler.has_handler(submodule_name, func_name):
4✔
1276
                return self.scipy_handler.handle_scipy_call(
4✔
1277
                    node, submodule_name, func_name
1278
                )
1279

1280
        if self.python_handler.has_handler(func_name):
4✔
1281
            return self.python_handler.handle_python_call(node, func_name)
4✔
1282

1283
        if func_name in self.globals_dict:
4✔
1284
            obj = self.globals_dict[func_name]
4✔
1285
            if inspect.isfunction(obj):
4✔
1286
                return self._handle_inline_call(node, obj)
4✔
1287

NEW
1288
        raise NotImplementedError(f"Function call {func_name} not supported")
×
1289

1290
    def _handle_inline_call(self, node, func_obj):
4✔
1291
        try:
4✔
1292
            source_lines, start_line = inspect.getsourcelines(func_obj)
4✔
1293
            source = textwrap.dedent("".join(source_lines))
4✔
1294
            tree = ast.parse(source)
4✔
1295
            func_def = tree.body[0]
4✔
NEW
1296
        except Exception as e:
×
NEW
1297
            raise NotImplementedError(
×
1298
                f"Could not parse function {func_obj.__name__}: {e}"
1299
            )
1300

1301
        arg_vars = [self.visit(arg) for arg in node.args]
4✔
1302

1303
        if len(arg_vars) != len(func_def.args.args):
4✔
NEW
1304
            raise NotImplementedError(
×
1305
                f"Argument count mismatch for {func_obj.__name__}"
1306
            )
1307

1308
        suffix = f"_{func_obj.__name__}_{self.builder.find_new_name()}"
4✔
1309
        res_name = f"_res{suffix}"
4✔
1310

1311
        # Combine globals with closure variables of the inlined function
1312
        combined_globals = dict(self.globals_dict)
4✔
1313
        closure_constants = {}  # name -> value for numeric closure vars
4✔
1314
        if func_obj.__closure__ is not None and func_obj.__code__.co_freevars:
4✔
1315
            for name, cell in zip(func_obj.__code__.co_freevars, func_obj.__closure__):
4✔
1316
                val = cell.cell_contents
4✔
1317
                combined_globals[name] = val
4✔
1318
                # Track numeric constants for injection
1319
                if isinstance(val, (int, float)) and not isinstance(val, bool):
4✔
1320
                    closure_constants[name] = val
4✔
1321

1322
        class VariableRenamer(ast.NodeTransformer):
4✔
1323
            BUILTINS = {
4✔
1324
                "range",
1325
                "len",
1326
                "int",
1327
                "float",
1328
                "bool",
1329
                "str",
1330
                "list",
1331
                "dict",
1332
                "tuple",
1333
                "set",
1334
                "print",
1335
                "abs",
1336
                "min",
1337
                "max",
1338
                "sum",
1339
                "enumerate",
1340
                "zip",
1341
                "map",
1342
                "filter",
1343
                "sorted",
1344
                "reversed",
1345
                "True",
1346
                "False",
1347
                "None",
1348
            }
1349

1350
            def __init__(self, suffix, globals_dict):
4✔
1351
                self.suffix = suffix
4✔
1352
                self.globals_dict = globals_dict
4✔
1353

1354
            def visit_Name(self, node):
4✔
1355
                if node.id in self.globals_dict or node.id in self.BUILTINS:
4✔
1356
                    return node
4✔
1357
                return ast.Name(id=f"{node.id}{self.suffix}", ctx=node.ctx)
4✔
1358

1359
            def visit_Return(self, node):
4✔
1360
                if node.value:
4✔
1361
                    val = self.visit(node.value)
4✔
1362
                    return ast.Assign(
4✔
1363
                        targets=[ast.Name(id=res_name, ctx=ast.Store())],
1364
                        value=val,
1365
                    )
NEW
1366
                return node
×
1367

1368
        renamer = VariableRenamer(suffix, combined_globals)
4✔
1369
        new_body = [renamer.visit(stmt) for stmt in func_def.body]
4✔
1370

1371
        param_assignments = []
4✔
1372

1373
        # Inject closure constants as assignments
1374
        for name, val in closure_constants.items():
4✔
1375
            if isinstance(val, int):
4✔
1376
                self.symbol_table[name] = Scalar(PrimitiveType.Int64)
4✔
1377
                self.builder.add_container(name, Scalar(PrimitiveType.Int64), False)
4✔
1378
                val_node = ast.Constant(value=val)
4✔
1379
            else:
NEW
1380
                self.symbol_table[name] = Scalar(PrimitiveType.Double)
×
NEW
1381
                self.builder.add_container(name, Scalar(PrimitiveType.Double), False)
×
NEW
1382
                val_node = ast.Constant(value=val)
×
1383
            assign = ast.Assign(
4✔
1384
                targets=[ast.Name(id=name, ctx=ast.Store())], value=val_node
1385
            )
1386
            param_assignments.append(assign)
4✔
1387

1388
        for arg_def, arg_val in zip(func_def.args.args, arg_vars):
4✔
1389
            param_name = f"{arg_def.arg}{suffix}"
4✔
1390

1391
            if arg_val in self.symbol_table:
4✔
1392
                self.symbol_table[param_name] = self.symbol_table[arg_val]
4✔
1393
                self.builder.add_container(
4✔
1394
                    param_name, self.symbol_table[arg_val], False
1395
                )
1396
                val_node = ast.Name(id=arg_val, ctx=ast.Load())
4✔
NEW
1397
            elif self._is_int(arg_val):
×
NEW
1398
                self.symbol_table[param_name] = Scalar(PrimitiveType.Int64)
×
NEW
1399
                self.builder.add_container(
×
1400
                    param_name, Scalar(PrimitiveType.Int64), False
1401
                )
NEW
1402
                val_node = ast.Constant(value=int(arg_val))
×
1403
            else:
NEW
1404
                try:
×
NEW
1405
                    val = float(arg_val)
×
NEW
1406
                    self.symbol_table[param_name] = Scalar(PrimitiveType.Double)
×
NEW
1407
                    self.builder.add_container(
×
1408
                        param_name, Scalar(PrimitiveType.Double), False
1409
                    )
NEW
1410
                    val_node = ast.Constant(value=val)
×
NEW
1411
                except ValueError:
×
NEW
1412
                    val_node = ast.Name(id=arg_val, ctx=ast.Load())
×
1413

1414
            assign = ast.Assign(
4✔
1415
                targets=[ast.Name(id=param_name, ctx=ast.Store())], value=val_node
1416
            )
1417
            param_assignments.append(assign)
4✔
1418

1419
        final_body = param_assignments + new_body
4✔
1420

1421
        # Create a new parser instance for the inlined function
1422
        parser = ASTParser(
4✔
1423
            self.builder,
1424
            self.array_info,
1425
            self.symbol_table,
1426
            globals_dict=combined_globals,
1427
            unique_counter_ref=self._unique_counter_ref,
1428
        )
1429

1430
        for stmt in final_body:
4✔
1431
            parser.visit(stmt)
4✔
1432

1433
        return res_name
4✔
1434

1435
    def _add_assign_constant(self, target_name, value_str, dtype):
4✔
1436
        block = self.builder.add_block()
4✔
1437
        t_const = self.builder.add_constant(block, value_str, dtype)
4✔
1438
        t_dst = self.builder.add_access(block, target_name)
4✔
1439
        t_task = self.builder.add_tasklet(block, TaskletCode.assign, ["_in"], ["_out"])
4✔
1440
        self.builder.add_memlet(block, t_const, "void", t_task, "_in", "")
4✔
1441
        self.builder.add_memlet(block, t_task, "_out", t_dst, "void", "")
4✔
1442

1443
    def _handle_expression_slicing(self, node, value_str, indices_nodes, shapes, ndim):
4✔
1444
        """Handle slicing in expressions (e.g., arr[1:, :, k+1])."""
1445
        if not self.builder:
4✔
NEW
1446
            raise ValueError("Builder required for expression slicing")
×
1447

1448
        dtype = Scalar(PrimitiveType.Double)
4✔
1449
        if value_str in self.symbol_table:
4✔
1450
            t = self.symbol_table[value_str]
4✔
1451
            if isinstance(t, Pointer) and t.has_pointee_type():
4✔
1452
                dtype = t.pointee_type
4✔
1453

1454
        result_shapes = []
4✔
1455
        result_shapes_runtime = []
4✔
1456
        slice_info = []
4✔
1457
        index_info = []
4✔
1458

1459
        for i, idx in enumerate(indices_nodes):
4✔
1460
            shape_val = shapes[i] if i < len(shapes) else f"_{value_str}_shape_{i}"
4✔
1461

1462
            if isinstance(idx, ast.Slice):
4✔
1463
                start_str = "0"
4✔
1464
                start_str_runtime = "0"
4✔
1465
                if idx.lower is not None:
4✔
1466
                    if self._contains_indirect_access(idx.lower):
4✔
1467
                        start_str, start_str_runtime = (
4✔
1468
                            self._materialize_indirect_access(
1469
                                idx.lower, return_original_expr=True
1470
                            )
1471
                        )
1472
                    else:
1473
                        start_str = self.visit(idx.lower)
4✔
1474
                        start_str_runtime = start_str
4✔
1475
                    if isinstance(start_str, str) and (
4✔
1476
                        start_str.startswith("-") or start_str.startswith("(-")
1477
                    ):
NEW
1478
                        start_str = f"({shape_val} + {start_str})"
×
NEW
1479
                        start_str_runtime = f"({shape_val} + {start_str_runtime})"
×
1480

1481
                stop_str = str(shape_val)
4✔
1482
                stop_str_runtime = str(shape_val)
4✔
1483
                if idx.upper is not None:
4✔
1484
                    if self._contains_indirect_access(idx.upper):
4✔
1485
                        stop_str, stop_str_runtime = self._materialize_indirect_access(
4✔
1486
                            idx.upper, return_original_expr=True
1487
                        )
1488
                    else:
1489
                        stop_str = self.visit(idx.upper)
4✔
1490
                        stop_str_runtime = stop_str
4✔
1491
                    if isinstance(stop_str, str) and (
4✔
1492
                        stop_str.startswith("-") or stop_str.startswith("(-")
1493
                    ):
1494
                        stop_str = f"({shape_val} + {stop_str})"
4✔
1495
                        stop_str_runtime = f"({shape_val} + {stop_str_runtime})"
4✔
1496

1497
                step_str = "1"
4✔
1498
                if idx.step is not None:
4✔
NEW
1499
                    step_str = self.visit(idx.step)
×
1500

1501
                dim_size = f"({stop_str} - {start_str})"
4✔
1502
                dim_size_runtime = f"({stop_str_runtime} - {start_str_runtime})"
4✔
1503
                result_shapes.append(dim_size)
4✔
1504
                result_shapes_runtime.append(dim_size_runtime)
4✔
1505
                slice_info.append((i, start_str, stop_str, step_str))
4✔
1506
            else:
1507
                if self._contains_indirect_access(idx):
4✔
NEW
1508
                    index_str = self._materialize_indirect_access(idx)
×
1509
                else:
1510
                    index_str = self.visit(idx)
4✔
1511
                if isinstance(index_str, str) and (
4✔
1512
                    index_str.startswith("-") or index_str.startswith("(-")
1513
                ):
NEW
1514
                    index_str = f"({shape_val} + {index_str})"
×
1515
                index_info.append((i, index_str))
4✔
1516

1517
        tmp_name = self.builder.find_new_name("_slice_tmp_")
4✔
1518
        result_ndim = len(result_shapes)
4✔
1519

1520
        if result_ndim == 0:
4✔
NEW
1521
            self.builder.add_container(tmp_name, dtype, False)
×
NEW
1522
            self.symbol_table[tmp_name] = dtype
×
1523
        else:
1524
            size_str = "1"
4✔
1525
            for dim in result_shapes:
4✔
1526
                size_str = f"({size_str} * {dim})"
4✔
1527

1528
            element_size = self.builder.get_sizeof(dtype)
4✔
1529
            total_size = f"({size_str} * {element_size})"
4✔
1530

1531
            ptr_type = Pointer(dtype)
4✔
1532
            self.builder.add_container(tmp_name, ptr_type, False)
4✔
1533
            self.symbol_table[tmp_name] = ptr_type
4✔
1534
            self.array_info[tmp_name] = {
4✔
1535
                "ndim": result_ndim,
1536
                "shapes": result_shapes,
1537
                "shapes_runtime": result_shapes_runtime,
1538
            }
1539

1540
            debug_info = DebugInfo()
4✔
1541
            block_alloc = self.builder.add_block(debug_info)
4✔
1542
            t_malloc = self.builder.add_malloc(block_alloc, total_size)
4✔
1543
            t_ptr = self.builder.add_access(block_alloc, tmp_name, debug_info)
4✔
1544
            self.builder.add_memlet(
4✔
1545
                block_alloc, t_malloc, "_ret", t_ptr, "void", "", ptr_type, debug_info
1546
            )
1547

1548
        loop_vars = []
4✔
1549
        debug_info = DebugInfo()
4✔
1550

1551
        for dim_idx, (orig_dim, start_str, stop_str, step_str) in enumerate(slice_info):
4✔
1552
            loop_var = self.builder.find_new_name(f"_slice_loop_{dim_idx}_")
4✔
1553
            loop_vars.append((loop_var, orig_dim, start_str, step_str))
4✔
1554

1555
            if not self.builder.exists(loop_var):
4✔
1556
                self.builder.add_container(loop_var, Scalar(PrimitiveType.Int64), False)
4✔
1557
                self.symbol_table[loop_var] = Scalar(PrimitiveType.Int64)
4✔
1558

1559
            count_str = f"({stop_str} - {start_str})"
4✔
1560
            self.builder.begin_for(loop_var, "0", count_str, "1", debug_info)
4✔
1561

1562
        src_indices = [""] * ndim
4✔
1563
        dst_indices = []
4✔
1564

1565
        for orig_dim, index_str in index_info:
4✔
1566
            src_indices[orig_dim] = index_str
4✔
1567

1568
        for loop_var, orig_dim, start_str, step_str in loop_vars:
4✔
1569
            if step_str == "1":
4✔
1570
                src_indices[orig_dim] = f"({start_str} + {loop_var})"
4✔
1571
            else:
NEW
1572
                src_indices[orig_dim] = f"({start_str} + {loop_var} * {step_str})"
×
1573
            dst_indices.append(loop_var)
4✔
1574

1575
        src_linear = self._compute_linear_index(src_indices, shapes, value_str, ndim)
4✔
1576
        if result_ndim > 0:
4✔
1577
            dst_linear = self._compute_linear_index(
4✔
1578
                dst_indices, result_shapes, tmp_name, result_ndim
1579
            )
1580
        else:
NEW
1581
            dst_linear = "0"
×
1582

1583
        block = self.builder.add_block(debug_info)
4✔
1584
        t_src = self.builder.add_access(block, value_str, debug_info)
4✔
1585
        t_dst = self.builder.add_access(block, tmp_name, debug_info)
4✔
1586
        t_task = self.builder.add_tasklet(
4✔
1587
            block, TaskletCode.assign, ["_in"], ["_out"], debug_info
1588
        )
1589

1590
        self.builder.add_memlet(
4✔
1591
            block, t_src, "void", t_task, "_in", src_linear, None, debug_info
1592
        )
1593
        self.builder.add_memlet(
4✔
1594
            block, t_task, "_out", t_dst, "void", dst_linear, None, debug_info
1595
        )
1596

1597
        for _ in loop_vars:
4✔
1598
            self.builder.end_for()
4✔
1599

1600
        return tmp_name
4✔
1601

1602
    def _compute_linear_index(self, indices, shapes, array_name, ndim):
4✔
1603
        """Compute linear index from multi-dimensional indices."""
1604
        if ndim == 0:
4✔
NEW
1605
            return "0"
×
1606

1607
        linear_index = ""
4✔
1608
        for i in range(ndim):
4✔
1609
            term = str(indices[i])
4✔
1610
            for j in range(i + 1, ndim):
4✔
1611
                shape_val = shapes[j] if j < len(shapes) else f"_{array_name}_shape_{j}"
4✔
1612
                term = f"(({term}) * {shape_val})"
4✔
1613

1614
            if i == 0:
4✔
1615
                linear_index = term
4✔
1616
            else:
1617
                linear_index = f"({linear_index} + {term})"
4✔
1618

1619
        return linear_index
4✔
1620

1621
    def _is_array_index(self, node):
4✔
1622
        """Check if a node represents an array that could be used as an index (gather)."""
1623
        if isinstance(node, ast.Name):
4✔
1624
            return node.id in self.array_info
4✔
1625
        return False
4✔
1626

1627
    def _handle_gather(self, value_str, index_node, debug_info=None):
4✔
1628
        """Handle gather operation: x[indices] where indices is an array."""
1629
        if debug_info is None:
4✔
1630
            debug_info = DebugInfo()
4✔
1631

1632
        if isinstance(index_node, ast.Name):
4✔
1633
            idx_array_name = index_node.id
4✔
1634
        else:
NEW
1635
            idx_array_name = self.visit(index_node)
×
1636

1637
        if idx_array_name not in self.array_info:
4✔
NEW
1638
            raise ValueError(f"Gather index must be an array, got {idx_array_name}")
×
1639

1640
        idx_shapes = self.array_info[idx_array_name].get("shapes", [])
4✔
1641
        idx_ndim = self.array_info[idx_array_name]["ndim"]
4✔
1642

1643
        if idx_ndim != 1:
4✔
NEW
1644
            raise NotImplementedError("Only 1D index arrays supported for gather")
×
1645

1646
        result_shape = idx_shapes[0] if idx_shapes else f"_{idx_array_name}_shape_0"
4✔
1647

1648
        dtype = Scalar(PrimitiveType.Double)
4✔
1649
        if value_str in self.symbol_table:
4✔
1650
            t = self.symbol_table[value_str]
4✔
1651
            if isinstance(t, Pointer) and t.has_pointee_type():
4✔
1652
                dtype = t.pointee_type
4✔
1653

1654
        idx_dtype = Scalar(PrimitiveType.Int64)
4✔
1655
        if idx_array_name in self.symbol_table:
4✔
1656
            t = self.symbol_table[idx_array_name]
4✔
1657
            if isinstance(t, Pointer) and t.has_pointee_type():
4✔
1658
                idx_dtype = t.pointee_type
4✔
1659

1660
        tmp_name = self.builder.find_new_name("_gather_")
4✔
1661

1662
        element_size = self.builder.get_sizeof(dtype)
4✔
1663
        total_size = f"({result_shape} * {element_size})"
4✔
1664

1665
        ptr_type = Pointer(dtype)
4✔
1666
        self.builder.add_container(tmp_name, ptr_type, False)
4✔
1667
        self.symbol_table[tmp_name] = ptr_type
4✔
1668
        self.array_info[tmp_name] = {"ndim": 1, "shapes": [result_shape]}
4✔
1669

1670
        block_alloc = self.builder.add_block(debug_info)
4✔
1671
        t_malloc = self.builder.add_malloc(block_alloc, total_size)
4✔
1672
        t_ptr = self.builder.add_access(block_alloc, tmp_name, debug_info)
4✔
1673
        self.builder.add_memlet(
4✔
1674
            block_alloc, t_malloc, "_ret", t_ptr, "void", "", ptr_type, debug_info
1675
        )
1676

1677
        loop_var = self.builder.find_new_name("_gather_i_")
4✔
1678
        self.builder.add_container(loop_var, Scalar(PrimitiveType.Int64), False)
4✔
1679
        self.symbol_table[loop_var] = Scalar(PrimitiveType.Int64)
4✔
1680

1681
        idx_var = self.builder.find_new_name("_gather_idx_")
4✔
1682
        self.builder.add_container(idx_var, idx_dtype, False)
4✔
1683
        self.symbol_table[idx_var] = idx_dtype
4✔
1684

1685
        self.builder.begin_for(loop_var, "0", str(result_shape), "1", debug_info)
4✔
1686

1687
        block_load_idx = self.builder.add_block(debug_info)
4✔
1688
        idx_arr_access = self.builder.add_access(
4✔
1689
            block_load_idx, idx_array_name, debug_info
1690
        )
1691
        idx_var_access = self.builder.add_access(block_load_idx, idx_var, debug_info)
4✔
1692
        tasklet_load = self.builder.add_tasklet(
4✔
1693
            block_load_idx, TaskletCode.assign, ["_in"], ["_out"], debug_info
1694
        )
1695
        self.builder.add_memlet(
4✔
1696
            block_load_idx,
1697
            idx_arr_access,
1698
            "void",
1699
            tasklet_load,
1700
            "_in",
1701
            loop_var,
1702
            None,
1703
            debug_info,
1704
        )
1705
        self.builder.add_memlet(
4✔
1706
            block_load_idx,
1707
            tasklet_load,
1708
            "_out",
1709
            idx_var_access,
1710
            "void",
1711
            "",
1712
            None,
1713
            debug_info,
1714
        )
1715

1716
        block_gather = self.builder.add_block(debug_info)
4✔
1717
        src_access = self.builder.add_access(block_gather, value_str, debug_info)
4✔
1718
        dst_access = self.builder.add_access(block_gather, tmp_name, debug_info)
4✔
1719
        tasklet_gather = self.builder.add_tasklet(
4✔
1720
            block_gather, TaskletCode.assign, ["_in"], ["_out"], debug_info
1721
        )
1722

1723
        self.builder.add_memlet(
4✔
1724
            block_gather,
1725
            src_access,
1726
            "void",
1727
            tasklet_gather,
1728
            "_in",
1729
            idx_var,
1730
            None,
1731
            debug_info,
1732
        )
1733
        self.builder.add_memlet(
4✔
1734
            block_gather,
1735
            tasklet_gather,
1736
            "_out",
1737
            dst_access,
1738
            "void",
1739
            loop_var,
1740
            None,
1741
            debug_info,
1742
        )
1743

1744
        self.builder.end_for()
4✔
1745

1746
        return tmp_name
4✔
1747

1748
    def _get_max_array_ndim_in_expr(self, node):
4✔
1749
        """Get the maximum array dimensionality in an expression."""
1750
        max_ndim = 0
4✔
1751

1752
        class NdimVisitor(ast.NodeVisitor):
4✔
1753
            def __init__(self, array_info):
4✔
1754
                self.array_info = array_info
4✔
1755
                self.max_ndim = 0
4✔
1756

1757
            def visit_Name(self, node):
4✔
1758
                if node.id in self.array_info:
4✔
1759
                    ndim = self.array_info[node.id].get("ndim", 0)
4✔
1760
                    self.max_ndim = max(self.max_ndim, ndim)
4✔
1761
                return self.generic_visit(node)
4✔
1762

1763
        visitor = NdimVisitor(self.array_info)
4✔
1764
        visitor.visit(node)
4✔
1765
        return visitor.max_ndim
4✔
1766

1767
    def _handle_broadcast_slice_assignment(
4✔
1768
        self, target, value, target_name, indices, target_ndim, value_ndim, debug_info
1769
    ):
1770
        """Handle slice assignment with broadcasting (e.g., 2D -= 1D)."""
1771
        # Number of broadcast dimensions (outer loops)
1772
        broadcast_dims = target_ndim - value_ndim
×
1773

1774
        shapes = self.array_info[target_name].get("shapes", [])
×
1775

1776
        # Create outer loops for broadcast dimensions
1777
        outer_loop_vars = []
×
1778
        for i in range(broadcast_dims):
×
NEW
1779
            loop_var = self.builder.find_new_name(f"_bcast_iter_{i}_")
×
1780
            outer_loop_vars.append(loop_var)
×
1781

1782
            if not self.builder.exists(loop_var):
×
1783
                self.builder.add_container(loop_var, Scalar(PrimitiveType.Int64), False)
×
1784
                self.symbol_table[loop_var] = Scalar(PrimitiveType.Int64)
×
1785

1786
            dim_size = shapes[i] if i < len(shapes) else f"_{target_name}_shape_{i}"
×
1787
            self.builder.begin_for(loop_var, "0", dim_size, "1", debug_info)
×
1788

1789
        # Create a row view (reference) for the inner dimensions
NEW
1790
        row_view_name = self.builder.find_new_name("_row_view_")
×
1791

1792
        # Get inner shape for the row view
1793
        inner_shapes = shapes[broadcast_dims:] if len(shapes) > broadcast_dims else []
×
1794

1795
        # Determine element type from the target
1796
        target_type = self.symbol_table.get(target_name)
×
1797
        if isinstance(target_type, Pointer) and target_type.has_pointee_type():
×
1798
            element_type = target_type.pointee_type
×
1799
        else:
1800
            element_type = Scalar(PrimitiveType.Double)
×
1801

1802
        # Create pointer type for row view
1803
        row_type = Pointer(element_type)
×
1804
        self.builder.add_container(row_view_name, row_type, False)
×
1805
        self.symbol_table[row_view_name] = row_type
×
1806

1807
        # Register row view in array_info
1808
        self.array_info[row_view_name] = {"ndim": value_ndim, "shapes": inner_shapes}
×
1809

1810
        # Create reference memlet: row_view = &target[i, 0, 0, ...]
1811
        # The index is: outer_loop_vars joined, then zeros for inner dims
1812
        ref_index_parts = outer_loop_vars[:]
×
1813
        for _ in range(value_ndim):
×
1814
            ref_index_parts.append("0")
×
1815

1816
        # Compute linearized index for reference
1817
        # For target[i, j] with shape (n, m), linear index for row i is i * m
1818
        linear_idx = outer_loop_vars[0] if outer_loop_vars else "0"
×
1819
        for dim_idx in range(1, broadcast_dims):
×
1820
            dim_size = (
×
1821
                shapes[dim_idx]
1822
                if dim_idx < len(shapes)
1823
                else f"_{target_name}_shape_{dim_idx}"
1824
            )
1825
            linear_idx = f"({linear_idx}) * ({dim_size}) + {outer_loop_vars[dim_idx]}"
×
1826

1827
        # Multiply by inner dimension sizes to get the start of the row
1828
        for dim_idx in range(broadcast_dims, target_ndim):
×
1829
            dim_size = (
×
1830
                shapes[dim_idx]
1831
                if dim_idx < len(shapes)
1832
                else f"_{target_name}_shape_{dim_idx}"
1833
            )
1834
            linear_idx = f"({linear_idx}) * ({dim_size})"
×
1835

1836
        # Create the reference memlet block
1837
        block = self.builder.add_block(debug_info)
×
1838
        t_src = self.builder.add_access(block, target_name, debug_info)
×
1839
        t_dst = self.builder.add_access(block, row_view_name, debug_info)
×
1840
        self.builder.add_reference_memlet(
×
1841
            block, t_src, t_dst, linear_idx, row_type, debug_info
1842
        )
1843

1844
        # Now handle the inner slice assignment with the row view
1845
        # Create inner indices (all slices for the inner dimensions)
1846
        inner_indices = [
×
1847
            ast.Slice(lower=None, upper=None, step=None) for _ in range(value_ndim)
1848
        ]
1849

1850
        # Create new target using row view
1851
        new_target = ast.Subscript(
×
1852
            value=ast.Name(id=row_view_name, ctx=ast.Load()),
1853
            slice=(
1854
                ast.Tuple(elts=inner_indices, ctx=ast.Load())
1855
                if len(inner_indices) > 1
1856
                else inner_indices[0]
1857
            ),
1858
            ctx=ast.Store(),
1859
        )
1860

1861
        # Recursively handle the inner assignment (now same-dimension)
1862
        self._handle_slice_assignment(
×
1863
            new_target, value, row_view_name, inner_indices, debug_info
1864
        )
1865

1866
        # Close outer loops
1867
        for _ in outer_loop_vars:
×
1868
            self.builder.end_for()
×
1869

1870
    def _handle_slice_assignment(
4✔
1871
        self, target, value, target_name, indices, debug_info=None
1872
    ):
1873
        if debug_info is None:
4✔
1874
            debug_info = DebugInfo()
×
1875

1876
        if target_name in self.array_info:
4✔
1877
            ndim = self.array_info[target_name]["ndim"]
4✔
1878
            if len(indices) < ndim:
4✔
1879
                indices = list(indices)
4✔
1880
                for _ in range(ndim - len(indices)):
4✔
1881
                    indices.append(ast.Slice(lower=None, upper=None, step=None))
4✔
1882

1883
        # Check if the RHS contains a ufunc outer operation
1884
        # If so, we handle it specially to avoid the loop transformation
1885
        # which would destroy the slice shape information
1886
        has_outer, ufunc_name, outer_node = contains_ufunc_outer(value)
4✔
1887
        if has_outer:
4✔
1888
            self._handle_ufunc_outer_slice_assignment(
4✔
1889
                target, value, target_name, indices, debug_info
1890
            )
1891
            return
4✔
1892

1893
        # Count slice dimensions to determine effective target dimensionality
1894
        # (slice indices produce array dimensions, point indices collapse them)
1895
        target_slice_ndim = sum(1 for idx in indices if isinstance(idx, ast.Slice))
4✔
1896
        value_max_ndim = self._get_max_array_ndim_in_expr(value)
4✔
1897

1898
        if (
4✔
1899
            target_slice_ndim > 0
1900
            and value_max_ndim > 0
1901
            and target_slice_ndim > value_max_ndim
1902
        ):
1903
            # Broadcasting case: use row-by-row approach with reference memlets
1904
            self._handle_broadcast_slice_assignment(
×
1905
                target,
1906
                value,
1907
                target_name,
1908
                indices,
1909
                target_slice_ndim,
1910
                value_max_ndim,
1911
                debug_info,
1912
            )
1913
            return
×
1914

1915
        loop_vars = []
4✔
1916
        new_target_indices = []
4✔
1917

1918
        for i, idx in enumerate(indices):
4✔
1919
            if isinstance(idx, ast.Slice):
4✔
1920
                loop_var = self.builder.find_new_name(f"_slice_iter_{len(loop_vars)}_")
4✔
1921
                loop_vars.append(loop_var)
4✔
1922

1923
                if not self.builder.exists(loop_var):
4✔
1924
                    self.builder.add_container(
4✔
1925
                        loop_var, Scalar(PrimitiveType.Int64), False
1926
                    )
1927
                    self.symbol_table[loop_var] = Scalar(PrimitiveType.Int64)
4✔
1928

1929
                start_str = "0"
4✔
1930
                if idx.lower:
4✔
1931
                    start_str = self.visit(idx.lower)
4✔
1932
                    if start_str.startswith("-"):
4✔
1933
                        shapes = self.array_info[target_name].get("shapes", [])
×
1934
                        dim_size = (
×
1935
                            str(shapes[i])
1936
                            if i < len(shapes)
1937
                            else f"_{target_name}_shape_{i}"
1938
                        )
1939
                        start_str = f"({dim_size} {start_str})"
×
1940

1941
                stop_str = ""
4✔
1942
                if idx.upper and not (
4✔
1943
                    isinstance(idx.upper, ast.Constant) and idx.upper.value is None
1944
                ):
1945
                    stop_str = self.visit(idx.upper)
4✔
1946
                    if stop_str.startswith("-") or stop_str.startswith("(-"):
4✔
1947
                        shapes = self.array_info[target_name].get("shapes", [])
4✔
1948
                        dim_size = (
4✔
1949
                            str(shapes[i])
1950
                            if i < len(shapes)
1951
                            else f"_{target_name}_shape_{i}"
1952
                        )
1953
                        stop_str = f"({dim_size} {stop_str})"
4✔
1954
                else:
1955
                    shapes = self.array_info[target_name].get("shapes", [])
4✔
1956
                    stop_str = (
4✔
1957
                        str(shapes[i])
1958
                        if i < len(shapes)
1959
                        else f"_{target_name}_shape_{i}"
1960
                    )
1961

1962
                step_str = "1"
4✔
1963
                if idx.step:
4✔
NEW
1964
                    step_str = self.visit(idx.step)
×
1965

1966
                count_str = f"({stop_str} - {start_str})"
4✔
1967

1968
                self.builder.begin_for(loop_var, "0", count_str, "1", debug_info)
4✔
1969
                self.symbol_table[loop_var] = Scalar(PrimitiveType.Int64)
4✔
1970

1971
                new_target_indices.append(
4✔
1972
                    ast.Name(
1973
                        id=f"{start_str} + {loop_var} * {step_str}", ctx=ast.Load()
1974
                    )
1975
                )
1976
            else:
1977
                # Handle non-slice indices - need to normalize negative indices
1978
                shapes = self.array_info[target_name].get("shapes", [])
4✔
1979
                dim_size = shapes[i] if i < len(shapes) else f"_{target_name}_shape_{i}"
4✔
1980
                normalized_idx = normalize_negative_index(idx, dim_size)
4✔
1981
                new_target_indices.append(normalized_idx)
4✔
1982

1983
        rewriter = SliceRewriter(loop_vars, self.array_info, self)
4✔
1984
        new_value = rewriter.visit(copy.deepcopy(value))
4✔
1985

1986
        new_target = copy.deepcopy(target)
4✔
1987
        if len(new_target_indices) == 1:
4✔
1988
            new_target.slice = new_target_indices[0]
4✔
1989
        else:
1990
            new_target.slice = ast.Tuple(elts=new_target_indices, ctx=ast.Load())
4✔
1991

1992
        target_str = self.visit(new_target)
4✔
1993
        value_str = self.visit(new_value)
4✔
1994
        self.builder.add_assignment(target_str, value_str, debug_info)
4✔
1995

1996
        for _ in loop_vars:
4✔
1997
            self.builder.end_for()
4✔
1998

1999
    def _handle_ufunc_outer_slice_assignment(
4✔
2000
        self, target, value, target_name, indices, debug_info=None
2001
    ):
2002
        """Handle slice assignment where RHS contains a ufunc outer operation.
2003

2004
        Example: path[:] = np.minimum(path[:], np.add.outer(path[:, k], path[k, :]))
2005

2006
        The strategy is:
2007
        1. Evaluate the entire RHS expression, which will create a temporary array
2008
           containing the result of the ufunc outer (potentially wrapped in other ops)
2009
        2. Copy the temporary result to the target slice
2010

2011
        This avoids the loop transformation that would destroy slice shape info.
2012
        """
2013
        if debug_info is None:
4✔
2014
            from docc.sdfg import DebugInfo
×
2015

2016
            debug_info = DebugInfo()
×
2017

2018
        # Evaluate the full RHS expression
2019
        # This will:
2020
        # - Create temp arrays for ufunc outer results
2021
        # - Apply any wrapping operations (np.minimum, etc.)
2022
        # - Return the name of the final result array
2023
        result_name = self.visit(value)
4✔
2024

2025
        # Now we need to copy result to target slice
2026
        # Count slice dimensions to determine if we need loops
2027
        target_slice_ndim = sum(1 for idx in indices if isinstance(idx, ast.Slice))
4✔
2028

2029
        if target_slice_ndim == 0:
4✔
2030
            # No slices on target - just simple assignment
NEW
2031
            target_str = self.visit(target)
×
2032
            block = self.builder.add_block(debug_info)
×
NEW
2033
            t_src, src_sub = self._add_read(block, result_name, debug_info)
×
2034
            t_dst = self.builder.add_access(block, target_str, debug_info)
×
2035
            t_task = self.builder.add_tasklet(
×
2036
                block, TaskletCode.assign, ["_in"], ["_out"], debug_info
2037
            )
2038
            self.builder.add_memlet(
×
2039
                block, t_src, "void", t_task, "_in", src_sub, None, debug_info
2040
            )
2041
            self.builder.add_memlet(
×
2042
                block, t_task, "_out", t_dst, "void", "", None, debug_info
2043
            )
2044
            return
×
2045

2046
        # We have slices on the target - need to create loops for copying
2047
        # Get target array info
2048
        target_info = self.array_info.get(target_name, {})
4✔
2049
        target_shapes = target_info.get("shapes", [])
4✔
2050

2051
        loop_vars = []
4✔
2052
        new_target_indices = []
4✔
2053

2054
        for i, idx in enumerate(indices):
4✔
2055
            if isinstance(idx, ast.Slice):
4✔
2056
                loop_var = self.builder.find_new_name(f"_copy_iter_{len(loop_vars)}_")
4✔
2057
                loop_vars.append(loop_var)
4✔
2058

2059
                if not self.builder.exists(loop_var):
4✔
2060
                    self.builder.add_container(
4✔
2061
                        loop_var, Scalar(PrimitiveType.Int64), False
2062
                    )
2063
                    self.symbol_table[loop_var] = Scalar(PrimitiveType.Int64)
4✔
2064

2065
                start_str = "0"
4✔
2066
                if idx.lower:
4✔
NEW
2067
                    start_str = self.visit(idx.lower)
×
2068

2069
                stop_str = ""
4✔
2070
                if idx.upper and not (
4✔
2071
                    isinstance(idx.upper, ast.Constant) and idx.upper.value is None
2072
                ):
NEW
2073
                    stop_str = self.visit(idx.upper)
×
2074
                else:
2075
                    stop_str = (
4✔
2076
                        target_shapes[i]
2077
                        if i < len(target_shapes)
2078
                        else f"_{target_name}_shape_{i}"
2079
                    )
2080

2081
                step_str = "1"
4✔
2082
                if idx.step:
4✔
NEW
2083
                    step_str = self.visit(idx.step)
×
2084

2085
                count_str = f"({stop_str} - {start_str})"
4✔
2086

2087
                self.builder.begin_for(loop_var, "0", count_str, "1", debug_info)
4✔
2088
                self.symbol_table[loop_var] = Scalar(PrimitiveType.Int64)
4✔
2089

2090
                new_target_indices.append(
4✔
2091
                    ast.Name(
2092
                        id=f"{start_str} + {loop_var} * {step_str}", ctx=ast.Load()
2093
                    )
2094
                )
2095
            else:
2096
                # Handle non-slice indices - need to normalize negative indices
2097
                dim_size = (
×
2098
                    target_shapes[i]
2099
                    if i < len(target_shapes)
2100
                    else f"_{target_name}_shape_{i}"
2101
                )
2102
                normalized_idx = normalize_negative_index(idx, dim_size)
×
2103
                new_target_indices.append(normalized_idx)
×
2104

2105
        # Create assignment block: target[i,j,...] = result[i,j,...]
2106
        block = self.builder.add_block(debug_info)
4✔
2107

2108
        # Access nodes
2109
        t_src = self.builder.add_access(block, result_name, debug_info)
4✔
2110
        t_dst = self.builder.add_access(block, target_name, debug_info)
4✔
2111
        t_task = self.builder.add_tasklet(
4✔
2112
            block, TaskletCode.assign, ["_in"], ["_out"], debug_info
2113
        )
2114

2115
        # Source index - just use loop vars for flat array from ufunc outer
2116
        # The ufunc outer result is a flat array of size M*N
2117
        if len(loop_vars) == 2:
4✔
2118
            # 2D case: result is indexed as i * N + j
2119
            # Get the second dimension size from target shapes
2120
            n_dim = (
4✔
2121
                target_shapes[1]
2122
                if len(target_shapes) > 1
2123
                else f"_{target_name}_shape_1"
2124
            )
2125
            src_index = f"(({loop_vars[0]}) * ({n_dim}) + ({loop_vars[1]}))"
4✔
2126
        elif len(loop_vars) == 1:
×
2127
            src_index = loop_vars[0]
×
2128
        else:
2129
            # General case - compute linear index
2130
            src_terms = []
×
2131
            stride = "1"
×
2132
            for i in range(len(loop_vars) - 1, -1, -1):
×
2133
                if stride == "1":
×
2134
                    src_terms.insert(0, loop_vars[i])
×
2135
                else:
2136
                    src_terms.insert(0, f"({loop_vars[i]} * {stride})")
×
2137
                if i > 0:
×
2138
                    dim_size = (
×
2139
                        target_shapes[i]
2140
                        if i < len(target_shapes)
2141
                        else f"_{target_name}_shape_{i}"
2142
                    )
2143
                    stride = (
×
2144
                        f"({stride} * {dim_size})" if stride != "1" else str(dim_size)
2145
                    )
2146
            src_index = " + ".join(src_terms) if src_terms else "0"
×
2147

2148
        # Target index - compute linear index (row-major order)
2149
        # For 2D array with shape (M, N): linear_index = i * N + j
2150
        target_index_parts = []
4✔
2151
        for idx in new_target_indices:
4✔
2152
            if isinstance(idx, ast.Name):
4✔
2153
                target_index_parts.append(idx.id)
4✔
2154
            else:
NEW
2155
                target_index_parts.append(self.visit(idx))
×
2156

2157
        # Convert to linear index
2158
        if len(target_index_parts) == 2:
4✔
2159
            # 2D case
2160
            n_dim = (
4✔
2161
                target_shapes[1]
2162
                if len(target_shapes) > 1
2163
                else f"_{target_name}_shape_1"
2164
            )
2165
            target_index = (
4✔
2166
                f"(({target_index_parts[0]}) * ({n_dim}) + ({target_index_parts[1]}))"
2167
            )
2168
        elif len(target_index_parts) == 1:
×
2169
            target_index = target_index_parts[0]
×
2170
        else:
2171
            # General case - compute linear index with strides
2172
            stride = "1"
×
2173
            target_index = "0"
×
2174
            for i in range(len(target_index_parts) - 1, -1, -1):
×
2175
                idx_part = target_index_parts[i]
×
2176
                if stride == "1":
×
2177
                    term = idx_part
×
2178
                else:
2179
                    term = f"(({idx_part}) * ({stride}))"
×
2180

2181
                if target_index == "0":
×
2182
                    target_index = term
×
2183
                else:
2184
                    target_index = f"({term} + {target_index})"
×
2185

2186
                if i > 0:
×
2187
                    dim_size = (
×
2188
                        target_shapes[i]
2189
                        if i < len(target_shapes)
2190
                        else f"_{target_name}_shape_{i}"
2191
                    )
2192
                    stride = (
×
2193
                        f"({stride} * {dim_size})" if stride != "1" else str(dim_size)
2194
                    )
2195

2196
        # Connect memlets
2197
        self.builder.add_memlet(
4✔
2198
            block, t_src, "void", t_task, "_in", src_index, None, debug_info
2199
        )
2200
        self.builder.add_memlet(
4✔
2201
            block, t_task, "_out", t_dst, "void", target_index, None, debug_info
2202
        )
2203

2204
        # End loops
2205
        for _ in loop_vars:
4✔
2206
            self.builder.end_for()
4✔
2207

2208
    def _is_indirect_access(self, node):
4✔
2209
        """Check if a node represents an indirect array access (e.g., A[B[i]]).
2210

2211
        Returns True if the node is a subscript where the index itself is a subscript
2212
        into an array (indirect access pattern).
2213
        """
NEW
2214
        if not isinstance(node, ast.Subscript):
×
NEW
2215
            return False
×
NEW
2216
        if isinstance(node.value, ast.Name):
×
NEW
2217
            arr_name = node.value.id
×
NEW
2218
            if arr_name in self.array_info:
×
NEW
2219
                if isinstance(node.slice, ast.Subscript):
×
NEW
2220
                    if isinstance(node.slice.value, ast.Name):
×
NEW
2221
                        idx_arr_name = node.slice.value.id
×
NEW
2222
                        if idx_arr_name in self.array_info:
×
NEW
2223
                            return True
×
NEW
2224
        return False
×
2225

2226
    def _contains_indirect_access(self, node):
4✔
2227
        """Check if an AST node contains any indirect array access."""
2228
        if isinstance(node, ast.Subscript):
4✔
2229
            if isinstance(node.value, ast.Name):
4✔
2230
                arr_name = node.value.id
4✔
2231
                if arr_name in self.array_info:
4✔
2232
                    return True
4✔
2233
        elif isinstance(node, ast.BinOp):
4✔
2234
            return self._contains_indirect_access(
4✔
2235
                node.left
2236
            ) or self._contains_indirect_access(node.right)
2237
        elif isinstance(node, ast.UnaryOp):
4✔
2238
            return self._contains_indirect_access(node.operand)
4✔
2239
        return False
4✔
2240

2241
    def _materialize_indirect_access(
4✔
2242
        self, node, debug_info=None, return_original_expr=False
2243
    ):
2244
        """Materialize an array access into a scalar variable using tasklet+memlets."""
2245
        if not self.builder:
4✔
NEW
2246
            expr = self.visit(node)
×
NEW
2247
            return (expr, expr) if return_original_expr else expr
×
2248

2249
        if debug_info is None:
4✔
2250
            debug_info = DebugInfo()
4✔
2251

2252
        if not isinstance(node, ast.Subscript):
4✔
NEW
2253
            expr = self.visit(node)
×
NEW
2254
            return (expr, expr) if return_original_expr else expr
×
2255

2256
        if not isinstance(node.value, ast.Name):
4✔
NEW
2257
            expr = self.visit(node)
×
NEW
2258
            return (expr, expr) if return_original_expr else expr
×
2259

2260
        arr_name = node.value.id
4✔
2261
        if arr_name not in self.array_info:
4✔
NEW
2262
            expr = self.visit(node)
×
NEW
2263
            return (expr, expr) if return_original_expr else expr
×
2264

2265
        dtype = Scalar(PrimitiveType.Int64)
4✔
2266
        if arr_name in self.symbol_table:
4✔
2267
            t = self.symbol_table[arr_name]
4✔
2268
            if isinstance(t, Pointer) and t.has_pointee_type():
4✔
2269
                dtype = t.pointee_type
4✔
2270

2271
        tmp_name = self.builder.find_new_name("_idx_")
4✔
2272
        self.builder.add_container(tmp_name, dtype, False)
4✔
2273
        self.symbol_table[tmp_name] = dtype
4✔
2274

2275
        ndim = self.array_info[arr_name]["ndim"]
4✔
2276
        shapes = self.array_info[arr_name].get("shapes", [])
4✔
2277

2278
        if isinstance(node.slice, ast.Tuple):
4✔
NEW
2279
            indices = [self.visit(elt) for elt in node.slice.elts]
×
2280
        else:
2281
            indices = [self.visit(node.slice)]
4✔
2282

2283
        materialized_indices = []
4✔
2284
        for idx_str in indices:
4✔
2285
            if "(" in idx_str and idx_str.endswith(")"):
4✔
NEW
2286
                materialized_indices.append(idx_str)
×
2287
            else:
2288
                materialized_indices.append(idx_str)
4✔
2289

2290
        linear_index = self._compute_linear_index(
4✔
2291
            materialized_indices, shapes, arr_name, ndim
2292
        )
2293

2294
        block = self.builder.add_block(debug_info)
4✔
2295
        t_src = self.builder.add_access(block, arr_name, debug_info)
4✔
2296
        t_dst = self.builder.add_access(block, tmp_name, debug_info)
4✔
2297
        t_task = self.builder.add_tasklet(
4✔
2298
            block, TaskletCode.assign, ["_in"], ["_out"], debug_info
2299
        )
2300

2301
        self.builder.add_memlet(
4✔
2302
            block, t_src, "void", t_task, "_in", linear_index, None, debug_info
2303
        )
2304
        self.builder.add_memlet(
4✔
2305
            block, t_task, "_out", t_dst, "void", "", None, debug_info
2306
        )
2307

2308
        if return_original_expr:
4✔
2309
            original_expr = f"{arr_name}({linear_index})"
4✔
2310
            return (tmp_name, original_expr)
4✔
2311

NEW
2312
        return tmp_name
×
2313

2314
    def _get_unique_id(self):
4✔
2315
        self._unique_counter_ref[0] += 1
4✔
2316
        return self._unique_counter_ref[0]
4✔
2317

2318
    def _element_type(self, name):
4✔
2319
        if name in self.symbol_table:
4✔
2320
            return element_type_from_sdfg_type(self.symbol_table[name])
4✔
2321
        else:  # Constant
2322
            if self._is_int(name):
4✔
2323
                return Scalar(PrimitiveType.Int64)
4✔
2324
            else:
2325
                return Scalar(PrimitiveType.Double)
4✔
2326

2327
    def _is_int(self, operand):
4✔
2328
        try:
4✔
2329
            if operand.lstrip("-").isdigit():
4✔
2330
                return True
4✔
NEW
2331
        except ValueError:
×
NEW
2332
            pass
×
2333

2334
        name = operand
4✔
2335
        if "(" in operand and operand.endswith(")"):
4✔
2336
            name = operand.split("(")[0]
4✔
2337

2338
        if name in self.symbol_table:
4✔
2339
            t = self.symbol_table[name]
4✔
2340

2341
            def is_int_ptype(pt):
4✔
2342
                return pt in [
4✔
2343
                    PrimitiveType.Int64,
2344
                    PrimitiveType.Int32,
2345
                    PrimitiveType.Int8,
2346
                    PrimitiveType.Int16,
2347
                    PrimitiveType.UInt64,
2348
                    PrimitiveType.UInt32,
2349
                    PrimitiveType.UInt8,
2350
                    PrimitiveType.UInt16,
2351
                ]
2352

2353
            if isinstance(t, Scalar):
4✔
2354
                return is_int_ptype(t.primitive_type)
4✔
2355

2356
            if type(t).__name__ == "Array" and hasattr(t, "element_type"):
4✔
NEW
2357
                et = t.element_type
×
NEW
2358
                if callable(et):
×
NEW
2359
                    et = et()
×
NEW
2360
                if isinstance(et, Scalar):
×
NEW
2361
                    return is_int_ptype(et.primitive_type)
×
2362

2363
            if type(t).__name__ == "Pointer":
4✔
2364
                if hasattr(t, "pointee_type"):
4✔
2365
                    et = t.pointee_type
4✔
2366
                    if callable(et):
4✔
NEW
2367
                        et = et()
×
2368
                    if isinstance(et, Scalar):
4✔
2369
                        return is_int_ptype(et.primitive_type)
4✔
NEW
2370
                if hasattr(t, "element_type"):
×
NEW
2371
                    et = t.element_type
×
NEW
2372
                    if callable(et):
×
NEW
2373
                        et = et()
×
NEW
2374
                    if isinstance(et, Scalar):
×
NEW
2375
                        return is_int_ptype(et.primitive_type)
×
2376

2377
        return False
4✔
2378

2379
    def _add_read(self, block, expr_str, debug_info=None):
4✔
2380
        try:
4✔
2381
            if (block, expr_str) in self._access_cache:
4✔
NEW
2382
                return self._access_cache[(block, expr_str)]
×
NEW
2383
        except TypeError:
×
NEW
2384
            pass
×
2385

2386
        if debug_info is None:
4✔
2387
            debug_info = DebugInfo()
4✔
2388

2389
        if "(" in expr_str and expr_str.endswith(")"):
4✔
2390
            name = expr_str.split("(")[0]
4✔
2391
            subset = expr_str[expr_str.find("(") + 1 : -1]
4✔
2392
            access = self.builder.add_access(block, name, debug_info)
4✔
2393
            try:
4✔
2394
                self._access_cache[(block, expr_str)] = (access, subset)
4✔
NEW
2395
            except TypeError:
×
NEW
2396
                pass
×
2397
            return access, subset
4✔
2398

2399
        if self.builder.exists(expr_str):
4✔
2400
            access = self.builder.add_access(block, expr_str, debug_info)
4✔
2401
            subset = ""
4✔
2402
            if expr_str in self.symbol_table:
4✔
2403
                sym_type = self.symbol_table[expr_str]
4✔
2404
                if isinstance(sym_type, Pointer):
4✔
NEW
2405
                    if expr_str in self.array_info:
×
NEW
2406
                        ndim = self.array_info[expr_str].get("ndim", 0)
×
NEW
2407
                        if ndim == 0:
×
NEW
2408
                            subset = "0"
×
2409
                    else:
NEW
2410
                        subset = "0"
×
2411
            try:
4✔
2412
                self._access_cache[(block, expr_str)] = (access, subset)
4✔
NEW
2413
            except TypeError:
×
NEW
2414
                pass
×
2415
            return access, subset
4✔
2416

2417
        dtype = Scalar(PrimitiveType.Double)
4✔
2418
        if self._is_int(expr_str):
4✔
2419
            dtype = Scalar(PrimitiveType.Int64)
4✔
2420
        elif expr_str == "true" or expr_str == "false":
4✔
NEW
2421
            dtype = Scalar(PrimitiveType.Bool)
×
2422

2423
        const_node = self.builder.add_constant(block, expr_str, dtype, debug_info)
4✔
2424
        try:
4✔
2425
            self._access_cache[(block, expr_str)] = (const_node, "")
4✔
NEW
2426
        except TypeError:
×
NEW
2427
            pass
×
2428
        return const_node, ""
4✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc