• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

daisytuner / docc / 21732822216

05 Feb 2026 11:41PM UTC coverage: 66.429% (-0.06%) from 66.484%
21732822216

Pull #507

github

web-flow
Merge c1db8d7e7 into cb356f32d
Pull Request #507: Adds MLP npbench benchmark

95 of 173 new or added lines in 3 files covered. (54.91%)

1 existing line in 1 file now uncovered.

23163 of 34869 relevant lines covered (66.43%)

374.87 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

81.55
/python/docc/python/expression_visitor.py
1
import ast
4✔
2
import inspect
4✔
3
import textwrap
4✔
4
from docc.sdfg import (
4✔
5
    Scalar,
6
    PrimitiveType,
7
    Pointer,
8
    Type,
9
    DebugInfo,
10
    Structure,
11
    TaskletCode,
12
    CMathFunction,
13
)
14

15

16
class ExpressionVisitor(ast.NodeVisitor):
4✔
17
    def __init__(
4✔
18
        self,
19
        array_info=None,
20
        builder=None,
21
        symbol_table=None,
22
        globals_dict=None,
23
        inliner=None,
24
        unique_counter_ref=None,
25
        structure_member_info=None,
26
    ):
27
        self.array_info = array_info if array_info is not None else {}
4✔
28
        self.builder = builder
4✔
29
        self.symbol_table = symbol_table if symbol_table is not None else {}
4✔
30
        self.globals_dict = globals_dict if globals_dict is not None else {}
4✔
31
        self.inliner = inliner
4✔
32
        self._unique_counter_ref = (
4✔
33
            unique_counter_ref if unique_counter_ref is not None else [0]
34
        )
35
        self._access_cache = {}
4✔
36
        self.la_handler = None
4✔
37
        self.structure_member_info = (
4✔
38
            structure_member_info if structure_member_info is not None else {}
39
        )
40
        self._init_numpy_handlers()
4✔
41

42
    def _get_unique_id(self):
4✔
43
        self._unique_counter_ref[0] += 1
4✔
44
        return self._unique_counter_ref[0]
4✔
45

46
    def _get_temp_name(self, prefix="_tmp_"):
4✔
47
        if hasattr(self.builder, "find_new_name"):
4✔
48
            return self.builder.find_new_name(prefix)
×
49
        return f"{prefix}{self._get_unique_id()}"
4✔
50

51
    def _is_indirect_access(self, node):
4✔
52
        """Check if a node represents an indirect array access (e.g., A[B[i]]).
53

54
        Returns True if the node is a subscript where the index itself is a subscript
55
        into an array (indirect access pattern).
56
        """
57
        if not isinstance(node, ast.Subscript):
×
58
            return False
×
59
        # Check if value is a subscripted array access
60
        if isinstance(node.value, ast.Name):
×
61
            arr_name = node.value.id
×
62
            if arr_name in self.array_info:
×
63
                # Check if slice/index is itself an array access
64
                if isinstance(node.slice, ast.Subscript):
×
65
                    if isinstance(node.slice.value, ast.Name):
×
66
                        idx_arr_name = node.slice.value.id
×
67
                        if idx_arr_name in self.array_info:
×
68
                            return True
×
69
        return False
×
70

71
    def _contains_indirect_access(self, node):
4✔
72
        """Check if an AST node contains any indirect array access.
73

74
        Used to detect expressions like A_row[i] that would be used as slice bounds.
75
        """
76
        if isinstance(node, ast.Subscript):
4✔
77
            if isinstance(node.value, ast.Name):
4✔
78
                arr_name = node.value.id
4✔
79
                if arr_name in self.array_info:
4✔
80
                    return True
4✔
81
        elif isinstance(node, ast.BinOp):
4✔
82
            return self._contains_indirect_access(
4✔
83
                node.left
84
            ) or self._contains_indirect_access(node.right)
85
        elif isinstance(node, ast.UnaryOp):
4✔
86
            return self._contains_indirect_access(node.operand)
4✔
87
        return False
4✔
88

89
    def _materialize_indirect_access(
4✔
90
        self, node, debug_info=None, return_original_expr=False
91
    ):
92
        """Materialize an array access into a scalar variable using tasklet+memlets.
93

94
        For indirect memory access patterns in SDFGs, we need to:
95
        1. Create a scalar container for the result
96
        2. Create a tasklet that performs the assignment
97
        3. Use memlets to read from the array and write to the scalar
98
        4. Return the scalar name (which can be used as a symbolic expression)
99

100
        This is the canonical SDFG pattern for indirect access.
101

102
        If return_original_expr is True, also returns the original array access
103
        expression using parentheses notation (e.g., "A_row(0)") which is consistent
104
        with SDFG subset notation. The runtime evaluator will convert this to
105
        bracket notation for Python evaluation.
106
        """
107
        if not self.builder:
4✔
108
            # Without builder, just return the expression string
109
            expr = self.visit(node)
×
110
            return (expr, expr) if return_original_expr else expr
×
111

112
        if debug_info is None:
4✔
113
            debug_info = DebugInfo()
4✔
114

115
        if not isinstance(node, ast.Subscript):
4✔
116
            expr = self.visit(node)
×
117
            return (expr, expr) if return_original_expr else expr
×
118

119
        if not isinstance(node.value, ast.Name):
4✔
120
            expr = self.visit(node)
×
121
            return (expr, expr) if return_original_expr else expr
×
122

123
        arr_name = node.value.id
4✔
124
        if arr_name not in self.array_info:
4✔
125
            expr = self.visit(node)
×
126
            return (expr, expr) if return_original_expr else expr
×
127

128
        # Determine the element type
129
        dtype = Scalar(PrimitiveType.Int64)  # Default for indices
4✔
130
        if arr_name in self.symbol_table:
4✔
131
            t = self.symbol_table[arr_name]
4✔
132
            if isinstance(t, Pointer) and t.has_pointee_type():
4✔
133
                dtype = t.pointee_type
4✔
134

135
        # Create scalar container for the result
136
        tmp_name = self._get_temp_name("_idx_")
4✔
137
        self.builder.add_container(tmp_name, dtype, False)
4✔
138
        self.symbol_table[tmp_name] = dtype
4✔
139

140
        # Get the index expression
141
        ndim = self.array_info[arr_name]["ndim"]
4✔
142
        shapes = self.array_info[arr_name].get("shapes", [])
4✔
143

144
        # Compute linear index from the subscript
145
        if isinstance(node.slice, ast.Tuple):
4✔
146
            indices = [self.visit(elt) for elt in node.slice.elts]
×
147
        else:
148
            indices = [self.visit(node.slice)]
4✔
149

150
        # Handle cases where we need recursive materialization
151
        materialized_indices = []
4✔
152
        for i, idx_str in enumerate(indices):
4✔
153
            # Check if the index itself needs materialization (nested indirect)
154
            # This happens when idx_str looks like an array access e.g., "arr(i)"
155
            if "(" in idx_str and idx_str.endswith(")"):
4✔
156
                # This is an array access, it should already be a valid symbolic expression
157
                # or a scalar variable name
158
                materialized_indices.append(idx_str)
×
159
            else:
160
                materialized_indices.append(idx_str)
4✔
161

162
        # Compute linear index
163
        linear_index = self._compute_linear_index(
4✔
164
            materialized_indices, shapes, arr_name, ndim
165
        )
166

167
        # Create block with tasklet and memlets
168
        block = self.builder.add_block(debug_info)
4✔
169
        t_src = self.builder.add_access(block, arr_name, debug_info)
4✔
170
        t_dst = self.builder.add_access(block, tmp_name, debug_info)
4✔
171
        t_task = self.builder.add_tasklet(
4✔
172
            block, TaskletCode.assign, ["_in"], ["_out"], debug_info
173
        )
174

175
        self.builder.add_memlet(
4✔
176
            block, t_src, "void", t_task, "_in", linear_index, None, debug_info
177
        )
178
        self.builder.add_memlet(
4✔
179
            block, t_task, "_out", t_dst, "void", "", None, debug_info
180
        )
181

182
        if return_original_expr:
4✔
183
            # Return both the materialized variable name and the original array access expression
184
            # Use parentheses notation which is consistent with SDFG subset syntax
185
            original_expr = f"{arr_name}({linear_index})"
4✔
186
            return (tmp_name, original_expr)
4✔
187

188
        return tmp_name
×
189

190
    def _init_numpy_handlers(self):
4✔
191
        self.numpy_handlers = {
4✔
192
            "empty": self._handle_numpy_alloc,
193
            "empty_like": self._handle_numpy_empty_like,
194
            "zeros": self._handle_numpy_alloc,
195
            "zeros_like": self._handle_numpy_zeros_like,
196
            "ones": self._handle_numpy_alloc,
197
            "ndarray": self._handle_numpy_alloc,  # np.ndarray() constructor
198
            "eye": self._handle_numpy_eye,
199
            "add": self._handle_numpy_binary_op,
200
            "subtract": self._handle_numpy_binary_op,
201
            "multiply": self._handle_numpy_binary_op,
202
            "divide": self._handle_numpy_binary_op,
203
            "power": self._handle_numpy_binary_op,
204
            "exp": self._handle_numpy_unary_op,
205
            "abs": self._handle_numpy_unary_op,
206
            "absolute": self._handle_numpy_unary_op,
207
            "sqrt": self._handle_numpy_unary_op,
208
            "tanh": self._handle_numpy_unary_op,
209
            "sum": self._handle_numpy_reduce,
210
            "max": self._handle_numpy_reduce,
211
            "min": self._handle_numpy_reduce,
212
            "mean": self._handle_numpy_reduce,
213
            "std": self._handle_numpy_reduce,
214
            "matmul": self._handle_numpy_matmul,
215
            "dot": self._handle_numpy_matmul,
216
            "matvec": self._handle_numpy_matmul,
217
            "outer": self._handle_numpy_outer,
218
            "minimum": self._handle_numpy_binary_op,
219
            "maximum": self._handle_numpy_binary_op,
220
            "where": self._handle_numpy_where,
221
            "clip": self._handle_numpy_clip,
222
        }
223

224
    def generic_visit(self, node):
4✔
225
        return super().generic_visit(node)
×
226

227
    def visit_Constant(self, node):
4✔
228
        if isinstance(node.value, bool):
4✔
229
            return "true" if node.value else "false"
×
230
        return str(node.value)
4✔
231

232
    def visit_Name(self, node):
4✔
233
        name = node.id
4✔
234
        # Check if it's a global constant (not a local variable/array)
235
        if name not in self.symbol_table and self.globals_dict is not None:
4✔
236
            if name in self.globals_dict:
4✔
237
                val = self.globals_dict[name]
4✔
238
                # Only substitute simple numeric constants
239
                if isinstance(val, (int, float)):
4✔
240
                    return str(val)
4✔
241
        return name
4✔
242

243
    def _map_numpy_dtype(self, dtype_node):
4✔
244
        # Default to double
245
        if dtype_node is None:
4✔
246
            return Scalar(PrimitiveType.Double)
×
247

248
        if isinstance(dtype_node, ast.Name):
4✔
249
            if dtype_node.id == "float":
4✔
250
                return Scalar(PrimitiveType.Double)
4✔
251
            if dtype_node.id == "int":
4✔
252
                return Scalar(PrimitiveType.Int64)
4✔
253
            if dtype_node.id == "bool":
×
254
                return Scalar(PrimitiveType.Bool)
×
255

256
        if isinstance(dtype_node, ast.Attribute):
4✔
257
            # Handle array.dtype
258
            if (
4✔
259
                isinstance(dtype_node.value, ast.Name)
260
                and dtype_node.value.id in self.symbol_table
261
                and dtype_node.attr == "dtype"
262
            ):
263
                sym_type = self.symbol_table[dtype_node.value.id]
4✔
264
                if isinstance(sym_type, Pointer) and sym_type.has_pointee_type():
4✔
265
                    return sym_type.pointee_type
4✔
266

267
            if isinstance(dtype_node.value, ast.Name) and dtype_node.value.id in [
4✔
268
                "numpy",
269
                "np",
270
            ]:
271
                if dtype_node.attr == "float64":
4✔
272
                    return Scalar(PrimitiveType.Double)
4✔
273
                if dtype_node.attr == "float32":
4✔
274
                    return Scalar(PrimitiveType.Float)
4✔
275
                if dtype_node.attr == "int64":
4✔
276
                    return Scalar(PrimitiveType.Int64)
4✔
277
                if dtype_node.attr == "int32":
4✔
278
                    return Scalar(PrimitiveType.Int32)
4✔
279
                if dtype_node.attr == "bool_":
×
280
                    return Scalar(PrimitiveType.Bool)
×
281

282
        # Fallback
283
        return Scalar(PrimitiveType.Double)
×
284

285
    def _is_int(self, operand):
4✔
286
        try:
4✔
287
            if operand.lstrip("-").isdigit():
4✔
288
                return True
4✔
289
        except ValueError:
×
290
            pass
×
291

292
        name = operand
4✔
293
        if "(" in operand and operand.endswith(")"):
4✔
294
            name = operand.split("(")[0]
4✔
295

296
        if name in self.symbol_table:
4✔
297
            t = self.symbol_table[name]
4✔
298

299
            def is_int_ptype(pt):
4✔
300
                return pt in [
4✔
301
                    PrimitiveType.Int64,
302
                    PrimitiveType.Int32,
303
                    PrimitiveType.Int8,
304
                    PrimitiveType.Int16,
305
                    PrimitiveType.UInt64,
306
                    PrimitiveType.UInt32,
307
                    PrimitiveType.UInt8,
308
                    PrimitiveType.UInt16,
309
                ]
310

311
            if isinstance(t, Scalar):
4✔
312
                return is_int_ptype(t.primitive_type)
4✔
313

314
            if type(t).__name__ == "Array" and hasattr(t, "element_type"):
4✔
315
                et = t.element_type
×
316
                if callable(et):
×
317
                    et = et()
×
318
                if isinstance(et, Scalar):
×
319
                    return is_int_ptype(et.primitive_type)
×
320

321
            if type(t).__name__ == "Pointer":
4✔
322
                if hasattr(t, "pointee_type"):
4✔
323
                    et = t.pointee_type
4✔
324
                    if callable(et):
4✔
325
                        et = et()
×
326
                    if isinstance(et, Scalar):
4✔
327
                        return is_int_ptype(et.primitive_type)
4✔
328
                # Fallback: check if it has element_type (maybe alias?)
329
                if hasattr(t, "element_type"):
×
330
                    et = t.element_type
×
331
                    if callable(et):
×
332
                        et = et()
×
333
                    if isinstance(et, Scalar):
×
334
                        return is_int_ptype(et.primitive_type)
×
335

336
        return False
4✔
337

338
    def _add_read(self, block, expr_str, debug_info=None):
4✔
339
        # Try to reuse access node
340
        try:
4✔
341
            if (block, expr_str) in self._access_cache:
4✔
342
                return self._access_cache[(block, expr_str)]
4✔
343
        except TypeError:
×
344
            # block might not be hashable
345
            pass
×
346

347
        if debug_info is None:
4✔
348
            debug_info = DebugInfo()
4✔
349

350
        if "(" in expr_str and expr_str.endswith(")"):
4✔
351
            name = expr_str.split("(")[0]
4✔
352
            subset = expr_str[expr_str.find("(") + 1 : -1]
4✔
353
            access = self.builder.add_access(block, name, debug_info)
4✔
354
            try:
4✔
355
                self._access_cache[(block, expr_str)] = (access, subset)
4✔
356
            except TypeError:
×
357
                pass
×
358
            return access, subset
4✔
359

360
        if self.builder.exists(expr_str):
4✔
361
            access = self.builder.add_access(block, expr_str, debug_info)
4✔
362
            # For pointer types representing 0-D arrays, dereference with "0"
363
            subset = ""
4✔
364
            if expr_str in self.symbol_table:
4✔
365
                sym_type = self.symbol_table[expr_str]
4✔
366
                if isinstance(sym_type, Pointer):
4✔
367
                    # Check if it's a 0-D array (scalar wrapped in pointer)
368
                    if expr_str in self.array_info:
×
369
                        ndim = self.array_info[expr_str].get("ndim", 0)
×
370
                        if ndim == 0:
×
371
                            subset = "0"
×
372
                    else:
373
                        # Pointer without array_info is treated as 0-D
374
                        subset = "0"
×
375
            try:
4✔
376
                self._access_cache[(block, expr_str)] = (access, subset)
4✔
377
            except TypeError:
×
378
                pass
×
379
            return access, subset
4✔
380

381
        dtype = Scalar(PrimitiveType.Double)
4✔
382
        if self._is_int(expr_str):
4✔
383
            dtype = Scalar(PrimitiveType.Int64)
4✔
384
        elif expr_str == "true" or expr_str == "false":
4✔
385
            dtype = Scalar(PrimitiveType.Bool)
×
386

387
        const_node = self.builder.add_constant(block, expr_str, dtype, debug_info)
4✔
388
        try:
4✔
389
            self._access_cache[(block, expr_str)] = (const_node, "")
4✔
390
        except TypeError:
×
391
            pass
×
392
        return const_node, ""
4✔
393

394
    def _handle_min_max(self, node, func_name):
4✔
395
        args = [self.visit(arg) for arg in node.args]
4✔
396
        if len(args) != 2:
4✔
397
            raise NotImplementedError(f"{func_name} only supported with 2 arguments")
×
398

399
        # Check types
400
        is_float = False
4✔
401
        arg_types = []
4✔
402

403
        for arg in args:
4✔
404
            name = arg
4✔
405
            if "(" in arg and arg.endswith(")"):
4✔
406
                name = arg.split("(")[0]
×
407

408
            if name in self.symbol_table:
4✔
409
                t = self.symbol_table[name]
4✔
410
                if isinstance(t, Pointer):
4✔
411
                    t = t.base_type
×
412

413
                if t.primitive_type == PrimitiveType.Double:
4✔
414
                    is_float = True
4✔
415
                    arg_types.append(PrimitiveType.Double)
4✔
416
                else:
417
                    arg_types.append(PrimitiveType.Int64)
4✔
418
            elif self._is_int(arg):
×
419
                arg_types.append(PrimitiveType.Int64)
×
420
            else:
421
                # Assume float constant
422
                is_float = True
×
423
                arg_types.append(PrimitiveType.Double)
×
424

425
        dtype = Scalar(PrimitiveType.Double if is_float else PrimitiveType.Int64)
4✔
426

427
        tmp_name = self._get_temp_name("_tmp_")
4✔
428
        self.builder.add_container(tmp_name, dtype, False)
4✔
429
        self.symbol_table[tmp_name] = dtype
4✔
430

431
        if is_float:
4✔
432
            # Cast args if necessary
433
            casted_args = []
4✔
434
            for i, arg in enumerate(args):
4✔
435
                if arg_types[i] != PrimitiveType.Double:
4✔
436
                    # Create temp double
437
                    tmp_cast = self._get_temp_name("_cast_")
4✔
438
                    self.builder.add_container(
4✔
439
                        tmp_cast, Scalar(PrimitiveType.Double), False
440
                    )
441
                    self.symbol_table[tmp_cast] = Scalar(PrimitiveType.Double)
4✔
442

443
                    # Assign int to double (implicit cast)
444
                    self.builder.add_assignment(tmp_cast, arg)
4✔
445
                    casted_args.append(tmp_cast)
4✔
446
                else:
447
                    casted_args.append(arg)
4✔
448

449
            block = self.builder.add_block()
4✔
450
            t_out = self.builder.add_access(block, tmp_name)
4✔
451

452
            intrinsic_name = (
4✔
453
                CMathFunction.fmax if func_name == "max" else CMathFunction.fmin
454
            )
455
            t_task = self.builder.add_cmath(block, intrinsic_name)
4✔
456

457
            for i, arg in enumerate(casted_args):
4✔
458
                t_arg, arg_sub = self._add_read(block, arg)
4✔
459
                self.builder.add_memlet(
4✔
460
                    block, t_arg, "void", t_task, f"_in{i+1}", arg_sub
461
                )
462
        else:
463
            block = self.builder.add_block()
4✔
464
            t_out = self.builder.add_access(block, tmp_name)
4✔
465

466
            # Use int_smax/int_smin tasklet
467
            opcode = None
4✔
468
            if func_name == "max":
4✔
469
                opcode = TaskletCode.int_smax
4✔
470
            else:
471
                opcode = TaskletCode.int_smin
4✔
472
            t_task = self.builder.add_tasklet(block, opcode, ["_in1", "_in2"], ["_out"])
4✔
473

474
            for i, arg in enumerate(args):
4✔
475
                t_arg, arg_sub = self._add_read(block, arg)
4✔
476
                self.builder.add_memlet(
4✔
477
                    block, t_arg, "void", t_task, f"_in{i+1}", arg_sub
478
                )
479

480
        self.builder.add_memlet(block, t_task, "_out", t_out, "void", "")
4✔
481
        return tmp_name
4✔
482

483
    def _handle_python_cast(self, node, func_name):
4✔
484
        """Handle Python type casts: int(), float(), bool()"""
485
        if len(node.args) != 1:
4✔
486
            raise NotImplementedError(f"{func_name}() cast requires exactly 1 argument")
×
487

488
        arg = self.visit(node.args[0])
4✔
489

490
        # Determine target type based on cast function
491
        if func_name == "int":
4✔
492
            target_dtype = Scalar(PrimitiveType.Int64)
4✔
493
        elif func_name == "float":
4✔
494
            target_dtype = Scalar(PrimitiveType.Double)
4✔
495
        elif func_name == "bool":
4✔
496
            target_dtype = Scalar(PrimitiveType.Bool)
4✔
497
        else:
498
            raise NotImplementedError(f"Cast to {func_name} not supported")
×
499

500
        # Determine source type
501
        source_dtype = None
4✔
502
        name = arg
4✔
503
        if "(" in arg and arg.endswith(")"):
4✔
504
            name = arg.split("(")[0]
×
505

506
        if name in self.symbol_table:
4✔
507
            source_dtype = self.symbol_table[name]
4✔
508
            if isinstance(source_dtype, Pointer):
4✔
509
                source_dtype = source_dtype.base_type
×
510
        elif self._is_int(arg):
×
511
            source_dtype = Scalar(PrimitiveType.Int64)
×
512
        elif arg == "true" or arg == "false":
×
513
            source_dtype = Scalar(PrimitiveType.Bool)
×
514
        else:
515
            # Assume float constant
516
            source_dtype = Scalar(PrimitiveType.Double)
×
517

518
        # Create temporary variable for result
519
        tmp_name = self._get_temp_name("_tmp_")
4✔
520
        self.builder.add_container(tmp_name, target_dtype, False)
4✔
521
        self.symbol_table[tmp_name] = target_dtype
4✔
522

523
        # Use tasklet assign opcode for casting (as specified in problem statement)
524
        block = self.builder.add_block()
4✔
525
        t_src, src_sub = self._add_read(block, arg)
4✔
526
        t_dst = self.builder.add_access(block, tmp_name)
4✔
527
        t_task = self.builder.add_tasklet(block, TaskletCode.assign, ["_in"], ["_out"])
4✔
528
        self.builder.add_memlet(block, t_src, "void", t_task, "_in", src_sub)
4✔
529
        self.builder.add_memlet(block, t_task, "_out", t_dst, "void", "")
4✔
530

531
        return tmp_name
4✔
532

533
    def visit_Call(self, node):
4✔
534
        func_name = ""
4✔
535
        module_name = ""
4✔
536
        if isinstance(node.func, ast.Attribute):
4✔
537
            if isinstance(node.func.value, ast.Name):
4✔
538
                if node.func.value.id == "math":
4✔
539
                    module_name = "math"
4✔
540
                    func_name = node.func.attr
4✔
541
                elif node.func.value.id in ["numpy", "np"]:
4✔
542
                    module_name = "numpy"
4✔
543
                    func_name = node.func.attr
4✔
544
                else:
545
                    # Check if it's a method call on an array (e.g., arr.astype(...), arr.copy())
546
                    array_name = node.func.value.id
4✔
547
                    method_name = node.func.attr
4✔
548
                    if array_name in self.array_info and method_name == "astype":
4✔
549
                        return self._handle_numpy_astype(node, array_name)
4✔
550
                    elif array_name in self.array_info and method_name == "copy":
4✔
551
                        return self._handle_numpy_copy(node, array_name)
4✔
552
            elif isinstance(node.func.value, ast.Attribute):
4✔
553
                if (
4✔
554
                    isinstance(node.func.value.value, ast.Name)
555
                    and node.func.value.value.id == "scipy"
556
                    and node.func.value.attr == "special"
557
                ):
558
                    if node.func.attr == "softmax":
4✔
559
                        return self._handle_scipy_softmax(node, "softmax")
4✔
560
                # Handle np.add.outer, np.subtract.outer, np.multiply.outer, etc.
561
                elif (
4✔
562
                    isinstance(node.func.value.value, ast.Name)
563
                    and node.func.value.value.id in ["numpy", "np"]
564
                    and node.func.attr == "outer"
565
                ):
566
                    ufunc_name = node.func.value.attr  # "add", "subtract", etc.
4✔
567
                    return self._handle_ufunc_outer(node, ufunc_name)
4✔
568

569
        elif isinstance(node.func, ast.Name):
4✔
570
            func_name = node.func.id
4✔
571

572
        if module_name == "numpy":
4✔
573
            if func_name in self.numpy_handlers:
4✔
574
                return self.numpy_handlers[func_name](node, func_name)
4✔
575

576
        if func_name in ["max", "min"]:
4✔
577
            return self._handle_min_max(node, func_name)
4✔
578

579
        # Handle Python type casts (int, float, bool)
580
        if func_name in ["int", "float", "bool"]:
4✔
581
            return self._handle_python_cast(node, func_name)
4✔
582

583
        math_funcs = {
4✔
584
            # Trigonometric functions
585
            "sin": CMathFunction.sin,
586
            "cos": CMathFunction.cos,
587
            "tan": CMathFunction.tan,
588
            "asin": CMathFunction.asin,
589
            "acos": CMathFunction.acos,
590
            "atan": CMathFunction.atan,
591
            "atan2": CMathFunction.atan2,
592
            # Hyperbolic functions
593
            "sinh": CMathFunction.sinh,
594
            "cosh": CMathFunction.cosh,
595
            "tanh": CMathFunction.tanh,
596
            "asinh": CMathFunction.asinh,
597
            "acosh": CMathFunction.acosh,
598
            "atanh": CMathFunction.atanh,
599
            # Exponential and logarithmic functions
600
            "exp": CMathFunction.exp,
601
            "exp2": CMathFunction.exp2,
602
            "expm1": CMathFunction.expm1,
603
            "log": CMathFunction.log,
604
            "log2": CMathFunction.log2,
605
            "log10": CMathFunction.log10,
606
            "log1p": CMathFunction.log1p,
607
            # Power functions
608
            "pow": CMathFunction.pow,
609
            "sqrt": CMathFunction.sqrt,
610
            "cbrt": CMathFunction.cbrt,
611
            "hypot": CMathFunction.hypot,
612
            # Rounding and remainder functions
613
            "abs": CMathFunction.fabs,
614
            "fabs": CMathFunction.fabs,
615
            "ceil": CMathFunction.ceil,
616
            "floor": CMathFunction.floor,
617
            "trunc": CMathFunction.trunc,
618
            "fmod": CMathFunction.fmod,
619
            "remainder": CMathFunction.remainder,
620
            # Floating-point manipulation functions
621
            "copysign": CMathFunction.copysign,
622
            # Other functions
623
            "fma": CMathFunction.fma,
624
        }
625

626
        if func_name in math_funcs:
4✔
627
            args = [self.visit(arg) for arg in node.args]
4✔
628

629
            tmp_name = self._get_temp_name("_tmp_")
4✔
630
            dtype = Scalar(PrimitiveType.Double)
4✔
631
            self.builder.add_container(tmp_name, dtype, False)
4✔
632
            self.symbol_table[tmp_name] = dtype
4✔
633

634
            block = self.builder.add_block()
4✔
635
            t_out = self.builder.add_access(block, tmp_name)
4✔
636

637
            t_task = self.builder.add_cmath(block, math_funcs[func_name])
4✔
638

639
            for i, arg in enumerate(args):
4✔
640
                t_arg, arg_sub = self._add_read(block, arg)
4✔
641
                self.builder.add_memlet(
4✔
642
                    block, t_arg, "void", t_task, f"_in{i+1}", arg_sub
643
                )
644

645
            self.builder.add_memlet(block, t_task, "_out", t_out, "void", "")
4✔
646
            return tmp_name
4✔
647

648
        if func_name in self.globals_dict:
4✔
649
            obj = self.globals_dict[func_name]
4✔
650
            if inspect.isfunction(obj):
4✔
651
                return self._handle_inline_call(node, obj)
4✔
652

653
        raise NotImplementedError(f"Function call {func_name} not supported")
×
654

655
    def _handle_inline_call(self, node, func_obj):
4✔
656
        # 1. Parse function source
657
        try:
4✔
658
            source_lines, start_line = inspect.getsourcelines(func_obj)
4✔
659
            source = textwrap.dedent("".join(source_lines))
4✔
660
            tree = ast.parse(source)
4✔
661
            func_def = tree.body[0]
4✔
662
        except Exception as e:
×
663
            raise NotImplementedError(
×
664
                f"Could not parse function {func_obj.__name__}: {e}"
665
            )
666

667
        # 2. Evaluate arguments
668
        arg_vars = [self.visit(arg) for arg in node.args]
4✔
669

670
        if len(arg_vars) != len(func_def.args.args):
4✔
671
            raise NotImplementedError(
×
672
                f"Argument count mismatch for {func_obj.__name__}"
673
            )
674

675
        # 3. Generate unique suffix
676
        suffix = f"_{func_obj.__name__}_{self._get_unique_id()}"
4✔
677
        res_name = f"_res{suffix}"
4✔
678

679
        # 4. Rename variables
680
        class VariableRenamer(ast.NodeTransformer):
4✔
681
            # Builtins that should not be renamed
682
            BUILTINS = {
4✔
683
                "range",
684
                "len",
685
                "int",
686
                "float",
687
                "bool",
688
                "str",
689
                "list",
690
                "dict",
691
                "tuple",
692
                "set",
693
                "print",
694
                "abs",
695
                "min",
696
                "max",
697
                "sum",
698
                "enumerate",
699
                "zip",
700
                "map",
701
                "filter",
702
                "sorted",
703
                "reversed",
704
                "True",
705
                "False",
706
                "None",
707
            }
708

709
            def __init__(self, suffix, globals_dict):
4✔
710
                self.suffix = suffix
4✔
711
                self.globals_dict = globals_dict
4✔
712

713
            def visit_Name(self, node):
4✔
714
                # Don't rename builtins or globals
715
                if node.id in self.globals_dict or node.id in self.BUILTINS:
4✔
716
                    return node
4✔
717
                return ast.Name(id=f"{node.id}{self.suffix}", ctx=node.ctx)
4✔
718

719
            def visit_Return(self, node):
4✔
720
                if node.value:
4✔
721
                    val = self.visit(node.value)
4✔
722
                    return ast.Assign(
4✔
723
                        targets=[ast.Name(id=res_name, ctx=ast.Store())],
724
                        value=val,
725
                    )
726
                return node
×
727

728
        renamer = VariableRenamer(suffix, self.globals_dict)
4✔
729
        new_body = [renamer.visit(stmt) for stmt in func_def.body]
4✔
730

731
        # 5. Assign arguments to parameters
732
        param_assignments = []
4✔
733
        for arg_def, arg_val in zip(func_def.args.args, arg_vars):
4✔
734
            param_name = f"{arg_def.arg}{suffix}"
4✔
735

736
            # Infer type and create container
737
            if arg_val in self.symbol_table:
4✔
738
                self.symbol_table[param_name] = self.symbol_table[arg_val]
4✔
739
                self.builder.add_container(
4✔
740
                    param_name, self.symbol_table[arg_val], False
741
                )
742
                val_node = ast.Name(id=arg_val, ctx=ast.Load())
4✔
743
            elif self._is_int(arg_val):
×
744
                self.symbol_table[param_name] = Scalar(PrimitiveType.Int64)
×
745
                self.builder.add_container(
×
746
                    param_name, Scalar(PrimitiveType.Int64), False
747
                )
748
                val_node = ast.Constant(value=int(arg_val))
×
749
            else:
750
                # Assume float constant
751
                try:
×
752
                    val = float(arg_val)
×
753
                    self.symbol_table[param_name] = Scalar(PrimitiveType.Double)
×
754
                    self.builder.add_container(
×
755
                        param_name, Scalar(PrimitiveType.Double), False
756
                    )
757
                    val_node = ast.Constant(value=val)
×
758
                except ValueError:
×
759
                    # Fallback to Name, might fail later if not in symbol table
760
                    val_node = ast.Name(id=arg_val, ctx=ast.Load())
×
761

762
            assign = ast.Assign(
4✔
763
                targets=[ast.Name(id=param_name, ctx=ast.Store())], value=val_node
764
            )
765
            param_assignments.append(assign)
4✔
766

767
        final_body = param_assignments + new_body
4✔
768

769
        # 6. Visit new body using ASTParser
770
        from .ast_parser import ASTParser
4✔
771

772
        parser = ASTParser(
4✔
773
            self.builder,
774
            self.array_info,
775
            self.symbol_table,
776
            globals_dict=self.globals_dict,
777
            unique_counter_ref=self._unique_counter_ref,
778
        )
779

780
        for stmt in final_body:
4✔
781
            parser.visit(stmt)
4✔
782

783
        return res_name
4✔
784

785
    def visit_BinOp(self, node):
4✔
786
        if isinstance(node.op, ast.MatMult):
4✔
787
            return self._handle_numpy_matmul_op(node.left, node.right)
4✔
788

789
        left = self.visit(node.left)
4✔
790
        op = self.visit(node.op)
4✔
791
        right = self.visit(node.right)
4✔
792

793
        # Check if left or right are arrays
794
        left_is_array = left in self.array_info
4✔
795
        right_is_array = right in self.array_info
4✔
796

797
        if left_is_array or right_is_array:
4✔
798
            op_map = {"+": "add", "-": "sub", "*": "mul", "/": "div", "**": "pow"}
4✔
799
            if op in op_map:
4✔
800
                return self._handle_array_binary_op(op_map[op], left, right)
4✔
801
            else:
802
                raise NotImplementedError(f"Array operation {op} not supported")
×
803

804
        tmp_name = f"_tmp_{self._get_unique_id()}"
4✔
805

806
        dtype = Scalar(PrimitiveType.Double)  # Default
4✔
807

808
        left_is_int = self._is_int(left)
4✔
809
        right_is_int = self._is_int(right)
4✔
810

811
        if left_is_int and right_is_int and op not in ["/", "**"]:
4✔
812
            dtype = Scalar(PrimitiveType.Int64)
4✔
813

814
        self.builder.add_container(tmp_name, dtype, False)
4✔
815
        self.symbol_table[tmp_name] = dtype
4✔
816

817
        real_left = left
4✔
818
        real_right = right
4✔
819

820
        if dtype.primitive_type == PrimitiveType.Double:
4✔
821
            if left_is_int:
4✔
822
                left_cast = f"_tmp_{self._get_unique_id()}"
4✔
823
                self.builder.add_container(
4✔
824
                    left_cast, Scalar(PrimitiveType.Double), False
825
                )
826
                self.symbol_table[left_cast] = Scalar(PrimitiveType.Double)
4✔
827

828
                c_block = self.builder.add_block()
4✔
829
                t_src, src_sub = self._add_read(c_block, left)
4✔
830
                t_dst = self.builder.add_access(c_block, left_cast)
4✔
831
                t_task = self.builder.add_tasklet(
4✔
832
                    c_block, TaskletCode.assign, ["_in"], ["_out"]
833
                )
834
                self.builder.add_memlet(c_block, t_src, "void", t_task, "_in", src_sub)
4✔
835
                self.builder.add_memlet(c_block, t_task, "_out", t_dst, "void", "")
4✔
836

837
                real_left = left_cast
4✔
838

839
            if right_is_int:
4✔
840
                right_cast = f"_tmp_{self._get_unique_id()}"
4✔
841
                self.builder.add_container(
4✔
842
                    right_cast, Scalar(PrimitiveType.Double), False
843
                )
844
                self.symbol_table[right_cast] = Scalar(PrimitiveType.Double)
4✔
845

846
                c_block = self.builder.add_block()
4✔
847
                t_src, src_sub = self._add_read(c_block, right)
4✔
848
                t_dst = self.builder.add_access(c_block, right_cast)
4✔
849
                t_task = self.builder.add_tasklet(
4✔
850
                    c_block, TaskletCode.assign, ["_in"], ["_out"]
851
                )
852
                self.builder.add_memlet(c_block, t_src, "void", t_task, "_in", src_sub)
4✔
853
                self.builder.add_memlet(c_block, t_task, "_out", t_dst, "void", "")
4✔
854

855
                real_right = right_cast
4✔
856

857
        # Special cases
858
        if op == "**":
4✔
859
            block = self.builder.add_block()
4✔
860
            t_left, left_sub = self._add_read(block, real_left)
4✔
861
            t_right, right_sub = self._add_read(block, real_right)
4✔
862
            t_out = self.builder.add_access(block, tmp_name)
4✔
863

864
            t_task = self.builder.add_cmath(block, CMathFunction.pow)
4✔
865
            self.builder.add_memlet(block, t_left, "void", t_task, "_in1", left_sub)
4✔
866
            self.builder.add_memlet(block, t_right, "void", t_task, "_in2", right_sub)
4✔
867
            self.builder.add_memlet(block, t_task, "_out", t_out, "void", "")
4✔
868

869
            return tmp_name
4✔
870
        elif op == "%":
4✔
871
            block = self.builder.add_block()
4✔
872
            t_left, left_sub = self._add_read(block, real_left)
4✔
873
            t_right, right_sub = self._add_read(block, real_right)
4✔
874
            t_out = self.builder.add_access(block, tmp_name)
4✔
875

876
            if dtype.primitive_type == PrimitiveType.Int64:
4✔
877
                # Implement ((a % b) + b) % b to match Python's modulo behavior
878

879
                # 1. rem1 = a % b
880
                t_rem1 = self.builder.add_tasklet(
4✔
881
                    block, TaskletCode.int_srem, ["_in1", "_in2"], ["_out"]
882
                )
883
                self.builder.add_memlet(block, t_left, "void", t_rem1, "_in1", left_sub)
4✔
884
                self.builder.add_memlet(
4✔
885
                    block, t_right, "void", t_rem1, "_in2", right_sub
886
                )
887

888
                rem1_name = f"_tmp_{self._get_unique_id()}"
4✔
889
                self.builder.add_container(rem1_name, dtype, False)
4✔
890
                t_rem1_out = self.builder.add_access(block, rem1_name)
4✔
891
                self.builder.add_memlet(block, t_rem1, "_out", t_rem1_out, "void", "")
4✔
892

893
                # 2. add = rem1 + b
894
                t_add = self.builder.add_tasklet(
4✔
895
                    block, TaskletCode.int_add, ["_in1", "_in2"], ["_out"]
896
                )
897
                self.builder.add_memlet(block, t_rem1_out, "void", t_add, "_in1", "")
4✔
898
                self.builder.add_memlet(
4✔
899
                    block, t_right, "void", t_add, "_in2", right_sub
900
                )
901

902
                add_name = f"_tmp_{self._get_unique_id()}"
4✔
903
                self.builder.add_container(add_name, dtype, False)
4✔
904
                t_add_out = self.builder.add_access(block, add_name)
4✔
905
                self.builder.add_memlet(block, t_add, "_out", t_add_out, "void", "")
4✔
906

907
                # 3. res = add % b
908
                t_rem2 = self.builder.add_tasklet(
4✔
909
                    block, TaskletCode.int_srem, ["_in1", "_in2"], ["_out"]
910
                )
911
                self.builder.add_memlet(block, t_add_out, "void", t_rem2, "_in1", "")
4✔
912
                self.builder.add_memlet(
4✔
913
                    block, t_right, "void", t_rem2, "_in2", right_sub
914
                )
915
                self.builder.add_memlet(block, t_rem2, "_out", t_out, "void", "")
4✔
916

917
                return tmp_name
4✔
918
            else:
919
                # Python's floored modulo: a % b = a - floor(a / b) * b
920
                # This differs from fmod which uses trunc instead of floor
921
                # Implement as: fmod(fmod(a, b) + b, b) to handle negative values
922

923
                # 1. rem1 = fmod(a, b)
924
                t_rem1 = self.builder.add_tasklet(
4✔
925
                    block, TaskletCode.fp_rem, ["_in1", "_in2"], ["_out"]
926
                )
927
                self.builder.add_memlet(block, t_left, "void", t_rem1, "_in1", left_sub)
4✔
928
                self.builder.add_memlet(
4✔
929
                    block, t_right, "void", t_rem1, "_in2", right_sub
930
                )
931

932
                rem1_name = f"_tmp_{self._get_unique_id()}"
4✔
933
                self.builder.add_container(rem1_name, dtype, False)
4✔
934
                t_rem1_out = self.builder.add_access(block, rem1_name)
4✔
935
                self.builder.add_memlet(block, t_rem1, "_out", t_rem1_out, "void", "")
4✔
936

937
                # 2. add = rem1 + b
938
                t_add = self.builder.add_tasklet(
4✔
939
                    block, TaskletCode.fp_add, ["_in1", "_in2"], ["_out"]
940
                )
941
                self.builder.add_memlet(block, t_rem1_out, "void", t_add, "_in1", "")
4✔
942
                self.builder.add_memlet(
4✔
943
                    block, t_right, "void", t_add, "_in2", right_sub
944
                )
945

946
                add_name = f"_tmp_{self._get_unique_id()}"
4✔
947
                self.builder.add_container(add_name, dtype, False)
4✔
948
                t_add_out = self.builder.add_access(block, add_name)
4✔
949
                self.builder.add_memlet(block, t_add, "_out", t_add_out, "void", "")
4✔
950

951
                # 3. res = fmod(add, b)
952
                t_rem2 = self.builder.add_tasklet(
4✔
953
                    block, TaskletCode.fp_rem, ["_in1", "_in2"], ["_out"]
954
                )
955
                self.builder.add_memlet(block, t_add_out, "void", t_rem2, "_in1", "")
4✔
956
                self.builder.add_memlet(
4✔
957
                    block, t_right, "void", t_rem2, "_in2", right_sub
958
                )
959
                self.builder.add_memlet(block, t_rem2, "_out", t_out, "void", "")
4✔
960

961
                return tmp_name
4✔
962

963
        tasklet_code = None
4✔
964
        if dtype.primitive_type == PrimitiveType.Int64:
4✔
965
            if op == "+":
4✔
966
                tasklet_code = TaskletCode.int_add
4✔
967
            elif op == "-":
4✔
968
                tasklet_code = TaskletCode.int_sub
4✔
969
            elif op == "*":
4✔
970
                tasklet_code = TaskletCode.int_mul
4✔
971
            elif op == "/":
4✔
972
                tasklet_code = TaskletCode.int_sdiv
×
973
            elif op == "//":
4✔
974
                tasklet_code = TaskletCode.int_sdiv
4✔
975
            elif op == "&":
4✔
NEW
976
                tasklet_code = TaskletCode.int_and
×
977
            elif op == "|":
4✔
978
                tasklet_code = TaskletCode.int_or
4✔
979
            elif op == "^":
4✔
980
                tasklet_code = TaskletCode.int_xor
4✔
NEW
981
            elif op == "<<":
×
NEW
982
                tasklet_code = TaskletCode.int_shl
×
NEW
983
            elif op == ">>":
×
NEW
984
                tasklet_code = TaskletCode.int_lshr
×
985
        else:
986
            if op == "+":
4✔
987
                tasklet_code = TaskletCode.fp_add
4✔
988
            elif op == "-":
4✔
989
                tasklet_code = TaskletCode.fp_sub
4✔
990
            elif op == "*":
4✔
991
                tasklet_code = TaskletCode.fp_mul
4✔
992
            elif op == "/":
4✔
993
                tasklet_code = TaskletCode.fp_div
4✔
994
            elif op == "//":
×
995
                tasklet_code = TaskletCode.fp_div
×
996
            else:
997
                raise NotImplementedError(f"Operation {op} not supported for floats")
×
998

999
        block = self.builder.add_block()
4✔
1000
        t_left, left_sub = self._add_read(block, real_left)
4✔
1001
        t_right, right_sub = self._add_read(block, real_right)
4✔
1002
        t_out = self.builder.add_access(block, tmp_name)
4✔
1003

1004
        t_task = self.builder.add_tasklet(
4✔
1005
            block, tasklet_code, ["_in1", "_in2"], ["_out"]
1006
        )
1007

1008
        self.builder.add_memlet(block, t_left, "void", t_task, "_in1", left_sub)
4✔
1009
        self.builder.add_memlet(block, t_right, "void", t_task, "_in2", right_sub)
4✔
1010
        self.builder.add_memlet(block, t_task, "_out", t_out, "void", "")
4✔
1011

1012
        return tmp_name
4✔
1013

1014
    def _add_assign_constant(self, target_name, value_str, dtype):
4✔
1015
        block = self.builder.add_block()
4✔
1016
        t_const = self.builder.add_constant(block, value_str, dtype)
4✔
1017
        t_dst = self.builder.add_access(block, target_name)
4✔
1018
        t_task = self.builder.add_tasklet(block, TaskletCode.assign, ["_in"], ["_out"])
4✔
1019
        self.builder.add_memlet(block, t_const, "void", t_task, "_in", "")
4✔
1020
        self.builder.add_memlet(block, t_task, "_out", t_dst, "void", "")
4✔
1021

1022
    def visit_BoolOp(self, node):
4✔
1023
        op = self.visit(node.op)
4✔
1024
        values = [f"({self.visit(v)} != 0)" for v in node.values]
4✔
1025
        expr_str = f"{f' {op} '.join(values)}"
4✔
1026

1027
        tmp_name = f"_tmp_{self._get_unique_id()}"
4✔
1028
        dtype = Scalar(PrimitiveType.Bool)
4✔
1029
        self.builder.add_container(tmp_name, dtype, False)
4✔
1030

1031
        # Use control flow to assign boolean value
1032
        self.builder.begin_if(expr_str)
4✔
1033
        self._add_assign_constant(tmp_name, "true", dtype)
4✔
1034
        self.builder.begin_else()
4✔
1035
        self._add_assign_constant(tmp_name, "false", dtype)
4✔
1036
        self.builder.end_if()
4✔
1037

1038
        self.symbol_table[tmp_name] = dtype
4✔
1039
        return tmp_name
4✔
1040

1041
    def visit_Compare(self, node):
4✔
1042
        left = self.visit(node.left)
4✔
1043
        if len(node.ops) > 1:
4✔
1044
            raise NotImplementedError("Chained comparisons not supported yet")
×
1045

1046
        op = self.visit(node.ops[0])
4✔
1047
        right = self.visit(node.comparators[0])
4✔
1048

1049
        # Check if this is an array comparison
1050
        left_is_array = left in self.array_info
4✔
1051
        right_is_array = right in self.array_info
4✔
1052

1053
        if left_is_array or right_is_array:
4✔
1054
            # Handle array comparison - return boolean array
1055
            return self._handle_array_compare(
4✔
1056
                left, op, right, left_is_array, right_is_array
1057
            )
1058

1059
        # Scalar comparison
1060
        expr_str = f"{left} {op} {right}"
4✔
1061

1062
        tmp_name = f"_tmp_{self._get_unique_id()}"
4✔
1063
        dtype = Scalar(PrimitiveType.Bool)
4✔
1064
        self.builder.add_container(tmp_name, dtype, False)
4✔
1065

1066
        # Use control flow to assign boolean value
1067
        self.builder.begin_if(expr_str)
4✔
1068
        self.builder.add_transition(tmp_name, "true")
4✔
1069
        self.builder.begin_else()
4✔
1070
        self.builder.add_transition(tmp_name, "false")
4✔
1071
        self.builder.end_if()
4✔
1072

1073
        self.symbol_table[tmp_name] = dtype
4✔
1074
        return tmp_name
4✔
1075

1076
    def visit_UnaryOp(self, node):
4✔
1077
        if (
4✔
1078
            isinstance(node.op, ast.USub)
1079
            and isinstance(node.operand, ast.Constant)
1080
            and isinstance(node.operand.value, (int, float))
1081
        ):
1082
            return f"-{node.operand.value}"
4✔
1083

1084
        op = self.visit(node.op)
4✔
1085
        operand = self.visit(node.operand)
4✔
1086

1087
        # Check if operand is an array - handle as array operation
1088
        if operand in self.array_info and op == "-":
4✔
1089
            return self._handle_array_negate(operand)
4✔
1090

1091
        tmp_name = f"_tmp_{self._get_unique_id()}"
4✔
1092
        dtype = Scalar(PrimitiveType.Double)
4✔
1093
        if operand in self.symbol_table:
4✔
1094
            dtype = self.symbol_table[operand]
4✔
1095
            # If it's a pointer (array), get the element type
1096
            if isinstance(dtype, Pointer) and dtype.has_pointee_type():
4✔
1097
                dtype = dtype.pointee_type
×
1098
        elif self._is_int(operand):
×
1099
            dtype = Scalar(PrimitiveType.Int64)
×
1100
        elif isinstance(node.op, ast.Not):
×
1101
            dtype = Scalar(PrimitiveType.Bool)
×
1102

1103
        self.builder.add_container(tmp_name, dtype, False)
4✔
1104
        self.symbol_table[tmp_name] = dtype
4✔
1105

1106
        block = self.builder.add_block()
4✔
1107
        t_src, src_sub = self._add_read(block, operand)
4✔
1108
        t_dst = self.builder.add_access(block, tmp_name)
4✔
1109

1110
        if isinstance(node.op, ast.Not):
4✔
1111
            t_const = self.builder.add_constant(
4✔
1112
                block, "true", Scalar(PrimitiveType.Bool)
1113
            )
1114
            t_task = self.builder.add_tasklet(
4✔
1115
                block, TaskletCode.int_xor, ["_in1", "_in2"], ["_out"]
1116
            )
1117
            self.builder.add_memlet(block, t_src, "void", t_task, "_in1", src_sub)
4✔
1118
            self.builder.add_memlet(block, t_const, "void", t_task, "_in2", "")
4✔
1119
            self.builder.add_memlet(block, t_task, "_out", t_dst, "void", "")
4✔
1120

1121
        elif op == "-":
4✔
1122
            if dtype.primitive_type == PrimitiveType.Int64:
4✔
1123
                t_const = self.builder.add_constant(block, "0", dtype)
4✔
1124
                t_task = self.builder.add_tasklet(
4✔
1125
                    block, TaskletCode.int_sub, ["_in1", "_in2"], ["_out"]
1126
                )
1127
                self.builder.add_memlet(block, t_const, "void", t_task, "_in1", "")
4✔
1128
                self.builder.add_memlet(block, t_src, "void", t_task, "_in2", src_sub)
4✔
1129
                self.builder.add_memlet(block, t_task, "_out", t_dst, "void", "")
4✔
1130
            else:
1131
                t_task = self.builder.add_tasklet(
4✔
1132
                    block, TaskletCode.fp_neg, ["_in"], ["_out"]
1133
                )
1134
                self.builder.add_memlet(block, t_src, "void", t_task, "_in", src_sub)
4✔
1135
                self.builder.add_memlet(block, t_task, "_out", t_dst, "void", "")
4✔
1136

NEW
1137
        elif op == "~":
×
1138
            # Bitwise NOT: ~x = x XOR -1 (all bits set)
NEW
1139
            t_const = self.builder.add_constant(
×
1140
                block, "-1", Scalar(PrimitiveType.Int64)
1141
            )
NEW
1142
            t_task = self.builder.add_tasklet(
×
1143
                block, TaskletCode.int_xor, ["_in1", "_in2"], ["_out"]
1144
            )
NEW
1145
            self.builder.add_memlet(block, t_src, "void", t_task, "_in1", src_sub)
×
NEW
1146
            self.builder.add_memlet(block, t_const, "void", t_task, "_in2", "")
×
NEW
1147
            self.builder.add_memlet(block, t_task, "_out", t_dst, "void", "")
×
1148

1149
        else:
1150
            t_task = self.builder.add_tasklet(
×
1151
                block, TaskletCode.assign, ["_in"], ["_out"]
1152
            )
1153
            self.builder.add_memlet(block, t_src, "void", t_task, "_in", src_sub)
×
1154
            self.builder.add_memlet(block, t_task, "_out", t_dst, "void", "")
×
1155

1156
        return tmp_name
4✔
1157

1158
    def _handle_array_negate(self, operand):
4✔
1159
        """Handle negation of an array operand (-arr)."""
1160
        shape = self.array_info[operand]["shapes"]
4✔
1161
        dtype = self._get_dtype(operand)
4✔
1162

1163
        # Create output array
1164
        tmp_name = self._create_array_temp(shape, dtype)
4✔
1165

1166
        # Use elementwise binary op: 0 - arr
1167
        # First create a zero constant
1168
        zero_name = f"_tmp_{self._get_unique_id()}"
4✔
1169
        self.builder.add_container(zero_name, dtype, False)
4✔
1170
        self.symbol_table[zero_name] = dtype
4✔
1171

1172
        zero_block = self.builder.add_block()
4✔
1173
        t_const = self.builder.add_constant(
4✔
1174
            zero_block,
1175
            "0.0" if dtype.primitive_type == PrimitiveType.Double else "0",
1176
            dtype,
1177
        )
1178
        t_zero = self.builder.add_access(zero_block, zero_name)
4✔
1179
        t_assign = self.builder.add_tasklet(
4✔
1180
            zero_block, TaskletCode.assign, ["_in"], ["_out"]
1181
        )
1182
        self.builder.add_memlet(zero_block, t_const, "void", t_assign, "_in", "")
4✔
1183
        self.builder.add_memlet(zero_block, t_assign, "_out", t_zero, "void", "")
4✔
1184

1185
        # Now subtract: tmp = 0 - operand (broadcast scalar subtraction)
1186
        self.builder.add_elementwise_op("sub", zero_name, operand, tmp_name, shape)
4✔
1187

1188
        return tmp_name
4✔
1189

1190
    def _handle_array_compare(self, left, op, right, left_is_array, right_is_array):
4✔
1191
        """Handle elementwise comparison of arrays, returning a boolean array.
1192

1193
        Supports: arr > 0, arr < scalar, arr1 > arr2, etc.
1194
        """
1195
        # Determine shape from the array operand
1196
        if left_is_array:
4✔
1197
            shape = self.array_info[left]["shapes"]
4✔
1198
            arr_name = left
4✔
1199
        else:
1200
            shape = self.array_info[right]["shapes"]
×
1201
            arr_name = right
×
1202

1203
        # Determine if we need integer or floating point comparison
1204
        # based on the array element type
1205
        use_int_cmp = False
4✔
1206
        arr_dtype = self._get_dtype(arr_name)
4✔
1207
        if arr_dtype.primitive_type in (PrimitiveType.Int32, PrimitiveType.Int64):
4✔
1208
            use_int_cmp = True
×
1209

1210
        # Create output boolean array
1211
        dtype = Scalar(PrimitiveType.Bool)
4✔
1212
        tmp_name = self._create_array_temp(shape, dtype)
4✔
1213

1214
        # Map comparison operators to tasklet codes
1215
        if use_int_cmp:
4✔
1216
            cmp_ops = {
×
1217
                ">": TaskletCode.int_sgt,
1218
                ">=": TaskletCode.int_sge,
1219
                "<": TaskletCode.int_slt,
1220
                "<=": TaskletCode.int_sle,
1221
                "==": TaskletCode.int_eq,
1222
                "!=": TaskletCode.int_ne,
1223
            }
1224
        else:
1225
            # Floating point ordered comparisons
1226
            cmp_ops = {
4✔
1227
                ">": TaskletCode.fp_ogt,
1228
                ">=": TaskletCode.fp_oge,
1229
                "<": TaskletCode.fp_olt,
1230
                "<=": TaskletCode.fp_ole,
1231
                "==": TaskletCode.fp_oeq,
1232
                "!=": TaskletCode.fp_one,
1233
            }
1234

1235
        if op not in cmp_ops:
4✔
1236
            raise NotImplementedError(
×
1237
                f"Comparison operator {op} not supported for arrays"
1238
            )
1239

1240
        tasklet_code = cmp_ops[op]
4✔
1241

1242
        # For scalar operand, we may need to convert integer to float
1243
        # Create a float constant if needed
1244
        scalar_name = None
4✔
1245
        if not left_is_array:
4✔
1246
            scalar_name = left
×
1247
        elif not right_is_array:
4✔
1248
            scalar_name = right
4✔
1249

1250
        if scalar_name is not None and not use_int_cmp:
4✔
1251
            # Check if scalar is an integer literal and convert to float
1252
            if self._is_int(scalar_name):
4✔
1253
                # Create a float constant
1254
                float_name = f"_tmp_{self._get_unique_id()}"
4✔
1255
                self.builder.add_container(
4✔
1256
                    float_name, Scalar(PrimitiveType.Double), False
1257
                )
1258
                self.symbol_table[float_name] = Scalar(PrimitiveType.Double)
4✔
1259

1260
                block_conv = self.builder.add_block()
4✔
1261
                t_const = self.builder.add_constant(
4✔
1262
                    block_conv, f"{scalar_name}.0", Scalar(PrimitiveType.Double)
1263
                )
1264
                t_float = self.builder.add_access(block_conv, float_name)
4✔
1265
                t_assign = self.builder.add_tasklet(
4✔
1266
                    block_conv, TaskletCode.assign, ["_in"], ["_out"]
1267
                )
1268
                self.builder.add_memlet(
4✔
1269
                    block_conv, t_const, "void", t_assign, "_in", ""
1270
                )
1271
                self.builder.add_memlet(
4✔
1272
                    block_conv, t_assign, "_out", t_float, "void", ""
1273
                )
1274

1275
                # Replace the scalar name with the converted float
1276
                if not left_is_array:
4✔
1277
                    left = float_name
×
1278
                else:
1279
                    right = float_name
4✔
1280

1281
        # Generate nested loops
1282
        loop_vars = []
4✔
1283
        for i, dim in enumerate(shape):
4✔
1284
            loop_var = f"_cmp_i{i}_{self._get_unique_id()}"
4✔
1285
            if not self.builder.exists(loop_var):
4✔
1286
                self.builder.add_container(loop_var, Scalar(PrimitiveType.Int64), False)
4✔
1287
                self.symbol_table[loop_var] = Scalar(PrimitiveType.Int64)
4✔
1288
            loop_vars.append(loop_var)
4✔
1289
            self.builder.begin_for(loop_var, "0", str(dim), "1")
4✔
1290

1291
        # Compute linear index for array access
1292
        linear_idx = self._compute_linear_index(loop_vars, shape, tmp_name, len(shape))
4✔
1293

1294
        # Create comparison block
1295
        block = self.builder.add_block()
4✔
1296

1297
        # Read left operand
1298
        if left_is_array:
4✔
1299
            t_left = self.builder.add_access(block, left)
4✔
1300
            left_sub = linear_idx
4✔
1301
        else:
1302
            t_left, left_sub = self._add_read(block, left)
×
1303

1304
        # Read right operand
1305
        if right_is_array:
4✔
1306
            t_right = self.builder.add_access(block, right)
×
1307
            right_sub = linear_idx
×
1308
        else:
1309
            t_right, right_sub = self._add_read(block, right)
4✔
1310

1311
        # Output access
1312
        t_out = self.builder.add_access(block, tmp_name)
4✔
1313

1314
        # Create tasklet for comparison
1315
        t_task = self.builder.add_tasklet(
4✔
1316
            block, tasklet_code, ["_in1", "_in2"], ["_out"]
1317
        )
1318

1319
        # Connect memlets
1320
        self.builder.add_memlet(block, t_left, "void", t_task, "_in1", left_sub)
4✔
1321
        self.builder.add_memlet(block, t_right, "void", t_task, "_in2", right_sub)
4✔
1322
        self.builder.add_memlet(block, t_task, "_out", t_out, "void", linear_idx)
4✔
1323

1324
        # Close loops
1325
        for _ in loop_vars:
4✔
1326
            self.builder.end_for()
4✔
1327

1328
        return tmp_name
4✔
1329

1330
    def _parse_array_arg(self, node, simple_visitor):
4✔
1331
        if isinstance(node, ast.Name):
×
1332
            if node.id in self.array_info:
×
1333
                return node.id, [], self.array_info[node.id]["shapes"]
×
1334
        elif isinstance(node, ast.Subscript):
×
1335
            if isinstance(node.value, ast.Name) and node.value.id in self.array_info:
×
1336
                name = node.value.id
×
1337
                ndim = self.array_info[name]["ndim"]
×
1338

1339
                indices = []
×
1340
                if isinstance(node.slice, ast.Tuple):
×
1341
                    indices = list(node.slice.elts)
×
1342
                else:
1343
                    indices = [node.slice]
×
1344

1345
                while len(indices) < ndim:
×
1346
                    indices.append(ast.Slice(lower=None, upper=None, step=None))
×
1347

1348
                start_indices = []
×
1349
                slice_shape = []
×
1350

1351
                for i, idx in enumerate(indices):
×
1352
                    if isinstance(idx, ast.Slice):
×
1353
                        start = "0"
×
1354
                        if idx.lower:
×
1355
                            start = simple_visitor.visit(idx.lower)
×
1356
                        start_indices.append(start)
×
1357

1358
                        shapes = self.array_info[name]["shapes"]
×
1359
                        dim_size = (
×
1360
                            shapes[i] if i < len(shapes) else f"_{name}_shape_{i}"
1361
                        )
1362
                        stop = dim_size
×
1363
                        if idx.upper:
×
1364
                            stop = simple_visitor.visit(idx.upper)
×
1365

1366
                        size = f"({stop} - {start})"
×
1367
                        slice_shape.append(size)
×
1368
                    else:
1369
                        val = simple_visitor.visit(idx)
×
1370
                        start_indices.append(val)
×
1371

1372
                shapes = self.array_info[name]["shapes"]
×
1373
                linear_index = ""
×
1374
                for i in range(ndim):
×
1375
                    term = start_indices[i]
×
1376
                    for j in range(i + 1, ndim):
×
1377
                        shape_val = shapes[j] if j < len(shapes) else None
×
1378
                        shape_sym = (
×
1379
                            shape_val if shape_val is not None else f"_{name}_shape_{j}"
1380
                        )
1381
                        term = f"({term} * {shape_sym})"
×
1382

1383
                    if i == 0:
×
1384
                        linear_index = term
×
1385
                    else:
1386
                        linear_index = f"({linear_index} + {term})"
×
1387

1388
                return name, [linear_index], slice_shape
×
1389

1390
        return None, None, None
×
1391

1392
    def visit_Attribute(self, node):
4✔
1393
        if node.attr == "shape":
4✔
1394
            if isinstance(node.value, ast.Name) and node.value.id in self.array_info:
4✔
1395
                return f"_shape_proxy_{node.value.id}"
4✔
1396

1397
        if isinstance(node.value, ast.Name) and node.value.id == "math":
4✔
1398
            val = ""
4✔
1399
            if node.attr == "pi":
4✔
1400
                val = "M_PI"
4✔
1401
            elif node.attr == "e":
4✔
1402
                val = "M_E"
4✔
1403

1404
            if val:
4✔
1405
                tmp_name = f"_tmp_{self._get_unique_id()}"
4✔
1406
                dtype = Scalar(PrimitiveType.Double)
4✔
1407
                self.builder.add_container(tmp_name, dtype, False)
4✔
1408
                self.symbol_table[tmp_name] = dtype
4✔
1409
                self._add_assign_constant(tmp_name, val, dtype)
4✔
1410
                return tmp_name
4✔
1411

1412
        # Handle class member access (e.g., obj.x, obj.y)
1413
        if isinstance(node.value, ast.Name):
4✔
1414
            obj_name = node.value.id
4✔
1415
            attr_name = node.attr
4✔
1416

1417
            # Check if the object is a class instance (has a Structure type)
1418
            if obj_name in self.symbol_table:
4✔
1419
                obj_type = self.symbol_table[obj_name]
4✔
1420
                if isinstance(obj_type, Pointer) and obj_type.has_pointee_type():
4✔
1421
                    pointee_type = obj_type.pointee_type
4✔
1422
                    if isinstance(pointee_type, Structure):
4✔
1423
                        struct_name = pointee_type.name
4✔
1424

1425
                        # Look up member index and type from structure info
1426
                        if (
4✔
1427
                            struct_name in self.structure_member_info
1428
                            and attr_name in self.structure_member_info[struct_name]
1429
                        ):
1430
                            member_index, member_type = self.structure_member_info[
4✔
1431
                                struct_name
1432
                            ][attr_name]
1433
                        else:
1434
                            # This should not happen if structure was registered properly
1435
                            raise RuntimeError(
×
1436
                                f"Member '{attr_name}' not found in structure '{struct_name}'. "
1437
                                f"Available members: {list(self.structure_member_info.get(struct_name, {}).keys())}"
1438
                            )
1439

1440
                        # Generate a tasklet to access the member
1441
                        tmp_name = f"_tmp_{self._get_unique_id()}"
4✔
1442

1443
                        self.builder.add_container(tmp_name, member_type, False)
4✔
1444
                        self.symbol_table[tmp_name] = member_type
4✔
1445

1446
                        # Create a tasklet that reads the member
1447
                        block = self.builder.add_block()
4✔
1448
                        obj_access = self.builder.add_access(block, obj_name)
4✔
1449
                        tmp_access = self.builder.add_access(block, tmp_name)
4✔
1450

1451
                        # Use tasklet to pass through the value
1452
                        # The actual member selection is done via the memlet subset
1453
                        tasklet = self.builder.add_tasklet(
4✔
1454
                            block, TaskletCode.assign, ["_in"], ["_out"]
1455
                        )
1456

1457
                        # Use member index in the subset to select the correct member
1458
                        subset = "0," + str(member_index)
4✔
1459
                        self.builder.add_memlet(
4✔
1460
                            block, obj_access, "", tasklet, "_in", subset
1461
                        )
1462
                        self.builder.add_memlet(block, tasklet, "_out", tmp_access, "")
4✔
1463

1464
                        return tmp_name
4✔
1465

1466
        raise NotImplementedError(f"Attribute access {node.attr} not supported")
×
1467

1468
    def _handle_expression_slicing(self, node, value_str, indices_nodes, shapes, ndim):
4✔
1469
        """Handle slicing in expressions (e.g., arr[1:, :, k+1]).
1470

1471
        Creates a temporary array, generates loops to copy sliced data,
1472
        and returns the temporary array name.
1473
        """
1474
        if not self.builder:
4✔
1475
            raise ValueError("Builder required for expression slicing")
×
1476

1477
        # Determine element type from source array
1478
        dtype = Scalar(PrimitiveType.Double)
4✔
1479
        if value_str in self.symbol_table:
4✔
1480
            t = self.symbol_table[value_str]
4✔
1481
            if isinstance(t, Pointer) and t.has_pointee_type():
4✔
1482
                dtype = t.pointee_type
4✔
1483

1484
        # Analyze each dimension: is it a slice or an index?
1485
        # For slices, compute the resulting shape dimension
1486
        # For indices, that dimension is collapsed
1487
        result_shapes = []  # Shape of the resulting array (for SDFG)
4✔
1488
        result_shapes_runtime = []  # Shape expressions for runtime evaluation
4✔
1489
        slice_info = []  # List of (dim_idx, start_str, stop_str, step_str) for slices
4✔
1490
        index_info = []  # List of (dim_idx, index_str) for point indices
4✔
1491

1492
        for i, idx in enumerate(indices_nodes):
4✔
1493
            shape_val = shapes[i] if i < len(shapes) else f"_{value_str}_shape_{i}"
4✔
1494

1495
            if isinstance(idx, ast.Slice):
4✔
1496
                # Parse slice bounds - check for indirect access patterns
1497
                start_str = "0"
4✔
1498
                start_str_runtime = "0"  # For runtime shape evaluation
4✔
1499
                if idx.lower is not None:
4✔
1500
                    # Check if lower bound contains indirect array access
1501
                    if self._contains_indirect_access(idx.lower):
4✔
1502
                        start_str, start_str_runtime = (
4✔
1503
                            self._materialize_indirect_access(
1504
                                idx.lower, return_original_expr=True
1505
                            )
1506
                        )
1507
                    else:
1508
                        start_str = self.visit(idx.lower)
4✔
1509
                        start_str_runtime = start_str
4✔
1510
                    # Handle negative indices
1511
                    if isinstance(start_str, str) and (
4✔
1512
                        start_str.startswith("-") or start_str.startswith("(-")
1513
                    ):
1514
                        start_str = f"({shape_val} + {start_str})"
×
1515
                        start_str_runtime = f"({shape_val} + {start_str_runtime})"
×
1516

1517
                stop_str = str(shape_val)
4✔
1518
                stop_str_runtime = str(shape_val)
4✔
1519
                if idx.upper is not None:
4✔
1520
                    # Check if upper bound contains indirect array access
1521
                    if self._contains_indirect_access(idx.upper):
4✔
1522
                        stop_str, stop_str_runtime = self._materialize_indirect_access(
4✔
1523
                            idx.upper, return_original_expr=True
1524
                        )
1525
                    else:
1526
                        stop_str = self.visit(idx.upper)
4✔
1527
                        stop_str_runtime = stop_str
4✔
1528
                    # Handle negative indices
1529
                    if isinstance(stop_str, str) and (
4✔
1530
                        stop_str.startswith("-") or stop_str.startswith("(-")
1531
                    ):
1532
                        stop_str = f"({shape_val} + {stop_str})"
4✔
1533
                        stop_str_runtime = f"({shape_val} + {stop_str_runtime})"
4✔
1534

1535
                step_str = "1"
4✔
1536
                if idx.step is not None:
4✔
1537
                    step_str = self.visit(idx.step)
×
1538

1539
                # Compute the size of this dimension in the result
1540
                dim_size = f"({stop_str} - {start_str})"
4✔
1541
                dim_size_runtime = f"({stop_str_runtime} - {start_str_runtime})"
4✔
1542
                result_shapes.append(dim_size)
4✔
1543
                result_shapes_runtime.append(dim_size_runtime)
4✔
1544
                slice_info.append((i, start_str, stop_str, step_str))
4✔
1545
            else:
1546
                # Point index - dimension is collapsed
1547
                # Check for indirect array access in the index
1548
                if self._contains_indirect_access(idx):
4✔
1549
                    index_str = self._materialize_indirect_access(idx)
×
1550
                else:
1551
                    index_str = self.visit(idx)
4✔
1552
                # Handle negative indices
1553
                if isinstance(index_str, str) and (
4✔
1554
                    index_str.startswith("-") or index_str.startswith("(-")
1555
                ):
1556
                    index_str = f"({shape_val} + {index_str})"
×
1557
                index_info.append((i, index_str))
4✔
1558

1559
        # Create temporary array for the result
1560
        tmp_name = self._get_temp_name("_slice_tmp_")
4✔
1561
        result_ndim = len(result_shapes)
4✔
1562

1563
        if result_ndim == 0:
4✔
1564
            # All dimensions indexed - result is a scalar
1565
            self.builder.add_container(tmp_name, dtype, False)
×
1566
            self.symbol_table[tmp_name] = dtype
×
1567
        else:
1568
            # Result is an array - use _create_array_temp to handle allocation
1569
            # Calculate size for malloc - use SDFG symbolic shapes
1570
            size_str = "1"
4✔
1571
            for dim in result_shapes:
4✔
1572
                size_str = f"({size_str} * {dim})"
4✔
1573

1574
            element_size = self.builder.get_sizeof(dtype)
4✔
1575
            total_size = f"({size_str} * {element_size})"
4✔
1576

1577
            # Create pointer
1578
            ptr_type = Pointer(dtype)
4✔
1579
            self.builder.add_container(tmp_name, ptr_type, False)
4✔
1580
            self.symbol_table[tmp_name] = ptr_type
4✔
1581
            # Store both SDFG shapes (for compilation) and runtime shapes (for evaluation)
1582
            # The "shapes" field uses SDFG symbolic variables for malloc sizing
1583
            # The "shapes_runtime" field uses original expressions for Python runtime evaluation
1584
            self.array_info[tmp_name] = {
4✔
1585
                "ndim": result_ndim,
1586
                "shapes": result_shapes,  # Uses materialized variables for SDFG
1587
                "shapes_runtime": result_shapes_runtime,  # Uses original expressions for runtime
1588
            }
1589

1590
            # Malloc for the temporary array
1591
            debug_info = DebugInfo()
4✔
1592
            block_alloc = self.builder.add_block(debug_info)
4✔
1593
            t_malloc = self.builder.add_malloc(block_alloc, total_size)
4✔
1594
            t_ptr = self.builder.add_access(block_alloc, tmp_name, debug_info)
4✔
1595
            self.builder.add_memlet(
4✔
1596
                block_alloc, t_malloc, "_ret", t_ptr, "void", "", ptr_type, debug_info
1597
            )
1598

1599
        # Generate loops to copy the sliced data
1600
        loop_vars = []
4✔
1601
        debug_info = DebugInfo()
4✔
1602

1603
        for dim_idx, (orig_dim, start_str, stop_str, step_str) in enumerate(slice_info):
4✔
1604
            loop_var = f"_slice_loop_{dim_idx}_{self._get_unique_id()}"
4✔
1605
            loop_vars.append((loop_var, orig_dim, start_str, step_str))
4✔
1606

1607
            if not self.builder.exists(loop_var):
4✔
1608
                self.builder.add_container(loop_var, Scalar(PrimitiveType.Int64), False)
4✔
1609
                self.symbol_table[loop_var] = Scalar(PrimitiveType.Int64)
4✔
1610

1611
            # Loop from 0 to (stop - start)
1612
            count_str = f"({stop_str} - {start_str})"
4✔
1613
            self.builder.begin_for(loop_var, "0", count_str, "1", debug_info)
4✔
1614

1615
        # Build source and destination indices
1616
        src_indices = [""] * ndim
4✔
1617
        dst_indices = []
4✔
1618

1619
        # Fill in point indices for source
1620
        for orig_dim, index_str in index_info:
4✔
1621
            src_indices[orig_dim] = index_str
4✔
1622

1623
        # Fill in slice indices for source and build destination indices
1624
        for loop_var, orig_dim, start_str, step_str in loop_vars:
4✔
1625
            if step_str == "1":
4✔
1626
                src_indices[orig_dim] = f"({start_str} + {loop_var})"
4✔
1627
            else:
1628
                src_indices[orig_dim] = f"({start_str} + {loop_var} * {step_str})"
×
1629
            dst_indices.append(loop_var)
4✔
1630

1631
        # Compute linear indices
1632
        src_linear = self._compute_linear_index(src_indices, shapes, value_str, ndim)
4✔
1633
        if result_ndim > 0:
4✔
1634
            dst_linear = self._compute_linear_index(
4✔
1635
                dst_indices, result_shapes, tmp_name, result_ndim
1636
            )
1637
        else:
1638
            dst_linear = "0"
×
1639

1640
        # Create the copy block
1641
        block = self.builder.add_block(debug_info)
4✔
1642
        t_src = self.builder.add_access(block, value_str, debug_info)
4✔
1643
        t_dst = self.builder.add_access(block, tmp_name, debug_info)
4✔
1644
        t_task = self.builder.add_tasklet(
4✔
1645
            block, TaskletCode.assign, ["_in"], ["_out"], debug_info
1646
        )
1647

1648
        self.builder.add_memlet(
4✔
1649
            block, t_src, "void", t_task, "_in", src_linear, None, debug_info
1650
        )
1651
        self.builder.add_memlet(
4✔
1652
            block, t_task, "_out", t_dst, "void", dst_linear, None, debug_info
1653
        )
1654

1655
        # Close all loops
1656
        for _ in loop_vars:
4✔
1657
            self.builder.end_for()
4✔
1658

1659
        return tmp_name
4✔
1660

1661
    def _compute_linear_index(self, indices, shapes, array_name, ndim):
4✔
1662
        """Compute linear index from multi-dimensional indices."""
1663
        if ndim == 0:
4✔
1664
            return "0"
×
1665

1666
        linear_index = ""
4✔
1667
        for i in range(ndim):
4✔
1668
            term = str(indices[i])
4✔
1669
            for j in range(i + 1, ndim):
4✔
1670
                shape_val = shapes[j] if j < len(shapes) else f"_{array_name}_shape_{j}"
4✔
1671
                term = f"(({term}) * {shape_val})"
4✔
1672

1673
            if i == 0:
4✔
1674
                linear_index = term
4✔
1675
            else:
1676
                linear_index = f"({linear_index} + {term})"
4✔
1677

1678
        return linear_index
4✔
1679

1680
    def _is_array_index(self, node):
4✔
1681
        """Check if a node represents an array that could be used as an index (gather).
1682

1683
        Returns True if the node is a Name referring to an array in array_info.
1684
        """
1685
        if isinstance(node, ast.Name):
4✔
1686
            return node.id in self.array_info
4✔
1687
        return False
4✔
1688

1689
    def _handle_gather(self, value_str, index_node, debug_info=None):
4✔
1690
        """Handle gather operation: x[indices] where indices is an array.
1691

1692
        Creates a temporary array and generates a loop to gather elements
1693
        from the source array using the index array.
1694

1695
        This is the canonical SDFG pattern for gather operations:
1696
        - Create a loop over the index array
1697
        - Load the index value using a tasklet+memlets
1698
        - Use that index in the memlet subset for the source array
1699
        """
1700
        if debug_info is None:
4✔
1701
            debug_info = DebugInfo()
4✔
1702

1703
        # Get the index array name
1704
        if isinstance(index_node, ast.Name):
4✔
1705
            idx_array_name = index_node.id
4✔
1706
        else:
1707
            # Visit the index to get its name (handles slices like cols)
1708
            idx_array_name = self.visit(index_node)
×
1709

1710
        if idx_array_name not in self.array_info:
4✔
1711
            raise ValueError(f"Gather index must be an array, got {idx_array_name}")
×
1712

1713
        # Get shapes
1714
        idx_shapes = self.array_info[idx_array_name].get("shapes", [])
4✔
1715
        src_ndim = self.array_info[value_str]["ndim"]
4✔
1716
        idx_ndim = self.array_info[idx_array_name]["ndim"]
4✔
1717

1718
        if idx_ndim != 1:
4✔
1719
            raise NotImplementedError("Only 1D index arrays supported for gather")
×
1720

1721
        # Result array has same shape as index array
1722
        result_shape = idx_shapes[0] if idx_shapes else f"_{idx_array_name}_shape_0"
4✔
1723

1724
        # Determine element type from source array
1725
        dtype = Scalar(PrimitiveType.Double)
4✔
1726
        if value_str in self.symbol_table:
4✔
1727
            t = self.symbol_table[value_str]
4✔
1728
            if isinstance(t, Pointer) and t.has_pointee_type():
4✔
1729
                dtype = t.pointee_type
4✔
1730

1731
        # Determine index type from index array
1732
        idx_dtype = Scalar(PrimitiveType.Int64)
4✔
1733
        if idx_array_name in self.symbol_table:
4✔
1734
            t = self.symbol_table[idx_array_name]
4✔
1735
            if isinstance(t, Pointer) and t.has_pointee_type():
4✔
1736
                idx_dtype = t.pointee_type
4✔
1737

1738
        # Create result array
1739
        tmp_name = self._get_temp_name("_gather_")
4✔
1740

1741
        # Calculate size for malloc
1742
        element_size = self.builder.get_sizeof(dtype)
4✔
1743
        total_size = f"({result_shape} * {element_size})"
4✔
1744

1745
        # Create pointer for result
1746
        ptr_type = Pointer(dtype)
4✔
1747
        self.builder.add_container(tmp_name, ptr_type, False)
4✔
1748
        self.symbol_table[tmp_name] = ptr_type
4✔
1749
        self.array_info[tmp_name] = {"ndim": 1, "shapes": [result_shape]}
4✔
1750

1751
        # Malloc for the result array
1752
        block_alloc = self.builder.add_block(debug_info)
4✔
1753
        t_malloc = self.builder.add_malloc(block_alloc, total_size)
4✔
1754
        t_ptr = self.builder.add_access(block_alloc, tmp_name, debug_info)
4✔
1755
        self.builder.add_memlet(
4✔
1756
            block_alloc, t_malloc, "_ret", t_ptr, "void", "", ptr_type, debug_info
1757
        )
1758

1759
        # Create loop variable
1760
        loop_var = f"_gather_i_{self._get_unique_id()}"
4✔
1761
        self.builder.add_container(loop_var, Scalar(PrimitiveType.Int64), False)
4✔
1762
        self.symbol_table[loop_var] = Scalar(PrimitiveType.Int64)
4✔
1763

1764
        # Create variable to hold the loaded index
1765
        idx_var = f"_gather_idx_{self._get_unique_id()}"
4✔
1766
        self.builder.add_container(idx_var, idx_dtype, False)
4✔
1767
        self.symbol_table[idx_var] = idx_dtype
4✔
1768

1769
        # Begin loop
1770
        self.builder.begin_for(loop_var, "0", str(result_shape), "1", debug_info)
4✔
1771

1772
        # Block 1: Load the index from index array using tasklet+memlets
1773
        block_load_idx = self.builder.add_block(debug_info)
4✔
1774
        idx_arr_access = self.builder.add_access(
4✔
1775
            block_load_idx, idx_array_name, debug_info
1776
        )
1777
        idx_var_access = self.builder.add_access(block_load_idx, idx_var, debug_info)
4✔
1778
        tasklet_load = self.builder.add_tasklet(
4✔
1779
            block_load_idx, TaskletCode.assign, ["_in"], ["_out"], debug_info
1780
        )
1781
        self.builder.add_memlet(
4✔
1782
            block_load_idx,
1783
            idx_arr_access,
1784
            "void",
1785
            tasklet_load,
1786
            "_in",
1787
            loop_var,
1788
            None,
1789
            debug_info,
1790
        )
1791
        self.builder.add_memlet(
4✔
1792
            block_load_idx,
1793
            tasklet_load,
1794
            "_out",
1795
            idx_var_access,
1796
            "void",
1797
            "",
1798
            None,
1799
            debug_info,
1800
        )
1801

1802
        # Block 2: Use the loaded index to gather from source array
1803
        block_gather = self.builder.add_block(debug_info)
4✔
1804
        src_access = self.builder.add_access(block_gather, value_str, debug_info)
4✔
1805
        dst_access = self.builder.add_access(block_gather, tmp_name, debug_info)
4✔
1806
        tasklet_gather = self.builder.add_tasklet(
4✔
1807
            block_gather, TaskletCode.assign, ["_in"], ["_out"], debug_info
1808
        )
1809

1810
        # Use the symbolic variable name (idx_var) in the memlet subset - this is key!
1811
        self.builder.add_memlet(
4✔
1812
            block_gather,
1813
            src_access,
1814
            "void",
1815
            tasklet_gather,
1816
            "_in",
1817
            idx_var,
1818
            None,
1819
            debug_info,
1820
        )
1821
        self.builder.add_memlet(
4✔
1822
            block_gather,
1823
            tasklet_gather,
1824
            "_out",
1825
            dst_access,
1826
            "void",
1827
            loop_var,
1828
            None,
1829
            debug_info,
1830
        )
1831

1832
        # End loop
1833
        self.builder.end_for()
4✔
1834

1835
        return tmp_name
4✔
1836

1837
    def visit_Subscript(self, node):
4✔
1838
        value_str = self.visit(node.value)
4✔
1839

1840
        if value_str.startswith("_shape_proxy_"):
4✔
1841
            array_name = value_str[len("_shape_proxy_") :]
4✔
1842
            if isinstance(node.slice, ast.Constant):
4✔
1843
                idx = node.slice.value
4✔
1844
            elif isinstance(node.slice, ast.Index):
×
1845
                idx = node.slice.value.value
×
1846
            else:
1847
                try:
×
1848
                    idx = int(self.visit(node.slice))
×
1849
                except:
×
1850
                    raise NotImplementedError(
×
1851
                        "Dynamic shape indexing not fully supported yet"
1852
                    )
1853

1854
            if (
4✔
1855
                array_name in self.array_info
1856
                and "shapes" in self.array_info[array_name]
1857
            ):
1858
                return self.array_info[array_name]["shapes"][idx]
4✔
1859

1860
            return f"_{array_name}_shape_{idx}"
×
1861

1862
        if value_str in self.array_info:
4✔
1863
            ndim = self.array_info[value_str]["ndim"]
4✔
1864
            shapes = self.array_info[value_str].get("shapes", [])
4✔
1865

1866
            indices = []
4✔
1867
            if isinstance(node.slice, ast.Tuple):
4✔
1868
                indices_nodes = node.slice.elts
4✔
1869
            else:
1870
                indices_nodes = [node.slice]
4✔
1871

1872
            # Check if all indices are full slices (e.g., path[:] or path[:, :])
1873
            # In this case, return just the array name since it's the full array
1874
            all_full_slices = True
4✔
1875
            for idx in indices_nodes:
4✔
1876
                if isinstance(idx, ast.Slice):
4✔
1877
                    # A full slice has no lower, upper bounds or only None
1878
                    if idx.lower is not None or idx.upper is not None:
4✔
1879
                        all_full_slices = False
4✔
1880
                        break
4✔
1881
                else:
1882
                    all_full_slices = False
4✔
1883
                    break
4✔
1884

1885
            # path[:] on an nD array returns the full array
1886
            # So if we have a single full slice, it covers all dimensions
1887
            if all_full_slices:
4✔
1888
                # This is path[:] or path[:,:] - return the array name
1889
                return value_str
4✔
1890

1891
            # Check if there are any slices in the indices
1892
            has_slices = any(isinstance(idx, ast.Slice) for idx in indices_nodes)
4✔
1893
            if has_slices:
4✔
1894
                # Handle mixed slicing (e.g., arr[1:, :, k] or arr[:-1, :, k+1])
1895
                return self._handle_expression_slicing(
4✔
1896
                    node, value_str, indices_nodes, shapes, ndim
1897
                )
1898

1899
            # Check for gather operation: x[indices_array] where indices_array is an array
1900
            # This happens when we have a 1D source array and a 1D index array
1901
            if len(indices_nodes) == 1 and self._is_array_index(indices_nodes[0]):
4✔
1902
                if self.builder:
4✔
1903
                    return self._handle_gather(value_str, indices_nodes[0])
4✔
1904

1905
            if isinstance(node.slice, ast.Tuple):
4✔
1906
                indices = [self.visit(elt) for elt in node.slice.elts]
4✔
1907
            else:
1908
                indices = [self.visit(node.slice)]
4✔
1909

1910
            if len(indices) != ndim:
4✔
1911
                raise ValueError(
×
1912
                    f"Array {value_str} has {ndim} dimensions, but accessed with {len(indices)} indices"
1913
                )
1914

1915
            # Normalize negative indices
1916
            normalized_indices = []
4✔
1917
            for i, idx_str in enumerate(indices):
4✔
1918
                shape_val = shapes[i] if i < len(shapes) else f"_{value_str}_shape_{i}"
4✔
1919
                # Check if index is negative (starts with "-" or "(-")
1920
                if isinstance(idx_str, str) and (
4✔
1921
                    idx_str.startswith("-") or idx_str.startswith("(-")
1922
                ):
1923
                    # Normalize: size + negative_index
1924
                    normalized_indices.append(f"({shape_val} + {idx_str})")
×
1925
                else:
1926
                    normalized_indices.append(idx_str)
4✔
1927

1928
            linear_index = ""
4✔
1929
            for i in range(ndim):
4✔
1930
                term = normalized_indices[i]
4✔
1931
                for j in range(i + 1, ndim):
4✔
1932
                    shape_val = shapes[j] if j < len(shapes) else None
4✔
1933
                    shape_sym = (
4✔
1934
                        shape_val
1935
                        if shape_val is not None
1936
                        else f"_{value_str}_shape_{j}"
1937
                    )
1938
                    term = f"(({term}) * {shape_sym})"
4✔
1939

1940
                if i == 0:
4✔
1941
                    linear_index = term
4✔
1942
                else:
1943
                    linear_index = f"({linear_index} + {term})"
4✔
1944

1945
            access_str = f"{value_str}({linear_index})"
4✔
1946

1947
            if self.builder and isinstance(node.ctx, ast.Load):
4✔
1948
                dtype = Scalar(PrimitiveType.Double)
4✔
1949
                if value_str in self.symbol_table:
4✔
1950
                    t = self.symbol_table[value_str]
4✔
1951
                    if type(t).__name__ == "Array" and hasattr(t, "element_type"):
4✔
1952
                        et = t.element_type
×
1953
                        if callable(et):
×
1954
                            et = et()
×
1955
                        dtype = et
×
1956
                    elif type(t).__name__ == "Pointer" and hasattr(t, "pointee_type"):
4✔
1957
                        et = t.pointee_type
4✔
1958
                        if callable(et):
4✔
1959
                            et = et()
×
1960
                        dtype = et
4✔
1961

1962
                tmp_name = f"_tmp_{self._get_unique_id()}"
4✔
1963
                self.builder.add_container(tmp_name, dtype, False)
4✔
1964

1965
                block = self.builder.add_block()
4✔
1966
                t_src = self.builder.add_access(block, value_str)
4✔
1967
                t_dst = self.builder.add_access(block, tmp_name)
4✔
1968
                t_task = self.builder.add_tasklet(
4✔
1969
                    block, TaskletCode.assign, ["_in"], ["_out"]
1970
                )
1971

1972
                self.builder.add_memlet(
4✔
1973
                    block, t_src, "void", t_task, "_in", linear_index
1974
                )
1975
                self.builder.add_memlet(block, t_task, "_out", t_dst, "void", "")
4✔
1976

1977
                self.symbol_table[tmp_name] = dtype
4✔
1978
                return tmp_name
4✔
1979

1980
            return access_str
4✔
1981

1982
        slice_val = self.visit(node.slice)
×
1983
        access_str = f"{value_str}({slice_val})"
×
1984

1985
        if (
×
1986
            self.builder
1987
            and isinstance(node.ctx, ast.Load)
1988
            and value_str in self.array_info
1989
        ):
1990
            tmp_name = f"_tmp_{self._get_unique_id()}"
×
1991
            self.builder.add_container(tmp_name, Scalar(PrimitiveType.Double), False)
×
1992
            self.builder.add_assignment(tmp_name, access_str)
×
1993
            self.symbol_table[tmp_name] = Scalar(PrimitiveType.Double)
×
1994
            return tmp_name
×
1995

1996
        return access_str
×
1997

1998
    def visit_Add(self, node):
4✔
1999
        return "+"
4✔
2000

2001
    def visit_Sub(self, node):
4✔
2002
        return "-"
4✔
2003

2004
    def visit_Mult(self, node):
4✔
2005
        return "*"
4✔
2006

2007
    def visit_Div(self, node):
4✔
2008
        return "/"
4✔
2009

2010
    def visit_FloorDiv(self, node):
4✔
2011
        return "//"
4✔
2012

2013
    def visit_Mod(self, node):
4✔
2014
        return "%"
4✔
2015

2016
    def visit_Pow(self, node):
4✔
2017
        return "**"
4✔
2018

2019
    def visit_Eq(self, node):
4✔
2020
        return "=="
×
2021

2022
    def visit_NotEq(self, node):
4✔
2023
        return "!="
×
2024

2025
    def visit_Lt(self, node):
4✔
2026
        return "<"
4✔
2027

2028
    def visit_LtE(self, node):
4✔
2029
        return "<="
×
2030

2031
    def visit_Gt(self, node):
4✔
2032
        return ">"
4✔
2033

2034
    def visit_GtE(self, node):
4✔
2035
        return ">="
×
2036

2037
    def visit_And(self, node):
4✔
2038
        return "&"
4✔
2039

2040
    def visit_Or(self, node):
4✔
2041
        return "|"
4✔
2042

2043
    def visit_BitAnd(self, node):
4✔
2044
        return "&"
×
2045

2046
    def visit_BitOr(self, node):
4✔
2047
        return "|"
4✔
2048

2049
    def visit_BitXor(self, node):
4✔
2050
        return "^"
4✔
2051

2052
    def visit_LShift(self, node):
4✔
NEW
2053
        return "<<"
×
2054

2055
    def visit_RShift(self, node):
4✔
NEW
2056
        return ">>"
×
2057

2058
    def visit_Not(self, node):
4✔
2059
        return "!"
4✔
2060

2061
    def visit_USub(self, node):
4✔
2062
        return "-"
4✔
2063

2064
    def visit_UAdd(self, node):
4✔
2065
        return "+"
×
2066

2067
    def visit_Invert(self, node):
4✔
2068
        return "~"
×
2069

2070
    def _get_dtype(self, name):
4✔
2071
        if name in self.symbol_table:
4✔
2072
            t = self.symbol_table[name]
4✔
2073
            if isinstance(t, Scalar):
4✔
2074
                return t
4✔
2075

2076
            if hasattr(t, "pointee_type"):
4✔
2077
                et = t.pointee_type
4✔
2078
                if callable(et):
4✔
2079
                    et = et()
×
2080
                if isinstance(et, Scalar):
4✔
2081
                    return et
4✔
2082

2083
            if hasattr(t, "element_type"):
×
2084
                et = t.element_type
×
2085
                if callable(et):
×
2086
                    et = et()
×
2087
                if isinstance(et, Scalar):
×
2088
                    return et
×
2089

2090
        if self._is_int(name):
4✔
2091
            return Scalar(PrimitiveType.Int64)
4✔
2092

2093
        return Scalar(PrimitiveType.Double)
4✔
2094

2095
    def _promote_dtypes(self, dtype_left, dtype_right):
4✔
2096
        """Promote two dtypes following NumPy rules: float > int, wider > narrower."""
2097
        # Priority order: Double > Float > Int64 > Int32
2098
        priority = {
4✔
2099
            PrimitiveType.Double: 4,
2100
            PrimitiveType.Float: 3,
2101
            PrimitiveType.Int64: 2,
2102
            PrimitiveType.Int32: 1,
2103
        }
2104
        left_prio = priority.get(dtype_left.primitive_type, 0)
4✔
2105
        right_prio = priority.get(dtype_right.primitive_type, 0)
4✔
2106
        if left_prio >= right_prio:
4✔
2107
            return dtype_left
4✔
2108
        else:
2109
            return dtype_right
4✔
2110

2111
    def _create_array_temp(
4✔
2112
        self, shape, dtype, zero_init=False, ones_init=False, shapes_runtime=None
2113
    ):
2114
        tmp_name = f"_tmp_{self._get_unique_id()}"
4✔
2115

2116
        # Handle 0-dimensional arrays as scalars
2117
        if not shape or (len(shape) == 0):
4✔
2118
            # 0-D array is just a scalar
2119
            self.builder.add_container(tmp_name, dtype, False)
4✔
2120
            self.symbol_table[tmp_name] = dtype
4✔
2121
            self.array_info[tmp_name] = {"ndim": 0, "shapes": []}
4✔
2122

2123
            if zero_init:
4✔
2124
                self.builder.add_assignment(
×
2125
                    tmp_name,
2126
                    "0.0" if dtype.primitive_type == PrimitiveType.Double else "0",
2127
                )
2128
            elif ones_init:
4✔
2129
                self.builder.add_assignment(
×
2130
                    tmp_name,
2131
                    "1.0" if dtype.primitive_type == PrimitiveType.Double else "1",
2132
                )
2133

2134
            return tmp_name
4✔
2135

2136
        # Calculate size
2137
        size_str = "1"
4✔
2138
        for dim in shape:
4✔
2139
            size_str = f"({size_str} * {dim})"
4✔
2140

2141
        element_size = self.builder.get_sizeof(dtype)
4✔
2142
        total_size = f"({size_str} * {element_size})"
4✔
2143

2144
        # Create pointer
2145
        ptr_type = Pointer(dtype)
4✔
2146
        self.builder.add_container(tmp_name, ptr_type, False)
4✔
2147
        self.symbol_table[tmp_name] = ptr_type
4✔
2148
        array_info_entry = {"ndim": len(shape), "shapes": shape}
4✔
2149
        if shapes_runtime is not None:
4✔
2150
            array_info_entry["shapes_runtime"] = shapes_runtime
4✔
2151
        self.array_info[tmp_name] = array_info_entry
4✔
2152

2153
        # Malloc
2154
        block1 = self.builder.add_block()
4✔
2155
        t_malloc = self.builder.add_malloc(block1, total_size)
4✔
2156
        t_ptr1 = self.builder.add_access(block1, tmp_name)
4✔
2157
        self.builder.add_memlet(block1, t_malloc, "_ret", t_ptr1, "void", "", ptr_type)
4✔
2158

2159
        if zero_init:
4✔
2160
            block2 = self.builder.add_block()
4✔
2161
            t_memset = self.builder.add_memset(block2, "0", total_size)
4✔
2162
            t_ptr2 = self.builder.add_access(block2, tmp_name)
4✔
2163
            self.builder.add_memlet(
4✔
2164
                block2, t_memset, "_ptr", t_ptr2, "void", "", ptr_type
2165
            )
2166
        elif ones_init:
4✔
2167
            # Initialize array with ones using a loop
2168
            loop_var = f"_i_{self._get_unique_id()}"
4✔
2169
            if not self.builder.exists(loop_var):
4✔
2170
                self.builder.add_container(loop_var, Scalar(PrimitiveType.Int64), False)
4✔
2171
                self.symbol_table[loop_var] = Scalar(PrimitiveType.Int64)
4✔
2172

2173
            self.builder.begin_for(loop_var, "0", size_str, "1")
4✔
2174

2175
            # Determine the value to set based on dtype
2176
            val = "1.0"
4✔
2177
            if dtype.primitive_type in [
4✔
2178
                PrimitiveType.Int64,
2179
                PrimitiveType.Int32,
2180
                PrimitiveType.Int8,
2181
                PrimitiveType.Int16,
2182
                PrimitiveType.UInt64,
2183
                PrimitiveType.UInt32,
2184
                PrimitiveType.UInt8,
2185
                PrimitiveType.UInt16,
2186
            ]:
2187
                val = "1"
4✔
2188

2189
            block_assign = self.builder.add_block()
4✔
2190
            t_const = self.builder.add_constant(block_assign, val, dtype)
4✔
2191
            t_arr = self.builder.add_access(block_assign, tmp_name)
4✔
2192

2193
            t_task = self.builder.add_tasklet(
4✔
2194
                block_assign, TaskletCode.assign, ["_in"], ["_out"]
2195
            )
2196
            self.builder.add_memlet(
4✔
2197
                block_assign, t_const, "void", t_task, "_in", "", dtype
2198
            )
2199
            self.builder.add_memlet(
4✔
2200
                block_assign, t_task, "_out", t_arr, "void", loop_var
2201
            )
2202

2203
            self.builder.end_for()
4✔
2204

2205
        return tmp_name
4✔
2206

2207
    def _handle_array_unary_op(self, op_type, operand):
4✔
2208
        # Determine output shape
2209
        shape = []
4✔
2210
        if operand in self.array_info:
4✔
2211
            shape = self.array_info[operand]["shapes"]
4✔
2212

2213
        # Determine dtype
2214
        dtype = self._get_dtype(operand)
4✔
2215

2216
        # For 0-D arrays (scalars), use an intrinsic (CMathNode) instead of library node
2217
        if not shape or len(shape) == 0:
4✔
2218
            tmp_name = self._create_array_temp(shape, dtype)
4✔
2219

2220
            # Map op_type to C function names
2221
            func_map = {
4✔
2222
                "sqrt": CMathFunction.sqrt,
2223
                "abs": CMathFunction.fabs,
2224
                "absolute": CMathFunction.fabs,
2225
                "exp": CMathFunction.exp,
2226
                "tanh": CMathFunction.tanh,
2227
            }
2228

2229
            block = self.builder.add_block()
4✔
2230
            t_src = self.builder.add_access(block, operand)
4✔
2231
            t_dst = self.builder.add_access(block, tmp_name)
4✔
2232
            t_task = self.builder.add_cmath(block, func_map[op_type])
4✔
2233

2234
            # CMathNode uses _in1, _in2, etc for inputs and _out for output
2235
            self.builder.add_memlet(block, t_src, "void", t_task, "_in1", "", dtype)
4✔
2236
            self.builder.add_memlet(block, t_task, "_out", t_dst, "void", "", dtype)
4✔
2237

2238
            return tmp_name
4✔
2239

2240
        tmp_name = self._create_array_temp(shape, dtype)
4✔
2241

2242
        # Add operation
2243
        self.builder.add_elementwise_unary_op(op_type, operand, tmp_name, shape)
4✔
2244

2245
        return tmp_name
4✔
2246

2247
    def _handle_array_binary_op(self, op_type, left, right):
4✔
2248
        # Determine output shape (handle broadcasting by picking the larger shape)
2249
        left_shape = []
4✔
2250
        right_shape = []
4✔
2251
        if left in self.array_info:
4✔
2252
            left_shape = self.array_info[left]["shapes"]
4✔
2253
        if right in self.array_info:
4✔
2254
            right_shape = self.array_info[right]["shapes"]
4✔
2255

2256
        # Compute broadcast output shape following NumPy rules
2257
        shape = self._compute_broadcast_shape(left_shape, right_shape)
4✔
2258

2259
        # Determine dtype with promotion (float > int, wider > narrower)
2260
        dtype_left = self._get_dtype(left)
4✔
2261
        dtype_right = self._get_dtype(right)
4✔
2262

2263
        # Promote dtypes: Double > Float > Int64 > Int32
2264
        dtype = self._promote_dtypes(dtype_left, dtype_right)
4✔
2265

2266
        # Cast scalar operands to the promoted dtype if needed
2267
        real_left = left
4✔
2268
        real_right = right
4✔
2269

2270
        # Helper to check if operand is a scalar (not an array)
2271
        left_is_scalar = left not in self.array_info
4✔
2272
        right_is_scalar = right not in self.array_info
4✔
2273

2274
        # Cast left operand if needed (scalar int to float)
2275
        if left_is_scalar and dtype_left.primitive_type != dtype.primitive_type:
4✔
2276
            left_cast = f"_tmp_{self._get_unique_id()}"
4✔
2277
            self.builder.add_container(left_cast, dtype, False)
4✔
2278
            self.symbol_table[left_cast] = dtype
4✔
2279

2280
            c_block = self.builder.add_block()
4✔
2281
            t_src, src_sub = self._add_read(c_block, left)
4✔
2282
            t_dst = self.builder.add_access(c_block, left_cast)
4✔
2283
            t_task = self.builder.add_tasklet(
4✔
2284
                c_block, TaskletCode.assign, ["_in"], ["_out"]
2285
            )
2286
            self.builder.add_memlet(c_block, t_src, "void", t_task, "_in", src_sub)
4✔
2287
            self.builder.add_memlet(c_block, t_task, "_out", t_dst, "void", "")
4✔
2288

2289
            real_left = left_cast
4✔
2290

2291
        # Cast right operand if needed (scalar int to float)
2292
        if right_is_scalar and dtype_right.primitive_type != dtype.primitive_type:
4✔
2293
            right_cast = f"_tmp_{self._get_unique_id()}"
4✔
2294
            self.builder.add_container(right_cast, dtype, False)
4✔
2295
            self.symbol_table[right_cast] = dtype
4✔
2296

2297
            c_block = self.builder.add_block()
4✔
2298
            t_src, src_sub = self._add_read(c_block, right)
4✔
2299
            t_dst = self.builder.add_access(c_block, right_cast)
4✔
2300
            t_task = self.builder.add_tasklet(
4✔
2301
                c_block, TaskletCode.assign, ["_in"], ["_out"]
2302
            )
2303
            self.builder.add_memlet(c_block, t_src, "void", t_task, "_in", src_sub)
4✔
2304
            self.builder.add_memlet(c_block, t_task, "_out", t_dst, "void", "")
4✔
2305

2306
            real_right = right_cast
4✔
2307

2308
        # Broadcast left operand if needed
2309
        if not left_is_scalar and self._needs_broadcast(left_shape, shape):
4✔
2310
            real_left = self._broadcast_array(real_left, left_shape, shape, dtype)
4✔
2311

2312
        # Broadcast right operand if needed
2313
        if not right_is_scalar and self._needs_broadcast(right_shape, shape):
4✔
2314
            real_right = self._broadcast_array(real_right, right_shape, shape, dtype)
4✔
2315

2316
        tmp_name = self._create_array_temp(shape, dtype)
4✔
2317

2318
        # Add operation with promoted dtype for implicit casting
2319
        self.builder.add_elementwise_op(op_type, real_left, real_right, tmp_name, shape)
4✔
2320

2321
        return tmp_name
4✔
2322

2323
    def _compute_broadcast_shape(self, shape_a, shape_b):
4✔
2324
        """Compute the broadcast output shape following NumPy broadcasting rules."""
2325
        if not shape_a:
4✔
2326
            return shape_b
4✔
2327
        if not shape_b:
4✔
2328
            return shape_a
4✔
2329

2330
        # Pad shorter shape with 1s on the left
2331
        max_ndim = max(len(shape_a), len(shape_b))
4✔
2332
        padded_a = ["1"] * (max_ndim - len(shape_a)) + [str(s) for s in shape_a]
4✔
2333
        padded_b = ["1"] * (max_ndim - len(shape_b)) + [str(s) for s in shape_b]
4✔
2334

2335
        result = []
4✔
2336
        for a, b in zip(padded_a, padded_b):
4✔
2337
            if a == "1":
4✔
NEW
2338
                result.append(b)
×
2339
            elif b == "1":
4✔
2340
                result.append(a)
4✔
2341
            elif a == b:
4✔
2342
                result.append(a)
4✔
2343
            else:
2344
                # For symbolic dimensions, use max (assume compatible)
2345
                result.append(a)
4✔
2346

2347
        return result
4✔
2348

2349
    def _needs_broadcast(self, input_shape, output_shape):
4✔
2350
        """Check if input shape needs broadcasting to match output shape."""
2351
        if len(input_shape) != len(output_shape):
4✔
2352
            return True
4✔
2353
        for in_dim, out_dim in zip(input_shape, output_shape):
4✔
2354
            if str(in_dim) != str(out_dim):
4✔
2355
                return True
4✔
2356
        return False
4✔
2357

2358
    def _broadcast_array(self, arr_name, input_shape, output_shape, dtype):
4✔
2359
        """Broadcast an array from input_shape to output_shape using BroadcastNode."""
2360
        # Create temporary array for broadcast result
2361
        broadcast_tmp = self._create_array_temp(output_shape, dtype)
4✔
2362

2363
        # Pad input shape to match output dimensions (add 1s on the left)
2364
        padded_input_shape = ["1"] * (len(output_shape) - len(input_shape)) + [
4✔
2365
            str(s) for s in input_shape
2366
        ]
2367

2368
        # Convert shapes to string lists
2369
        input_shape_strs = padded_input_shape
4✔
2370
        output_shape_strs = [str(s) for s in output_shape]
4✔
2371

2372
        # Add broadcast operation
2373
        self.builder.add_broadcast(
4✔
2374
            arr_name, broadcast_tmp, input_shape_strs, output_shape_strs
2375
        )
2376

2377
        return broadcast_tmp
4✔
2378

2379
    def _shape_to_runtime_expr(self, shape_node):
4✔
2380
        """Convert a shape expression AST node to a runtime-evaluable string.
2381

2382
        This converts the AST to a string expression that can be evaluated
2383
        at runtime using only input arrays and shape symbols (_s0, _s1, etc.).
2384
        It does NOT visit the node (which would create SDFG variables).
2385
        """
2386
        if isinstance(shape_node, ast.Constant):
4✔
2387
            return str(shape_node.value)
4✔
2388
        elif isinstance(shape_node, ast.Name):
4✔
2389
            return shape_node.id
4✔
2390
        elif isinstance(shape_node, ast.BinOp):
4✔
2391
            left = self._shape_to_runtime_expr(shape_node.left)
4✔
2392
            right = self._shape_to_runtime_expr(shape_node.right)
4✔
2393
            op = self.visit(shape_node.op)
4✔
2394
            return f"({left} {op} {right})"
4✔
2395
        elif isinstance(shape_node, ast.UnaryOp):
4✔
2396
            operand = self._shape_to_runtime_expr(shape_node.operand)
×
2397
            if isinstance(shape_node.op, ast.USub):
×
2398
                return f"(-{operand})"
×
2399
            elif isinstance(shape_node.op, ast.UAdd):
×
2400
                return operand
×
2401
            else:
2402
                # Fall back to visit for other unary ops
2403
                return self.visit(shape_node)
×
2404
        elif isinstance(shape_node, ast.Subscript):
4✔
2405
            # Handle arr.shape[0] -> arr.shape[0] for runtime eval
2406
            # or _shape_proxy_arr[0] -> _s<idx>
2407
            val = shape_node.value
4✔
2408
            if isinstance(val, ast.Attribute) and val.attr == "shape":
4✔
2409
                # arr.shape[0] -> use the shape symbol
2410
                if isinstance(val.value, ast.Name):
4✔
2411
                    arr_name = val.value.id
4✔
2412
                    if isinstance(shape_node.slice, ast.Constant):
4✔
2413
                        idx = shape_node.slice.value
4✔
2414
                        # Get the shape symbol for this array dimension
2415
                        if arr_name in self.array_info:
4✔
2416
                            shapes = self.array_info[arr_name].get("shapes", [])
4✔
2417
                            if idx < len(shapes):
4✔
2418
                                return shapes[idx]
4✔
2419
                        return f"{arr_name}.shape[{idx}]"
×
2420
            # Fall back to visit
2421
            return self.visit(shape_node)
×
2422
        elif isinstance(shape_node, ast.Tuple):
×
2423
            return [self._shape_to_runtime_expr(elt) for elt in shape_node.elts]
×
2424
        elif isinstance(shape_node, ast.List):
×
2425
            return [self._shape_to_runtime_expr(elt) for elt in shape_node.elts]
×
2426
        else:
2427
            # Fall back to visit for complex expressions
2428
            return self.visit(shape_node)
×
2429

2430
    def _handle_numpy_alloc(self, node, func_name):
4✔
2431
        # Parse shape
2432
        shape_arg = node.args[0]
4✔
2433
        dims = []
4✔
2434
        dims_runtime = []  # Runtime-evaluable shape expressions
4✔
2435
        if isinstance(shape_arg, ast.Tuple):
4✔
2436
            dims = [self.visit(elt) for elt in shape_arg.elts]
4✔
2437
            dims_runtime = [self._shape_to_runtime_expr(elt) for elt in shape_arg.elts]
4✔
2438
        elif isinstance(shape_arg, ast.List):
4✔
2439
            dims = [self.visit(elt) for elt in shape_arg.elts]
×
2440
            dims_runtime = [self._shape_to_runtime_expr(elt) for elt in shape_arg.elts]
×
2441
        else:
2442
            val = self.visit(shape_arg)
4✔
2443
            runtime_val = self._shape_to_runtime_expr(shape_arg)
4✔
2444
            if val.startswith("_shape_proxy_"):
4✔
2445
                array_name = val[len("_shape_proxy_") :]
×
2446
                if array_name in self.array_info:
×
2447
                    dims = self.array_info[array_name]["shapes"]
×
2448
                    dims_runtime = self.array_info[array_name].get(
×
2449
                        "shapes_runtime", dims
2450
                    )
2451
                else:
2452
                    dims = [val]
×
2453
                    dims_runtime = [runtime_val]
×
2454
            else:
2455
                dims = [val]
4✔
2456
                dims_runtime = [runtime_val]
4✔
2457

2458
        # Parse dtype
2459
        dtype_arg = None
4✔
2460
        if len(node.args) > 1:
4✔
2461
            dtype_arg = node.args[1]
×
2462

2463
        for kw in node.keywords:
4✔
2464
            if kw.arg == "dtype":
4✔
2465
                dtype_arg = kw.value
4✔
2466
                break
4✔
2467

2468
        element_type = self._map_numpy_dtype(dtype_arg)
4✔
2469

2470
        return self._create_array_temp(
4✔
2471
            dims,
2472
            element_type,
2473
            zero_init=(func_name == "zeros"),
2474
            ones_init=(func_name == "ones"),
2475
            shapes_runtime=dims_runtime,
2476
        )
2477

2478
    def _handle_numpy_empty_like(self, node, func_name):
4✔
2479
        prototype_arg = node.args[0]
4✔
2480
        prototype_name = self.visit(prototype_arg)
4✔
2481

2482
        # Parse shape from prototype
2483
        dims = []
4✔
2484
        if prototype_name in self.array_info:
4✔
2485
            dims = self.array_info[prototype_name]["shapes"]
4✔
2486

2487
        # Parse dtype
2488
        dtype_arg = None
4✔
2489
        if len(node.args) > 1:
4✔
2490
            dtype_arg = node.args[1]
×
2491

2492
        for kw in node.keywords:
4✔
2493
            if kw.arg == "dtype":
4✔
2494
                dtype_arg = kw.value
4✔
2495
                break
4✔
2496

2497
        element_type = None
4✔
2498
        if dtype_arg:
4✔
2499
            element_type = self._map_numpy_dtype(dtype_arg)
4✔
2500
        else:
2501
            if prototype_name in self.symbol_table:
4✔
2502
                sym_type = self.symbol_table[prototype_name]
4✔
2503
                if isinstance(sym_type, Pointer) and sym_type.has_pointee_type():
4✔
2504
                    element_type = sym_type.pointee_type
4✔
2505

2506
        if element_type is None:
4✔
2507
            element_type = Scalar(PrimitiveType.Double)
×
2508

2509
        return self._create_array_temp(
4✔
2510
            dims,
2511
            element_type,
2512
            zero_init=False,
2513
            ones_init=False,
2514
        )
2515

2516
    def _handle_numpy_zeros_like(self, node, func_name):
4✔
2517
        prototype_arg = node.args[0]
4✔
2518
        prototype_name = self.visit(prototype_arg)
4✔
2519

2520
        # Parse shape from prototype
2521
        dims = []
4✔
2522
        if prototype_name in self.array_info:
4✔
2523
            dims = self.array_info[prototype_name]["shapes"]
4✔
2524

2525
        # Parse dtype
2526
        dtype_arg = None
4✔
2527
        if len(node.args) > 1:
4✔
2528
            dtype_arg = node.args[1]
×
2529

2530
        for kw in node.keywords:
4✔
2531
            if kw.arg == "dtype":
4✔
2532
                dtype_arg = kw.value
4✔
2533
                break
4✔
2534

2535
        element_type = None
4✔
2536
        if dtype_arg:
4✔
2537
            element_type = self._map_numpy_dtype(dtype_arg)
4✔
2538
        else:
2539
            if prototype_name in self.symbol_table:
4✔
2540
                sym_type = self.symbol_table[prototype_name]
4✔
2541
                if isinstance(sym_type, Pointer) and sym_type.has_pointee_type():
4✔
2542
                    element_type = sym_type.pointee_type
4✔
2543

2544
        if element_type is None:
4✔
2545
            element_type = Scalar(PrimitiveType.Double)
×
2546

2547
        return self._create_array_temp(
4✔
2548
            dims,
2549
            element_type,
2550
            zero_init=True,
2551
            ones_init=False,
2552
        )
2553

2554
    def _handle_numpy_eye(self, node, func_name):
4✔
2555
        # Parse N
2556
        N_arg = node.args[0]
4✔
2557
        N_str = self.visit(N_arg)
4✔
2558

2559
        # Parse M
2560
        M_str = N_str
4✔
2561
        if len(node.args) > 1:
4✔
2562
            M_str = self.visit(node.args[1])
×
2563

2564
        # Parse k
2565
        k_str = "0"
4✔
2566
        if len(node.args) > 2:
4✔
2567
            k_str = self.visit(node.args[2])
×
2568

2569
        # Check keywords for M, k, dtype
2570
        dtype_arg = None
4✔
2571
        for kw in node.keywords:
4✔
2572
            if kw.arg == "M":
4✔
2573
                M_str = self.visit(kw.value)
4✔
2574
                if M_str == "None":
4✔
2575
                    M_str = N_str
4✔
2576
            elif kw.arg == "k":
4✔
2577
                k_str = self.visit(kw.value)
4✔
2578
            elif kw.arg == "dtype":
4✔
2579
                dtype_arg = kw.value
4✔
2580

2581
        element_type = self._map_numpy_dtype(dtype_arg)
4✔
2582

2583
        ptr_name = self._create_array_temp([N_str, M_str], element_type, zero_init=True)
4✔
2584

2585
        # Loop to set diagonal
2586
        loop_var = f"_i_{self._get_unique_id()}"
4✔
2587
        if not self.builder.exists(loop_var):
4✔
2588
            self.builder.add_container(loop_var, Scalar(PrimitiveType.Int64), False)
4✔
2589
            self.symbol_table[loop_var] = Scalar(PrimitiveType.Int64)
4✔
2590

2591
        self.builder.begin_for(loop_var, "0", N_str, "1")
4✔
2592

2593
        # Condition: 0 <= i + k < M
2594
        cond = f"(({loop_var} + {k_str}) >= 0) & (({loop_var} + {k_str}) < {M_str})"
4✔
2595
        self.builder.begin_if(cond)
4✔
2596

2597
        # Assignment: A[i, i+k] = 1
2598
        val = "1.0"
4✔
2599
        if element_type.primitive_type in [
4✔
2600
            PrimitiveType.Int64,
2601
            PrimitiveType.Int32,
2602
            PrimitiveType.Int8,
2603
            PrimitiveType.Int16,
2604
            PrimitiveType.UInt64,
2605
            PrimitiveType.UInt32,
2606
            PrimitiveType.UInt8,
2607
            PrimitiveType.UInt16,
2608
        ]:
2609
            val = "1"
×
2610

2611
        block_assign = self.builder.add_block()
4✔
2612
        t_const = self.builder.add_constant(block_assign, val, element_type)
4✔
2613
        t_arr = self.builder.add_access(block_assign, ptr_name)
4✔
2614
        flat_index = f"(({loop_var}) * ({M_str}) + ({loop_var}) + ({k_str}))"
4✔
2615
        subset = flat_index
4✔
2616

2617
        t_task = self.builder.add_tasklet(
4✔
2618
            block_assign, TaskletCode.assign, ["_in"], ["_out"]
2619
        )
2620
        self.builder.add_memlet(
4✔
2621
            block_assign, t_const, "void", t_task, "_in", "", element_type
2622
        )
2623
        self.builder.add_memlet(block_assign, t_task, "_out", t_arr, "void", subset)
4✔
2624

2625
        self.builder.end_if()
4✔
2626
        self.builder.end_for()
4✔
2627

2628
        return ptr_name
4✔
2629

2630
    def _handle_numpy_binary_op(self, node, func_name):
4✔
2631
        args = [self.visit(arg) for arg in node.args]
4✔
2632
        if len(args) != 2:
4✔
2633
            raise NotImplementedError(
×
2634
                f"Numpy function {func_name} requires 2 arguments"
2635
            )
2636

2637
        op_map = {
4✔
2638
            "add": "add",
2639
            "subtract": "sub",
2640
            "multiply": "mul",
2641
            "divide": "div",
2642
            "power": "pow",
2643
            "minimum": "min",
2644
            "maximum": "max",
2645
        }
2646
        return self._handle_array_binary_op(op_map[func_name], args[0], args[1])
4✔
2647

2648
    def _handle_numpy_where(self, node, func_name):
4✔
2649
        """Handle np.where(condition, x, y) - elementwise ternary selection.
2650

2651
        Returns an array where elements are taken from x where condition is True,
2652
        and from y where condition is False.
2653
        """
2654
        if len(node.args) != 3:
4✔
2655
            raise NotImplementedError("np.where requires 3 arguments (condition, x, y)")
×
2656

2657
        # Visit all arguments
2658
        cond_name = self.visit(node.args[0])
4✔
2659
        x_name = self.visit(node.args[1])
4✔
2660
        y_name = self.visit(node.args[2])
4✔
2661

2662
        # Determine output shape from the array arguments
2663
        # Priority: condition > y > x (since x might be scalar 0)
2664
        shape = []
4✔
2665
        dtype = Scalar(PrimitiveType.Double)
4✔
2666

2667
        # Check condition shape
2668
        if cond_name in self.array_info:
4✔
2669
            shape = self.array_info[cond_name]["shapes"]
4✔
2670

2671
        # If condition is scalar, check y
2672
        if not shape and y_name in self.array_info:
4✔
2673
            shape = self.array_info[y_name]["shapes"]
×
2674

2675
        # If y is scalar, check x
2676
        if not shape and x_name in self.array_info:
4✔
2677
            shape = self.array_info[x_name]["shapes"]
×
2678

2679
        if not shape:
4✔
2680
            raise NotImplementedError("np.where requires at least one array argument")
×
2681

2682
        # Determine dtype from y (since x might be scalar 0)
2683
        if y_name in self.symbol_table:
4✔
2684
            y_type = self.symbol_table[y_name]
4✔
2685
            if isinstance(y_type, Pointer) and y_type.has_pointee_type():
4✔
2686
                dtype = y_type.pointee_type
4✔
2687
            elif isinstance(y_type, Scalar):
×
2688
                dtype = y_type
×
2689

2690
        # Create output array
2691
        tmp_name = self._create_array_temp(shape, dtype)
4✔
2692

2693
        # Generate nested loops for the shape
2694
        loop_vars = []
4✔
2695
        for i, dim in enumerate(shape):
4✔
2696
            loop_var = f"_where_i{i}_{self._get_unique_id()}"
4✔
2697
            if not self.builder.exists(loop_var):
4✔
2698
                self.builder.add_container(loop_var, Scalar(PrimitiveType.Int64), False)
4✔
2699
                self.symbol_table[loop_var] = Scalar(PrimitiveType.Int64)
4✔
2700
            loop_vars.append(loop_var)
4✔
2701
            self.builder.begin_for(loop_var, "0", str(dim), "1")
4✔
2702

2703
        # Compute linear index
2704
        linear_idx = self._compute_linear_index(loop_vars, shape, tmp_name, len(shape))
4✔
2705

2706
        # Read condition value
2707
        cond_tmp = f"_where_cond_{self._get_unique_id()}"
4✔
2708
        self.builder.add_container(cond_tmp, Scalar(PrimitiveType.Bool), False)
4✔
2709
        self.symbol_table[cond_tmp] = Scalar(PrimitiveType.Bool)
4✔
2710

2711
        block_cond = self.builder.add_block()
4✔
2712
        if cond_name in self.array_info:
4✔
2713
            t_cond_arr = self.builder.add_access(block_cond, cond_name)
4✔
2714
            t_cond_out = self.builder.add_access(block_cond, cond_tmp)
4✔
2715
            t_cond_task = self.builder.add_tasklet(
4✔
2716
                block_cond, TaskletCode.assign, ["_in"], ["_out"]
2717
            )
2718
            self.builder.add_memlet(
4✔
2719
                block_cond, t_cond_arr, "void", t_cond_task, "_in", linear_idx
2720
            )
2721
            self.builder.add_memlet(
4✔
2722
                block_cond, t_cond_task, "_out", t_cond_out, "void", ""
2723
            )
2724
        else:
2725
            # Scalar condition - just use it directly
2726
            t_cond_src, cond_sub = self._add_read(block_cond, cond_name)
×
2727
            t_cond_out = self.builder.add_access(block_cond, cond_tmp)
×
2728
            t_cond_task = self.builder.add_tasklet(
×
2729
                block_cond, TaskletCode.assign, ["_in"], ["_out"]
2730
            )
2731
            self.builder.add_memlet(
×
2732
                block_cond, t_cond_src, "void", t_cond_task, "_in", cond_sub
2733
            )
2734
            self.builder.add_memlet(
×
2735
                block_cond, t_cond_task, "_out", t_cond_out, "void", ""
2736
            )
2737

2738
        # If-else based on condition
2739
        self.builder.begin_if(f"{cond_tmp} == true")
4✔
2740

2741
        # True branch: assign x
2742
        block_true = self.builder.add_block()
4✔
2743
        t_out_true = self.builder.add_access(block_true, tmp_name)
4✔
2744
        if x_name in self.array_info:
4✔
2745
            # x is an array
2746
            t_x = self.builder.add_access(block_true, x_name)
4✔
2747
            t_task_true = self.builder.add_tasklet(
4✔
2748
                block_true, TaskletCode.assign, ["_in"], ["_out"]
2749
            )
2750
            self.builder.add_memlet(
4✔
2751
                block_true, t_x, "void", t_task_true, "_in", linear_idx
2752
            )
2753
        else:
2754
            # x is a scalar
2755
            t_x, x_sub = self._add_read(block_true, x_name)
4✔
2756
            t_task_true = self.builder.add_tasklet(
4✔
2757
                block_true, TaskletCode.assign, ["_in"], ["_out"]
2758
            )
2759
            self.builder.add_memlet(block_true, t_x, "void", t_task_true, "_in", x_sub)
4✔
2760
        self.builder.add_memlet(
4✔
2761
            block_true, t_task_true, "_out", t_out_true, "void", linear_idx
2762
        )
2763

2764
        self.builder.begin_else()
4✔
2765

2766
        # False branch: assign y
2767
        block_false = self.builder.add_block()
4✔
2768
        t_out_false = self.builder.add_access(block_false, tmp_name)
4✔
2769
        if y_name in self.array_info:
4✔
2770
            # y is an array
2771
            t_y = self.builder.add_access(block_false, y_name)
4✔
2772
            t_task_false = self.builder.add_tasklet(
4✔
2773
                block_false, TaskletCode.assign, ["_in"], ["_out"]
2774
            )
2775
            self.builder.add_memlet(
4✔
2776
                block_false, t_y, "void", t_task_false, "_in", linear_idx
2777
            )
2778
        else:
2779
            # y is a scalar
2780
            t_y, y_sub = self._add_read(block_false, y_name)
4✔
2781
            t_task_false = self.builder.add_tasklet(
4✔
2782
                block_false, TaskletCode.assign, ["_in"], ["_out"]
2783
            )
2784
            self.builder.add_memlet(
4✔
2785
                block_false, t_y, "void", t_task_false, "_in", y_sub
2786
            )
2787
        self.builder.add_memlet(
4✔
2788
            block_false, t_task_false, "_out", t_out_false, "void", linear_idx
2789
        )
2790

2791
        self.builder.end_if()
4✔
2792

2793
        # Close all loops
2794
        for _ in loop_vars:
4✔
2795
            self.builder.end_for()
4✔
2796

2797
        return tmp_name
4✔
2798

2799
    def _handle_numpy_clip(self, node, func_name):
4✔
2800
        """Handle np.clip(a, a_min, a_max) - elementwise clipping.
2801

2802
        Clips array values to be within [a_min, a_max].
2803
        Implemented as: min(max(a, a_min), a_max)
2804

2805
        This uses the existing min/max elementwise operations.
2806
        """
2807
        if len(node.args) != 3:
4✔
2808
            raise NotImplementedError("np.clip requires 3 arguments (a, a_min, a_max)")
×
2809

2810
        # Visit the array argument
2811
        arr_name = self.visit(node.args[0])
4✔
2812
        # Visit the bound arguments (scalars or arrays)
2813
        a_min = self.visit(node.args[1])
4✔
2814
        a_max = self.visit(node.args[2])
4✔
2815

2816
        # First: tmp1 = max(arr, a_min) - ensures values are at least a_min
2817
        tmp1 = self._handle_array_binary_op("max", arr_name, a_min)
4✔
2818

2819
        # Second: result = min(tmp1, a_max) - ensures values are at most a_max
2820
        result = self._handle_array_binary_op("min", tmp1, a_max)
4✔
2821

2822
        return result
4✔
2823

2824
    def _handle_numpy_matmul_op(self, left_node, right_node):
4✔
2825
        return self._handle_matmul_helper(left_node, right_node)
4✔
2826

2827
    def _handle_numpy_matmul(self, node, func_name):
4✔
2828
        if len(node.args) != 2:
4✔
2829
            raise NotImplementedError("matmul/dot requires 2 arguments")
×
2830
        return self._handle_matmul_helper(node.args[0], node.args[1])
4✔
2831

2832
    def _handle_numpy_outer(self, node, func_name):
4✔
2833
        if len(node.args) != 2:
4✔
2834
            raise NotImplementedError("outer requires 2 arguments")
×
2835

2836
        arg0 = node.args[0]
4✔
2837
        arg1 = node.args[1]
4✔
2838

2839
        if not self.la_handler:
4✔
2840
            raise RuntimeError("LinearAlgebraHandler not initialized")
×
2841

2842
        res_a = self.la_handler.parse_arg(arg0)
4✔
2843
        res_b = self.la_handler.parse_arg(arg1)
4✔
2844

2845
        # Resolve standard names if parse_arg failed (likely complex expression)
2846
        if not res_a[0]:
4✔
2847
            left_name = self.visit(arg0)
×
2848
            arg0 = ast.Name(id=left_name)
×
2849
            res_a = self.la_handler.parse_arg(arg0)
×
2850

2851
        if not res_b[0]:
4✔
2852
            right_name = self.visit(arg1)
×
2853
            arg1 = ast.Name(id=right_name)
×
2854
            res_b = self.la_handler.parse_arg(arg1)
×
2855

2856
        name_a, subset_a, shape_a, indices_a = res_a
4✔
2857
        name_b, subset_b, shape_b, indices_b = res_b
4✔
2858

2859
        if not name_a or not name_b:
4✔
2860
            raise NotImplementedError("Could not resolve outer operands")
×
2861

2862
        def get_flattened_size_expr(name, indices, shapes):
4✔
2863
            # Simplified: if slice, we use parse_arg's returned `shapes` (which are dim sizes of the slice)
2864
            # And multiply them.
2865
            size_expr = "1"
4✔
2866
            for s in shapes:
4✔
2867
                if size_expr == "1":
4✔
2868
                    size_expr = str(s)
4✔
2869
                else:
2870
                    size_expr = f"({size_expr} * {str(s)})"
×
2871
            return size_expr
4✔
2872

2873
        m_expr = get_flattened_size_expr(name_a, indices_a, shape_a)
4✔
2874
        n_expr = get_flattened_size_expr(name_b, indices_b, shape_b)
4✔
2875

2876
        # Infer dtype from input arrays (promote if different)
2877
        dtype_a = self._get_dtype(name_a)
4✔
2878
        dtype_b = self._get_dtype(name_b)
4✔
2879
        dtype = self._promote_dtypes(dtype_a, dtype_b)
4✔
2880

2881
        # Use helper to create array temp which handles symbol table and array info
2882
        tmp_name = self._create_array_temp([m_expr, n_expr], dtype)
4✔
2883

2884
        new_call_node = ast.Call(
4✔
2885
            func=node.func, args=[arg0, arg1], keywords=node.keywords
2886
        )
2887

2888
        self.la_handler.handle_outer(tmp_name, new_call_node)
4✔
2889

2890
        return tmp_name
4✔
2891

2892
    def _handle_ufunc_outer(self, node, ufunc_name):
4✔
2893
        """Handle np.add.outer, np.subtract.outer, np.multiply.outer, etc.
2894

2895
        These compute the outer operation for the given ufunc:
2896
        - np.add.outer(a, b) -> a[:, np.newaxis] + b (outer sum)
2897
        - np.subtract.outer(a, b) -> a[:, np.newaxis] - b (outer difference)
2898
        - np.multiply.outer(a, b) -> a[:, np.newaxis] * b (same as np.outer)
2899
        """
2900
        if len(node.args) != 2:
4✔
2901
            raise NotImplementedError(f"{ufunc_name}.outer requires 2 arguments")
×
2902

2903
        # For np.multiply.outer, use the existing GEMM-based outer handler
2904
        if ufunc_name == "multiply":
4✔
2905
            return self._handle_numpy_outer(node, "outer")
4✔
2906

2907
        # Map ufunc names to operation names and tasklet opcodes
2908
        op_map = {
4✔
2909
            "add": ("add", TaskletCode.fp_add, TaskletCode.int_add),
2910
            "subtract": ("sub", TaskletCode.fp_sub, TaskletCode.int_sub),
2911
            "divide": ("div", TaskletCode.fp_div, TaskletCode.int_sdiv),
2912
            "minimum": ("min", CMathFunction.fmin, TaskletCode.int_smin),
2913
            "maximum": ("max", CMathFunction.fmax, TaskletCode.int_smax),
2914
        }
2915

2916
        if ufunc_name not in op_map:
4✔
2917
            raise NotImplementedError(f"{ufunc_name}.outer not supported")
×
2918

2919
        op_name, fp_opcode, int_opcode = op_map[ufunc_name]
4✔
2920

2921
        # Use la_handler.parse_arg to properly handle sliced arrays
2922
        if not self.la_handler:
4✔
2923
            raise RuntimeError("LinearAlgebraHandler not initialized")
×
2924

2925
        arg0 = node.args[0]
4✔
2926
        arg1 = node.args[1]
4✔
2927

2928
        res_a = self.la_handler.parse_arg(arg0)
4✔
2929
        res_b = self.la_handler.parse_arg(arg1)
4✔
2930

2931
        # If parse_arg fails for complex expressions, try visiting and re-parsing
2932
        if not res_a[0]:
4✔
2933
            left_name = self.visit(arg0)
×
2934
            arg0 = ast.Name(id=left_name)
×
2935
            res_a = self.la_handler.parse_arg(arg0)
×
2936

2937
        if not res_b[0]:
4✔
2938
            right_name = self.visit(arg1)
×
2939
            arg1 = ast.Name(id=right_name)
×
2940
            res_b = self.la_handler.parse_arg(arg1)
×
2941

2942
        name_a, subset_a, shape_a, indices_a = res_a
4✔
2943
        name_b, subset_b, shape_b, indices_b = res_b
4✔
2944

2945
        if not name_a or not name_b:
4✔
2946
            raise NotImplementedError("Could not resolve ufunc outer operands")
×
2947

2948
        # Compute flattened sizes - outer treats inputs as 1D
2949
        def get_flattened_size_expr(shapes):
4✔
2950
            if not shapes:
4✔
2951
                return "1"
×
2952
            size_expr = str(shapes[0])
4✔
2953
            for s in shapes[1:]:
4✔
2954
                size_expr = f"({size_expr} * {str(s)})"
×
2955
            return size_expr
4✔
2956

2957
        m_expr = get_flattened_size_expr(shape_a)
4✔
2958
        n_expr = get_flattened_size_expr(shape_b)
4✔
2959

2960
        # Determine output dtype - infer from inputs or default to double
2961
        dtype_left = self._get_dtype(name_a)
4✔
2962
        dtype_right = self._get_dtype(name_b)
4✔
2963
        dtype = self._promote_dtypes(dtype_left, dtype_right)
4✔
2964

2965
        # Determine if we're working with integers
2966
        is_int = dtype.primitive_type in [
4✔
2967
            PrimitiveType.Int64,
2968
            PrimitiveType.Int32,
2969
            PrimitiveType.Int8,
2970
            PrimitiveType.Int16,
2971
            PrimitiveType.UInt64,
2972
            PrimitiveType.UInt32,
2973
            PrimitiveType.UInt8,
2974
            PrimitiveType.UInt16,
2975
        ]
2976

2977
        # Create output array with shape (M, N)
2978
        tmp_name = self._create_array_temp([m_expr, n_expr], dtype)
4✔
2979

2980
        # Generate unique loop variable names
2981
        i_var = self._get_temp_name("_outer_i_")
4✔
2982
        j_var = self._get_temp_name("_outer_j_")
4✔
2983

2984
        # Ensure loop variables exist
2985
        if not self.builder.exists(i_var):
4✔
2986
            self.builder.add_container(i_var, Scalar(PrimitiveType.Int64), False)
4✔
2987
            self.symbol_table[i_var] = Scalar(PrimitiveType.Int64)
4✔
2988
        if not self.builder.exists(j_var):
4✔
2989
            self.builder.add_container(j_var, Scalar(PrimitiveType.Int64), False)
4✔
2990
            self.symbol_table[j_var] = Scalar(PrimitiveType.Int64)
4✔
2991

2992
        # Helper function to compute the linear index for a sliced array access
2993
        def compute_linear_index(name, subset, indices, loop_var):
4✔
2994
            """
2995
            Compute linear index for accessing element loop_var of a sliced array.
2996

2997
            For array A with shape (N, M):
2998
            - A[:, k] (column k): linear_index = loop_var * M + k
2999
            - A[k, :] (row k): linear_index = k * M + loop_var
3000
            - A[:] (1D array): linear_index = loop_var
3001

3002
            The indices list contains AST nodes showing which dims are sliced vs fixed.
3003
            subset contains start indices for each dimension.
3004
            """
3005
            if not indices:
4✔
3006
                # Simple 1D array, no slicing
3007
                return loop_var
4✔
3008

3009
            info = self.array_info.get(name, {})
4✔
3010
            shapes = info.get("shapes", [])
4✔
3011
            ndim = info.get("ndim", len(shapes))
4✔
3012

3013
            if ndim == 0:
4✔
3014
                return loop_var
×
3015

3016
            # Compute strides (row-major order)
3017
            strides = []
4✔
3018
            current_stride = "1"
4✔
3019
            for i in range(ndim - 1, -1, -1):
4✔
3020
                strides.insert(0, current_stride)
4✔
3021
                if i > 0:
4✔
3022
                    dim_size = shapes[i] if i < len(shapes) else f"_{name}_shape_{i}"
4✔
3023
                    if current_stride == "1":
4✔
3024
                        current_stride = str(dim_size)
4✔
3025
                    else:
3026
                        current_stride = f"({current_stride} * {dim_size})"
×
3027

3028
            # Build linear index from subset and indices info
3029
            terms = []
4✔
3030
            loop_var_used = False
4✔
3031

3032
            for i, idx in enumerate(indices):
4✔
3033
                stride = strides[i] if i < len(strides) else "1"
4✔
3034
                start = subset[i] if i < len(subset) else "0"
4✔
3035

3036
                if isinstance(idx, ast.Slice):
4✔
3037
                    # This dimension is sliced - use loop_var
3038
                    if stride == "1":
4✔
3039
                        term = f"({start} + {loop_var})"
4✔
3040
                    else:
3041
                        term = f"(({start} + {loop_var}) * {stride})"
4✔
3042
                    loop_var_used = True
4✔
3043
                else:
3044
                    # This dimension has a fixed index
3045
                    if stride == "1":
4✔
3046
                        term = start
4✔
3047
                    else:
3048
                        term = f"({start} * {stride})"
4✔
3049

3050
                terms.append(term)
4✔
3051

3052
            # Sum all terms
3053
            if not terms:
4✔
3054
                return loop_var
×
3055

3056
            result = terms[0]
4✔
3057
            for t in terms[1:]:
4✔
3058
                result = f"({result} + {t})"
4✔
3059

3060
            return result
4✔
3061

3062
        # Create nested for loops: for i in range(M): for j in range(N): C[i,j] = A[i] op B[j]
3063
        self.builder.begin_for(i_var, "0", m_expr, "1")
4✔
3064
        self.builder.begin_for(j_var, "0", n_expr, "1")
4✔
3065

3066
        # Create the assignment block: C[i, j] = A[i] op B[j]
3067
        block = self.builder.add_block()
4✔
3068

3069
        # Add access nodes
3070
        t_a = self.builder.add_access(block, name_a)
4✔
3071
        t_b = self.builder.add_access(block, name_b)
4✔
3072
        t_c = self.builder.add_access(block, tmp_name)
4✔
3073

3074
        # Determine tasklet type based on operation
3075
        if ufunc_name in ["minimum", "maximum"]:
4✔
3076
            # Use intrinsic for min/max
3077
            if is_int:
4✔
3078
                t_task = self.builder.add_tasklet(
4✔
3079
                    block, int_opcode, ["_in1", "_in2"], ["_out"]
3080
                )
3081
            else:
3082
                t_task = self.builder.add_cmath(block, fp_opcode)
4✔
3083
        else:
3084
            # Use regular tasklet for arithmetic ops
3085
            tasklet_code = int_opcode if is_int else fp_opcode
4✔
3086
            t_task = self.builder.add_tasklet(
4✔
3087
                block, tasklet_code, ["_in1", "_in2"], ["_out"]
3088
            )
3089

3090
        # Compute the linear index for A[i]
3091
        a_index = compute_linear_index(name_a, subset_a, indices_a, i_var)
4✔
3092

3093
        # Compute the linear index for B[j]
3094
        b_index = compute_linear_index(name_b, subset_b, indices_b, j_var)
4✔
3095

3096
        # Connect A[i + offset_a] -> tasklet
3097
        self.builder.add_memlet(block, t_a, "void", t_task, "_in1", a_index)
4✔
3098

3099
        # Connect B[j + offset_b] -> tasklet
3100
        self.builder.add_memlet(block, t_b, "void", t_task, "_in2", b_index)
4✔
3101

3102
        # Connect tasklet -> C[i * N + j] (linear index for 2D output)
3103
        flat_index = f"(({i_var}) * ({n_expr}) + ({j_var}))"
4✔
3104
        self.builder.add_memlet(block, t_task, "_out", t_c, "void", flat_index)
4✔
3105

3106
        self.builder.end_for()  # end j loop
4✔
3107
        self.builder.end_for()  # end i loop
4✔
3108

3109
        return tmp_name
4✔
3110

3111
    def _op_symbol(self, op_name):
4✔
3112
        """Convert operation name to symbol."""
3113
        symbols = {
×
3114
            "add": "+",
3115
            "sub": "-",
3116
            "mul": "*",
3117
            "div": "/",
3118
            "min": "min",  # Will need special handling
3119
            "max": "max",  # Will need special handling
3120
        }
3121
        return symbols.get(op_name, op_name)
×
3122

3123
    def _handle_matmul_helper(self, left_node, right_node):
4✔
3124
        if not self.la_handler:
4✔
3125
            raise RuntimeError("LinearAlgebraHandler not initialized")
×
3126

3127
        res_a = self.la_handler.parse_arg(left_node)
4✔
3128
        res_b = self.la_handler.parse_arg(right_node)
4✔
3129

3130
        if not res_a[0]:
4✔
3131
            left_name = self.visit(left_node)
4✔
3132
            left_node = ast.Name(id=left_name)
4✔
3133
            res_a = self.la_handler.parse_arg(left_node)
4✔
3134

3135
        if not res_b[0]:
4✔
3136
            right_name = self.visit(right_node)
4✔
3137
            right_node = ast.Name(id=right_name)
4✔
3138
            res_b = self.la_handler.parse_arg(right_node)
4✔
3139

3140
        name_a, subset_a, shape_a, indices_a = res_a
4✔
3141
        name_b, subset_b, shape_b, indices_b = res_b
4✔
3142

3143
        if not name_a or not name_b:
4✔
3144
            raise NotImplementedError("Could not resolve matmul operands")
×
3145

3146
        real_shape_a = shape_a
4✔
3147
        real_shape_b = shape_b
4✔
3148

3149
        ndim_a = len(real_shape_a)
4✔
3150
        ndim_b = len(real_shape_b)
4✔
3151

3152
        output_shape = []
4✔
3153
        is_scalar = False
4✔
3154

3155
        if ndim_a == 1 and ndim_b == 1:
4✔
3156
            is_scalar = True
4✔
3157
            output_shape = []
4✔
3158
        elif ndim_a == 2 and ndim_b == 2:
4✔
3159
            output_shape = [real_shape_a[0], real_shape_b[1]]
4✔
3160
        elif ndim_a == 2 and ndim_b == 1:
4✔
3161
            output_shape = [real_shape_a[0]]
4✔
3162
        elif ndim_a == 1 and ndim_b == 2:
4✔
3163
            output_shape = [real_shape_b[1]]
×
3164
        elif ndim_a > 2 or ndim_b > 2:
4✔
3165
            if ndim_a == ndim_b:
4✔
3166
                output_shape = list(real_shape_a[:-2]) + [
4✔
3167
                    real_shape_a[-2],
3168
                    real_shape_b[-1],
3169
                ]
3170
            else:
3171
                raise NotImplementedError(
×
3172
                    "Broadcasting with different ranks not fully supported yet"
3173
                )
3174
        else:
3175
            raise NotImplementedError(
×
3176
                f"Matmul with ranks {ndim_a} and {ndim_b} not supported"
3177
            )
3178

3179
        # Infer dtype from input arrays (promote if different)
3180
        dtype_a = self._get_dtype(name_a)
4✔
3181
        dtype_b = self._get_dtype(name_b)
4✔
3182
        dtype = self._promote_dtypes(dtype_a, dtype_b)
4✔
3183

3184
        if is_scalar:
4✔
3185
            tmp_name = f"_tmp_{self._get_unique_id()}"
4✔
3186
            self.builder.add_container(tmp_name, dtype, False)
4✔
3187
            self.symbol_table[tmp_name] = dtype
4✔
3188
        else:
3189
            tmp_name = self._create_array_temp(output_shape, dtype)
4✔
3190

3191
        if ndim_a > 2 or ndim_b > 2:
4✔
3192
            # Generate loops for broadcasting
3193
            batch_dims = ndim_a - 2
4✔
3194
            loop_vars = []
4✔
3195

3196
            for i in range(batch_dims):
4✔
3197
                loop_var = f"_i{self._get_unique_id()}"
4✔
3198
                self.builder.add_container(loop_var, Scalar(PrimitiveType.Int64), False)
4✔
3199
                loop_vars.append(loop_var)
4✔
3200
                dim_size = real_shape_a[i]
4✔
3201
                self.builder.begin_for(loop_var, "0", str(dim_size), "1")
4✔
3202

3203
            def make_slice(name, indices):
4✔
3204
                elts = []
4✔
3205
                for idx in indices:
4✔
3206
                    if idx == ":":
4✔
3207
                        elts.append(ast.Slice())
4✔
3208
                    else:
3209
                        elts.append(ast.Name(id=idx))
4✔
3210

3211
                return ast.Subscript(
4✔
3212
                    value=ast.Name(id=name), slice=ast.Tuple(elts=elts), ctx=ast.Load()
3213
                )
3214

3215
            indices = loop_vars + [":", ":"]
4✔
3216
            slice_a = make_slice(name_a, indices)
4✔
3217
            slice_b = make_slice(name_b, indices)
4✔
3218
            slice_c = make_slice(tmp_name, indices)
4✔
3219

3220
            self.la_handler.handle_gemm(
4✔
3221
                slice_c, ast.BinOp(left=slice_a, op=ast.MatMult(), right=slice_b)
3222
            )
3223

3224
            for _ in range(batch_dims):
4✔
3225
                self.builder.end_for()
4✔
3226
        else:
3227
            if is_scalar:
4✔
3228
                self.la_handler.handle_dot(
4✔
3229
                    tmp_name,
3230
                    ast.BinOp(left=left_node, op=ast.MatMult(), right=right_node),
3231
                )
3232
            else:
3233
                self.la_handler.handle_gemm(
4✔
3234
                    tmp_name,
3235
                    ast.BinOp(left=left_node, op=ast.MatMult(), right=right_node),
3236
                )
3237

3238
        return tmp_name
4✔
3239

3240
    def _handle_numpy_unary_op(self, node, func_name):
4✔
3241
        args = [self.visit(arg) for arg in node.args]
4✔
3242
        if len(args) != 1:
4✔
3243
            raise NotImplementedError(f"Numpy function {func_name} requires 1 argument")
×
3244

3245
        op_name = func_name
4✔
3246
        if op_name == "absolute":
4✔
3247
            op_name = "abs"
×
3248

3249
        return self._handle_array_unary_op(op_name, args[0])
4✔
3250

3251
    def _handle_numpy_reduce(self, node, func_name):
4✔
3252
        args = node.args
4✔
3253
        keywords = {kw.arg: kw.value for kw in node.keywords}
4✔
3254

3255
        array_node = args[0]
4✔
3256
        array_name = self.visit(array_node)
4✔
3257

3258
        if array_name not in self.array_info:
4✔
3259
            raise ValueError(f"Reduction input must be an array, got {array_name}")
×
3260

3261
        input_shape = self.array_info[array_name]["shapes"]
4✔
3262
        ndim = len(input_shape)
4✔
3263

3264
        axis = None
4✔
3265
        if len(args) > 1:
4✔
3266
            axis = args[1]
×
3267
        elif "axis" in keywords:
4✔
3268
            axis = keywords["axis"]
4✔
3269

3270
        keepdims = False
4✔
3271
        if "keepdims" in keywords:
4✔
3272
            keepdims_node = keywords["keepdims"]
4✔
3273
            if isinstance(keepdims_node, ast.Constant):
4✔
3274
                keepdims = bool(keepdims_node.value)
4✔
3275

3276
        axes = []
4✔
3277
        if axis is None:
4✔
3278
            axes = list(range(ndim))
4✔
3279
        elif isinstance(axis, ast.Constant):  # Single axis
4✔
3280
            val = axis.value
4✔
3281
            if val < 0:
4✔
3282
                val += ndim
×
3283
            axes = [val]
4✔
3284
        elif isinstance(axis, ast.Tuple):  # Multiple axes
×
3285
            for elt in axis.elts:
×
3286
                if isinstance(elt, ast.Constant):
×
3287
                    val = elt.value
×
3288
                    if val < 0:
×
3289
                        val += ndim
×
3290
                    axes.append(val)
×
3291
        elif (
×
3292
            isinstance(axis, ast.UnaryOp)
3293
            and isinstance(axis.op, ast.USub)
3294
            and isinstance(axis.operand, ast.Constant)
3295
        ):
3296
            val = -axis.operand.value
×
3297
            if val < 0:
×
3298
                val += ndim
×
3299
            axes = [val]
×
3300
        else:
3301
            # Try to evaluate simple expression
3302
            try:
×
3303
                val = int(self.visit(axis))
×
3304
                if val < 0:
×
3305
                    val += ndim
×
3306
                axes = [val]
×
3307
            except:
×
3308
                raise NotImplementedError("Dynamic axis not supported")
×
3309

3310
        # Calculate output shape
3311
        output_shape = []
4✔
3312
        for i in range(ndim):
4✔
3313
            if i in axes:
4✔
3314
                if keepdims:
4✔
3315
                    output_shape.append("1")
4✔
3316
            else:
3317
                output_shape.append(input_shape[i])
4✔
3318

3319
        dtype = self._get_dtype(array_name)
4✔
3320

3321
        if not output_shape:
4✔
3322
            tmp_name = f"_tmp_{self._get_unique_id()}"
4✔
3323
            self.builder.add_container(tmp_name, dtype, False)
4✔
3324
            self.symbol_table[tmp_name] = dtype
4✔
3325
            self.array_info[tmp_name] = {"ndim": 0, "shapes": []}
4✔
3326
        else:
3327
            tmp_name = self._create_array_temp(output_shape, dtype)
4✔
3328

3329
        self.builder.add_reduce_op(
4✔
3330
            func_name, array_name, tmp_name, input_shape, axes, keepdims
3331
        )
3332

3333
        return tmp_name
4✔
3334

3335
    def _handle_numpy_astype(self, node, array_name):
4✔
3336
        """Handle numpy array.astype(dtype) method calls."""
3337
        if len(node.args) < 1:
4✔
3338
            raise ValueError("astype requires at least one argument (dtype)")
×
3339

3340
        dtype_arg = node.args[0]
4✔
3341
        target_dtype = self._map_numpy_dtype(dtype_arg)
4✔
3342

3343
        # Get input array shape
3344
        if array_name not in self.array_info:
4✔
3345
            raise ValueError(f"Array {array_name} not found in array_info")
×
3346

3347
        input_shape = self.array_info[array_name]["shapes"]
4✔
3348

3349
        # Create output array with target dtype
3350
        tmp_name = self._create_array_temp(input_shape, target_dtype)
4✔
3351

3352
        # Add cast operation
3353
        self.builder.add_cast_op(
4✔
3354
            array_name, tmp_name, input_shape, target_dtype.primitive_type
3355
        )
3356

3357
        return tmp_name
4✔
3358

3359
    def _handle_numpy_copy(self, node, array_name):
4✔
3360
        """Handle numpy array.copy() method calls using memcpy."""
3361
        if array_name not in self.array_info:
4✔
3362
            raise ValueError(f"Array {array_name} not found in array_info")
×
3363

3364
        input_shape = self.array_info[array_name]["shapes"]
4✔
3365

3366
        # Get element type from array
3367
        element_type = Scalar(PrimitiveType.Double)  # Default
4✔
3368
        if array_name in self.symbol_table:
4✔
3369
            sym_type = self.symbol_table[array_name]
4✔
3370
            if isinstance(sym_type, Pointer) and sym_type.has_pointee_type():
4✔
3371
                element_type = sym_type.pointee_type
4✔
3372

3373
        # Create output array with same dtype
3374
        tmp_name = self._create_array_temp(input_shape, element_type)
4✔
3375

3376
        # Calculate total number of bytes to copy
3377
        # count = total_elements * sizeof(element_type)
3378
        total_elements = " * ".join([f"({s})" for s in input_shape])
4✔
3379
        element_size = self.builder.get_sizeof(element_type)
4✔
3380
        count_expr = f"({total_elements}) * ({element_size})"
4✔
3381

3382
        # Get pointer type for memlets
3383
        ptr_type = Pointer(element_type)
4✔
3384

3385
        # Add memcpy operation
3386
        block = self.builder.add_block()
4✔
3387
        t_src = self.builder.add_access(block, array_name)
4✔
3388
        t_dst = self.builder.add_access(block, tmp_name)
4✔
3389
        t_memcpy = self.builder.add_memcpy(block, count_expr)
4✔
3390

3391
        # Connect source and destination
3392
        self.builder.add_memlet(block, t_src, "void", t_memcpy, "_src", "", ptr_type)
4✔
3393
        self.builder.add_memlet(block, t_memcpy, "_dst", t_dst, "void", "", ptr_type)
4✔
3394

3395
        return tmp_name
4✔
3396

3397
    def _handle_scipy_softmax(self, node, func_name):
4✔
3398
        args = node.args
4✔
3399
        keywords = {kw.arg: kw.value for kw in node.keywords}
4✔
3400

3401
        array_node = args[0]
4✔
3402
        array_name = self.visit(array_node)
4✔
3403

3404
        if array_name not in self.array_info:
4✔
3405
            raise ValueError(f"Softmax input must be an array, got {array_name}")
×
3406

3407
        input_shape = self.array_info[array_name]["shapes"]
4✔
3408
        ndim = len(input_shape)
4✔
3409

3410
        axis = None
4✔
3411
        if len(args) > 1:
4✔
3412
            axis = args[1]
×
3413
        elif "axis" in keywords:
4✔
3414
            axis = keywords["axis"]
4✔
3415

3416
        axes = []
4✔
3417
        if axis is None:
4✔
3418
            axes = list(range(ndim))
4✔
3419
        elif isinstance(axis, ast.Constant):  # Single axis
4✔
3420
            val = axis.value
4✔
3421
            if val < 0:
4✔
3422
                val += ndim
×
3423
            axes = [val]
4✔
3424
        elif isinstance(axis, ast.Tuple):  # Multiple axes
×
3425
            for elt in axis.elts:
×
3426
                if isinstance(elt, ast.Constant):
×
3427
                    val = elt.value
×
3428
                    if val < 0:
×
3429
                        val += ndim
×
3430
                    axes.append(val)
×
3431
        elif (
×
3432
            isinstance(axis, ast.UnaryOp)
3433
            and isinstance(axis.op, ast.USub)
3434
            and isinstance(axis.operand, ast.Constant)
3435
        ):
3436
            val = -axis.operand.value
×
3437
            if val < 0:
×
3438
                val += ndim
×
3439
            axes = [val]
×
3440
        else:
3441
            # Try to evaluate simple expression
3442
            try:
×
3443
                val = int(self.visit(axis))
×
3444
                if val < 0:
×
3445
                    val += ndim
×
3446
                axes = [val]
×
3447
            except:
×
3448
                raise NotImplementedError("Dynamic axis not supported")
×
3449

3450
        # Create output array
3451
        # Assume double for now, or infer from input
3452
        dtype = Scalar(PrimitiveType.Double)  # TODO: infer
4✔
3453

3454
        tmp_name = self._create_array_temp(input_shape, dtype)
4✔
3455

3456
        self.builder.add_reduce_op(
4✔
3457
            func_name, array_name, tmp_name, input_shape, axes, False
3458
        )
3459

3460
        return tmp_name
4✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc