• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

daisytuner / sdfglib / 21478974613

29 Jan 2026 12:55PM UTC coverage: 65.778% (-0.07%) from 65.843%
21478974613

push

github

web-flow
Merge pull request #485 from daisytuner/npbench-cavity-flow

Adds support for npbench's cavity_flow

59 of 130 new or added lines in 6 files covered. (45.38%)

1 existing line in 1 file now uncovered.

22446 of 34124 relevant lines covered (65.78%)

382.63 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

81.52
/python/docc/expression_visitor.py
1
import ast
4✔
2
import inspect
4✔
3
import textwrap
4✔
4
from ._sdfg import (
4✔
5
    Scalar,
6
    PrimitiveType,
7
    Pointer,
8
    Type,
9
    DebugInfo,
10
    Structure,
11
    TaskletCode,
12
    CMathFunction,
13
)
14

15

16
class ExpressionVisitor(ast.NodeVisitor):
4✔
17
    def __init__(
4✔
18
        self,
19
        array_info=None,
20
        builder=None,
21
        symbol_table=None,
22
        globals_dict=None,
23
        inliner=None,
24
        unique_counter_ref=None,
25
        structure_member_info=None,
26
    ):
27
        self.array_info = array_info if array_info is not None else {}
4✔
28
        self.builder = builder
4✔
29
        self.symbol_table = symbol_table if symbol_table is not None else {}
4✔
30
        self.globals_dict = globals_dict if globals_dict is not None else {}
4✔
31
        self.inliner = inliner
4✔
32
        self._unique_counter_ref = (
4✔
33
            unique_counter_ref if unique_counter_ref is not None else [0]
34
        )
35
        self._access_cache = {}
4✔
36
        self.la_handler = None
4✔
37
        self.structure_member_info = (
4✔
38
            structure_member_info if structure_member_info is not None else {}
39
        )
40
        self._init_numpy_handlers()
4✔
41

42
    def _get_unique_id(self):
4✔
43
        self._unique_counter_ref[0] += 1
4✔
44
        return self._unique_counter_ref[0]
4✔
45

46
    def _get_temp_name(self, prefix="_tmp_"):
4✔
47
        if hasattr(self.builder, "find_new_name"):
4✔
48
            return self.builder.find_new_name(prefix)
×
49
        return f"{prefix}{self._get_unique_id()}"
4✔
50

51
    def _is_indirect_access(self, node):
4✔
52
        """Check if a node represents an indirect array access (e.g., A[B[i]]).
53

54
        Returns True if the node is a subscript where the index itself is a subscript
55
        into an array (indirect access pattern).
56
        """
57
        if not isinstance(node, ast.Subscript):
×
58
            return False
×
59
        # Check if value is a subscripted array access
60
        if isinstance(node.value, ast.Name):
×
61
            arr_name = node.value.id
×
62
            if arr_name in self.array_info:
×
63
                # Check if slice/index is itself an array access
64
                if isinstance(node.slice, ast.Subscript):
×
65
                    if isinstance(node.slice.value, ast.Name):
×
66
                        idx_arr_name = node.slice.value.id
×
67
                        if idx_arr_name in self.array_info:
×
68
                            return True
×
69
        return False
×
70

71
    def _contains_indirect_access(self, node):
4✔
72
        """Check if an AST node contains any indirect array access.
73

74
        Used to detect expressions like A_row[i] that would be used as slice bounds.
75
        """
76
        if isinstance(node, ast.Subscript):
4✔
77
            if isinstance(node.value, ast.Name):
4✔
78
                arr_name = node.value.id
4✔
79
                if arr_name in self.array_info:
4✔
80
                    return True
4✔
81
        elif isinstance(node, ast.BinOp):
4✔
82
            return self._contains_indirect_access(
4✔
83
                node.left
84
            ) or self._contains_indirect_access(node.right)
85
        elif isinstance(node, ast.UnaryOp):
4✔
86
            return self._contains_indirect_access(node.operand)
4✔
87
        return False
4✔
88

89
    def _materialize_indirect_access(
4✔
90
        self, node, debug_info=None, return_original_expr=False
91
    ):
92
        """Materialize an array access into a scalar variable using tasklet+memlets.
93

94
        For indirect memory access patterns in SDFGs, we need to:
95
        1. Create a scalar container for the result
96
        2. Create a tasklet that performs the assignment
97
        3. Use memlets to read from the array and write to the scalar
98
        4. Return the scalar name (which can be used as a symbolic expression)
99

100
        This is the canonical SDFG pattern for indirect access.
101

102
        If return_original_expr is True, also returns the original array access
103
        expression using parentheses notation (e.g., "A_row(0)") which is consistent
104
        with SDFG subset notation. The runtime evaluator will convert this to
105
        bracket notation for Python evaluation.
106
        """
107
        if not self.builder:
4✔
108
            # Without builder, just return the expression string
109
            expr = self.visit(node)
×
110
            return (expr, expr) if return_original_expr else expr
×
111

112
        if debug_info is None:
4✔
113
            debug_info = DebugInfo()
4✔
114

115
        if not isinstance(node, ast.Subscript):
4✔
116
            expr = self.visit(node)
×
117
            return (expr, expr) if return_original_expr else expr
×
118

119
        if not isinstance(node.value, ast.Name):
4✔
120
            expr = self.visit(node)
×
121
            return (expr, expr) if return_original_expr else expr
×
122

123
        arr_name = node.value.id
4✔
124
        if arr_name not in self.array_info:
4✔
125
            expr = self.visit(node)
×
126
            return (expr, expr) if return_original_expr else expr
×
127

128
        # Determine the element type
129
        dtype = Scalar(PrimitiveType.Int64)  # Default for indices
4✔
130
        if arr_name in self.symbol_table:
4✔
131
            t = self.symbol_table[arr_name]
4✔
132
            if isinstance(t, Pointer) and t.has_pointee_type():
4✔
133
                dtype = t.pointee_type
4✔
134

135
        # Create scalar container for the result
136
        tmp_name = self._get_temp_name("_idx_")
4✔
137
        self.builder.add_container(tmp_name, dtype, False)
4✔
138
        self.symbol_table[tmp_name] = dtype
4✔
139

140
        # Get the index expression
141
        ndim = self.array_info[arr_name]["ndim"]
4✔
142
        shapes = self.array_info[arr_name].get("shapes", [])
4✔
143

144
        # Compute linear index from the subscript
145
        if isinstance(node.slice, ast.Tuple):
4✔
146
            indices = [self.visit(elt) for elt in node.slice.elts]
×
147
        else:
148
            indices = [self.visit(node.slice)]
4✔
149

150
        # Handle cases where we need recursive materialization
151
        materialized_indices = []
4✔
152
        for i, idx_str in enumerate(indices):
4✔
153
            # Check if the index itself needs materialization (nested indirect)
154
            # This happens when idx_str looks like an array access e.g., "arr(i)"
155
            if "(" in idx_str and idx_str.endswith(")"):
4✔
156
                # This is an array access, it should already be a valid symbolic expression
157
                # or a scalar variable name
158
                materialized_indices.append(idx_str)
×
159
            else:
160
                materialized_indices.append(idx_str)
4✔
161

162
        # Compute linear index
163
        linear_index = self._compute_linear_index(
4✔
164
            materialized_indices, shapes, arr_name, ndim
165
        )
166

167
        # Create block with tasklet and memlets
168
        block = self.builder.add_block(debug_info)
4✔
169
        t_src = self.builder.add_access(block, arr_name, debug_info)
4✔
170
        t_dst = self.builder.add_access(block, tmp_name, debug_info)
4✔
171
        t_task = self.builder.add_tasklet(
4✔
172
            block, TaskletCode.assign, ["_in"], ["_out"], debug_info
173
        )
174

175
        self.builder.add_memlet(
4✔
176
            block, t_src, "void", t_task, "_in", linear_index, None, debug_info
177
        )
178
        self.builder.add_memlet(
4✔
179
            block, t_task, "_out", t_dst, "void", "", None, debug_info
180
        )
181

182
        if return_original_expr:
4✔
183
            # Return both the materialized variable name and the original array access expression
184
            # Use parentheses notation which is consistent with SDFG subset syntax
185
            original_expr = f"{arr_name}({linear_index})"
4✔
186
            return (tmp_name, original_expr)
4✔
187

188
        return tmp_name
×
189

190
    def _init_numpy_handlers(self):
4✔
191
        self.numpy_handlers = {
4✔
192
            "empty": self._handle_numpy_alloc,
193
            "empty_like": self._handle_numpy_empty_like,
194
            "zeros": self._handle_numpy_alloc,
195
            "zeros_like": self._handle_numpy_zeros_like,
196
            "ones": self._handle_numpy_alloc,
197
            "ndarray": self._handle_numpy_alloc,  # np.ndarray() constructor
198
            "eye": self._handle_numpy_eye,
199
            "add": self._handle_numpy_binary_op,
200
            "subtract": self._handle_numpy_binary_op,
201
            "multiply": self._handle_numpy_binary_op,
202
            "divide": self._handle_numpy_binary_op,
203
            "power": self._handle_numpy_binary_op,
204
            "exp": self._handle_numpy_unary_op,
205
            "abs": self._handle_numpy_unary_op,
206
            "absolute": self._handle_numpy_unary_op,
207
            "sqrt": self._handle_numpy_unary_op,
208
            "tanh": self._handle_numpy_unary_op,
209
            "sum": self._handle_numpy_reduce,
210
            "max": self._handle_numpy_reduce,
211
            "min": self._handle_numpy_reduce,
212
            "mean": self._handle_numpy_reduce,
213
            "std": self._handle_numpy_reduce,
214
            "matmul": self._handle_numpy_matmul,
215
            "dot": self._handle_numpy_matmul,
216
            "matvec": self._handle_numpy_matmul,
217
            "outer": self._handle_numpy_outer,
218
            "minimum": self._handle_numpy_binary_op,
219
            "maximum": self._handle_numpy_binary_op,
220
            "where": self._handle_numpy_where,
221
        }
222

223
    def generic_visit(self, node):
4✔
224
        return super().generic_visit(node)
×
225

226
    def visit_Constant(self, node):
4✔
227
        if isinstance(node.value, bool):
4✔
228
            return "true" if node.value else "false"
×
229
        return str(node.value)
4✔
230

231
    def visit_Name(self, node):
4✔
232
        name = node.id
4✔
233
        # Check if it's a global constant (not a local variable/array)
234
        if name not in self.symbol_table and self.globals_dict is not None:
4✔
235
            if name in self.globals_dict:
4✔
236
                val = self.globals_dict[name]
4✔
237
                # Only substitute simple numeric constants
238
                if isinstance(val, (int, float)):
4✔
239
                    return str(val)
4✔
240
        return name
4✔
241

242
    def _map_numpy_dtype(self, dtype_node):
4✔
243
        # Default to double
244
        if dtype_node is None:
4✔
245
            return Scalar(PrimitiveType.Double)
×
246

247
        if isinstance(dtype_node, ast.Name):
4✔
248
            if dtype_node.id == "float":
4✔
249
                return Scalar(PrimitiveType.Double)
4✔
250
            if dtype_node.id == "int":
4✔
251
                return Scalar(PrimitiveType.Int64)
4✔
252
            if dtype_node.id == "bool":
×
253
                return Scalar(PrimitiveType.Bool)
×
254

255
        if isinstance(dtype_node, ast.Attribute):
4✔
256
            # Handle array.dtype
257
            if (
4✔
258
                isinstance(dtype_node.value, ast.Name)
259
                and dtype_node.value.id in self.symbol_table
260
                and dtype_node.attr == "dtype"
261
            ):
262
                sym_type = self.symbol_table[dtype_node.value.id]
4✔
263
                if isinstance(sym_type, Pointer) and sym_type.has_pointee_type():
4✔
264
                    return sym_type.pointee_type
4✔
265

266
            if isinstance(dtype_node.value, ast.Name) and dtype_node.value.id in [
4✔
267
                "numpy",
268
                "np",
269
            ]:
270
                if dtype_node.attr == "float64":
4✔
271
                    return Scalar(PrimitiveType.Double)
4✔
272
                if dtype_node.attr == "float32":
4✔
273
                    return Scalar(PrimitiveType.Float)
4✔
274
                if dtype_node.attr == "int64":
4✔
275
                    return Scalar(PrimitiveType.Int64)
4✔
276
                if dtype_node.attr == "int32":
4✔
277
                    return Scalar(PrimitiveType.Int32)
4✔
278
                if dtype_node.attr == "bool_":
×
279
                    return Scalar(PrimitiveType.Bool)
×
280

281
        # Fallback
282
        return Scalar(PrimitiveType.Double)
×
283

284
    def _is_int(self, operand):
4✔
285
        try:
4✔
286
            if operand.lstrip("-").isdigit():
4✔
287
                return True
4✔
288
        except ValueError:
×
289
            pass
×
290

291
        name = operand
4✔
292
        if "(" in operand and operand.endswith(")"):
4✔
293
            name = operand.split("(")[0]
4✔
294

295
        if name in self.symbol_table:
4✔
296
            t = self.symbol_table[name]
4✔
297

298
            def is_int_ptype(pt):
4✔
299
                return pt in [
4✔
300
                    PrimitiveType.Int64,
301
                    PrimitiveType.Int32,
302
                    PrimitiveType.Int8,
303
                    PrimitiveType.Int16,
304
                    PrimitiveType.UInt64,
305
                    PrimitiveType.UInt32,
306
                    PrimitiveType.UInt8,
307
                    PrimitiveType.UInt16,
308
                ]
309

310
            if isinstance(t, Scalar):
4✔
311
                return is_int_ptype(t.primitive_type)
4✔
312

313
            if type(t).__name__ == "Array" and hasattr(t, "element_type"):
4✔
314
                et = t.element_type
×
315
                if callable(et):
×
316
                    et = et()
×
317
                if isinstance(et, Scalar):
×
318
                    return is_int_ptype(et.primitive_type)
×
319

320
            if type(t).__name__ == "Pointer":
4✔
321
                if hasattr(t, "pointee_type"):
4✔
322
                    et = t.pointee_type
4✔
323
                    if callable(et):
4✔
324
                        et = et()
×
325
                    if isinstance(et, Scalar):
4✔
326
                        return is_int_ptype(et.primitive_type)
4✔
327
                # Fallback: check if it has element_type (maybe alias?)
328
                if hasattr(t, "element_type"):
×
329
                    et = t.element_type
×
330
                    if callable(et):
×
331
                        et = et()
×
332
                    if isinstance(et, Scalar):
×
333
                        return is_int_ptype(et.primitive_type)
×
334

335
        return False
4✔
336

337
    def _add_read(self, block, expr_str, debug_info=None):
4✔
338
        # Try to reuse access node
339
        try:
4✔
340
            if (block, expr_str) in self._access_cache:
4✔
341
                return self._access_cache[(block, expr_str)]
4✔
342
        except TypeError:
×
343
            # block might not be hashable
344
            pass
×
345

346
        if debug_info is None:
4✔
347
            debug_info = DebugInfo()
4✔
348

349
        if "(" in expr_str and expr_str.endswith(")"):
4✔
350
            name = expr_str.split("(")[0]
4✔
351
            subset = expr_str[expr_str.find("(") + 1 : -1]
4✔
352
            access = self.builder.add_access(block, name, debug_info)
4✔
353
            try:
4✔
354
                self._access_cache[(block, expr_str)] = (access, subset)
4✔
355
            except TypeError:
×
356
                pass
×
357
            return access, subset
4✔
358

359
        if self.builder.exists(expr_str):
4✔
360
            access = self.builder.add_access(block, expr_str, debug_info)
4✔
361
            # For pointer types representing 0-D arrays, dereference with "0"
362
            subset = ""
4✔
363
            if expr_str in self.symbol_table:
4✔
364
                sym_type = self.symbol_table[expr_str]
4✔
365
                if isinstance(sym_type, Pointer):
4✔
366
                    # Check if it's a 0-D array (scalar wrapped in pointer)
367
                    if expr_str in self.array_info:
×
368
                        ndim = self.array_info[expr_str].get("ndim", 0)
×
369
                        if ndim == 0:
×
370
                            subset = "0"
×
371
                    else:
372
                        # Pointer without array_info is treated as 0-D
373
                        subset = "0"
×
374
            try:
4✔
375
                self._access_cache[(block, expr_str)] = (access, subset)
4✔
376
            except TypeError:
×
377
                pass
×
378
            return access, subset
4✔
379

380
        dtype = Scalar(PrimitiveType.Double)
4✔
381
        if self._is_int(expr_str):
4✔
382
            dtype = Scalar(PrimitiveType.Int64)
4✔
383
        elif expr_str == "true" or expr_str == "false":
4✔
384
            dtype = Scalar(PrimitiveType.Bool)
×
385

386
        const_node = self.builder.add_constant(block, expr_str, dtype, debug_info)
4✔
387
        try:
4✔
388
            self._access_cache[(block, expr_str)] = (const_node, "")
4✔
389
        except TypeError:
×
390
            pass
×
391
        return const_node, ""
4✔
392

393
    def _handle_min_max(self, node, func_name):
4✔
394
        args = [self.visit(arg) for arg in node.args]
4✔
395
        if len(args) != 2:
4✔
396
            raise NotImplementedError(f"{func_name} only supported with 2 arguments")
×
397

398
        # Check types
399
        is_float = False
4✔
400
        arg_types = []
4✔
401

402
        for arg in args:
4✔
403
            name = arg
4✔
404
            if "(" in arg and arg.endswith(")"):
4✔
405
                name = arg.split("(")[0]
×
406

407
            if name in self.symbol_table:
4✔
408
                t = self.symbol_table[name]
4✔
409
                if isinstance(t, Pointer):
4✔
410
                    t = t.base_type
×
411

412
                if t.primitive_type == PrimitiveType.Double:
4✔
413
                    is_float = True
4✔
414
                    arg_types.append(PrimitiveType.Double)
4✔
415
                else:
416
                    arg_types.append(PrimitiveType.Int64)
4✔
417
            elif self._is_int(arg):
×
418
                arg_types.append(PrimitiveType.Int64)
×
419
            else:
420
                # Assume float constant
421
                is_float = True
×
422
                arg_types.append(PrimitiveType.Double)
×
423

424
        dtype = Scalar(PrimitiveType.Double if is_float else PrimitiveType.Int64)
4✔
425

426
        tmp_name = self._get_temp_name("_tmp_")
4✔
427
        self.builder.add_container(tmp_name, dtype, False)
4✔
428
        self.symbol_table[tmp_name] = dtype
4✔
429

430
        if is_float:
4✔
431
            # Cast args if necessary
432
            casted_args = []
4✔
433
            for i, arg in enumerate(args):
4✔
434
                if arg_types[i] != PrimitiveType.Double:
4✔
435
                    # Create temp double
436
                    tmp_cast = self._get_temp_name("_cast_")
4✔
437
                    self.builder.add_container(
4✔
438
                        tmp_cast, Scalar(PrimitiveType.Double), False
439
                    )
440
                    self.symbol_table[tmp_cast] = Scalar(PrimitiveType.Double)
4✔
441

442
                    # Assign int to double (implicit cast)
443
                    self.builder.add_assignment(tmp_cast, arg)
4✔
444
                    casted_args.append(tmp_cast)
4✔
445
                else:
446
                    casted_args.append(arg)
4✔
447

448
            block = self.builder.add_block()
4✔
449
            t_out = self.builder.add_access(block, tmp_name)
4✔
450

451
            intrinsic_name = (
4✔
452
                CMathFunction.fmax if func_name == "max" else CMathFunction.fmin
453
            )
454
            t_task = self.builder.add_cmath(block, intrinsic_name)
4✔
455

456
            for i, arg in enumerate(casted_args):
4✔
457
                t_arg, arg_sub = self._add_read(block, arg)
4✔
458
                self.builder.add_memlet(
4✔
459
                    block, t_arg, "void", t_task, f"_in{i+1}", arg_sub
460
                )
461
        else:
462
            block = self.builder.add_block()
4✔
463
            t_out = self.builder.add_access(block, tmp_name)
4✔
464

465
            # Use int_smax/int_smin tasklet
466
            opcode = None
4✔
467
            if func_name == "max":
4✔
468
                opcode = TaskletCode.int_smax
4✔
469
            else:
470
                opcode = TaskletCode.int_smin
4✔
471
            t_task = self.builder.add_tasklet(block, opcode, ["_in1", "_in2"], ["_out"])
4✔
472

473
            for i, arg in enumerate(args):
4✔
474
                t_arg, arg_sub = self._add_read(block, arg)
4✔
475
                self.builder.add_memlet(
4✔
476
                    block, t_arg, "void", t_task, f"_in{i+1}", arg_sub
477
                )
478

479
        self.builder.add_memlet(block, t_task, "_out", t_out, "void", "")
4✔
480
        return tmp_name
4✔
481

482
    def _handle_python_cast(self, node, func_name):
4✔
483
        """Handle Python type casts: int(), float(), bool()"""
484
        if len(node.args) != 1:
4✔
485
            raise NotImplementedError(f"{func_name}() cast requires exactly 1 argument")
×
486

487
        arg = self.visit(node.args[0])
4✔
488

489
        # Determine target type based on cast function
490
        if func_name == "int":
4✔
491
            target_dtype = Scalar(PrimitiveType.Int64)
4✔
492
        elif func_name == "float":
4✔
493
            target_dtype = Scalar(PrimitiveType.Double)
4✔
494
        elif func_name == "bool":
4✔
495
            target_dtype = Scalar(PrimitiveType.Bool)
4✔
496
        else:
497
            raise NotImplementedError(f"Cast to {func_name} not supported")
×
498

499
        # Determine source type
500
        source_dtype = None
4✔
501
        name = arg
4✔
502
        if "(" in arg and arg.endswith(")"):
4✔
503
            name = arg.split("(")[0]
×
504

505
        if name in self.symbol_table:
4✔
506
            source_dtype = self.symbol_table[name]
4✔
507
            if isinstance(source_dtype, Pointer):
4✔
508
                source_dtype = source_dtype.base_type
×
509
        elif self._is_int(arg):
×
510
            source_dtype = Scalar(PrimitiveType.Int64)
×
511
        elif arg == "true" or arg == "false":
×
512
            source_dtype = Scalar(PrimitiveType.Bool)
×
513
        else:
514
            # Assume float constant
515
            source_dtype = Scalar(PrimitiveType.Double)
×
516

517
        # Create temporary variable for result
518
        tmp_name = self._get_temp_name("_tmp_")
4✔
519
        self.builder.add_container(tmp_name, target_dtype, False)
4✔
520
        self.symbol_table[tmp_name] = target_dtype
4✔
521

522
        # Use tasklet assign opcode for casting (as specified in problem statement)
523
        block = self.builder.add_block()
4✔
524
        t_src, src_sub = self._add_read(block, arg)
4✔
525
        t_dst = self.builder.add_access(block, tmp_name)
4✔
526
        t_task = self.builder.add_tasklet(block, TaskletCode.assign, ["_in"], ["_out"])
4✔
527
        self.builder.add_memlet(block, t_src, "void", t_task, "_in", src_sub)
4✔
528
        self.builder.add_memlet(block, t_task, "_out", t_dst, "void", "")
4✔
529

530
        return tmp_name
4✔
531

532
    def visit_Call(self, node):
4✔
533
        func_name = ""
4✔
534
        module_name = ""
4✔
535
        if isinstance(node.func, ast.Attribute):
4✔
536
            if isinstance(node.func.value, ast.Name):
4✔
537
                if node.func.value.id == "math":
4✔
538
                    module_name = "math"
4✔
539
                    func_name = node.func.attr
4✔
540
                elif node.func.value.id in ["numpy", "np"]:
4✔
541
                    module_name = "numpy"
4✔
542
                    func_name = node.func.attr
4✔
543
                else:
544
                    # Check if it's a method call on an array (e.g., arr.astype(...), arr.copy())
545
                    array_name = node.func.value.id
4✔
546
                    method_name = node.func.attr
4✔
547
                    if array_name in self.array_info and method_name == "astype":
4✔
548
                        return self._handle_numpy_astype(node, array_name)
4✔
549
                    elif array_name in self.array_info and method_name == "copy":
4✔
550
                        return self._handle_numpy_copy(node, array_name)
4✔
551
            elif isinstance(node.func.value, ast.Attribute):
4✔
552
                if (
4✔
553
                    isinstance(node.func.value.value, ast.Name)
554
                    and node.func.value.value.id == "scipy"
555
                    and node.func.value.attr == "special"
556
                ):
557
                    if node.func.attr == "softmax":
4✔
558
                        return self._handle_scipy_softmax(node, "softmax")
4✔
559
                # Handle np.add.outer, np.subtract.outer, np.multiply.outer, etc.
560
                elif (
4✔
561
                    isinstance(node.func.value.value, ast.Name)
562
                    and node.func.value.value.id in ["numpy", "np"]
563
                    and node.func.attr == "outer"
564
                ):
565
                    ufunc_name = node.func.value.attr  # "add", "subtract", etc.
4✔
566
                    return self._handle_ufunc_outer(node, ufunc_name)
4✔
567

568
        elif isinstance(node.func, ast.Name):
4✔
569
            func_name = node.func.id
4✔
570

571
        if module_name == "numpy":
4✔
572
            if func_name in self.numpy_handlers:
4✔
573
                return self.numpy_handlers[func_name](node, func_name)
4✔
574

575
        if func_name in ["max", "min"]:
4✔
576
            return self._handle_min_max(node, func_name)
4✔
577

578
        # Handle Python type casts (int, float, bool)
579
        if func_name in ["int", "float", "bool"]:
4✔
580
            return self._handle_python_cast(node, func_name)
4✔
581

582
        math_funcs = {
4✔
583
            # Trigonometric functions
584
            "sin": CMathFunction.sin,
585
            "cos": CMathFunction.cos,
586
            "tan": CMathFunction.tan,
587
            "asin": CMathFunction.asin,
588
            "acos": CMathFunction.acos,
589
            "atan": CMathFunction.atan,
590
            "atan2": CMathFunction.atan2,
591
            # Hyperbolic functions
592
            "sinh": CMathFunction.sinh,
593
            "cosh": CMathFunction.cosh,
594
            "tanh": CMathFunction.tanh,
595
            "asinh": CMathFunction.asinh,
596
            "acosh": CMathFunction.acosh,
597
            "atanh": CMathFunction.atanh,
598
            # Exponential and logarithmic functions
599
            "exp": CMathFunction.exp,
600
            "exp2": CMathFunction.exp2,
601
            "expm1": CMathFunction.expm1,
602
            "log": CMathFunction.log,
603
            "log2": CMathFunction.log2,
604
            "log10": CMathFunction.log10,
605
            "log1p": CMathFunction.log1p,
606
            # Power functions
607
            "pow": CMathFunction.pow,
608
            "sqrt": CMathFunction.sqrt,
609
            "cbrt": CMathFunction.cbrt,
610
            "hypot": CMathFunction.hypot,
611
            # Rounding and remainder functions
612
            "abs": CMathFunction.fabs,
613
            "fabs": CMathFunction.fabs,
614
            "ceil": CMathFunction.ceil,
615
            "floor": CMathFunction.floor,
616
            "trunc": CMathFunction.trunc,
617
            "fmod": CMathFunction.fmod,
618
            "remainder": CMathFunction.remainder,
619
            # Floating-point manipulation functions
620
            "copysign": CMathFunction.copysign,
621
            # Other functions
622
            "fma": CMathFunction.fma,
623
        }
624

625
        if func_name in math_funcs:
4✔
626
            args = [self.visit(arg) for arg in node.args]
4✔
627

628
            tmp_name = self._get_temp_name("_tmp_")
4✔
629
            dtype = Scalar(PrimitiveType.Double)
4✔
630
            self.builder.add_container(tmp_name, dtype, False)
4✔
631
            self.symbol_table[tmp_name] = dtype
4✔
632

633
            block = self.builder.add_block()
4✔
634
            t_out = self.builder.add_access(block, tmp_name)
4✔
635

636
            t_task = self.builder.add_cmath(block, math_funcs[func_name])
4✔
637

638
            for i, arg in enumerate(args):
4✔
639
                t_arg, arg_sub = self._add_read(block, arg)
4✔
640
                self.builder.add_memlet(
4✔
641
                    block, t_arg, "void", t_task, f"_in{i+1}", arg_sub
642
                )
643

644
            self.builder.add_memlet(block, t_task, "_out", t_out, "void", "")
4✔
645
            return tmp_name
4✔
646

647
        if func_name in self.globals_dict:
4✔
648
            obj = self.globals_dict[func_name]
4✔
649
            if inspect.isfunction(obj):
4✔
650
                return self._handle_inline_call(node, obj)
4✔
651

652
        raise NotImplementedError(f"Function call {func_name} not supported")
×
653

654
    def _handle_inline_call(self, node, func_obj):
4✔
655
        # 1. Parse function source
656
        try:
4✔
657
            source_lines, start_line = inspect.getsourcelines(func_obj)
4✔
658
            source = textwrap.dedent("".join(source_lines))
4✔
659
            tree = ast.parse(source)
4✔
660
            func_def = tree.body[0]
4✔
661
        except Exception as e:
×
662
            raise NotImplementedError(
×
663
                f"Could not parse function {func_obj.__name__}: {e}"
664
            )
665

666
        # 2. Evaluate arguments
667
        arg_vars = [self.visit(arg) for arg in node.args]
4✔
668

669
        if len(arg_vars) != len(func_def.args.args):
4✔
670
            raise NotImplementedError(
×
671
                f"Argument count mismatch for {func_obj.__name__}"
672
            )
673

674
        # 3. Generate unique suffix
675
        suffix = f"_{func_obj.__name__}_{self._get_unique_id()}"
4✔
676
        res_name = f"_res{suffix}"
4✔
677

678
        # Assume Int64 for now as match returns 0/1
679
        dtype = Scalar(PrimitiveType.Int64)
4✔
680
        self.builder.add_container(res_name, dtype, False)
4✔
681
        self.symbol_table[res_name] = dtype
4✔
682

683
        # 4. Rename variables
684
        class VariableRenamer(ast.NodeTransformer):
4✔
685
            # Builtins that should not be renamed
686
            BUILTINS = {
4✔
687
                "range",
688
                "len",
689
                "int",
690
                "float",
691
                "bool",
692
                "str",
693
                "list",
694
                "dict",
695
                "tuple",
696
                "set",
697
                "print",
698
                "abs",
699
                "min",
700
                "max",
701
                "sum",
702
                "enumerate",
703
                "zip",
704
                "map",
705
                "filter",
706
                "sorted",
707
                "reversed",
708
                "True",
709
                "False",
710
                "None",
711
            }
712

713
            def __init__(self, suffix, globals_dict):
4✔
714
                self.suffix = suffix
4✔
715
                self.globals_dict = globals_dict
4✔
716

717
            def visit_Name(self, node):
4✔
718
                # Don't rename builtins or globals
719
                if node.id in self.globals_dict or node.id in self.BUILTINS:
4✔
720
                    return node
4✔
721
                return ast.Name(id=f"{node.id}{self.suffix}", ctx=node.ctx)
4✔
722

723
            def visit_Return(self, node):
4✔
724
                if node.value:
4✔
725
                    val = self.visit(node.value)
4✔
726
                    return ast.Assign(
4✔
727
                        targets=[ast.Name(id=res_name, ctx=ast.Store())],
728
                        value=val,
729
                    )
730
                return node
×
731

732
        renamer = VariableRenamer(suffix, self.globals_dict)
4✔
733
        new_body = [renamer.visit(stmt) for stmt in func_def.body]
4✔
734

735
        # 5. Assign arguments to parameters
736
        param_assignments = []
4✔
737
        for arg_def, arg_val in zip(func_def.args.args, arg_vars):
4✔
738
            param_name = f"{arg_def.arg}{suffix}"
4✔
739

740
            # Infer type and create container
741
            if arg_val in self.symbol_table:
4✔
742
                self.symbol_table[param_name] = self.symbol_table[arg_val]
4✔
743
                self.builder.add_container(
4✔
744
                    param_name, self.symbol_table[arg_val], False
745
                )
746
                val_node = ast.Name(id=arg_val, ctx=ast.Load())
4✔
747
            elif self._is_int(arg_val):
×
748
                self.symbol_table[param_name] = Scalar(PrimitiveType.Int64)
×
749
                self.builder.add_container(
×
750
                    param_name, Scalar(PrimitiveType.Int64), False
751
                )
752
                val_node = ast.Constant(value=int(arg_val))
×
753
            else:
754
                # Assume float constant
755
                try:
×
756
                    val = float(arg_val)
×
757
                    self.symbol_table[param_name] = Scalar(PrimitiveType.Double)
×
758
                    self.builder.add_container(
×
759
                        param_name, Scalar(PrimitiveType.Double), False
760
                    )
761
                    val_node = ast.Constant(value=val)
×
762
                except ValueError:
×
763
                    # Fallback to Name, might fail later if not in symbol table
764
                    val_node = ast.Name(id=arg_val, ctx=ast.Load())
×
765

766
            assign = ast.Assign(
4✔
767
                targets=[ast.Name(id=param_name, ctx=ast.Store())], value=val_node
768
            )
769
            param_assignments.append(assign)
4✔
770

771
        final_body = param_assignments + new_body
4✔
772

773
        # 6. Visit new body using ASTParser
774
        from .ast_parser import ASTParser
4✔
775

776
        parser = ASTParser(
4✔
777
            self.builder,
778
            self.array_info,
779
            self.symbol_table,
780
            globals_dict=self.globals_dict,
781
            unique_counter_ref=self._unique_counter_ref,
782
        )
783

784
        for stmt in final_body:
4✔
785
            parser.visit(stmt)
4✔
786

787
        return res_name
4✔
788

789
    def visit_BinOp(self, node):
4✔
790
        if isinstance(node.op, ast.MatMult):
4✔
791
            return self._handle_numpy_matmul_op(node.left, node.right)
4✔
792

793
        left = self.visit(node.left)
4✔
794
        op = self.visit(node.op)
4✔
795
        right = self.visit(node.right)
4✔
796

797
        # Check if left or right are arrays
798
        left_is_array = left in self.array_info
4✔
799
        right_is_array = right in self.array_info
4✔
800

801
        if left_is_array or right_is_array:
4✔
802
            op_map = {"+": "add", "-": "sub", "*": "mul", "/": "div", "**": "pow"}
4✔
803
            if op in op_map:
4✔
804
                return self._handle_array_binary_op(op_map[op], left, right)
4✔
805
            else:
806
                raise NotImplementedError(f"Array operation {op} not supported")
×
807

808
        tmp_name = f"_tmp_{self._get_unique_id()}"
4✔
809

810
        dtype = Scalar(PrimitiveType.Double)  # Default
4✔
811

812
        left_is_int = self._is_int(left)
4✔
813
        right_is_int = self._is_int(right)
4✔
814

815
        if left_is_int and right_is_int and op not in ["/", "**"]:
4✔
816
            dtype = Scalar(PrimitiveType.Int64)
4✔
817

818
        self.builder.add_container(tmp_name, dtype, False)
4✔
819
        self.symbol_table[tmp_name] = dtype
4✔
820

821
        real_left = left
4✔
822
        real_right = right
4✔
823

824
        if dtype.primitive_type == PrimitiveType.Double:
4✔
825
            if left_is_int:
4✔
826
                left_cast = f"_tmp_{self._get_unique_id()}"
4✔
827
                self.builder.add_container(
4✔
828
                    left_cast, Scalar(PrimitiveType.Double), False
829
                )
830
                self.symbol_table[left_cast] = Scalar(PrimitiveType.Double)
4✔
831

832
                c_block = self.builder.add_block()
4✔
833
                t_src, src_sub = self._add_read(c_block, left)
4✔
834
                t_dst = self.builder.add_access(c_block, left_cast)
4✔
835
                t_task = self.builder.add_tasklet(
4✔
836
                    c_block, TaskletCode.assign, ["_in"], ["_out"]
837
                )
838
                self.builder.add_memlet(c_block, t_src, "void", t_task, "_in", src_sub)
4✔
839
                self.builder.add_memlet(c_block, t_task, "_out", t_dst, "void", "")
4✔
840

841
                real_left = left_cast
4✔
842

843
            if right_is_int:
4✔
844
                right_cast = f"_tmp_{self._get_unique_id()}"
4✔
845
                self.builder.add_container(
4✔
846
                    right_cast, Scalar(PrimitiveType.Double), False
847
                )
848
                self.symbol_table[right_cast] = Scalar(PrimitiveType.Double)
4✔
849

850
                c_block = self.builder.add_block()
4✔
851
                t_src, src_sub = self._add_read(c_block, right)
4✔
852
                t_dst = self.builder.add_access(c_block, right_cast)
4✔
853
                t_task = self.builder.add_tasklet(
4✔
854
                    c_block, TaskletCode.assign, ["_in"], ["_out"]
855
                )
856
                self.builder.add_memlet(c_block, t_src, "void", t_task, "_in", src_sub)
4✔
857
                self.builder.add_memlet(c_block, t_task, "_out", t_dst, "void", "")
4✔
858

859
                real_right = right_cast
4✔
860

861
        # Special cases
862
        if op == "**":
4✔
863
            block = self.builder.add_block()
4✔
864
            t_left, left_sub = self._add_read(block, real_left)
4✔
865
            t_right, right_sub = self._add_read(block, real_right)
4✔
866
            t_out = self.builder.add_access(block, tmp_name)
4✔
867

868
            t_task = self.builder.add_cmath(block, CMathFunction.pow)
4✔
869
            self.builder.add_memlet(block, t_left, "void", t_task, "_in1", left_sub)
4✔
870
            self.builder.add_memlet(block, t_right, "void", t_task, "_in2", right_sub)
4✔
871
            self.builder.add_memlet(block, t_task, "_out", t_out, "void", "")
4✔
872

873
            return tmp_name
4✔
874
        elif op == "%":
4✔
875
            block = self.builder.add_block()
4✔
876
            t_left, left_sub = self._add_read(block, real_left)
4✔
877
            t_right, right_sub = self._add_read(block, real_right)
4✔
878
            t_out = self.builder.add_access(block, tmp_name)
4✔
879

880
            if dtype.primitive_type == PrimitiveType.Int64:
4✔
881
                # Implement ((a % b) + b) % b to match Python's modulo behavior
882

883
                # 1. rem1 = a % b
884
                t_rem1 = self.builder.add_tasklet(
4✔
885
                    block, TaskletCode.int_srem, ["_in1", "_in2"], ["_out"]
886
                )
887
                self.builder.add_memlet(block, t_left, "void", t_rem1, "_in1", left_sub)
4✔
888
                self.builder.add_memlet(
4✔
889
                    block, t_right, "void", t_rem1, "_in2", right_sub
890
                )
891

892
                rem1_name = f"_tmp_{self._get_unique_id()}"
4✔
893
                self.builder.add_container(rem1_name, dtype, False)
4✔
894
                t_rem1_out = self.builder.add_access(block, rem1_name)
4✔
895
                self.builder.add_memlet(block, t_rem1, "_out", t_rem1_out, "void", "")
4✔
896

897
                # 2. add = rem1 + b
898
                t_add = self.builder.add_tasklet(
4✔
899
                    block, TaskletCode.int_add, ["_in1", "_in2"], ["_out"]
900
                )
901
                self.builder.add_memlet(block, t_rem1_out, "void", t_add, "_in1", "")
4✔
902
                self.builder.add_memlet(
4✔
903
                    block, t_right, "void", t_add, "_in2", right_sub
904
                )
905

906
                add_name = f"_tmp_{self._get_unique_id()}"
4✔
907
                self.builder.add_container(add_name, dtype, False)
4✔
908
                t_add_out = self.builder.add_access(block, add_name)
4✔
909
                self.builder.add_memlet(block, t_add, "_out", t_add_out, "void", "")
4✔
910

911
                # 3. res = add % b
912
                t_rem2 = self.builder.add_tasklet(
4✔
913
                    block, TaskletCode.int_srem, ["_in1", "_in2"], ["_out"]
914
                )
915
                self.builder.add_memlet(block, t_add_out, "void", t_rem2, "_in1", "")
4✔
916
                self.builder.add_memlet(
4✔
917
                    block, t_right, "void", t_rem2, "_in2", right_sub
918
                )
919
                self.builder.add_memlet(block, t_rem2, "_out", t_out, "void", "")
4✔
920

921
                return tmp_name
4✔
922
            else:
923
                # Python's floored modulo: a % b = a - floor(a / b) * b
924
                # This differs from fmod which uses trunc instead of floor
925
                # Implement as: fmod(fmod(a, b) + b, b) to handle negative values
926

927
                # 1. rem1 = fmod(a, b)
928
                t_rem1 = self.builder.add_tasklet(
4✔
929
                    block, TaskletCode.fp_rem, ["_in1", "_in2"], ["_out"]
930
                )
931
                self.builder.add_memlet(block, t_left, "void", t_rem1, "_in1", left_sub)
4✔
932
                self.builder.add_memlet(
4✔
933
                    block, t_right, "void", t_rem1, "_in2", right_sub
934
                )
935

936
                rem1_name = f"_tmp_{self._get_unique_id()}"
4✔
937
                self.builder.add_container(rem1_name, dtype, False)
4✔
938
                t_rem1_out = self.builder.add_access(block, rem1_name)
4✔
939
                self.builder.add_memlet(block, t_rem1, "_out", t_rem1_out, "void", "")
4✔
940

941
                # 2. add = rem1 + b
942
                t_add = self.builder.add_tasklet(
4✔
943
                    block, TaskletCode.fp_add, ["_in1", "_in2"], ["_out"]
944
                )
945
                self.builder.add_memlet(block, t_rem1_out, "void", t_add, "_in1", "")
4✔
946
                self.builder.add_memlet(
4✔
947
                    block, t_right, "void", t_add, "_in2", right_sub
948
                )
949

950
                add_name = f"_tmp_{self._get_unique_id()}"
4✔
951
                self.builder.add_container(add_name, dtype, False)
4✔
952
                t_add_out = self.builder.add_access(block, add_name)
4✔
953
                self.builder.add_memlet(block, t_add, "_out", t_add_out, "void", "")
4✔
954

955
                # 3. res = fmod(add, b)
956
                t_rem2 = self.builder.add_tasklet(
4✔
957
                    block, TaskletCode.fp_rem, ["_in1", "_in2"], ["_out"]
958
                )
959
                self.builder.add_memlet(block, t_add_out, "void", t_rem2, "_in1", "")
4✔
960
                self.builder.add_memlet(
4✔
961
                    block, t_right, "void", t_rem2, "_in2", right_sub
962
                )
963
                self.builder.add_memlet(block, t_rem2, "_out", t_out, "void", "")
4✔
964

965
                return tmp_name
4✔
966

967
        tasklet_code = None
4✔
968
        if dtype.primitive_type == PrimitiveType.Int64:
4✔
969
            if op == "+":
4✔
970
                tasklet_code = TaskletCode.int_add
4✔
971
            elif op == "-":
4✔
972
                tasklet_code = TaskletCode.int_sub
4✔
973
            elif op == "*":
4✔
974
                tasklet_code = TaskletCode.int_mul
4✔
975
            elif op == "/":
4✔
976
                tasklet_code = TaskletCode.int_sdiv
×
977
            elif op == "//":
4✔
978
                tasklet_code = TaskletCode.int_sdiv
4✔
979
            elif op == "|":
4✔
980
                tasklet_code = TaskletCode.int_or
4✔
981
            elif op == "^":
4✔
982
                tasklet_code = TaskletCode.int_xor
4✔
983
        else:
984
            if op == "+":
4✔
985
                tasklet_code = TaskletCode.fp_add
4✔
986
            elif op == "-":
4✔
987
                tasklet_code = TaskletCode.fp_sub
4✔
988
            elif op == "*":
4✔
989
                tasklet_code = TaskletCode.fp_mul
4✔
990
            elif op == "/":
4✔
991
                tasklet_code = TaskletCode.fp_div
4✔
992
            elif op == "//":
×
993
                tasklet_code = TaskletCode.fp_div
×
994
            else:
995
                raise NotImplementedError(f"Operation {op} not supported for floats")
×
996

997
        block = self.builder.add_block()
4✔
998
        t_left, left_sub = self._add_read(block, real_left)
4✔
999
        t_right, right_sub = self._add_read(block, real_right)
4✔
1000
        t_out = self.builder.add_access(block, tmp_name)
4✔
1001

1002
        t_task = self.builder.add_tasklet(
4✔
1003
            block, tasklet_code, ["_in1", "_in2"], ["_out"]
1004
        )
1005

1006
        self.builder.add_memlet(block, t_left, "void", t_task, "_in1", left_sub)
4✔
1007
        self.builder.add_memlet(block, t_right, "void", t_task, "_in2", right_sub)
4✔
1008
        self.builder.add_memlet(block, t_task, "_out", t_out, "void", "")
4✔
1009

1010
        return tmp_name
4✔
1011

1012
    def _add_assign_constant(self, target_name, value_str, dtype):
4✔
1013
        block = self.builder.add_block()
4✔
1014
        t_const = self.builder.add_constant(block, value_str, dtype)
4✔
1015
        t_dst = self.builder.add_access(block, target_name)
4✔
1016
        t_task = self.builder.add_tasklet(block, TaskletCode.assign, ["_in"], ["_out"])
4✔
1017
        self.builder.add_memlet(block, t_const, "void", t_task, "_in", "")
4✔
1018
        self.builder.add_memlet(block, t_task, "_out", t_dst, "void", "")
4✔
1019

1020
    def visit_BoolOp(self, node):
4✔
1021
        op = self.visit(node.op)
4✔
1022
        values = [f"({self.visit(v)} != 0)" for v in node.values]
4✔
1023
        expr_str = f"{f' {op} '.join(values)}"
4✔
1024

1025
        tmp_name = f"_tmp_{self._get_unique_id()}"
4✔
1026
        dtype = Scalar(PrimitiveType.Bool)
4✔
1027
        self.builder.add_container(tmp_name, dtype, False)
4✔
1028

1029
        # Use control flow to assign boolean value
1030
        self.builder.begin_if(expr_str)
4✔
1031
        self._add_assign_constant(tmp_name, "true", dtype)
4✔
1032
        self.builder.begin_else()
4✔
1033
        self._add_assign_constant(tmp_name, "false", dtype)
4✔
1034
        self.builder.end_if()
4✔
1035

1036
        self.symbol_table[tmp_name] = dtype
4✔
1037
        return tmp_name
4✔
1038

1039
    def visit_Compare(self, node):
4✔
1040
        left = self.visit(node.left)
4✔
1041
        if len(node.ops) > 1:
4✔
1042
            raise NotImplementedError("Chained comparisons not supported yet")
×
1043

1044
        op = self.visit(node.ops[0])
4✔
1045
        right = self.visit(node.comparators[0])
4✔
1046

1047
        # Check if this is an array comparison
1048
        left_is_array = left in self.array_info
4✔
1049
        right_is_array = right in self.array_info
4✔
1050

1051
        if left_is_array or right_is_array:
4✔
1052
            # Handle array comparison - return boolean array
1053
            return self._handle_array_compare(
4✔
1054
                left, op, right, left_is_array, right_is_array
1055
            )
1056

1057
        # Scalar comparison
1058
        expr_str = f"{left} {op} {right}"
4✔
1059

1060
        tmp_name = f"_tmp_{self._get_unique_id()}"
4✔
1061
        dtype = Scalar(PrimitiveType.Bool)
4✔
1062
        self.builder.add_container(tmp_name, dtype, False)
4✔
1063

1064
        # Use control flow to assign boolean value
1065
        self.builder.begin_if(expr_str)
4✔
1066
        self.builder.add_transition(tmp_name, "true")
4✔
1067
        self.builder.begin_else()
4✔
1068
        self.builder.add_transition(tmp_name, "false")
4✔
1069
        self.builder.end_if()
4✔
1070

1071
        self.symbol_table[tmp_name] = dtype
4✔
1072
        return tmp_name
4✔
1073

1074
    def visit_UnaryOp(self, node):
4✔
1075
        if (
4✔
1076
            isinstance(node.op, ast.USub)
1077
            and isinstance(node.operand, ast.Constant)
1078
            and isinstance(node.operand.value, (int, float))
1079
        ):
1080
            return f"-{node.operand.value}"
4✔
1081

1082
        op = self.visit(node.op)
4✔
1083
        operand = self.visit(node.operand)
4✔
1084

1085
        # Check if operand is an array - handle as array operation
1086
        if operand in self.array_info and op == "-":
4✔
1087
            return self._handle_array_negate(operand)
4✔
1088

1089
        tmp_name = f"_tmp_{self._get_unique_id()}"
4✔
1090
        dtype = Scalar(PrimitiveType.Double)
4✔
1091
        if operand in self.symbol_table:
4✔
1092
            dtype = self.symbol_table[operand]
4✔
1093
            # If it's a pointer (array), get the element type
1094
            if isinstance(dtype, Pointer) and dtype.has_pointee_type():
4✔
1095
                dtype = dtype.pointee_type
×
1096
        elif self._is_int(operand):
×
1097
            dtype = Scalar(PrimitiveType.Int64)
×
1098
        elif isinstance(node.op, ast.Not):
×
1099
            dtype = Scalar(PrimitiveType.Bool)
×
1100

1101
        self.builder.add_container(tmp_name, dtype, False)
4✔
1102
        self.symbol_table[tmp_name] = dtype
4✔
1103

1104
        block = self.builder.add_block()
4✔
1105
        t_src, src_sub = self._add_read(block, operand)
4✔
1106
        t_dst = self.builder.add_access(block, tmp_name)
4✔
1107

1108
        if isinstance(node.op, ast.Not):
4✔
1109
            t_const = self.builder.add_constant(
4✔
1110
                block, "true", Scalar(PrimitiveType.Bool)
1111
            )
1112
            t_task = self.builder.add_tasklet(
4✔
1113
                block, TaskletCode.int_xor, ["_in1", "_in2"], ["_out"]
1114
            )
1115
            self.builder.add_memlet(block, t_src, "void", t_task, "_in1", src_sub)
4✔
1116
            self.builder.add_memlet(block, t_const, "void", t_task, "_in2", "")
4✔
1117
            self.builder.add_memlet(block, t_task, "_out", t_dst, "void", "")
4✔
1118

1119
        elif op == "-":
4✔
1120
            if dtype.primitive_type == PrimitiveType.Int64:
4✔
1121
                t_const = self.builder.add_constant(block, "0", dtype)
4✔
1122
                t_task = self.builder.add_tasklet(
4✔
1123
                    block, TaskletCode.int_sub, ["_in1", "_in2"], ["_out"]
1124
                )
1125
                self.builder.add_memlet(block, t_const, "void", t_task, "_in1", "")
4✔
1126
                self.builder.add_memlet(block, t_src, "void", t_task, "_in2", src_sub)
4✔
1127
                self.builder.add_memlet(block, t_task, "_out", t_dst, "void", "")
4✔
1128
            else:
1129
                t_task = self.builder.add_tasklet(
4✔
1130
                    block, TaskletCode.fp_neg, ["_in"], ["_out"]
1131
                )
1132
                self.builder.add_memlet(block, t_src, "void", t_task, "_in", src_sub)
4✔
1133
                self.builder.add_memlet(block, t_task, "_out", t_dst, "void", "")
4✔
1134
        else:
1135
            t_task = self.builder.add_tasklet(
×
1136
                block, TaskletCode.assign, ["_in"], ["_out"]
1137
            )
1138
            self.builder.add_memlet(block, t_src, "void", t_task, "_in", src_sub)
×
1139
            self.builder.add_memlet(block, t_task, "_out", t_dst, "void", "")
×
1140

1141
        return tmp_name
4✔
1142

1143
    def _handle_array_negate(self, operand):
4✔
1144
        """Handle negation of an array operand (-arr)."""
1145
        shape = self.array_info[operand]["shapes"]
4✔
1146
        dtype = self._get_dtype(operand)
4✔
1147

1148
        # Create output array
1149
        tmp_name = self._create_array_temp(shape, dtype)
4✔
1150

1151
        # Use elementwise binary op: 0 - arr
1152
        # First create a zero constant
1153
        zero_name = f"_tmp_{self._get_unique_id()}"
4✔
1154
        self.builder.add_container(zero_name, dtype, False)
4✔
1155
        self.symbol_table[zero_name] = dtype
4✔
1156

1157
        zero_block = self.builder.add_block()
4✔
1158
        t_const = self.builder.add_constant(
4✔
1159
            zero_block,
1160
            "0.0" if dtype.primitive_type == PrimitiveType.Double else "0",
1161
            dtype,
1162
        )
1163
        t_zero = self.builder.add_access(zero_block, zero_name)
4✔
1164
        t_assign = self.builder.add_tasklet(
4✔
1165
            zero_block, TaskletCode.assign, ["_in"], ["_out"]
1166
        )
1167
        self.builder.add_memlet(zero_block, t_const, "void", t_assign, "_in", "")
4✔
1168
        self.builder.add_memlet(zero_block, t_assign, "_out", t_zero, "void", "")
4✔
1169

1170
        # Now subtract: tmp = 0 - operand (broadcast scalar subtraction)
1171
        self.builder.add_elementwise_op("sub", zero_name, operand, tmp_name, shape)
4✔
1172

1173
        return tmp_name
4✔
1174

1175
    def _handle_array_compare(self, left, op, right, left_is_array, right_is_array):
4✔
1176
        """Handle elementwise comparison of arrays, returning a boolean array.
1177

1178
        Supports: arr > 0, arr < scalar, arr1 > arr2, etc.
1179
        """
1180
        # Determine shape from the array operand
1181
        if left_is_array:
4✔
1182
            shape = self.array_info[left]["shapes"]
4✔
1183
            arr_name = left
4✔
1184
        else:
1185
            shape = self.array_info[right]["shapes"]
×
1186
            arr_name = right
×
1187

1188
        # Determine if we need integer or floating point comparison
1189
        # based on the array element type
1190
        use_int_cmp = False
4✔
1191
        arr_dtype = self._get_dtype(arr_name)
4✔
1192
        if arr_dtype.primitive_type in (PrimitiveType.Int32, PrimitiveType.Int64):
4✔
1193
            use_int_cmp = True
×
1194

1195
        # Create output boolean array
1196
        dtype = Scalar(PrimitiveType.Bool)
4✔
1197
        tmp_name = self._create_array_temp(shape, dtype)
4✔
1198

1199
        # Map comparison operators to tasklet codes
1200
        if use_int_cmp:
4✔
1201
            cmp_ops = {
×
1202
                ">": TaskletCode.int_sgt,
1203
                ">=": TaskletCode.int_sge,
1204
                "<": TaskletCode.int_slt,
1205
                "<=": TaskletCode.int_sle,
1206
                "==": TaskletCode.int_eq,
1207
                "!=": TaskletCode.int_ne,
1208
            }
1209
        else:
1210
            # Floating point ordered comparisons
1211
            cmp_ops = {
4✔
1212
                ">": TaskletCode.fp_ogt,
1213
                ">=": TaskletCode.fp_oge,
1214
                "<": TaskletCode.fp_olt,
1215
                "<=": TaskletCode.fp_ole,
1216
                "==": TaskletCode.fp_oeq,
1217
                "!=": TaskletCode.fp_one,
1218
            }
1219

1220
        if op not in cmp_ops:
4✔
1221
            raise NotImplementedError(
×
1222
                f"Comparison operator {op} not supported for arrays"
1223
            )
1224

1225
        tasklet_code = cmp_ops[op]
4✔
1226

1227
        # For scalar operand, we may need to convert integer to float
1228
        # Create a float constant if needed
1229
        scalar_name = None
4✔
1230
        if not left_is_array:
4✔
1231
            scalar_name = left
×
1232
        elif not right_is_array:
4✔
1233
            scalar_name = right
4✔
1234

1235
        if scalar_name is not None and not use_int_cmp:
4✔
1236
            # Check if scalar is an integer literal and convert to float
1237
            if self._is_int(scalar_name):
4✔
1238
                # Create a float constant
1239
                float_name = f"_tmp_{self._get_unique_id()}"
4✔
1240
                self.builder.add_container(
4✔
1241
                    float_name, Scalar(PrimitiveType.Double), False
1242
                )
1243
                self.symbol_table[float_name] = Scalar(PrimitiveType.Double)
4✔
1244

1245
                block_conv = self.builder.add_block()
4✔
1246
                t_const = self.builder.add_constant(
4✔
1247
                    block_conv, f"{scalar_name}.0", Scalar(PrimitiveType.Double)
1248
                )
1249
                t_float = self.builder.add_access(block_conv, float_name)
4✔
1250
                t_assign = self.builder.add_tasklet(
4✔
1251
                    block_conv, TaskletCode.assign, ["_in"], ["_out"]
1252
                )
1253
                self.builder.add_memlet(
4✔
1254
                    block_conv, t_const, "void", t_assign, "_in", ""
1255
                )
1256
                self.builder.add_memlet(
4✔
1257
                    block_conv, t_assign, "_out", t_float, "void", ""
1258
                )
1259

1260
                # Replace the scalar name with the converted float
1261
                if not left_is_array:
4✔
1262
                    left = float_name
×
1263
                else:
1264
                    right = float_name
4✔
1265

1266
        # Generate nested loops
1267
        loop_vars = []
4✔
1268
        for i, dim in enumerate(shape):
4✔
1269
            loop_var = f"_cmp_i{i}_{self._get_unique_id()}"
4✔
1270
            if not self.builder.exists(loop_var):
4✔
1271
                self.builder.add_container(loop_var, Scalar(PrimitiveType.Int64), False)
4✔
1272
                self.symbol_table[loop_var] = Scalar(PrimitiveType.Int64)
4✔
1273
            loop_vars.append(loop_var)
4✔
1274
            self.builder.begin_for(loop_var, "0", str(dim), "1")
4✔
1275

1276
        # Compute linear index for array access
1277
        linear_idx = self._compute_linear_index(loop_vars, shape, tmp_name, len(shape))
4✔
1278

1279
        # Create comparison block
1280
        block = self.builder.add_block()
4✔
1281

1282
        # Read left operand
1283
        if left_is_array:
4✔
1284
            t_left = self.builder.add_access(block, left)
4✔
1285
            left_sub = linear_idx
4✔
1286
        else:
1287
            t_left, left_sub = self._add_read(block, left)
×
1288

1289
        # Read right operand
1290
        if right_is_array:
4✔
1291
            t_right = self.builder.add_access(block, right)
×
1292
            right_sub = linear_idx
×
1293
        else:
1294
            t_right, right_sub = self._add_read(block, right)
4✔
1295

1296
        # Output access
1297
        t_out = self.builder.add_access(block, tmp_name)
4✔
1298

1299
        # Create tasklet for comparison
1300
        t_task = self.builder.add_tasklet(
4✔
1301
            block, tasklet_code, ["_in1", "_in2"], ["_out"]
1302
        )
1303

1304
        # Connect memlets
1305
        self.builder.add_memlet(block, t_left, "void", t_task, "_in1", left_sub)
4✔
1306
        self.builder.add_memlet(block, t_right, "void", t_task, "_in2", right_sub)
4✔
1307
        self.builder.add_memlet(block, t_task, "_out", t_out, "void", linear_idx)
4✔
1308

1309
        # Close loops
1310
        for _ in loop_vars:
4✔
1311
            self.builder.end_for()
4✔
1312

1313
        return tmp_name
4✔
1314

1315
    def _parse_array_arg(self, node, simple_visitor):
4✔
1316
        if isinstance(node, ast.Name):
×
1317
            if node.id in self.array_info:
×
1318
                return node.id, [], self.array_info[node.id]["shapes"]
×
1319
        elif isinstance(node, ast.Subscript):
×
1320
            if isinstance(node.value, ast.Name) and node.value.id in self.array_info:
×
1321
                name = node.value.id
×
1322
                ndim = self.array_info[name]["ndim"]
×
1323

1324
                indices = []
×
1325
                if isinstance(node.slice, ast.Tuple):
×
1326
                    indices = list(node.slice.elts)
×
1327
                else:
1328
                    indices = [node.slice]
×
1329

1330
                while len(indices) < ndim:
×
1331
                    indices.append(ast.Slice(lower=None, upper=None, step=None))
×
1332

1333
                start_indices = []
×
1334
                slice_shape = []
×
1335

1336
                for i, idx in enumerate(indices):
×
1337
                    if isinstance(idx, ast.Slice):
×
1338
                        start = "0"
×
1339
                        if idx.lower:
×
1340
                            start = simple_visitor.visit(idx.lower)
×
1341
                        start_indices.append(start)
×
1342

1343
                        shapes = self.array_info[name]["shapes"]
×
1344
                        dim_size = (
×
1345
                            shapes[i] if i < len(shapes) else f"_{name}_shape_{i}"
1346
                        )
1347
                        stop = dim_size
×
1348
                        if idx.upper:
×
1349
                            stop = simple_visitor.visit(idx.upper)
×
1350

1351
                        size = f"({stop} - {start})"
×
1352
                        slice_shape.append(size)
×
1353
                    else:
1354
                        val = simple_visitor.visit(idx)
×
1355
                        start_indices.append(val)
×
1356

1357
                shapes = self.array_info[name]["shapes"]
×
1358
                linear_index = ""
×
1359
                for i in range(ndim):
×
1360
                    term = start_indices[i]
×
1361
                    for j in range(i + 1, ndim):
×
1362
                        shape_val = shapes[j] if j < len(shapes) else None
×
1363
                        shape_sym = (
×
1364
                            shape_val if shape_val is not None else f"_{name}_shape_{j}"
1365
                        )
1366
                        term = f"({term} * {shape_sym})"
×
1367

1368
                    if i == 0:
×
1369
                        linear_index = term
×
1370
                    else:
1371
                        linear_index = f"({linear_index} + {term})"
×
1372

1373
                return name, [linear_index], slice_shape
×
1374

1375
        return None, None, None
×
1376

1377
    def visit_Attribute(self, node):
4✔
1378
        if node.attr == "shape":
4✔
1379
            if isinstance(node.value, ast.Name) and node.value.id in self.array_info:
4✔
1380
                return f"_shape_proxy_{node.value.id}"
4✔
1381

1382
        if isinstance(node.value, ast.Name) and node.value.id == "math":
4✔
1383
            val = ""
4✔
1384
            if node.attr == "pi":
4✔
1385
                val = "M_PI"
4✔
1386
            elif node.attr == "e":
4✔
1387
                val = "M_E"
4✔
1388

1389
            if val:
4✔
1390
                tmp_name = f"_tmp_{self._get_unique_id()}"
4✔
1391
                dtype = Scalar(PrimitiveType.Double)
4✔
1392
                self.builder.add_container(tmp_name, dtype, False)
4✔
1393
                self.symbol_table[tmp_name] = dtype
4✔
1394
                self._add_assign_constant(tmp_name, val, dtype)
4✔
1395
                return tmp_name
4✔
1396

1397
        # Handle class member access (e.g., obj.x, obj.y)
1398
        if isinstance(node.value, ast.Name):
4✔
1399
            obj_name = node.value.id
4✔
1400
            attr_name = node.attr
4✔
1401

1402
            # Check if the object is a class instance (has a Structure type)
1403
            if obj_name in self.symbol_table:
4✔
1404
                obj_type = self.symbol_table[obj_name]
4✔
1405
                if isinstance(obj_type, Pointer) and obj_type.has_pointee_type():
4✔
1406
                    pointee_type = obj_type.pointee_type
4✔
1407
                    if isinstance(pointee_type, Structure):
4✔
1408
                        struct_name = pointee_type.name
4✔
1409

1410
                        # Look up member index and type from structure info
1411
                        if (
4✔
1412
                            struct_name in self.structure_member_info
1413
                            and attr_name in self.structure_member_info[struct_name]
1414
                        ):
1415
                            member_index, member_type = self.structure_member_info[
4✔
1416
                                struct_name
1417
                            ][attr_name]
1418
                        else:
1419
                            # This should not happen if structure was registered properly
1420
                            raise RuntimeError(
×
1421
                                f"Member '{attr_name}' not found in structure '{struct_name}'. "
1422
                                f"Available members: {list(self.structure_member_info.get(struct_name, {}).keys())}"
1423
                            )
1424

1425
                        # Generate a tasklet to access the member
1426
                        tmp_name = f"_tmp_{self._get_unique_id()}"
4✔
1427

1428
                        self.builder.add_container(tmp_name, member_type, False)
4✔
1429
                        self.symbol_table[tmp_name] = member_type
4✔
1430

1431
                        # Create a tasklet that reads the member
1432
                        block = self.builder.add_block()
4✔
1433
                        obj_access = self.builder.add_access(block, obj_name)
4✔
1434
                        tmp_access = self.builder.add_access(block, tmp_name)
4✔
1435

1436
                        # Use tasklet to pass through the value
1437
                        # The actual member selection is done via the memlet subset
1438
                        tasklet = self.builder.add_tasklet(
4✔
1439
                            block, TaskletCode.assign, ["_in"], ["_out"]
1440
                        )
1441

1442
                        # Use member index in the subset to select the correct member
1443
                        subset = "0," + str(member_index)
4✔
1444
                        self.builder.add_memlet(
4✔
1445
                            block, obj_access, "", tasklet, "_in", subset
1446
                        )
1447
                        self.builder.add_memlet(block, tasklet, "_out", tmp_access, "")
4✔
1448

1449
                        return tmp_name
4✔
1450

1451
        raise NotImplementedError(f"Attribute access {node.attr} not supported")
×
1452

1453
    def _handle_expression_slicing(self, node, value_str, indices_nodes, shapes, ndim):
4✔
1454
        """Handle slicing in expressions (e.g., arr[1:, :, k+1]).
1455

1456
        Creates a temporary array, generates loops to copy sliced data,
1457
        and returns the temporary array name.
1458
        """
1459
        if not self.builder:
4✔
1460
            raise ValueError("Builder required for expression slicing")
×
1461

1462
        # Determine element type from source array
1463
        dtype = Scalar(PrimitiveType.Double)
4✔
1464
        if value_str in self.symbol_table:
4✔
1465
            t = self.symbol_table[value_str]
4✔
1466
            if isinstance(t, Pointer) and t.has_pointee_type():
4✔
1467
                dtype = t.pointee_type
4✔
1468

1469
        # Analyze each dimension: is it a slice or an index?
1470
        # For slices, compute the resulting shape dimension
1471
        # For indices, that dimension is collapsed
1472
        result_shapes = []  # Shape of the resulting array (for SDFG)
4✔
1473
        result_shapes_runtime = []  # Shape expressions for runtime evaluation
4✔
1474
        slice_info = []  # List of (dim_idx, start_str, stop_str, step_str) for slices
4✔
1475
        index_info = []  # List of (dim_idx, index_str) for point indices
4✔
1476

1477
        for i, idx in enumerate(indices_nodes):
4✔
1478
            shape_val = shapes[i] if i < len(shapes) else f"_{value_str}_shape_{i}"
4✔
1479

1480
            if isinstance(idx, ast.Slice):
4✔
1481
                # Parse slice bounds - check for indirect access patterns
1482
                start_str = "0"
4✔
1483
                start_str_runtime = "0"  # For runtime shape evaluation
4✔
1484
                if idx.lower is not None:
4✔
1485
                    # Check if lower bound contains indirect array access
1486
                    if self._contains_indirect_access(idx.lower):
4✔
1487
                        start_str, start_str_runtime = (
4✔
1488
                            self._materialize_indirect_access(
1489
                                idx.lower, return_original_expr=True
1490
                            )
1491
                        )
1492
                    else:
1493
                        start_str = self.visit(idx.lower)
4✔
1494
                        start_str_runtime = start_str
4✔
1495
                    # Handle negative indices
1496
                    if isinstance(start_str, str) and (
4✔
1497
                        start_str.startswith("-") or start_str.startswith("(-")
1498
                    ):
1499
                        start_str = f"({shape_val} + {start_str})"
×
1500
                        start_str_runtime = f"({shape_val} + {start_str_runtime})"
×
1501

1502
                stop_str = str(shape_val)
4✔
1503
                stop_str_runtime = str(shape_val)
4✔
1504
                if idx.upper is not None:
4✔
1505
                    # Check if upper bound contains indirect array access
1506
                    if self._contains_indirect_access(idx.upper):
4✔
1507
                        stop_str, stop_str_runtime = self._materialize_indirect_access(
4✔
1508
                            idx.upper, return_original_expr=True
1509
                        )
1510
                    else:
1511
                        stop_str = self.visit(idx.upper)
4✔
1512
                        stop_str_runtime = stop_str
4✔
1513
                    # Handle negative indices
1514
                    if isinstance(stop_str, str) and (
4✔
1515
                        stop_str.startswith("-") or stop_str.startswith("(-")
1516
                    ):
1517
                        stop_str = f"({shape_val} + {stop_str})"
4✔
1518
                        stop_str_runtime = f"({shape_val} + {stop_str_runtime})"
4✔
1519

1520
                step_str = "1"
4✔
1521
                if idx.step is not None:
4✔
1522
                    step_str = self.visit(idx.step)
×
1523

1524
                # Compute the size of this dimension in the result
1525
                dim_size = f"({stop_str} - {start_str})"
4✔
1526
                dim_size_runtime = f"({stop_str_runtime} - {start_str_runtime})"
4✔
1527
                result_shapes.append(dim_size)
4✔
1528
                result_shapes_runtime.append(dim_size_runtime)
4✔
1529
                slice_info.append((i, start_str, stop_str, step_str))
4✔
1530
            else:
1531
                # Point index - dimension is collapsed
1532
                # Check for indirect array access in the index
1533
                if self._contains_indirect_access(idx):
4✔
1534
                    index_str = self._materialize_indirect_access(idx)
×
1535
                else:
1536
                    index_str = self.visit(idx)
4✔
1537
                # Handle negative indices
1538
                if isinstance(index_str, str) and (
4✔
1539
                    index_str.startswith("-") or index_str.startswith("(-")
1540
                ):
1541
                    index_str = f"({shape_val} + {index_str})"
×
1542
                index_info.append((i, index_str))
4✔
1543

1544
        # Create temporary array for the result
1545
        tmp_name = self._get_temp_name("_slice_tmp_")
4✔
1546
        result_ndim = len(result_shapes)
4✔
1547

1548
        if result_ndim == 0:
4✔
1549
            # All dimensions indexed - result is a scalar
1550
            self.builder.add_container(tmp_name, dtype, False)
×
1551
            self.symbol_table[tmp_name] = dtype
×
1552
        else:
1553
            # Result is an array - use _create_array_temp to handle allocation
1554
            # Calculate size for malloc - use SDFG symbolic shapes
1555
            size_str = "1"
4✔
1556
            for dim in result_shapes:
4✔
1557
                size_str = f"({size_str} * {dim})"
4✔
1558

1559
            element_size = self.builder.get_sizeof(dtype)
4✔
1560
            total_size = f"({size_str} * {element_size})"
4✔
1561

1562
            # Create pointer
1563
            ptr_type = Pointer(dtype)
4✔
1564
            self.builder.add_container(tmp_name, ptr_type, False)
4✔
1565
            self.symbol_table[tmp_name] = ptr_type
4✔
1566
            # Store both SDFG shapes (for compilation) and runtime shapes (for evaluation)
1567
            # The "shapes" field uses SDFG symbolic variables for malloc sizing
1568
            # The "shapes_runtime" field uses original expressions for Python runtime evaluation
1569
            self.array_info[tmp_name] = {
4✔
1570
                "ndim": result_ndim,
1571
                "shapes": result_shapes,  # Uses materialized variables for SDFG
1572
                "shapes_runtime": result_shapes_runtime,  # Uses original expressions for runtime
1573
            }
1574

1575
            # Malloc for the temporary array
1576
            debug_info = DebugInfo()
4✔
1577
            block_alloc = self.builder.add_block(debug_info)
4✔
1578
            t_malloc = self.builder.add_malloc(block_alloc, total_size)
4✔
1579
            t_ptr = self.builder.add_access(block_alloc, tmp_name, debug_info)
4✔
1580
            self.builder.add_memlet(
4✔
1581
                block_alloc, t_malloc, "_ret", t_ptr, "void", "", ptr_type, debug_info
1582
            )
1583

1584
        # Generate loops to copy the sliced data
1585
        loop_vars = []
4✔
1586
        debug_info = DebugInfo()
4✔
1587

1588
        for dim_idx, (orig_dim, start_str, stop_str, step_str) in enumerate(slice_info):
4✔
1589
            loop_var = f"_slice_loop_{dim_idx}_{self._get_unique_id()}"
4✔
1590
            loop_vars.append((loop_var, orig_dim, start_str, step_str))
4✔
1591

1592
            if not self.builder.exists(loop_var):
4✔
1593
                self.builder.add_container(loop_var, Scalar(PrimitiveType.Int64), False)
4✔
1594
                self.symbol_table[loop_var] = Scalar(PrimitiveType.Int64)
4✔
1595

1596
            # Loop from 0 to (stop - start)
1597
            count_str = f"({stop_str} - {start_str})"
4✔
1598
            self.builder.begin_for(loop_var, "0", count_str, "1", debug_info)
4✔
1599

1600
        # Build source and destination indices
1601
        src_indices = [""] * ndim
4✔
1602
        dst_indices = []
4✔
1603

1604
        # Fill in point indices for source
1605
        for orig_dim, index_str in index_info:
4✔
1606
            src_indices[orig_dim] = index_str
4✔
1607

1608
        # Fill in slice indices for source and build destination indices
1609
        for loop_var, orig_dim, start_str, step_str in loop_vars:
4✔
1610
            if step_str == "1":
4✔
1611
                src_indices[orig_dim] = f"({start_str} + {loop_var})"
4✔
1612
            else:
1613
                src_indices[orig_dim] = f"({start_str} + {loop_var} * {step_str})"
×
1614
            dst_indices.append(loop_var)
4✔
1615

1616
        # Compute linear indices
1617
        src_linear = self._compute_linear_index(src_indices, shapes, value_str, ndim)
4✔
1618
        if result_ndim > 0:
4✔
1619
            dst_linear = self._compute_linear_index(
4✔
1620
                dst_indices, result_shapes, tmp_name, result_ndim
1621
            )
1622
        else:
1623
            dst_linear = "0"
×
1624

1625
        # Create the copy block
1626
        block = self.builder.add_block(debug_info)
4✔
1627
        t_src = self.builder.add_access(block, value_str, debug_info)
4✔
1628
        t_dst = self.builder.add_access(block, tmp_name, debug_info)
4✔
1629
        t_task = self.builder.add_tasklet(
4✔
1630
            block, TaskletCode.assign, ["_in"], ["_out"], debug_info
1631
        )
1632

1633
        self.builder.add_memlet(
4✔
1634
            block, t_src, "void", t_task, "_in", src_linear, None, debug_info
1635
        )
1636
        self.builder.add_memlet(
4✔
1637
            block, t_task, "_out", t_dst, "void", dst_linear, None, debug_info
1638
        )
1639

1640
        # Close all loops
1641
        for _ in loop_vars:
4✔
1642
            self.builder.end_for()
4✔
1643

1644
        return tmp_name
4✔
1645

1646
    def _compute_linear_index(self, indices, shapes, array_name, ndim):
4✔
1647
        """Compute linear index from multi-dimensional indices."""
1648
        if ndim == 0:
4✔
1649
            return "0"
×
1650

1651
        linear_index = ""
4✔
1652
        for i in range(ndim):
4✔
1653
            term = str(indices[i])
4✔
1654
            for j in range(i + 1, ndim):
4✔
1655
                shape_val = shapes[j] if j < len(shapes) else f"_{array_name}_shape_{j}"
4✔
1656
                term = f"(({term}) * {shape_val})"
4✔
1657

1658
            if i == 0:
4✔
1659
                linear_index = term
4✔
1660
            else:
1661
                linear_index = f"({linear_index} + {term})"
4✔
1662

1663
        return linear_index
4✔
1664

1665
    def _is_array_index(self, node):
4✔
1666
        """Check if a node represents an array that could be used as an index (gather).
1667

1668
        Returns True if the node is a Name referring to an array in array_info.
1669
        """
1670
        if isinstance(node, ast.Name):
4✔
1671
            return node.id in self.array_info
4✔
1672
        return False
4✔
1673

1674
    def _handle_gather(self, value_str, index_node, debug_info=None):
4✔
1675
        """Handle gather operation: x[indices] where indices is an array.
1676

1677
        Creates a temporary array and generates a loop to gather elements
1678
        from the source array using the index array.
1679

1680
        This is the canonical SDFG pattern for gather operations:
1681
        - Create a loop over the index array
1682
        - Load the index value using a tasklet+memlets
1683
        - Use that index in the memlet subset for the source array
1684
        """
1685
        if debug_info is None:
4✔
1686
            debug_info = DebugInfo()
4✔
1687

1688
        # Get the index array name
1689
        if isinstance(index_node, ast.Name):
4✔
1690
            idx_array_name = index_node.id
4✔
1691
        else:
1692
            # Visit the index to get its name (handles slices like cols)
1693
            idx_array_name = self.visit(index_node)
×
1694

1695
        if idx_array_name not in self.array_info:
4✔
1696
            raise ValueError(f"Gather index must be an array, got {idx_array_name}")
×
1697

1698
        # Get shapes
1699
        idx_shapes = self.array_info[idx_array_name].get("shapes", [])
4✔
1700
        src_ndim = self.array_info[value_str]["ndim"]
4✔
1701
        idx_ndim = self.array_info[idx_array_name]["ndim"]
4✔
1702

1703
        if idx_ndim != 1:
4✔
1704
            raise NotImplementedError("Only 1D index arrays supported for gather")
×
1705

1706
        # Result array has same shape as index array
1707
        result_shape = idx_shapes[0] if idx_shapes else f"_{idx_array_name}_shape_0"
4✔
1708

1709
        # Determine element type from source array
1710
        dtype = Scalar(PrimitiveType.Double)
4✔
1711
        if value_str in self.symbol_table:
4✔
1712
            t = self.symbol_table[value_str]
4✔
1713
            if isinstance(t, Pointer) and t.has_pointee_type():
4✔
1714
                dtype = t.pointee_type
4✔
1715

1716
        # Determine index type from index array
1717
        idx_dtype = Scalar(PrimitiveType.Int64)
4✔
1718
        if idx_array_name in self.symbol_table:
4✔
1719
            t = self.symbol_table[idx_array_name]
4✔
1720
            if isinstance(t, Pointer) and t.has_pointee_type():
4✔
1721
                idx_dtype = t.pointee_type
4✔
1722

1723
        # Create result array
1724
        tmp_name = self._get_temp_name("_gather_")
4✔
1725

1726
        # Calculate size for malloc
1727
        element_size = self.builder.get_sizeof(dtype)
4✔
1728
        total_size = f"({result_shape} * {element_size})"
4✔
1729

1730
        # Create pointer for result
1731
        ptr_type = Pointer(dtype)
4✔
1732
        self.builder.add_container(tmp_name, ptr_type, False)
4✔
1733
        self.symbol_table[tmp_name] = ptr_type
4✔
1734
        self.array_info[tmp_name] = {"ndim": 1, "shapes": [result_shape]}
4✔
1735

1736
        # Malloc for the result array
1737
        block_alloc = self.builder.add_block(debug_info)
4✔
1738
        t_malloc = self.builder.add_malloc(block_alloc, total_size)
4✔
1739
        t_ptr = self.builder.add_access(block_alloc, tmp_name, debug_info)
4✔
1740
        self.builder.add_memlet(
4✔
1741
            block_alloc, t_malloc, "_ret", t_ptr, "void", "", ptr_type, debug_info
1742
        )
1743

1744
        # Create loop variable
1745
        loop_var = f"_gather_i_{self._get_unique_id()}"
4✔
1746
        self.builder.add_container(loop_var, Scalar(PrimitiveType.Int64), False)
4✔
1747
        self.symbol_table[loop_var] = Scalar(PrimitiveType.Int64)
4✔
1748

1749
        # Create variable to hold the loaded index
1750
        idx_var = f"_gather_idx_{self._get_unique_id()}"
4✔
1751
        self.builder.add_container(idx_var, idx_dtype, False)
4✔
1752
        self.symbol_table[idx_var] = idx_dtype
4✔
1753

1754
        # Begin loop
1755
        self.builder.begin_for(loop_var, "0", str(result_shape), "1", debug_info)
4✔
1756

1757
        # Block 1: Load the index from index array using tasklet+memlets
1758
        block_load_idx = self.builder.add_block(debug_info)
4✔
1759
        idx_arr_access = self.builder.add_access(
4✔
1760
            block_load_idx, idx_array_name, debug_info
1761
        )
1762
        idx_var_access = self.builder.add_access(block_load_idx, idx_var, debug_info)
4✔
1763
        tasklet_load = self.builder.add_tasklet(
4✔
1764
            block_load_idx, TaskletCode.assign, ["_in"], ["_out"], debug_info
1765
        )
1766
        self.builder.add_memlet(
4✔
1767
            block_load_idx,
1768
            idx_arr_access,
1769
            "void",
1770
            tasklet_load,
1771
            "_in",
1772
            loop_var,
1773
            None,
1774
            debug_info,
1775
        )
1776
        self.builder.add_memlet(
4✔
1777
            block_load_idx,
1778
            tasklet_load,
1779
            "_out",
1780
            idx_var_access,
1781
            "void",
1782
            "",
1783
            None,
1784
            debug_info,
1785
        )
1786

1787
        # Block 2: Use the loaded index to gather from source array
1788
        block_gather = self.builder.add_block(debug_info)
4✔
1789
        src_access = self.builder.add_access(block_gather, value_str, debug_info)
4✔
1790
        dst_access = self.builder.add_access(block_gather, tmp_name, debug_info)
4✔
1791
        tasklet_gather = self.builder.add_tasklet(
4✔
1792
            block_gather, TaskletCode.assign, ["_in"], ["_out"], debug_info
1793
        )
1794

1795
        # Use the symbolic variable name (idx_var) in the memlet subset - this is key!
1796
        self.builder.add_memlet(
4✔
1797
            block_gather,
1798
            src_access,
1799
            "void",
1800
            tasklet_gather,
1801
            "_in",
1802
            idx_var,
1803
            None,
1804
            debug_info,
1805
        )
1806
        self.builder.add_memlet(
4✔
1807
            block_gather,
1808
            tasklet_gather,
1809
            "_out",
1810
            dst_access,
1811
            "void",
1812
            loop_var,
1813
            None,
1814
            debug_info,
1815
        )
1816

1817
        # End loop
1818
        self.builder.end_for()
4✔
1819

1820
        return tmp_name
4✔
1821

1822
    def visit_Subscript(self, node):
4✔
1823
        value_str = self.visit(node.value)
4✔
1824

1825
        if value_str.startswith("_shape_proxy_"):
4✔
1826
            array_name = value_str[len("_shape_proxy_") :]
4✔
1827
            if isinstance(node.slice, ast.Constant):
4✔
1828
                idx = node.slice.value
4✔
1829
            elif isinstance(node.slice, ast.Index):
×
1830
                idx = node.slice.value.value
×
1831
            else:
1832
                try:
×
1833
                    idx = int(self.visit(node.slice))
×
1834
                except:
×
1835
                    raise NotImplementedError(
×
1836
                        "Dynamic shape indexing not fully supported yet"
1837
                    )
1838

1839
            if (
4✔
1840
                array_name in self.array_info
1841
                and "shapes" in self.array_info[array_name]
1842
            ):
1843
                return self.array_info[array_name]["shapes"][idx]
4✔
1844

1845
            return f"_{array_name}_shape_{idx}"
×
1846

1847
        if value_str in self.array_info:
4✔
1848
            ndim = self.array_info[value_str]["ndim"]
4✔
1849
            shapes = self.array_info[value_str].get("shapes", [])
4✔
1850

1851
            indices = []
4✔
1852
            if isinstance(node.slice, ast.Tuple):
4✔
1853
                indices_nodes = node.slice.elts
4✔
1854
            else:
1855
                indices_nodes = [node.slice]
4✔
1856

1857
            # Check if all indices are full slices (e.g., path[:] or path[:, :])
1858
            # In this case, return just the array name since it's the full array
1859
            all_full_slices = True
4✔
1860
            for idx in indices_nodes:
4✔
1861
                if isinstance(idx, ast.Slice):
4✔
1862
                    # A full slice has no lower, upper bounds or only None
1863
                    if idx.lower is not None or idx.upper is not None:
4✔
1864
                        all_full_slices = False
4✔
1865
                        break
4✔
1866
                else:
1867
                    all_full_slices = False
4✔
1868
                    break
4✔
1869

1870
            # path[:] on an nD array returns the full array
1871
            # So if we have a single full slice, it covers all dimensions
1872
            if all_full_slices:
4✔
1873
                # This is path[:] or path[:,:] - return the array name
1874
                return value_str
4✔
1875

1876
            # Check if there are any slices in the indices
1877
            has_slices = any(isinstance(idx, ast.Slice) for idx in indices_nodes)
4✔
1878
            if has_slices:
4✔
1879
                # Handle mixed slicing (e.g., arr[1:, :, k] or arr[:-1, :, k+1])
1880
                return self._handle_expression_slicing(
4✔
1881
                    node, value_str, indices_nodes, shapes, ndim
1882
                )
1883

1884
            # Check for gather operation: x[indices_array] where indices_array is an array
1885
            # This happens when we have a 1D source array and a 1D index array
1886
            if len(indices_nodes) == 1 and self._is_array_index(indices_nodes[0]):
4✔
1887
                if self.builder:
4✔
1888
                    return self._handle_gather(value_str, indices_nodes[0])
4✔
1889

1890
            if isinstance(node.slice, ast.Tuple):
4✔
1891
                indices = [self.visit(elt) for elt in node.slice.elts]
4✔
1892
            else:
1893
                indices = [self.visit(node.slice)]
4✔
1894

1895
            if len(indices) != ndim:
4✔
1896
                raise ValueError(
×
1897
                    f"Array {value_str} has {ndim} dimensions, but accessed with {len(indices)} indices"
1898
                )
1899

1900
            # Normalize negative indices
1901
            normalized_indices = []
4✔
1902
            for i, idx_str in enumerate(indices):
4✔
1903
                shape_val = shapes[i] if i < len(shapes) else f"_{value_str}_shape_{i}"
4✔
1904
                # Check if index is negative (starts with "-" or "(-")
1905
                if isinstance(idx_str, str) and (
4✔
1906
                    idx_str.startswith("-") or idx_str.startswith("(-")
1907
                ):
1908
                    # Normalize: size + negative_index
1909
                    normalized_indices.append(f"({shape_val} + {idx_str})")
×
1910
                else:
1911
                    normalized_indices.append(idx_str)
4✔
1912

1913
            linear_index = ""
4✔
1914
            for i in range(ndim):
4✔
1915
                term = normalized_indices[i]
4✔
1916
                for j in range(i + 1, ndim):
4✔
1917
                    shape_val = shapes[j] if j < len(shapes) else None
4✔
1918
                    shape_sym = (
4✔
1919
                        shape_val
1920
                        if shape_val is not None
1921
                        else f"_{value_str}_shape_{j}"
1922
                    )
1923
                    term = f"(({term}) * {shape_sym})"
4✔
1924

1925
                if i == 0:
4✔
1926
                    linear_index = term
4✔
1927
                else:
1928
                    linear_index = f"({linear_index} + {term})"
4✔
1929

1930
            access_str = f"{value_str}({linear_index})"
4✔
1931

1932
            if self.builder and isinstance(node.ctx, ast.Load):
4✔
1933
                dtype = Scalar(PrimitiveType.Double)
4✔
1934
                if value_str in self.symbol_table:
4✔
1935
                    t = self.symbol_table[value_str]
4✔
1936
                    if type(t).__name__ == "Array" and hasattr(t, "element_type"):
4✔
1937
                        et = t.element_type
×
1938
                        if callable(et):
×
1939
                            et = et()
×
1940
                        dtype = et
×
1941
                    elif type(t).__name__ == "Pointer" and hasattr(t, "pointee_type"):
4✔
1942
                        et = t.pointee_type
4✔
1943
                        if callable(et):
4✔
1944
                            et = et()
×
1945
                        dtype = et
4✔
1946

1947
                tmp_name = f"_tmp_{self._get_unique_id()}"
4✔
1948
                self.builder.add_container(tmp_name, dtype, False)
4✔
1949

1950
                block = self.builder.add_block()
4✔
1951
                t_src = self.builder.add_access(block, value_str)
4✔
1952
                t_dst = self.builder.add_access(block, tmp_name)
4✔
1953
                t_task = self.builder.add_tasklet(
4✔
1954
                    block, TaskletCode.assign, ["_in"], ["_out"]
1955
                )
1956

1957
                self.builder.add_memlet(
4✔
1958
                    block, t_src, "void", t_task, "_in", linear_index
1959
                )
1960
                self.builder.add_memlet(block, t_task, "_out", t_dst, "void", "")
4✔
1961

1962
                self.symbol_table[tmp_name] = dtype
4✔
1963
                return tmp_name
4✔
1964

1965
            return access_str
4✔
1966

1967
        slice_val = self.visit(node.slice)
×
1968
        access_str = f"{value_str}({slice_val})"
×
1969

1970
        if (
×
1971
            self.builder
1972
            and isinstance(node.ctx, ast.Load)
1973
            and value_str in self.array_info
1974
        ):
1975
            tmp_name = f"_tmp_{self._get_unique_id()}"
×
1976
            self.builder.add_container(tmp_name, Scalar(PrimitiveType.Double), False)
×
1977
            self.builder.add_assignment(tmp_name, access_str)
×
1978
            self.symbol_table[tmp_name] = Scalar(PrimitiveType.Double)
×
1979
            return tmp_name
×
1980

1981
        return access_str
×
1982

1983
    def visit_Add(self, node):
4✔
1984
        return "+"
4✔
1985

1986
    def visit_Sub(self, node):
4✔
1987
        return "-"
4✔
1988

1989
    def visit_Mult(self, node):
4✔
1990
        return "*"
4✔
1991

1992
    def visit_Div(self, node):
4✔
1993
        return "/"
4✔
1994

1995
    def visit_FloorDiv(self, node):
4✔
1996
        return "//"
4✔
1997

1998
    def visit_Mod(self, node):
4✔
1999
        return "%"
4✔
2000

2001
    def visit_Pow(self, node):
4✔
2002
        return "**"
4✔
2003

2004
    def visit_Eq(self, node):
4✔
2005
        return "=="
×
2006

2007
    def visit_NotEq(self, node):
4✔
2008
        return "!="
×
2009

2010
    def visit_Lt(self, node):
4✔
2011
        return "<"
4✔
2012

2013
    def visit_LtE(self, node):
4✔
2014
        return "<="
×
2015

2016
    def visit_Gt(self, node):
4✔
2017
        return ">"
4✔
2018

2019
    def visit_GtE(self, node):
4✔
2020
        return ">="
×
2021

2022
    def visit_And(self, node):
4✔
2023
        return "&"
4✔
2024

2025
    def visit_Or(self, node):
4✔
2026
        return "|"
4✔
2027

2028
    def visit_BitAnd(self, node):
4✔
2029
        return "&"
×
2030

2031
    def visit_BitOr(self, node):
4✔
2032
        return "|"
4✔
2033

2034
    def visit_BitXor(self, node):
4✔
2035
        return "^"
4✔
2036

2037
    def visit_Not(self, node):
4✔
2038
        return "!"
4✔
2039

2040
    def visit_USub(self, node):
4✔
2041
        return "-"
4✔
2042

2043
    def visit_UAdd(self, node):
4✔
2044
        return "+"
×
2045

2046
    def visit_Invert(self, node):
4✔
2047
        return "~"
×
2048

2049
    def _get_dtype(self, name):
4✔
2050
        if name in self.symbol_table:
4✔
2051
            t = self.symbol_table[name]
4✔
2052
            if isinstance(t, Scalar):
4✔
2053
                return t
4✔
2054

2055
            if hasattr(t, "pointee_type"):
4✔
2056
                et = t.pointee_type
4✔
2057
                if callable(et):
4✔
2058
                    et = et()
×
2059
                if isinstance(et, Scalar):
4✔
2060
                    return et
4✔
2061

2062
            if hasattr(t, "element_type"):
×
2063
                et = t.element_type
×
2064
                if callable(et):
×
2065
                    et = et()
×
2066
                if isinstance(et, Scalar):
×
2067
                    return et
×
2068

2069
        if self._is_int(name):
4✔
2070
            return Scalar(PrimitiveType.Int64)
×
2071

2072
        return Scalar(PrimitiveType.Double)
4✔
2073

2074
    def _promote_dtypes(self, dtype_left, dtype_right):
4✔
2075
        """Promote two dtypes following NumPy rules: float > int, wider > narrower."""
2076
        # Priority order: Double > Float > Int64 > Int32
2077
        priority = {
4✔
2078
            PrimitiveType.Double: 4,
2079
            PrimitiveType.Float: 3,
2080
            PrimitiveType.Int64: 2,
2081
            PrimitiveType.Int32: 1,
2082
        }
2083
        left_prio = priority.get(dtype_left.primitive_type, 0)
4✔
2084
        right_prio = priority.get(dtype_right.primitive_type, 0)
4✔
2085
        if left_prio >= right_prio:
4✔
2086
            return dtype_left
4✔
2087
        else:
2088
            return dtype_right
4✔
2089

2090
    def _create_array_temp(
4✔
2091
        self, shape, dtype, zero_init=False, ones_init=False, shapes_runtime=None
2092
    ):
2093
        tmp_name = f"_tmp_{self._get_unique_id()}"
4✔
2094

2095
        # Handle 0-dimensional arrays as scalars
2096
        if not shape or (len(shape) == 0):
4✔
2097
            # 0-D array is just a scalar
2098
            self.builder.add_container(tmp_name, dtype, False)
4✔
2099
            self.symbol_table[tmp_name] = dtype
4✔
2100
            self.array_info[tmp_name] = {"ndim": 0, "shapes": []}
4✔
2101

2102
            if zero_init:
4✔
2103
                self.builder.add_assignment(
×
2104
                    tmp_name,
2105
                    "0.0" if dtype.primitive_type == PrimitiveType.Double else "0",
2106
                )
2107
            elif ones_init:
4✔
2108
                self.builder.add_assignment(
×
2109
                    tmp_name,
2110
                    "1.0" if dtype.primitive_type == PrimitiveType.Double else "1",
2111
                )
2112

2113
            return tmp_name
4✔
2114

2115
        # Calculate size
2116
        size_str = "1"
4✔
2117
        for dim in shape:
4✔
2118
            size_str = f"({size_str} * {dim})"
4✔
2119

2120
        element_size = self.builder.get_sizeof(dtype)
4✔
2121
        total_size = f"({size_str} * {element_size})"
4✔
2122

2123
        # Create pointer
2124
        ptr_type = Pointer(dtype)
4✔
2125
        self.builder.add_container(tmp_name, ptr_type, False)
4✔
2126
        self.symbol_table[tmp_name] = ptr_type
4✔
2127
        array_info_entry = {"ndim": len(shape), "shapes": shape}
4✔
2128
        if shapes_runtime is not None:
4✔
2129
            array_info_entry["shapes_runtime"] = shapes_runtime
4✔
2130
        self.array_info[tmp_name] = array_info_entry
4✔
2131

2132
        # Malloc
2133
        block1 = self.builder.add_block()
4✔
2134
        t_malloc = self.builder.add_malloc(block1, total_size)
4✔
2135
        t_ptr1 = self.builder.add_access(block1, tmp_name)
4✔
2136
        self.builder.add_memlet(block1, t_malloc, "_ret", t_ptr1, "void", "", ptr_type)
4✔
2137

2138
        if zero_init:
4✔
2139
            block2 = self.builder.add_block()
4✔
2140
            t_memset = self.builder.add_memset(block2, "0", total_size)
4✔
2141
            t_ptr2 = self.builder.add_access(block2, tmp_name)
4✔
2142
            self.builder.add_memlet(
4✔
2143
                block2, t_memset, "_ptr", t_ptr2, "void", "", ptr_type
2144
            )
2145
        elif ones_init:
4✔
2146
            # Initialize array with ones using a loop
2147
            loop_var = f"_i_{self._get_unique_id()}"
4✔
2148
            if not self.builder.exists(loop_var):
4✔
2149
                self.builder.add_container(loop_var, Scalar(PrimitiveType.Int64), False)
4✔
2150
                self.symbol_table[loop_var] = Scalar(PrimitiveType.Int64)
4✔
2151

2152
            self.builder.begin_for(loop_var, "0", size_str, "1")
4✔
2153

2154
            # Determine the value to set based on dtype
2155
            val = "1.0"
4✔
2156
            if dtype.primitive_type in [
4✔
2157
                PrimitiveType.Int64,
2158
                PrimitiveType.Int32,
2159
                PrimitiveType.Int8,
2160
                PrimitiveType.Int16,
2161
                PrimitiveType.UInt64,
2162
                PrimitiveType.UInt32,
2163
                PrimitiveType.UInt8,
2164
                PrimitiveType.UInt16,
2165
            ]:
2166
                val = "1"
4✔
2167

2168
            block_assign = self.builder.add_block()
4✔
2169
            t_const = self.builder.add_constant(block_assign, val, dtype)
4✔
2170
            t_arr = self.builder.add_access(block_assign, tmp_name)
4✔
2171

2172
            t_task = self.builder.add_tasklet(
4✔
2173
                block_assign, TaskletCode.assign, ["_in"], ["_out"]
2174
            )
2175
            self.builder.add_memlet(
4✔
2176
                block_assign, t_const, "void", t_task, "_in", "", dtype
2177
            )
2178
            self.builder.add_memlet(
4✔
2179
                block_assign, t_task, "_out", t_arr, "void", loop_var
2180
            )
2181

2182
            self.builder.end_for()
4✔
2183

2184
        return tmp_name
4✔
2185

2186
    def _handle_array_unary_op(self, op_type, operand):
4✔
2187
        # Determine output shape
2188
        shape = []
4✔
2189
        if operand in self.array_info:
4✔
2190
            shape = self.array_info[operand]["shapes"]
4✔
2191

2192
        # Determine dtype
2193
        dtype = self._get_dtype(operand)
4✔
2194

2195
        # For 0-D arrays (scalars), use an intrinsic (CMathNode) instead of library node
2196
        if not shape or len(shape) == 0:
4✔
2197
            tmp_name = self._create_array_temp(shape, dtype)
4✔
2198

2199
            # Map op_type to C function names
2200
            func_map = {
4✔
2201
                "sqrt": CMathFunction.sqrt,
2202
                "abs": CMathFunction.fabs,
2203
                "absolute": CMathFunction.fabs,
2204
                "exp": CMathFunction.exp,
2205
                "tanh": CMathFunction.tanh,
2206
            }
2207

2208
            block = self.builder.add_block()
4✔
2209
            t_src = self.builder.add_access(block, operand)
4✔
2210
            t_dst = self.builder.add_access(block, tmp_name)
4✔
2211
            t_task = self.builder.add_cmath(block, func_map[op_type])
4✔
2212

2213
            # CMathNode uses _in1, _in2, etc for inputs and _out for output
2214
            self.builder.add_memlet(block, t_src, "void", t_task, "_in1", "", dtype)
4✔
2215
            self.builder.add_memlet(block, t_task, "_out", t_dst, "void", "", dtype)
4✔
2216

2217
            return tmp_name
4✔
2218

2219
        tmp_name = self._create_array_temp(shape, dtype)
4✔
2220

2221
        # Add operation
2222
        self.builder.add_elementwise_unary_op(op_type, operand, tmp_name, shape)
4✔
2223

2224
        return tmp_name
4✔
2225

2226
    def _handle_array_binary_op(self, op_type, left, right):
4✔
2227
        # Determine output shape (handle broadcasting by picking the larger shape)
2228
        left_shape = []
4✔
2229
        right_shape = []
4✔
2230
        if left in self.array_info:
4✔
2231
            left_shape = self.array_info[left]["shapes"]
4✔
2232
        if right in self.array_info:
4✔
2233
            right_shape = self.array_info[right]["shapes"]
4✔
2234
        # Pick the shape with more dimensions for broadcasting
2235
        shape = left_shape if len(left_shape) >= len(right_shape) else right_shape
4✔
2236

2237
        # Determine dtype with promotion (float > int, wider > narrower)
2238
        dtype_left = self._get_dtype(left)
4✔
2239
        dtype_right = self._get_dtype(right)
4✔
2240

2241
        # Promote dtypes: Double > Float > Int64 > Int32
2242
        dtype = self._promote_dtypes(dtype_left, dtype_right)
4✔
2243

2244
        # Cast scalar operands to the promoted dtype if needed
2245
        real_left = left
4✔
2246
        real_right = right
4✔
2247

2248
        # Helper to check if operand is a scalar (not an array)
2249
        left_is_scalar = left not in self.array_info
4✔
2250
        right_is_scalar = right not in self.array_info
4✔
2251

2252
        # Cast left operand if needed (scalar int to float)
2253
        if left_is_scalar and dtype_left.primitive_type != dtype.primitive_type:
4✔
2254
            left_cast = f"_tmp_{self._get_unique_id()}"
4✔
2255
            self.builder.add_container(left_cast, dtype, False)
4✔
2256
            self.symbol_table[left_cast] = dtype
4✔
2257

2258
            c_block = self.builder.add_block()
4✔
2259
            t_src, src_sub = self._add_read(c_block, left)
4✔
2260
            t_dst = self.builder.add_access(c_block, left_cast)
4✔
2261
            t_task = self.builder.add_tasklet(
4✔
2262
                c_block, TaskletCode.assign, ["_in"], ["_out"]
2263
            )
2264
            self.builder.add_memlet(c_block, t_src, "void", t_task, "_in", src_sub)
4✔
2265
            self.builder.add_memlet(c_block, t_task, "_out", t_dst, "void", "")
4✔
2266

2267
            real_left = left_cast
4✔
2268

2269
        # Cast right operand if needed (scalar int to float)
2270
        if right_is_scalar and dtype_right.primitive_type != dtype.primitive_type:
4✔
2271
            right_cast = f"_tmp_{self._get_unique_id()}"
4✔
2272
            self.builder.add_container(right_cast, dtype, False)
4✔
2273
            self.symbol_table[right_cast] = dtype
4✔
2274

2275
            c_block = self.builder.add_block()
4✔
2276
            t_src, src_sub = self._add_read(c_block, right)
4✔
2277
            t_dst = self.builder.add_access(c_block, right_cast)
4✔
2278
            t_task = self.builder.add_tasklet(
4✔
2279
                c_block, TaskletCode.assign, ["_in"], ["_out"]
2280
            )
2281
            self.builder.add_memlet(c_block, t_src, "void", t_task, "_in", src_sub)
4✔
2282
            self.builder.add_memlet(c_block, t_task, "_out", t_dst, "void", "")
4✔
2283

2284
            real_right = right_cast
4✔
2285

2286
        tmp_name = self._create_array_temp(shape, dtype)
4✔
2287

2288
        # Add operation with promoted dtype for implicit casting
2289
        self.builder.add_elementwise_op(op_type, real_left, real_right, tmp_name, shape)
4✔
2290

2291
        return tmp_name
4✔
2292

2293
    def _shape_to_runtime_expr(self, shape_node):
4✔
2294
        """Convert a shape expression AST node to a runtime-evaluable string.
2295

2296
        This converts the AST to a string expression that can be evaluated
2297
        at runtime using only input arrays and shape symbols (_s0, _s1, etc.).
2298
        It does NOT visit the node (which would create SDFG variables).
2299
        """
2300
        if isinstance(shape_node, ast.Constant):
4✔
2301
            return str(shape_node.value)
4✔
2302
        elif isinstance(shape_node, ast.Name):
4✔
2303
            return shape_node.id
4✔
2304
        elif isinstance(shape_node, ast.BinOp):
4✔
2305
            left = self._shape_to_runtime_expr(shape_node.left)
4✔
2306
            right = self._shape_to_runtime_expr(shape_node.right)
4✔
2307
            op = self.visit(shape_node.op)
4✔
2308
            return f"({left} {op} {right})"
4✔
2309
        elif isinstance(shape_node, ast.UnaryOp):
4✔
2310
            operand = self._shape_to_runtime_expr(shape_node.operand)
×
2311
            if isinstance(shape_node.op, ast.USub):
×
2312
                return f"(-{operand})"
×
2313
            elif isinstance(shape_node.op, ast.UAdd):
×
2314
                return operand
×
2315
            else:
2316
                # Fall back to visit for other unary ops
2317
                return self.visit(shape_node)
×
2318
        elif isinstance(shape_node, ast.Subscript):
4✔
2319
            # Handle arr.shape[0] -> arr.shape[0] for runtime eval
2320
            # or _shape_proxy_arr[0] -> _s<idx>
2321
            val = shape_node.value
4✔
2322
            if isinstance(val, ast.Attribute) and val.attr == "shape":
4✔
2323
                # arr.shape[0] -> use the shape symbol
2324
                if isinstance(val.value, ast.Name):
4✔
2325
                    arr_name = val.value.id
4✔
2326
                    if isinstance(shape_node.slice, ast.Constant):
4✔
2327
                        idx = shape_node.slice.value
4✔
2328
                        # Get the shape symbol for this array dimension
2329
                        if arr_name in self.array_info:
4✔
2330
                            shapes = self.array_info[arr_name].get("shapes", [])
4✔
2331
                            if idx < len(shapes):
4✔
2332
                                return shapes[idx]
4✔
2333
                        return f"{arr_name}.shape[{idx}]"
×
2334
            # Fall back to visit
2335
            return self.visit(shape_node)
×
2336
        elif isinstance(shape_node, ast.Tuple):
×
2337
            return [self._shape_to_runtime_expr(elt) for elt in shape_node.elts]
×
2338
        elif isinstance(shape_node, ast.List):
×
2339
            return [self._shape_to_runtime_expr(elt) for elt in shape_node.elts]
×
2340
        else:
2341
            # Fall back to visit for complex expressions
2342
            return self.visit(shape_node)
×
2343

2344
    def _handle_numpy_alloc(self, node, func_name):
4✔
2345
        # Parse shape
2346
        shape_arg = node.args[0]
4✔
2347
        dims = []
4✔
2348
        dims_runtime = []  # Runtime-evaluable shape expressions
4✔
2349
        if isinstance(shape_arg, ast.Tuple):
4✔
2350
            dims = [self.visit(elt) for elt in shape_arg.elts]
4✔
2351
            dims_runtime = [self._shape_to_runtime_expr(elt) for elt in shape_arg.elts]
4✔
2352
        elif isinstance(shape_arg, ast.List):
4✔
2353
            dims = [self.visit(elt) for elt in shape_arg.elts]
×
2354
            dims_runtime = [self._shape_to_runtime_expr(elt) for elt in shape_arg.elts]
×
2355
        else:
2356
            val = self.visit(shape_arg)
4✔
2357
            runtime_val = self._shape_to_runtime_expr(shape_arg)
4✔
2358
            if val.startswith("_shape_proxy_"):
4✔
2359
                array_name = val[len("_shape_proxy_") :]
×
2360
                if array_name in self.array_info:
×
2361
                    dims = self.array_info[array_name]["shapes"]
×
2362
                    dims_runtime = self.array_info[array_name].get(
×
2363
                        "shapes_runtime", dims
2364
                    )
2365
                else:
2366
                    dims = [val]
×
2367
                    dims_runtime = [runtime_val]
×
2368
            else:
2369
                dims = [val]
4✔
2370
                dims_runtime = [runtime_val]
4✔
2371

2372
        # Parse dtype
2373
        dtype_arg = None
4✔
2374
        if len(node.args) > 1:
4✔
2375
            dtype_arg = node.args[1]
×
2376

2377
        for kw in node.keywords:
4✔
2378
            if kw.arg == "dtype":
4✔
2379
                dtype_arg = kw.value
4✔
2380
                break
4✔
2381

2382
        element_type = self._map_numpy_dtype(dtype_arg)
4✔
2383

2384
        return self._create_array_temp(
4✔
2385
            dims,
2386
            element_type,
2387
            zero_init=(func_name == "zeros"),
2388
            ones_init=(func_name == "ones"),
2389
            shapes_runtime=dims_runtime,
2390
        )
2391

2392
    def _handle_numpy_empty_like(self, node, func_name):
4✔
2393
        prototype_arg = node.args[0]
4✔
2394
        prototype_name = self.visit(prototype_arg)
4✔
2395

2396
        # Parse shape from prototype
2397
        dims = []
4✔
2398
        if prototype_name in self.array_info:
4✔
2399
            dims = self.array_info[prototype_name]["shapes"]
4✔
2400

2401
        # Parse dtype
2402
        dtype_arg = None
4✔
2403
        if len(node.args) > 1:
4✔
2404
            dtype_arg = node.args[1]
×
2405

2406
        for kw in node.keywords:
4✔
2407
            if kw.arg == "dtype":
4✔
2408
                dtype_arg = kw.value
4✔
2409
                break
4✔
2410

2411
        element_type = None
4✔
2412
        if dtype_arg:
4✔
2413
            element_type = self._map_numpy_dtype(dtype_arg)
4✔
2414
        else:
2415
            if prototype_name in self.symbol_table:
4✔
2416
                sym_type = self.symbol_table[prototype_name]
4✔
2417
                if isinstance(sym_type, Pointer) and sym_type.has_pointee_type():
4✔
2418
                    element_type = sym_type.pointee_type
4✔
2419

2420
        if element_type is None:
4✔
2421
            element_type = Scalar(PrimitiveType.Double)
×
2422

2423
        return self._create_array_temp(
4✔
2424
            dims,
2425
            element_type,
2426
            zero_init=False,
2427
            ones_init=False,
2428
        )
2429

2430
    def _handle_numpy_zeros_like(self, node, func_name):
4✔
2431
        prototype_arg = node.args[0]
4✔
2432
        prototype_name = self.visit(prototype_arg)
4✔
2433

2434
        # Parse shape from prototype
2435
        dims = []
4✔
2436
        if prototype_name in self.array_info:
4✔
2437
            dims = self.array_info[prototype_name]["shapes"]
4✔
2438

2439
        # Parse dtype
2440
        dtype_arg = None
4✔
2441
        if len(node.args) > 1:
4✔
2442
            dtype_arg = node.args[1]
×
2443

2444
        for kw in node.keywords:
4✔
2445
            if kw.arg == "dtype":
4✔
2446
                dtype_arg = kw.value
4✔
2447
                break
4✔
2448

2449
        element_type = None
4✔
2450
        if dtype_arg:
4✔
2451
            element_type = self._map_numpy_dtype(dtype_arg)
4✔
2452
        else:
2453
            if prototype_name in self.symbol_table:
4✔
2454
                sym_type = self.symbol_table[prototype_name]
4✔
2455
                if isinstance(sym_type, Pointer) and sym_type.has_pointee_type():
4✔
2456
                    element_type = sym_type.pointee_type
4✔
2457

2458
        if element_type is None:
4✔
2459
            element_type = Scalar(PrimitiveType.Double)
×
2460

2461
        return self._create_array_temp(
4✔
2462
            dims,
2463
            element_type,
2464
            zero_init=True,
2465
            ones_init=False,
2466
        )
2467

2468
    def _handle_numpy_eye(self, node, func_name):
4✔
2469
        # Parse N
2470
        N_arg = node.args[0]
4✔
2471
        N_str = self.visit(N_arg)
4✔
2472

2473
        # Parse M
2474
        M_str = N_str
4✔
2475
        if len(node.args) > 1:
4✔
2476
            M_str = self.visit(node.args[1])
×
2477

2478
        # Parse k
2479
        k_str = "0"
4✔
2480
        if len(node.args) > 2:
4✔
2481
            k_str = self.visit(node.args[2])
×
2482

2483
        # Check keywords for M, k, dtype
2484
        dtype_arg = None
4✔
2485
        for kw in node.keywords:
4✔
2486
            if kw.arg == "M":
4✔
2487
                M_str = self.visit(kw.value)
4✔
2488
                if M_str == "None":
4✔
2489
                    M_str = N_str
4✔
2490
            elif kw.arg == "k":
4✔
2491
                k_str = self.visit(kw.value)
4✔
2492
            elif kw.arg == "dtype":
4✔
2493
                dtype_arg = kw.value
4✔
2494

2495
        element_type = self._map_numpy_dtype(dtype_arg)
4✔
2496

2497
        ptr_name = self._create_array_temp([N_str, M_str], element_type, zero_init=True)
4✔
2498

2499
        # Loop to set diagonal
2500
        loop_var = f"_i_{self._get_unique_id()}"
4✔
2501
        if not self.builder.exists(loop_var):
4✔
2502
            self.builder.add_container(loop_var, Scalar(PrimitiveType.Int64), False)
4✔
2503
            self.symbol_table[loop_var] = Scalar(PrimitiveType.Int64)
4✔
2504

2505
        self.builder.begin_for(loop_var, "0", N_str, "1")
4✔
2506

2507
        # Condition: 0 <= i + k < M
2508
        cond = f"(({loop_var} + {k_str}) >= 0) & (({loop_var} + {k_str}) < {M_str})"
4✔
2509
        self.builder.begin_if(cond)
4✔
2510

2511
        # Assignment: A[i, i+k] = 1
2512
        val = "1.0"
4✔
2513
        if element_type.primitive_type in [
4✔
2514
            PrimitiveType.Int64,
2515
            PrimitiveType.Int32,
2516
            PrimitiveType.Int8,
2517
            PrimitiveType.Int16,
2518
            PrimitiveType.UInt64,
2519
            PrimitiveType.UInt32,
2520
            PrimitiveType.UInt8,
2521
            PrimitiveType.UInt16,
2522
        ]:
2523
            val = "1"
×
2524

2525
        block_assign = self.builder.add_block()
4✔
2526
        t_const = self.builder.add_constant(block_assign, val, element_type)
4✔
2527
        t_arr = self.builder.add_access(block_assign, ptr_name)
4✔
2528
        flat_index = f"(({loop_var}) * ({M_str}) + ({loop_var}) + ({k_str}))"
4✔
2529
        subset = flat_index
4✔
2530

2531
        t_task = self.builder.add_tasklet(
4✔
2532
            block_assign, TaskletCode.assign, ["_in"], ["_out"]
2533
        )
2534
        self.builder.add_memlet(
4✔
2535
            block_assign, t_const, "void", t_task, "_in", "", element_type
2536
        )
2537
        self.builder.add_memlet(block_assign, t_task, "_out", t_arr, "void", subset)
4✔
2538

2539
        self.builder.end_if()
4✔
2540
        self.builder.end_for()
4✔
2541

2542
        return ptr_name
4✔
2543

2544
    def _handle_numpy_binary_op(self, node, func_name):
4✔
2545
        args = [self.visit(arg) for arg in node.args]
4✔
2546
        if len(args) != 2:
4✔
2547
            raise NotImplementedError(
×
2548
                f"Numpy function {func_name} requires 2 arguments"
2549
            )
2550

2551
        op_map = {
4✔
2552
            "add": "add",
2553
            "subtract": "sub",
2554
            "multiply": "mul",
2555
            "divide": "div",
2556
            "power": "pow",
2557
            "minimum": "min",
2558
            "maximum": "max",
2559
        }
2560
        return self._handle_array_binary_op(op_map[func_name], args[0], args[1])
4✔
2561

2562
    def _handle_numpy_where(self, node, func_name):
4✔
2563
        """Handle np.where(condition, x, y) - elementwise ternary selection.
2564

2565
        Returns an array where elements are taken from x where condition is True,
2566
        and from y where condition is False.
2567
        """
2568
        if len(node.args) != 3:
4✔
2569
            raise NotImplementedError("np.where requires 3 arguments (condition, x, y)")
×
2570

2571
        # Visit all arguments
2572
        cond_name = self.visit(node.args[0])
4✔
2573
        x_name = self.visit(node.args[1])
4✔
2574
        y_name = self.visit(node.args[2])
4✔
2575

2576
        # Determine output shape from the array arguments
2577
        # Priority: condition > y > x (since x might be scalar 0)
2578
        shape = []
4✔
2579
        dtype = Scalar(PrimitiveType.Double)
4✔
2580

2581
        # Check condition shape
2582
        if cond_name in self.array_info:
4✔
2583
            shape = self.array_info[cond_name]["shapes"]
4✔
2584

2585
        # If condition is scalar, check y
2586
        if not shape and y_name in self.array_info:
4✔
2587
            shape = self.array_info[y_name]["shapes"]
×
2588

2589
        # If y is scalar, check x
2590
        if not shape and x_name in self.array_info:
4✔
2591
            shape = self.array_info[x_name]["shapes"]
×
2592

2593
        if not shape:
4✔
2594
            raise NotImplementedError("np.where requires at least one array argument")
×
2595

2596
        # Determine dtype from y (since x might be scalar 0)
2597
        if y_name in self.symbol_table:
4✔
2598
            y_type = self.symbol_table[y_name]
4✔
2599
            if isinstance(y_type, Pointer) and y_type.has_pointee_type():
4✔
2600
                dtype = y_type.pointee_type
4✔
2601
            elif isinstance(y_type, Scalar):
×
2602
                dtype = y_type
×
2603

2604
        # Create output array
2605
        tmp_name = self._create_array_temp(shape, dtype)
4✔
2606

2607
        # Generate nested loops for the shape
2608
        loop_vars = []
4✔
2609
        for i, dim in enumerate(shape):
4✔
2610
            loop_var = f"_where_i{i}_{self._get_unique_id()}"
4✔
2611
            if not self.builder.exists(loop_var):
4✔
2612
                self.builder.add_container(loop_var, Scalar(PrimitiveType.Int64), False)
4✔
2613
                self.symbol_table[loop_var] = Scalar(PrimitiveType.Int64)
4✔
2614
            loop_vars.append(loop_var)
4✔
2615
            self.builder.begin_for(loop_var, "0", str(dim), "1")
4✔
2616

2617
        # Compute linear index
2618
        linear_idx = self._compute_linear_index(loop_vars, shape, tmp_name, len(shape))
4✔
2619

2620
        # Read condition value
2621
        cond_tmp = f"_where_cond_{self._get_unique_id()}"
4✔
2622
        self.builder.add_container(cond_tmp, Scalar(PrimitiveType.Bool), False)
4✔
2623
        self.symbol_table[cond_tmp] = Scalar(PrimitiveType.Bool)
4✔
2624

2625
        block_cond = self.builder.add_block()
4✔
2626
        if cond_name in self.array_info:
4✔
2627
            t_cond_arr = self.builder.add_access(block_cond, cond_name)
4✔
2628
            t_cond_out = self.builder.add_access(block_cond, cond_tmp)
4✔
2629
            t_cond_task = self.builder.add_tasklet(
4✔
2630
                block_cond, TaskletCode.assign, ["_in"], ["_out"]
2631
            )
2632
            self.builder.add_memlet(
4✔
2633
                block_cond, t_cond_arr, "void", t_cond_task, "_in", linear_idx
2634
            )
2635
            self.builder.add_memlet(
4✔
2636
                block_cond, t_cond_task, "_out", t_cond_out, "void", ""
2637
            )
2638
        else:
2639
            # Scalar condition - just use it directly
2640
            t_cond_src, cond_sub = self._add_read(block_cond, cond_name)
×
2641
            t_cond_out = self.builder.add_access(block_cond, cond_tmp)
×
2642
            t_cond_task = self.builder.add_tasklet(
×
2643
                block_cond, TaskletCode.assign, ["_in"], ["_out"]
2644
            )
2645
            self.builder.add_memlet(
×
2646
                block_cond, t_cond_src, "void", t_cond_task, "_in", cond_sub
2647
            )
2648
            self.builder.add_memlet(
×
2649
                block_cond, t_cond_task, "_out", t_cond_out, "void", ""
2650
            )
2651

2652
        # If-else based on condition
2653
        self.builder.begin_if(f"{cond_tmp} == true")
4✔
2654

2655
        # True branch: assign x
2656
        block_true = self.builder.add_block()
4✔
2657
        t_out_true = self.builder.add_access(block_true, tmp_name)
4✔
2658
        if x_name in self.array_info:
4✔
2659
            # x is an array
2660
            t_x = self.builder.add_access(block_true, x_name)
4✔
2661
            t_task_true = self.builder.add_tasklet(
4✔
2662
                block_true, TaskletCode.assign, ["_in"], ["_out"]
2663
            )
2664
            self.builder.add_memlet(
4✔
2665
                block_true, t_x, "void", t_task_true, "_in", linear_idx
2666
            )
2667
        else:
2668
            # x is a scalar
2669
            t_x, x_sub = self._add_read(block_true, x_name)
4✔
2670
            t_task_true = self.builder.add_tasklet(
4✔
2671
                block_true, TaskletCode.assign, ["_in"], ["_out"]
2672
            )
2673
            self.builder.add_memlet(block_true, t_x, "void", t_task_true, "_in", x_sub)
4✔
2674
        self.builder.add_memlet(
4✔
2675
            block_true, t_task_true, "_out", t_out_true, "void", linear_idx
2676
        )
2677

2678
        self.builder.begin_else()
4✔
2679

2680
        # False branch: assign y
2681
        block_false = self.builder.add_block()
4✔
2682
        t_out_false = self.builder.add_access(block_false, tmp_name)
4✔
2683
        if y_name in self.array_info:
4✔
2684
            # y is an array
2685
            t_y = self.builder.add_access(block_false, y_name)
4✔
2686
            t_task_false = self.builder.add_tasklet(
4✔
2687
                block_false, TaskletCode.assign, ["_in"], ["_out"]
2688
            )
2689
            self.builder.add_memlet(
4✔
2690
                block_false, t_y, "void", t_task_false, "_in", linear_idx
2691
            )
2692
        else:
2693
            # y is a scalar
2694
            t_y, y_sub = self._add_read(block_false, y_name)
4✔
2695
            t_task_false = self.builder.add_tasklet(
4✔
2696
                block_false, TaskletCode.assign, ["_in"], ["_out"]
2697
            )
2698
            self.builder.add_memlet(
4✔
2699
                block_false, t_y, "void", t_task_false, "_in", y_sub
2700
            )
2701
        self.builder.add_memlet(
4✔
2702
            block_false, t_task_false, "_out", t_out_false, "void", linear_idx
2703
        )
2704

2705
        self.builder.end_if()
4✔
2706

2707
        # Close all loops
2708
        for _ in loop_vars:
4✔
2709
            self.builder.end_for()
4✔
2710

2711
        return tmp_name
4✔
2712

2713
    def _handle_numpy_matmul_op(self, left_node, right_node):
4✔
2714
        return self._handle_matmul_helper(left_node, right_node)
4✔
2715

2716
    def _handle_numpy_matmul(self, node, func_name):
4✔
2717
        if len(node.args) != 2:
4✔
2718
            raise NotImplementedError("matmul/dot requires 2 arguments")
×
2719
        return self._handle_matmul_helper(node.args[0], node.args[1])
4✔
2720

2721
    def _handle_numpy_outer(self, node, func_name):
4✔
2722
        if len(node.args) != 2:
4✔
2723
            raise NotImplementedError("outer requires 2 arguments")
×
2724

2725
        arg0 = node.args[0]
4✔
2726
        arg1 = node.args[1]
4✔
2727

2728
        if not self.la_handler:
4✔
2729
            raise RuntimeError("LinearAlgebraHandler not initialized")
×
2730

2731
        res_a = self.la_handler.parse_arg(arg0)
4✔
2732
        res_b = self.la_handler.parse_arg(arg1)
4✔
2733

2734
        # Resolve standard names if parse_arg failed (likely complex expression)
2735
        if not res_a[0]:
4✔
2736
            left_name = self.visit(arg0)
×
2737
            arg0 = ast.Name(id=left_name)
×
2738
            res_a = self.la_handler.parse_arg(arg0)
×
2739

2740
        if not res_b[0]:
4✔
2741
            right_name = self.visit(arg1)
×
2742
            arg1 = ast.Name(id=right_name)
×
2743
            res_b = self.la_handler.parse_arg(arg1)
×
2744

2745
        name_a, subset_a, shape_a, indices_a = res_a
4✔
2746
        name_b, subset_b, shape_b, indices_b = res_b
4✔
2747

2748
        if not name_a or not name_b:
4✔
2749
            raise NotImplementedError("Could not resolve outer operands")
×
2750

2751
        def get_flattened_size_expr(name, indices, shapes):
4✔
2752
            # Simplified: if slice, we use parse_arg's returned `shapes` (which are dim sizes of the slice)
2753
            # And multiply them.
2754
            size_expr = "1"
4✔
2755
            for s in shapes:
4✔
2756
                if size_expr == "1":
4✔
2757
                    size_expr = str(s)
4✔
2758
                else:
2759
                    size_expr = f"({size_expr} * {str(s)})"
×
2760
            return size_expr
4✔
2761

2762
        m_expr = get_flattened_size_expr(name_a, indices_a, shape_a)
4✔
2763
        n_expr = get_flattened_size_expr(name_b, indices_b, shape_b)
4✔
2764

2765
        # Create temporary container
2766
        # Since outer usually promotes types or uses standard types, we default to double for now.
2767
        dtype = Scalar(PrimitiveType.Double)
4✔
2768

2769
        # Use helper to create array temp which handles symbol table and array info
2770
        tmp_name = self._create_array_temp([m_expr, n_expr], dtype)
4✔
2771

2772
        new_call_node = ast.Call(
4✔
2773
            func=node.func, args=[arg0, arg1], keywords=node.keywords
2774
        )
2775

2776
        self.la_handler.handle_outer(tmp_name, new_call_node)
4✔
2777

2778
        return tmp_name
4✔
2779

2780
    def _handle_ufunc_outer(self, node, ufunc_name):
4✔
2781
        """Handle np.add.outer, np.subtract.outer, np.multiply.outer, etc.
2782

2783
        These compute the outer operation for the given ufunc:
2784
        - np.add.outer(a, b) -> a[:, np.newaxis] + b (outer sum)
2785
        - np.subtract.outer(a, b) -> a[:, np.newaxis] - b (outer difference)
2786
        - np.multiply.outer(a, b) -> a[:, np.newaxis] * b (same as np.outer)
2787
        """
2788
        if len(node.args) != 2:
4✔
2789
            raise NotImplementedError(f"{ufunc_name}.outer requires 2 arguments")
×
2790

2791
        # For np.multiply.outer, use the existing GEMM-based outer handler
2792
        if ufunc_name == "multiply":
4✔
2793
            return self._handle_numpy_outer(node, "outer")
4✔
2794

2795
        # Map ufunc names to operation names and tasklet opcodes
2796
        op_map = {
4✔
2797
            "add": ("add", TaskletCode.fp_add, TaskletCode.int_add),
2798
            "subtract": ("sub", TaskletCode.fp_sub, TaskletCode.int_sub),
2799
            "divide": ("div", TaskletCode.fp_div, TaskletCode.int_sdiv),
2800
            "minimum": ("min", CMathFunction.fmin, TaskletCode.int_smin),
2801
            "maximum": ("max", CMathFunction.fmax, TaskletCode.int_smax),
2802
        }
2803

2804
        if ufunc_name not in op_map:
4✔
2805
            raise NotImplementedError(f"{ufunc_name}.outer not supported")
×
2806

2807
        op_name, fp_opcode, int_opcode = op_map[ufunc_name]
4✔
2808

2809
        # Use la_handler.parse_arg to properly handle sliced arrays
2810
        if not self.la_handler:
4✔
2811
            raise RuntimeError("LinearAlgebraHandler not initialized")
×
2812

2813
        arg0 = node.args[0]
4✔
2814
        arg1 = node.args[1]
4✔
2815

2816
        res_a = self.la_handler.parse_arg(arg0)
4✔
2817
        res_b = self.la_handler.parse_arg(arg1)
4✔
2818

2819
        # If parse_arg fails for complex expressions, try visiting and re-parsing
2820
        if not res_a[0]:
4✔
2821
            left_name = self.visit(arg0)
×
2822
            arg0 = ast.Name(id=left_name)
×
2823
            res_a = self.la_handler.parse_arg(arg0)
×
2824

2825
        if not res_b[0]:
4✔
2826
            right_name = self.visit(arg1)
×
2827
            arg1 = ast.Name(id=right_name)
×
2828
            res_b = self.la_handler.parse_arg(arg1)
×
2829

2830
        name_a, subset_a, shape_a, indices_a = res_a
4✔
2831
        name_b, subset_b, shape_b, indices_b = res_b
4✔
2832

2833
        if not name_a or not name_b:
4✔
2834
            raise NotImplementedError("Could not resolve ufunc outer operands")
×
2835

2836
        # Compute flattened sizes - outer treats inputs as 1D
2837
        def get_flattened_size_expr(shapes):
4✔
2838
            if not shapes:
4✔
2839
                return "1"
×
2840
            size_expr = str(shapes[0])
4✔
2841
            for s in shapes[1:]:
4✔
2842
                size_expr = f"({size_expr} * {str(s)})"
×
2843
            return size_expr
4✔
2844

2845
        m_expr = get_flattened_size_expr(shape_a)
4✔
2846
        n_expr = get_flattened_size_expr(shape_b)
4✔
2847

2848
        # Determine output dtype - infer from inputs or default to double
2849
        dtype_left = self._get_dtype(name_a)
4✔
2850
        dtype_right = self._get_dtype(name_b)
4✔
2851
        dtype = self._promote_dtypes(dtype_left, dtype_right)
4✔
2852

2853
        # Determine if we're working with integers
2854
        is_int = dtype.primitive_type in [
4✔
2855
            PrimitiveType.Int64,
2856
            PrimitiveType.Int32,
2857
            PrimitiveType.Int8,
2858
            PrimitiveType.Int16,
2859
            PrimitiveType.UInt64,
2860
            PrimitiveType.UInt32,
2861
            PrimitiveType.UInt8,
2862
            PrimitiveType.UInt16,
2863
        ]
2864

2865
        # Create output array with shape (M, N)
2866
        tmp_name = self._create_array_temp([m_expr, n_expr], dtype)
4✔
2867

2868
        # Generate unique loop variable names
2869
        i_var = self._get_temp_name("_outer_i_")
4✔
2870
        j_var = self._get_temp_name("_outer_j_")
4✔
2871

2872
        # Ensure loop variables exist
2873
        if not self.builder.exists(i_var):
4✔
2874
            self.builder.add_container(i_var, Scalar(PrimitiveType.Int64), False)
4✔
2875
            self.symbol_table[i_var] = Scalar(PrimitiveType.Int64)
4✔
2876
        if not self.builder.exists(j_var):
4✔
2877
            self.builder.add_container(j_var, Scalar(PrimitiveType.Int64), False)
4✔
2878
            self.symbol_table[j_var] = Scalar(PrimitiveType.Int64)
4✔
2879

2880
        # Helper function to compute the linear index for a sliced array access
2881
        def compute_linear_index(name, subset, indices, loop_var):
4✔
2882
            """
2883
            Compute linear index for accessing element loop_var of a sliced array.
2884

2885
            For array A with shape (N, M):
2886
            - A[:, k] (column k): linear_index = loop_var * M + k
2887
            - A[k, :] (row k): linear_index = k * M + loop_var
2888
            - A[:] (1D array): linear_index = loop_var
2889

2890
            The indices list contains AST nodes showing which dims are sliced vs fixed.
2891
            subset contains start indices for each dimension.
2892
            """
2893
            if not indices:
4✔
2894
                # Simple 1D array, no slicing
2895
                return loop_var
4✔
2896

2897
            info = self.array_info.get(name, {})
4✔
2898
            shapes = info.get("shapes", [])
4✔
2899
            ndim = info.get("ndim", len(shapes))
4✔
2900

2901
            if ndim == 0:
4✔
2902
                return loop_var
×
2903

2904
            # Compute strides (row-major order)
2905
            strides = []
4✔
2906
            current_stride = "1"
4✔
2907
            for i in range(ndim - 1, -1, -1):
4✔
2908
                strides.insert(0, current_stride)
4✔
2909
                if i > 0:
4✔
2910
                    dim_size = shapes[i] if i < len(shapes) else f"_{name}_shape_{i}"
4✔
2911
                    if current_stride == "1":
4✔
2912
                        current_stride = str(dim_size)
4✔
2913
                    else:
2914
                        current_stride = f"({current_stride} * {dim_size})"
×
2915

2916
            # Build linear index from subset and indices info
2917
            terms = []
4✔
2918
            loop_var_used = False
4✔
2919

2920
            for i, idx in enumerate(indices):
4✔
2921
                stride = strides[i] if i < len(strides) else "1"
4✔
2922
                start = subset[i] if i < len(subset) else "0"
4✔
2923

2924
                if isinstance(idx, ast.Slice):
4✔
2925
                    # This dimension is sliced - use loop_var
2926
                    if stride == "1":
4✔
2927
                        term = f"({start} + {loop_var})"
4✔
2928
                    else:
2929
                        term = f"(({start} + {loop_var}) * {stride})"
4✔
2930
                    loop_var_used = True
4✔
2931
                else:
2932
                    # This dimension has a fixed index
2933
                    if stride == "1":
4✔
2934
                        term = start
4✔
2935
                    else:
2936
                        term = f"({start} * {stride})"
4✔
2937

2938
                terms.append(term)
4✔
2939

2940
            # Sum all terms
2941
            if not terms:
4✔
2942
                return loop_var
×
2943

2944
            result = terms[0]
4✔
2945
            for t in terms[1:]:
4✔
2946
                result = f"({result} + {t})"
4✔
2947

2948
            return result
4✔
2949

2950
        # Create nested for loops: for i in range(M): for j in range(N): C[i,j] = A[i] op B[j]
2951
        self.builder.begin_for(i_var, "0", m_expr, "1")
4✔
2952
        self.builder.begin_for(j_var, "0", n_expr, "1")
4✔
2953

2954
        # Create the assignment block: C[i, j] = A[i] op B[j]
2955
        block = self.builder.add_block()
4✔
2956

2957
        # Add access nodes
2958
        t_a = self.builder.add_access(block, name_a)
4✔
2959
        t_b = self.builder.add_access(block, name_b)
4✔
2960
        t_c = self.builder.add_access(block, tmp_name)
4✔
2961

2962
        # Determine tasklet type based on operation
2963
        if ufunc_name in ["minimum", "maximum"]:
4✔
2964
            # Use intrinsic for min/max
2965
            if is_int:
4✔
2966
                t_task = self.builder.add_tasklet(
4✔
2967
                    block, int_opcode, ["_in1", "_in2"], ["_out"]
2968
                )
2969
            else:
2970
                t_task = self.builder.add_cmath(block, fp_opcode)
4✔
2971
        else:
2972
            # Use regular tasklet for arithmetic ops
2973
            tasklet_code = int_opcode if is_int else fp_opcode
4✔
2974
            t_task = self.builder.add_tasklet(
4✔
2975
                block, tasklet_code, ["_in1", "_in2"], ["_out"]
2976
            )
2977

2978
        # Compute the linear index for A[i]
2979
        a_index = compute_linear_index(name_a, subset_a, indices_a, i_var)
4✔
2980

2981
        # Compute the linear index for B[j]
2982
        b_index = compute_linear_index(name_b, subset_b, indices_b, j_var)
4✔
2983

2984
        # Connect A[i + offset_a] -> tasklet
2985
        self.builder.add_memlet(block, t_a, "void", t_task, "_in1", a_index)
4✔
2986

2987
        # Connect B[j + offset_b] -> tasklet
2988
        self.builder.add_memlet(block, t_b, "void", t_task, "_in2", b_index)
4✔
2989

2990
        # Connect tasklet -> C[i * N + j] (linear index for 2D output)
2991
        flat_index = f"(({i_var}) * ({n_expr}) + ({j_var}))"
4✔
2992
        self.builder.add_memlet(block, t_task, "_out", t_c, "void", flat_index)
4✔
2993

2994
        self.builder.end_for()  # end j loop
4✔
2995
        self.builder.end_for()  # end i loop
4✔
2996

2997
        return tmp_name
4✔
2998

2999
    def _op_symbol(self, op_name):
4✔
3000
        """Convert operation name to symbol."""
3001
        symbols = {
×
3002
            "add": "+",
3003
            "sub": "-",
3004
            "mul": "*",
3005
            "div": "/",
3006
            "min": "min",  # Will need special handling
3007
            "max": "max",  # Will need special handling
3008
        }
3009
        return symbols.get(op_name, op_name)
×
3010

3011
    def _handle_matmul_helper(self, left_node, right_node):
4✔
3012
        if not self.la_handler:
4✔
3013
            raise RuntimeError("LinearAlgebraHandler not initialized")
×
3014

3015
        res_a = self.la_handler.parse_arg(left_node)
4✔
3016
        res_b = self.la_handler.parse_arg(right_node)
4✔
3017

3018
        if not res_a[0]:
4✔
3019
            left_name = self.visit(left_node)
×
3020
            left_node = ast.Name(id=left_name)
×
3021
            res_a = self.la_handler.parse_arg(left_node)
×
3022

3023
        if not res_b[0]:
4✔
3024
            right_name = self.visit(right_node)
4✔
3025
            right_node = ast.Name(id=right_name)
4✔
3026
            res_b = self.la_handler.parse_arg(right_node)
4✔
3027

3028
        name_a, subset_a, shape_a, indices_a = res_a
4✔
3029
        name_b, subset_b, shape_b, indices_b = res_b
4✔
3030

3031
        if not name_a or not name_b:
4✔
3032
            raise NotImplementedError("Could not resolve matmul operands")
×
3033

3034
        real_shape_a = shape_a
4✔
3035
        real_shape_b = shape_b
4✔
3036

3037
        ndim_a = len(real_shape_a)
4✔
3038
        ndim_b = len(real_shape_b)
4✔
3039

3040
        output_shape = []
4✔
3041
        is_scalar = False
4✔
3042

3043
        if ndim_a == 1 and ndim_b == 1:
4✔
3044
            is_scalar = True
4✔
3045
            output_shape = []
4✔
3046
        elif ndim_a == 2 and ndim_b == 2:
4✔
3047
            output_shape = [real_shape_a[0], real_shape_b[1]]
4✔
3048
        elif ndim_a == 2 and ndim_b == 1:
4✔
3049
            output_shape = [real_shape_a[0]]
4✔
3050
        elif ndim_a == 1 and ndim_b == 2:
4✔
3051
            output_shape = [real_shape_b[1]]
×
3052
        elif ndim_a > 2 or ndim_b > 2:
4✔
3053
            if ndim_a == ndim_b:
4✔
3054
                output_shape = list(real_shape_a[:-2]) + [
4✔
3055
                    real_shape_a[-2],
3056
                    real_shape_b[-1],
3057
                ]
3058
            else:
3059
                raise NotImplementedError(
×
3060
                    "Broadcasting with different ranks not fully supported yet"
3061
                )
3062
        else:
3063
            raise NotImplementedError(
×
3064
                f"Matmul with ranks {ndim_a} and {ndim_b} not supported"
3065
            )
3066

3067
        dtype = Scalar(PrimitiveType.Double)
4✔
3068

3069
        if is_scalar:
4✔
3070
            tmp_name = f"_tmp_{self._get_unique_id()}"
4✔
3071
            self.builder.add_container(tmp_name, dtype, False)
4✔
3072
            self.symbol_table[tmp_name] = dtype
4✔
3073
        else:
3074
            tmp_name = self._create_array_temp(output_shape, dtype)
4✔
3075

3076
        if ndim_a > 2 or ndim_b > 2:
4✔
3077
            # Generate loops for broadcasting
3078
            batch_dims = ndim_a - 2
4✔
3079
            loop_vars = []
4✔
3080

3081
            for i in range(batch_dims):
4✔
3082
                loop_var = f"_i{self._get_unique_id()}"
4✔
3083
                self.builder.add_container(loop_var, Scalar(PrimitiveType.Int64), False)
4✔
3084
                loop_vars.append(loop_var)
4✔
3085
                dim_size = real_shape_a[i]
4✔
3086
                self.builder.begin_for(loop_var, "0", str(dim_size), "1")
4✔
3087

3088
            def make_slice(name, indices):
4✔
3089
                elts = []
4✔
3090
                for idx in indices:
4✔
3091
                    if idx == ":":
4✔
3092
                        elts.append(ast.Slice())
4✔
3093
                    else:
3094
                        elts.append(ast.Name(id=idx))
4✔
3095

3096
                return ast.Subscript(
4✔
3097
                    value=ast.Name(id=name), slice=ast.Tuple(elts=elts), ctx=ast.Load()
3098
                )
3099

3100
            indices = loop_vars + [":", ":"]
4✔
3101
            slice_a = make_slice(name_a, indices)
4✔
3102
            slice_b = make_slice(name_b, indices)
4✔
3103
            slice_c = make_slice(tmp_name, indices)
4✔
3104

3105
            self.la_handler.handle_gemm(
4✔
3106
                slice_c, ast.BinOp(left=slice_a, op=ast.MatMult(), right=slice_b)
3107
            )
3108

3109
            for _ in range(batch_dims):
4✔
3110
                self.builder.end_for()
4✔
3111
        else:
3112
            if is_scalar:
4✔
3113
                self.la_handler.handle_dot(
4✔
3114
                    tmp_name,
3115
                    ast.BinOp(left=left_node, op=ast.MatMult(), right=right_node),
3116
                )
3117
            else:
3118
                self.la_handler.handle_gemm(
4✔
3119
                    tmp_name,
3120
                    ast.BinOp(left=left_node, op=ast.MatMult(), right=right_node),
3121
                )
3122

3123
        return tmp_name
4✔
3124

3125
    def _handle_numpy_unary_op(self, node, func_name):
4✔
3126
        args = [self.visit(arg) for arg in node.args]
4✔
3127
        if len(args) != 1:
4✔
3128
            raise NotImplementedError(f"Numpy function {func_name} requires 1 argument")
×
3129

3130
        op_name = func_name
4✔
3131
        if op_name == "absolute":
4✔
3132
            op_name = "abs"
×
3133

3134
        return self._handle_array_unary_op(op_name, args[0])
4✔
3135

3136
    def _handle_numpy_reduce(self, node, func_name):
4✔
3137
        args = node.args
4✔
3138
        keywords = {kw.arg: kw.value for kw in node.keywords}
4✔
3139

3140
        array_node = args[0]
4✔
3141
        array_name = self.visit(array_node)
4✔
3142

3143
        if array_name not in self.array_info:
4✔
3144
            raise ValueError(f"Reduction input must be an array, got {array_name}")
×
3145

3146
        input_shape = self.array_info[array_name]["shapes"]
4✔
3147
        ndim = len(input_shape)
4✔
3148

3149
        axis = None
4✔
3150
        if len(args) > 1:
4✔
3151
            axis = args[1]
×
3152
        elif "axis" in keywords:
4✔
3153
            axis = keywords["axis"]
4✔
3154

3155
        keepdims = False
4✔
3156
        if "keepdims" in keywords:
4✔
3157
            keepdims_node = keywords["keepdims"]
4✔
3158
            if isinstance(keepdims_node, ast.Constant):
4✔
3159
                keepdims = bool(keepdims_node.value)
4✔
3160

3161
        axes = []
4✔
3162
        if axis is None:
4✔
3163
            axes = list(range(ndim))
4✔
3164
        elif isinstance(axis, ast.Constant):  # Single axis
4✔
3165
            val = axis.value
4✔
3166
            if val < 0:
4✔
3167
                val += ndim
×
3168
            axes = [val]
4✔
3169
        elif isinstance(axis, ast.Tuple):  # Multiple axes
×
3170
            for elt in axis.elts:
×
3171
                if isinstance(elt, ast.Constant):
×
3172
                    val = elt.value
×
3173
                    if val < 0:
×
3174
                        val += ndim
×
3175
                    axes.append(val)
×
3176
        elif (
×
3177
            isinstance(axis, ast.UnaryOp)
3178
            and isinstance(axis.op, ast.USub)
3179
            and isinstance(axis.operand, ast.Constant)
3180
        ):
3181
            val = -axis.operand.value
×
3182
            if val < 0:
×
3183
                val += ndim
×
3184
            axes = [val]
×
3185
        else:
3186
            # Try to evaluate simple expression
3187
            try:
×
3188
                val = int(self.visit(axis))
×
3189
                if val < 0:
×
3190
                    val += ndim
×
3191
                axes = [val]
×
3192
            except:
×
3193
                raise NotImplementedError("Dynamic axis not supported")
×
3194

3195
        # Calculate output shape
3196
        output_shape = []
4✔
3197
        for i in range(ndim):
4✔
3198
            if i in axes:
4✔
3199
                if keepdims:
4✔
3200
                    output_shape.append("1")
4✔
3201
            else:
3202
                output_shape.append(input_shape[i])
4✔
3203

3204
        dtype = self._get_dtype(array_name)
4✔
3205

3206
        if not output_shape:
4✔
3207
            tmp_name = f"_tmp_{self._get_unique_id()}"
4✔
3208
            self.builder.add_container(tmp_name, dtype, False)
4✔
3209
            self.symbol_table[tmp_name] = dtype
4✔
3210
            self.array_info[tmp_name] = {"ndim": 0, "shapes": []}
4✔
3211
        else:
3212
            tmp_name = self._create_array_temp(output_shape, dtype)
4✔
3213

3214
        self.builder.add_reduce_op(
4✔
3215
            func_name, array_name, tmp_name, input_shape, axes, keepdims
3216
        )
3217

3218
        return tmp_name
4✔
3219

3220
    def _handle_numpy_astype(self, node, array_name):
4✔
3221
        """Handle numpy array.astype(dtype) method calls."""
3222
        if len(node.args) < 1:
4✔
3223
            raise ValueError("astype requires at least one argument (dtype)")
×
3224

3225
        dtype_arg = node.args[0]
4✔
3226
        target_dtype = self._map_numpy_dtype(dtype_arg)
4✔
3227

3228
        # Get input array shape
3229
        if array_name not in self.array_info:
4✔
3230
            raise ValueError(f"Array {array_name} not found in array_info")
×
3231

3232
        input_shape = self.array_info[array_name]["shapes"]
4✔
3233

3234
        # Create output array with target dtype
3235
        tmp_name = self._create_array_temp(input_shape, target_dtype)
4✔
3236

3237
        # Add cast operation
3238
        self.builder.add_cast_op(
4✔
3239
            array_name, tmp_name, input_shape, target_dtype.primitive_type
3240
        )
3241

3242
        return tmp_name
4✔
3243

3244
    def _handle_numpy_copy(self, node, array_name):
4✔
3245
        """Handle numpy array.copy() method calls using memcpy."""
3246
        if array_name not in self.array_info:
4✔
NEW
3247
            raise ValueError(f"Array {array_name} not found in array_info")
×
3248

3249
        input_shape = self.array_info[array_name]["shapes"]
4✔
3250

3251
        # Get element type from array
3252
        element_type = Scalar(PrimitiveType.Double)  # Default
4✔
3253
        if array_name in self.symbol_table:
4✔
3254
            sym_type = self.symbol_table[array_name]
4✔
3255
            if isinstance(sym_type, Pointer) and sym_type.has_pointee_type():
4✔
3256
                element_type = sym_type.pointee_type
4✔
3257

3258
        # Create output array with same dtype
3259
        tmp_name = self._create_array_temp(input_shape, element_type)
4✔
3260

3261
        # Calculate total number of bytes to copy
3262
        # count = total_elements * sizeof(element_type)
3263
        total_elements = " * ".join([f"({s})" for s in input_shape])
4✔
3264
        element_size = self.builder.get_sizeof(element_type)
4✔
3265
        count_expr = f"({total_elements}) * ({element_size})"
4✔
3266

3267
        # Get pointer type for memlets
3268
        ptr_type = Pointer(element_type)
4✔
3269

3270
        # Add memcpy operation
3271
        block = self.builder.add_block()
4✔
3272
        t_src = self.builder.add_access(block, array_name)
4✔
3273
        t_dst = self.builder.add_access(block, tmp_name)
4✔
3274
        t_memcpy = self.builder.add_memcpy(block, count_expr)
4✔
3275

3276
        # Connect source and destination
3277
        self.builder.add_memlet(block, t_src, "void", t_memcpy, "_src", "", ptr_type)
4✔
3278
        self.builder.add_memlet(block, t_memcpy, "_dst", t_dst, "void", "", ptr_type)
4✔
3279

3280
        return tmp_name
4✔
3281

3282
    def _handle_scipy_softmax(self, node, func_name):
4✔
3283
        args = node.args
4✔
3284
        keywords = {kw.arg: kw.value for kw in node.keywords}
4✔
3285

3286
        array_node = args[0]
4✔
3287
        array_name = self.visit(array_node)
4✔
3288

3289
        if array_name not in self.array_info:
4✔
3290
            raise ValueError(f"Softmax input must be an array, got {array_name}")
×
3291

3292
        input_shape = self.array_info[array_name]["shapes"]
4✔
3293
        ndim = len(input_shape)
4✔
3294

3295
        axis = None
4✔
3296
        if len(args) > 1:
4✔
3297
            axis = args[1]
×
3298
        elif "axis" in keywords:
4✔
3299
            axis = keywords["axis"]
4✔
3300

3301
        axes = []
4✔
3302
        if axis is None:
4✔
3303
            axes = list(range(ndim))
4✔
3304
        elif isinstance(axis, ast.Constant):  # Single axis
4✔
3305
            val = axis.value
4✔
3306
            if val < 0:
4✔
3307
                val += ndim
×
3308
            axes = [val]
4✔
3309
        elif isinstance(axis, ast.Tuple):  # Multiple axes
×
3310
            for elt in axis.elts:
×
3311
                if isinstance(elt, ast.Constant):
×
3312
                    val = elt.value
×
3313
                    if val < 0:
×
3314
                        val += ndim
×
3315
                    axes.append(val)
×
3316
        elif (
×
3317
            isinstance(axis, ast.UnaryOp)
3318
            and isinstance(axis.op, ast.USub)
3319
            and isinstance(axis.operand, ast.Constant)
3320
        ):
3321
            val = -axis.operand.value
×
3322
            if val < 0:
×
3323
                val += ndim
×
3324
            axes = [val]
×
3325
        else:
3326
            # Try to evaluate simple expression
3327
            try:
×
3328
                val = int(self.visit(axis))
×
3329
                if val < 0:
×
3330
                    val += ndim
×
3331
                axes = [val]
×
3332
            except:
×
3333
                raise NotImplementedError("Dynamic axis not supported")
×
3334

3335
        # Create output array
3336
        # Assume double for now, or infer from input
3337
        dtype = Scalar(PrimitiveType.Double)  # TODO: infer
4✔
3338

3339
        tmp_name = self._create_array_temp(input_shape, dtype)
4✔
3340

3341
        self.builder.add_reduce_op(
4✔
3342
            func_name, array_name, tmp_name, input_shape, axes, False
3343
        )
3344

3345
        return tmp_name
4✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc