• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

daisytuner / docc / 21746009216

06 Feb 2026 09:43AM UTC coverage: 66.359% (-0.1%) from 66.484%
21746009216

push

github

web-flow
Merge pull request #506 from daisytuner/npbench-crc16

adds crc16 npbench benchmark

53 of 130 new or added lines in 3 files covered. (40.77%)

1 existing line in 1 file now uncovered.

23114 of 34832 relevant lines covered (66.36%)

375.27 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

81.07
/python/docc/python/expression_visitor.py
1
import ast
4✔
2
import inspect
4✔
3
import textwrap
4✔
4
from docc.sdfg import (
4✔
5
    Scalar,
6
    PrimitiveType,
7
    Pointer,
8
    Type,
9
    DebugInfo,
10
    Structure,
11
    TaskletCode,
12
    CMathFunction,
13
)
14

15

16
class ExpressionVisitor(ast.NodeVisitor):
4✔
17
    def __init__(
4✔
18
        self,
19
        array_info=None,
20
        builder=None,
21
        symbol_table=None,
22
        globals_dict=None,
23
        inliner=None,
24
        unique_counter_ref=None,
25
        structure_member_info=None,
26
    ):
27
        self.array_info = array_info if array_info is not None else {}
4✔
28
        self.builder = builder
4✔
29
        self.symbol_table = symbol_table if symbol_table is not None else {}
4✔
30
        self.globals_dict = globals_dict if globals_dict is not None else {}
4✔
31
        self.inliner = inliner
4✔
32
        self._unique_counter_ref = (
4✔
33
            unique_counter_ref if unique_counter_ref is not None else [0]
34
        )
35
        self._access_cache = {}
4✔
36
        self.la_handler = None
4✔
37
        self.structure_member_info = (
4✔
38
            structure_member_info if structure_member_info is not None else {}
39
        )
40
        self._init_numpy_handlers()
4✔
41

42
    def _get_unique_id(self):
4✔
43
        self._unique_counter_ref[0] += 1
4✔
44
        return self._unique_counter_ref[0]
4✔
45

46
    def _get_temp_name(self, prefix="_tmp_"):
4✔
47
        if hasattr(self.builder, "find_new_name"):
4✔
48
            return self.builder.find_new_name(prefix)
×
49
        return f"{prefix}{self._get_unique_id()}"
4✔
50

51
    def _is_indirect_access(self, node):
4✔
52
        """Check if a node represents an indirect array access (e.g., A[B[i]]).
53

54
        Returns True if the node is a subscript where the index itself is a subscript
55
        into an array (indirect access pattern).
56
        """
57
        if not isinstance(node, ast.Subscript):
×
58
            return False
×
59
        # Check if value is a subscripted array access
60
        if isinstance(node.value, ast.Name):
×
61
            arr_name = node.value.id
×
62
            if arr_name in self.array_info:
×
63
                # Check if slice/index is itself an array access
64
                if isinstance(node.slice, ast.Subscript):
×
65
                    if isinstance(node.slice.value, ast.Name):
×
66
                        idx_arr_name = node.slice.value.id
×
67
                        if idx_arr_name in self.array_info:
×
68
                            return True
×
69
        return False
×
70

71
    def _contains_indirect_access(self, node):
4✔
72
        """Check if an AST node contains any indirect array access.
73

74
        Used to detect expressions like A_row[i] that would be used as slice bounds.
75
        """
76
        if isinstance(node, ast.Subscript):
4✔
77
            if isinstance(node.value, ast.Name):
4✔
78
                arr_name = node.value.id
4✔
79
                if arr_name in self.array_info:
4✔
80
                    return True
4✔
81
        elif isinstance(node, ast.BinOp):
4✔
82
            return self._contains_indirect_access(
4✔
83
                node.left
84
            ) or self._contains_indirect_access(node.right)
85
        elif isinstance(node, ast.UnaryOp):
4✔
86
            return self._contains_indirect_access(node.operand)
4✔
87
        return False
4✔
88

89
    def _materialize_indirect_access(
4✔
90
        self, node, debug_info=None, return_original_expr=False
91
    ):
92
        """Materialize an array access into a scalar variable using tasklet+memlets.
93

94
        For indirect memory access patterns in SDFGs, we need to:
95
        1. Create a scalar container for the result
96
        2. Create a tasklet that performs the assignment
97
        3. Use memlets to read from the array and write to the scalar
98
        4. Return the scalar name (which can be used as a symbolic expression)
99

100
        This is the canonical SDFG pattern for indirect access.
101

102
        If return_original_expr is True, also returns the original array access
103
        expression using parentheses notation (e.g., "A_row(0)") which is consistent
104
        with SDFG subset notation. The runtime evaluator will convert this to
105
        bracket notation for Python evaluation.
106
        """
107
        if not self.builder:
4✔
108
            # Without builder, just return the expression string
109
            expr = self.visit(node)
×
110
            return (expr, expr) if return_original_expr else expr
×
111

112
        if debug_info is None:
4✔
113
            debug_info = DebugInfo()
4✔
114

115
        if not isinstance(node, ast.Subscript):
4✔
116
            expr = self.visit(node)
×
117
            return (expr, expr) if return_original_expr else expr
×
118

119
        if not isinstance(node.value, ast.Name):
4✔
120
            expr = self.visit(node)
×
121
            return (expr, expr) if return_original_expr else expr
×
122

123
        arr_name = node.value.id
4✔
124
        if arr_name not in self.array_info:
4✔
125
            expr = self.visit(node)
×
126
            return (expr, expr) if return_original_expr else expr
×
127

128
        # Determine the element type
129
        dtype = Scalar(PrimitiveType.Int64)  # Default for indices
4✔
130
        if arr_name in self.symbol_table:
4✔
131
            t = self.symbol_table[arr_name]
4✔
132
            if isinstance(t, Pointer) and t.has_pointee_type():
4✔
133
                dtype = t.pointee_type
4✔
134

135
        # Create scalar container for the result
136
        tmp_name = self._get_temp_name("_idx_")
4✔
137
        self.builder.add_container(tmp_name, dtype, False)
4✔
138
        self.symbol_table[tmp_name] = dtype
4✔
139

140
        # Get the index expression
141
        ndim = self.array_info[arr_name]["ndim"]
4✔
142
        shapes = self.array_info[arr_name].get("shapes", [])
4✔
143

144
        # Compute linear index from the subscript
145
        if isinstance(node.slice, ast.Tuple):
4✔
146
            indices = [self.visit(elt) for elt in node.slice.elts]
×
147
        else:
148
            indices = [self.visit(node.slice)]
4✔
149

150
        # Handle cases where we need recursive materialization
151
        materialized_indices = []
4✔
152
        for i, idx_str in enumerate(indices):
4✔
153
            # Check if the index itself needs materialization (nested indirect)
154
            # This happens when idx_str looks like an array access e.g., "arr(i)"
155
            if "(" in idx_str and idx_str.endswith(")"):
4✔
156
                # This is an array access, it should already be a valid symbolic expression
157
                # or a scalar variable name
158
                materialized_indices.append(idx_str)
×
159
            else:
160
                materialized_indices.append(idx_str)
4✔
161

162
        # Compute linear index
163
        linear_index = self._compute_linear_index(
4✔
164
            materialized_indices, shapes, arr_name, ndim
165
        )
166

167
        # Create block with tasklet and memlets
168
        block = self.builder.add_block(debug_info)
4✔
169
        t_src = self.builder.add_access(block, arr_name, debug_info)
4✔
170
        t_dst = self.builder.add_access(block, tmp_name, debug_info)
4✔
171
        t_task = self.builder.add_tasklet(
4✔
172
            block, TaskletCode.assign, ["_in"], ["_out"], debug_info
173
        )
174

175
        self.builder.add_memlet(
4✔
176
            block, t_src, "void", t_task, "_in", linear_index, None, debug_info
177
        )
178
        self.builder.add_memlet(
4✔
179
            block, t_task, "_out", t_dst, "void", "", None, debug_info
180
        )
181

182
        if return_original_expr:
4✔
183
            # Return both the materialized variable name and the original array access expression
184
            # Use parentheses notation which is consistent with SDFG subset syntax
185
            original_expr = f"{arr_name}({linear_index})"
4✔
186
            return (tmp_name, original_expr)
4✔
187

188
        return tmp_name
×
189

190
    def _init_numpy_handlers(self):
4✔
191
        self.numpy_handlers = {
4✔
192
            "empty": self._handle_numpy_alloc,
193
            "empty_like": self._handle_numpy_empty_like,
194
            "zeros": self._handle_numpy_alloc,
195
            "zeros_like": self._handle_numpy_zeros_like,
196
            "ones": self._handle_numpy_alloc,
197
            "ndarray": self._handle_numpy_alloc,  # np.ndarray() constructor
198
            "eye": self._handle_numpy_eye,
199
            "add": self._handle_numpy_binary_op,
200
            "subtract": self._handle_numpy_binary_op,
201
            "multiply": self._handle_numpy_binary_op,
202
            "divide": self._handle_numpy_binary_op,
203
            "power": self._handle_numpy_binary_op,
204
            "exp": self._handle_numpy_unary_op,
205
            "abs": self._handle_numpy_unary_op,
206
            "absolute": self._handle_numpy_unary_op,
207
            "sqrt": self._handle_numpy_unary_op,
208
            "tanh": self._handle_numpy_unary_op,
209
            "sum": self._handle_numpy_reduce,
210
            "max": self._handle_numpy_reduce,
211
            "min": self._handle_numpy_reduce,
212
            "mean": self._handle_numpy_reduce,
213
            "std": self._handle_numpy_reduce,
214
            "matmul": self._handle_numpy_matmul,
215
            "dot": self._handle_numpy_matmul,
216
            "matvec": self._handle_numpy_matmul,
217
            "outer": self._handle_numpy_outer,
218
            "minimum": self._handle_numpy_binary_op,
219
            "maximum": self._handle_numpy_binary_op,
220
            "where": self._handle_numpy_where,
221
            "clip": self._handle_numpy_clip,
222
        }
223

224
    def generic_visit(self, node):
4✔
225
        return super().generic_visit(node)
×
226

227
    def visit_Constant(self, node):
4✔
228
        if isinstance(node.value, bool):
4✔
229
            return "true" if node.value else "false"
×
230
        return str(node.value)
4✔
231

232
    def visit_Name(self, node):
4✔
233
        name = node.id
4✔
234
        # Check if it's a global constant (not a local variable/array)
235
        if name not in self.symbol_table and self.globals_dict is not None:
4✔
236
            if name in self.globals_dict:
4✔
237
                val = self.globals_dict[name]
4✔
238
                # Only substitute simple numeric constants
239
                if isinstance(val, (int, float)):
4✔
240
                    return str(val)
4✔
241
        return name
4✔
242

243
    def _map_numpy_dtype(self, dtype_node):
4✔
244
        # Default to double
245
        if dtype_node is None:
4✔
246
            return Scalar(PrimitiveType.Double)
×
247

248
        if isinstance(dtype_node, ast.Name):
4✔
249
            if dtype_node.id == "float":
4✔
250
                return Scalar(PrimitiveType.Double)
4✔
251
            if dtype_node.id == "int":
4✔
252
                return Scalar(PrimitiveType.Int64)
4✔
253
            if dtype_node.id == "bool":
×
254
                return Scalar(PrimitiveType.Bool)
×
255

256
        if isinstance(dtype_node, ast.Attribute):
4✔
257
            # Handle array.dtype
258
            if (
4✔
259
                isinstance(dtype_node.value, ast.Name)
260
                and dtype_node.value.id in self.symbol_table
261
                and dtype_node.attr == "dtype"
262
            ):
263
                sym_type = self.symbol_table[dtype_node.value.id]
4✔
264
                if isinstance(sym_type, Pointer) and sym_type.has_pointee_type():
4✔
265
                    return sym_type.pointee_type
4✔
266

267
            if isinstance(dtype_node.value, ast.Name) and dtype_node.value.id in [
4✔
268
                "numpy",
269
                "np",
270
            ]:
271
                if dtype_node.attr == "float64":
4✔
272
                    return Scalar(PrimitiveType.Double)
4✔
273
                if dtype_node.attr == "float32":
4✔
274
                    return Scalar(PrimitiveType.Float)
4✔
275
                if dtype_node.attr == "int64":
4✔
276
                    return Scalar(PrimitiveType.Int64)
4✔
277
                if dtype_node.attr == "int32":
4✔
278
                    return Scalar(PrimitiveType.Int32)
4✔
279
                if dtype_node.attr == "bool_":
×
280
                    return Scalar(PrimitiveType.Bool)
×
281

282
        # Fallback
283
        return Scalar(PrimitiveType.Double)
×
284

285
    def _is_int(self, operand):
4✔
286
        try:
4✔
287
            if operand.lstrip("-").isdigit():
4✔
288
                return True
4✔
289
        except ValueError:
×
290
            pass
×
291

292
        name = operand
4✔
293
        if "(" in operand and operand.endswith(")"):
4✔
294
            name = operand.split("(")[0]
4✔
295

296
        if name in self.symbol_table:
4✔
297
            t = self.symbol_table[name]
4✔
298

299
            def is_int_ptype(pt):
4✔
300
                return pt in [
4✔
301
                    PrimitiveType.Int64,
302
                    PrimitiveType.Int32,
303
                    PrimitiveType.Int8,
304
                    PrimitiveType.Int16,
305
                    PrimitiveType.UInt64,
306
                    PrimitiveType.UInt32,
307
                    PrimitiveType.UInt8,
308
                    PrimitiveType.UInt16,
309
                ]
310

311
            if isinstance(t, Scalar):
4✔
312
                return is_int_ptype(t.primitive_type)
4✔
313

314
            if type(t).__name__ == "Array" and hasattr(t, "element_type"):
4✔
315
                et = t.element_type
×
316
                if callable(et):
×
317
                    et = et()
×
318
                if isinstance(et, Scalar):
×
319
                    return is_int_ptype(et.primitive_type)
×
320

321
            if type(t).__name__ == "Pointer":
4✔
322
                if hasattr(t, "pointee_type"):
4✔
323
                    et = t.pointee_type
4✔
324
                    if callable(et):
4✔
325
                        et = et()
×
326
                    if isinstance(et, Scalar):
4✔
327
                        return is_int_ptype(et.primitive_type)
4✔
328
                # Fallback: check if it has element_type (maybe alias?)
329
                if hasattr(t, "element_type"):
×
330
                    et = t.element_type
×
331
                    if callable(et):
×
332
                        et = et()
×
333
                    if isinstance(et, Scalar):
×
334
                        return is_int_ptype(et.primitive_type)
×
335

336
        return False
4✔
337

338
    def _add_read(self, block, expr_str, debug_info=None):
4✔
339
        # Try to reuse access node
340
        try:
4✔
341
            if (block, expr_str) in self._access_cache:
4✔
342
                return self._access_cache[(block, expr_str)]
4✔
343
        except TypeError:
×
344
            # block might not be hashable
345
            pass
×
346

347
        if debug_info is None:
4✔
348
            debug_info = DebugInfo()
4✔
349

350
        if "(" in expr_str and expr_str.endswith(")"):
4✔
351
            name = expr_str.split("(")[0]
4✔
352
            subset = expr_str[expr_str.find("(") + 1 : -1]
4✔
353
            access = self.builder.add_access(block, name, debug_info)
4✔
354
            try:
4✔
355
                self._access_cache[(block, expr_str)] = (access, subset)
4✔
356
            except TypeError:
×
357
                pass
×
358
            return access, subset
4✔
359

360
        if self.builder.exists(expr_str):
4✔
361
            access = self.builder.add_access(block, expr_str, debug_info)
4✔
362
            # For pointer types representing 0-D arrays, dereference with "0"
363
            subset = ""
4✔
364
            if expr_str in self.symbol_table:
4✔
365
                sym_type = self.symbol_table[expr_str]
4✔
366
                if isinstance(sym_type, Pointer):
4✔
367
                    # Check if it's a 0-D array (scalar wrapped in pointer)
368
                    if expr_str in self.array_info:
×
369
                        ndim = self.array_info[expr_str].get("ndim", 0)
×
370
                        if ndim == 0:
×
371
                            subset = "0"
×
372
                    else:
373
                        # Pointer without array_info is treated as 0-D
374
                        subset = "0"
×
375
            try:
4✔
376
                self._access_cache[(block, expr_str)] = (access, subset)
4✔
377
            except TypeError:
×
378
                pass
×
379
            return access, subset
4✔
380

381
        dtype = Scalar(PrimitiveType.Double)
4✔
382
        if self._is_int(expr_str):
4✔
383
            dtype = Scalar(PrimitiveType.Int64)
4✔
384
        elif expr_str == "true" or expr_str == "false":
4✔
385
            dtype = Scalar(PrimitiveType.Bool)
×
386

387
        const_node = self.builder.add_constant(block, expr_str, dtype, debug_info)
4✔
388
        try:
4✔
389
            self._access_cache[(block, expr_str)] = (const_node, "")
4✔
390
        except TypeError:
×
391
            pass
×
392
        return const_node, ""
4✔
393

394
    def _handle_min_max(self, node, func_name):
4✔
395
        args = [self.visit(arg) for arg in node.args]
4✔
396
        if len(args) != 2:
4✔
397
            raise NotImplementedError(f"{func_name} only supported with 2 arguments")
×
398

399
        # Check types
400
        is_float = False
4✔
401
        arg_types = []
4✔
402

403
        for arg in args:
4✔
404
            name = arg
4✔
405
            if "(" in arg and arg.endswith(")"):
4✔
406
                name = arg.split("(")[0]
×
407

408
            if name in self.symbol_table:
4✔
409
                t = self.symbol_table[name]
4✔
410
                if isinstance(t, Pointer):
4✔
411
                    t = t.base_type
×
412

413
                if t.primitive_type == PrimitiveType.Double:
4✔
414
                    is_float = True
4✔
415
                    arg_types.append(PrimitiveType.Double)
4✔
416
                else:
417
                    arg_types.append(PrimitiveType.Int64)
4✔
418
            elif self._is_int(arg):
×
419
                arg_types.append(PrimitiveType.Int64)
×
420
            else:
421
                # Assume float constant
422
                is_float = True
×
423
                arg_types.append(PrimitiveType.Double)
×
424

425
        dtype = Scalar(PrimitiveType.Double if is_float else PrimitiveType.Int64)
4✔
426

427
        tmp_name = self._get_temp_name("_tmp_")
4✔
428
        self.builder.add_container(tmp_name, dtype, False)
4✔
429
        self.symbol_table[tmp_name] = dtype
4✔
430

431
        if is_float:
4✔
432
            # Cast args if necessary
433
            casted_args = []
4✔
434
            for i, arg in enumerate(args):
4✔
435
                if arg_types[i] != PrimitiveType.Double:
4✔
436
                    # Create temp double
437
                    tmp_cast = self._get_temp_name("_cast_")
4✔
438
                    self.builder.add_container(
4✔
439
                        tmp_cast, Scalar(PrimitiveType.Double), False
440
                    )
441
                    self.symbol_table[tmp_cast] = Scalar(PrimitiveType.Double)
4✔
442

443
                    # Assign int to double (implicit cast)
444
                    self.builder.add_assignment(tmp_cast, arg)
4✔
445
                    casted_args.append(tmp_cast)
4✔
446
                else:
447
                    casted_args.append(arg)
4✔
448

449
            block = self.builder.add_block()
4✔
450
            t_out = self.builder.add_access(block, tmp_name)
4✔
451

452
            intrinsic_name = (
4✔
453
                CMathFunction.fmax if func_name == "max" else CMathFunction.fmin
454
            )
455
            t_task = self.builder.add_cmath(block, intrinsic_name)
4✔
456

457
            for i, arg in enumerate(casted_args):
4✔
458
                t_arg, arg_sub = self._add_read(block, arg)
4✔
459
                self.builder.add_memlet(
4✔
460
                    block, t_arg, "void", t_task, f"_in{i+1}", arg_sub
461
                )
462
        else:
463
            block = self.builder.add_block()
4✔
464
            t_out = self.builder.add_access(block, tmp_name)
4✔
465

466
            # Use int_smax/int_smin tasklet
467
            opcode = None
4✔
468
            if func_name == "max":
4✔
469
                opcode = TaskletCode.int_smax
4✔
470
            else:
471
                opcode = TaskletCode.int_smin
4✔
472
            t_task = self.builder.add_tasklet(block, opcode, ["_in1", "_in2"], ["_out"])
4✔
473

474
            for i, arg in enumerate(args):
4✔
475
                t_arg, arg_sub = self._add_read(block, arg)
4✔
476
                self.builder.add_memlet(
4✔
477
                    block, t_arg, "void", t_task, f"_in{i+1}", arg_sub
478
                )
479

480
        self.builder.add_memlet(block, t_task, "_out", t_out, "void", "")
4✔
481
        return tmp_name
4✔
482

483
    def _handle_python_cast(self, node, func_name):
4✔
484
        """Handle Python type casts: int(), float(), bool()"""
485
        if len(node.args) != 1:
4✔
486
            raise NotImplementedError(f"{func_name}() cast requires exactly 1 argument")
×
487

488
        arg = self.visit(node.args[0])
4✔
489

490
        # Determine target type based on cast function
491
        if func_name == "int":
4✔
492
            target_dtype = Scalar(PrimitiveType.Int64)
4✔
493
        elif func_name == "float":
4✔
494
            target_dtype = Scalar(PrimitiveType.Double)
4✔
495
        elif func_name == "bool":
4✔
496
            target_dtype = Scalar(PrimitiveType.Bool)
4✔
497
        else:
498
            raise NotImplementedError(f"Cast to {func_name} not supported")
×
499

500
        # Determine source type
501
        source_dtype = None
4✔
502
        name = arg
4✔
503
        if "(" in arg and arg.endswith(")"):
4✔
504
            name = arg.split("(")[0]
×
505

506
        if name in self.symbol_table:
4✔
507
            source_dtype = self.symbol_table[name]
4✔
508
            if isinstance(source_dtype, Pointer):
4✔
509
                source_dtype = source_dtype.base_type
×
510
        elif self._is_int(arg):
×
511
            source_dtype = Scalar(PrimitiveType.Int64)
×
512
        elif arg == "true" or arg == "false":
×
513
            source_dtype = Scalar(PrimitiveType.Bool)
×
514
        else:
515
            # Assume float constant
516
            source_dtype = Scalar(PrimitiveType.Double)
×
517

518
        # Create temporary variable for result
519
        tmp_name = self._get_temp_name("_tmp_")
4✔
520
        self.builder.add_container(tmp_name, target_dtype, False)
4✔
521
        self.symbol_table[tmp_name] = target_dtype
4✔
522

523
        # Use tasklet assign opcode for casting (as specified in problem statement)
524
        block = self.builder.add_block()
4✔
525
        t_src, src_sub = self._add_read(block, arg)
4✔
526
        t_dst = self.builder.add_access(block, tmp_name)
4✔
527
        t_task = self.builder.add_tasklet(block, TaskletCode.assign, ["_in"], ["_out"])
4✔
528
        self.builder.add_memlet(block, t_src, "void", t_task, "_in", src_sub)
4✔
529
        self.builder.add_memlet(block, t_task, "_out", t_dst, "void", "")
4✔
530

531
        return tmp_name
4✔
532

533
    def visit_Call(self, node):
4✔
534
        func_name = ""
4✔
535
        module_name = ""
4✔
536
        if isinstance(node.func, ast.Attribute):
4✔
537
            if isinstance(node.func.value, ast.Name):
4✔
538
                if node.func.value.id == "math":
4✔
539
                    module_name = "math"
4✔
540
                    func_name = node.func.attr
4✔
541
                elif node.func.value.id in ["numpy", "np"]:
4✔
542
                    module_name = "numpy"
4✔
543
                    func_name = node.func.attr
4✔
544
                else:
545
                    # Check if it's a method call on an array (e.g., arr.astype(...), arr.copy())
546
                    array_name = node.func.value.id
4✔
547
                    method_name = node.func.attr
4✔
548
                    if array_name in self.array_info and method_name == "astype":
4✔
549
                        return self._handle_numpy_astype(node, array_name)
4✔
550
                    elif array_name in self.array_info and method_name == "copy":
4✔
551
                        return self._handle_numpy_copy(node, array_name)
4✔
552
            elif isinstance(node.func.value, ast.Attribute):
4✔
553
                if (
4✔
554
                    isinstance(node.func.value.value, ast.Name)
555
                    and node.func.value.value.id == "scipy"
556
                    and node.func.value.attr == "special"
557
                ):
558
                    if node.func.attr == "softmax":
4✔
559
                        return self._handle_scipy_softmax(node, "softmax")
4✔
560
                # Handle np.add.outer, np.subtract.outer, np.multiply.outer, etc.
561
                elif (
4✔
562
                    isinstance(node.func.value.value, ast.Name)
563
                    and node.func.value.value.id in ["numpy", "np"]
564
                    and node.func.attr == "outer"
565
                ):
566
                    ufunc_name = node.func.value.attr  # "add", "subtract", etc.
4✔
567
                    return self._handle_ufunc_outer(node, ufunc_name)
4✔
568

569
        elif isinstance(node.func, ast.Name):
4✔
570
            func_name = node.func.id
4✔
571

572
        if module_name == "numpy":
4✔
573
            if func_name in self.numpy_handlers:
4✔
574
                return self.numpy_handlers[func_name](node, func_name)
4✔
575

576
        if func_name in ["max", "min"]:
4✔
577
            return self._handle_min_max(node, func_name)
4✔
578

579
        # Handle Python type casts (int, float, bool)
580
        if func_name in ["int", "float", "bool"]:
4✔
581
            return self._handle_python_cast(node, func_name)
4✔
582

583
        math_funcs = {
4✔
584
            # Trigonometric functions
585
            "sin": CMathFunction.sin,
586
            "cos": CMathFunction.cos,
587
            "tan": CMathFunction.tan,
588
            "asin": CMathFunction.asin,
589
            "acos": CMathFunction.acos,
590
            "atan": CMathFunction.atan,
591
            "atan2": CMathFunction.atan2,
592
            # Hyperbolic functions
593
            "sinh": CMathFunction.sinh,
594
            "cosh": CMathFunction.cosh,
595
            "tanh": CMathFunction.tanh,
596
            "asinh": CMathFunction.asinh,
597
            "acosh": CMathFunction.acosh,
598
            "atanh": CMathFunction.atanh,
599
            # Exponential and logarithmic functions
600
            "exp": CMathFunction.exp,
601
            "exp2": CMathFunction.exp2,
602
            "expm1": CMathFunction.expm1,
603
            "log": CMathFunction.log,
604
            "log2": CMathFunction.log2,
605
            "log10": CMathFunction.log10,
606
            "log1p": CMathFunction.log1p,
607
            # Power functions
608
            "pow": CMathFunction.pow,
609
            "sqrt": CMathFunction.sqrt,
610
            "cbrt": CMathFunction.cbrt,
611
            "hypot": CMathFunction.hypot,
612
            # Rounding and remainder functions
613
            "abs": CMathFunction.fabs,
614
            "fabs": CMathFunction.fabs,
615
            "ceil": CMathFunction.ceil,
616
            "floor": CMathFunction.floor,
617
            "trunc": CMathFunction.trunc,
618
            "fmod": CMathFunction.fmod,
619
            "remainder": CMathFunction.remainder,
620
            # Floating-point manipulation functions
621
            "copysign": CMathFunction.copysign,
622
            # Other functions
623
            "fma": CMathFunction.fma,
624
        }
625

626
        if func_name in math_funcs:
4✔
627
            args = [self.visit(arg) for arg in node.args]
4✔
628

629
            tmp_name = self._get_temp_name("_tmp_")
4✔
630
            dtype = Scalar(PrimitiveType.Double)
4✔
631
            self.builder.add_container(tmp_name, dtype, False)
4✔
632
            self.symbol_table[tmp_name] = dtype
4✔
633

634
            block = self.builder.add_block()
4✔
635
            t_out = self.builder.add_access(block, tmp_name)
4✔
636

637
            t_task = self.builder.add_cmath(block, math_funcs[func_name])
4✔
638

639
            for i, arg in enumerate(args):
4✔
640
                t_arg, arg_sub = self._add_read(block, arg)
4✔
641
                self.builder.add_memlet(
4✔
642
                    block, t_arg, "void", t_task, f"_in{i+1}", arg_sub
643
                )
644

645
            self.builder.add_memlet(block, t_task, "_out", t_out, "void", "")
4✔
646
            return tmp_name
4✔
647

648
        if func_name in self.globals_dict:
4✔
649
            obj = self.globals_dict[func_name]
4✔
650
            if inspect.isfunction(obj):
4✔
651
                return self._handle_inline_call(node, obj)
4✔
652

653
        raise NotImplementedError(f"Function call {func_name} not supported")
×
654

655
    def _handle_inline_call(self, node, func_obj):
4✔
656
        # 1. Parse function source
657
        try:
4✔
658
            source_lines, start_line = inspect.getsourcelines(func_obj)
4✔
659
            source = textwrap.dedent("".join(source_lines))
4✔
660
            tree = ast.parse(source)
4✔
661
            func_def = tree.body[0]
4✔
662
        except Exception as e:
×
663
            raise NotImplementedError(
×
664
                f"Could not parse function {func_obj.__name__}: {e}"
665
            )
666

667
        # 2. Evaluate arguments
668
        arg_vars = [self.visit(arg) for arg in node.args]
4✔
669

670
        if len(arg_vars) != len(func_def.args.args):
4✔
671
            raise NotImplementedError(
×
672
                f"Argument count mismatch for {func_obj.__name__}"
673
            )
674

675
        # 3. Generate unique suffix
676
        suffix = f"_{func_obj.__name__}_{self._get_unique_id()}"
4✔
677
        res_name = f"_res{suffix}"
4✔
678

679
        # Assume Int64 for now as match returns 0/1
680
        dtype = Scalar(PrimitiveType.Int64)
4✔
681
        self.builder.add_container(res_name, dtype, False)
4✔
682
        self.symbol_table[res_name] = dtype
4✔
683

684
        # 4. Rename variables
685
        class VariableRenamer(ast.NodeTransformer):
4✔
686
            # Builtins that should not be renamed
687
            BUILTINS = {
4✔
688
                "range",
689
                "len",
690
                "int",
691
                "float",
692
                "bool",
693
                "str",
694
                "list",
695
                "dict",
696
                "tuple",
697
                "set",
698
                "print",
699
                "abs",
700
                "min",
701
                "max",
702
                "sum",
703
                "enumerate",
704
                "zip",
705
                "map",
706
                "filter",
707
                "sorted",
708
                "reversed",
709
                "True",
710
                "False",
711
                "None",
712
            }
713

714
            def __init__(self, suffix, globals_dict):
4✔
715
                self.suffix = suffix
4✔
716
                self.globals_dict = globals_dict
4✔
717

718
            def visit_Name(self, node):
4✔
719
                # Don't rename builtins or globals
720
                if node.id in self.globals_dict or node.id in self.BUILTINS:
4✔
721
                    return node
4✔
722
                return ast.Name(id=f"{node.id}{self.suffix}", ctx=node.ctx)
4✔
723

724
            def visit_Return(self, node):
4✔
725
                if node.value:
4✔
726
                    val = self.visit(node.value)
4✔
727
                    return ast.Assign(
4✔
728
                        targets=[ast.Name(id=res_name, ctx=ast.Store())],
729
                        value=val,
730
                    )
731
                return node
×
732

733
        renamer = VariableRenamer(suffix, self.globals_dict)
4✔
734
        new_body = [renamer.visit(stmt) for stmt in func_def.body]
4✔
735

736
        # 5. Assign arguments to parameters
737
        param_assignments = []
4✔
738
        for arg_def, arg_val in zip(func_def.args.args, arg_vars):
4✔
739
            param_name = f"{arg_def.arg}{suffix}"
4✔
740

741
            # Infer type and create container
742
            if arg_val in self.symbol_table:
4✔
743
                self.symbol_table[param_name] = self.symbol_table[arg_val]
4✔
744
                self.builder.add_container(
4✔
745
                    param_name, self.symbol_table[arg_val], False
746
                )
747
                val_node = ast.Name(id=arg_val, ctx=ast.Load())
4✔
748
            elif self._is_int(arg_val):
×
749
                self.symbol_table[param_name] = Scalar(PrimitiveType.Int64)
×
750
                self.builder.add_container(
×
751
                    param_name, Scalar(PrimitiveType.Int64), False
752
                )
753
                val_node = ast.Constant(value=int(arg_val))
×
754
            else:
755
                # Assume float constant
756
                try:
×
757
                    val = float(arg_val)
×
758
                    self.symbol_table[param_name] = Scalar(PrimitiveType.Double)
×
759
                    self.builder.add_container(
×
760
                        param_name, Scalar(PrimitiveType.Double), False
761
                    )
762
                    val_node = ast.Constant(value=val)
×
763
                except ValueError:
×
764
                    # Fallback to Name, might fail later if not in symbol table
765
                    val_node = ast.Name(id=arg_val, ctx=ast.Load())
×
766

767
            assign = ast.Assign(
4✔
768
                targets=[ast.Name(id=param_name, ctx=ast.Store())], value=val_node
769
            )
770
            param_assignments.append(assign)
4✔
771

772
        final_body = param_assignments + new_body
4✔
773

774
        # 6. Visit new body using ASTParser
775
        from .ast_parser import ASTParser
4✔
776

777
        parser = ASTParser(
4✔
778
            self.builder,
779
            self.array_info,
780
            self.symbol_table,
781
            globals_dict=self.globals_dict,
782
            unique_counter_ref=self._unique_counter_ref,
783
        )
784

785
        for stmt in final_body:
4✔
786
            parser.visit(stmt)
4✔
787

788
        return res_name
4✔
789

790
    def visit_BinOp(self, node):
4✔
791
        if isinstance(node.op, ast.MatMult):
4✔
792
            return self._handle_numpy_matmul_op(node.left, node.right)
4✔
793

794
        left = self.visit(node.left)
4✔
795
        op = self.visit(node.op)
4✔
796
        right = self.visit(node.right)
4✔
797

798
        # Check if left or right are arrays
799
        left_is_array = left in self.array_info
4✔
800
        right_is_array = right in self.array_info
4✔
801

802
        if left_is_array or right_is_array:
4✔
803
            op_map = {"+": "add", "-": "sub", "*": "mul", "/": "div", "**": "pow"}
4✔
804
            if op in op_map:
4✔
805
                return self._handle_array_binary_op(op_map[op], left, right)
4✔
806
            else:
807
                raise NotImplementedError(f"Array operation {op} not supported")
×
808

809
        tmp_name = f"_tmp_{self._get_unique_id()}"
4✔
810

811
        dtype = Scalar(PrimitiveType.Double)  # Default
4✔
812

813
        left_is_int = self._is_int(left)
4✔
814
        right_is_int = self._is_int(right)
4✔
815

816
        if left_is_int and right_is_int and op not in ["/", "**"]:
4✔
817
            dtype = Scalar(PrimitiveType.Int64)
4✔
818

819
        self.builder.add_container(tmp_name, dtype, False)
4✔
820
        self.symbol_table[tmp_name] = dtype
4✔
821

822
        real_left = left
4✔
823
        real_right = right
4✔
824

825
        if dtype.primitive_type == PrimitiveType.Double:
4✔
826
            if left_is_int:
4✔
827
                left_cast = f"_tmp_{self._get_unique_id()}"
4✔
828
                self.builder.add_container(
4✔
829
                    left_cast, Scalar(PrimitiveType.Double), False
830
                )
831
                self.symbol_table[left_cast] = Scalar(PrimitiveType.Double)
4✔
832

833
                c_block = self.builder.add_block()
4✔
834
                t_src, src_sub = self._add_read(c_block, left)
4✔
835
                t_dst = self.builder.add_access(c_block, left_cast)
4✔
836
                t_task = self.builder.add_tasklet(
4✔
837
                    c_block, TaskletCode.assign, ["_in"], ["_out"]
838
                )
839
                self.builder.add_memlet(c_block, t_src, "void", t_task, "_in", src_sub)
4✔
840
                self.builder.add_memlet(c_block, t_task, "_out", t_dst, "void", "")
4✔
841

842
                real_left = left_cast
4✔
843

844
            if right_is_int:
4✔
845
                right_cast = f"_tmp_{self._get_unique_id()}"
4✔
846
                self.builder.add_container(
4✔
847
                    right_cast, Scalar(PrimitiveType.Double), False
848
                )
849
                self.symbol_table[right_cast] = Scalar(PrimitiveType.Double)
4✔
850

851
                c_block = self.builder.add_block()
4✔
852
                t_src, src_sub = self._add_read(c_block, right)
4✔
853
                t_dst = self.builder.add_access(c_block, right_cast)
4✔
854
                t_task = self.builder.add_tasklet(
4✔
855
                    c_block, TaskletCode.assign, ["_in"], ["_out"]
856
                )
857
                self.builder.add_memlet(c_block, t_src, "void", t_task, "_in", src_sub)
4✔
858
                self.builder.add_memlet(c_block, t_task, "_out", t_dst, "void", "")
4✔
859

860
                real_right = right_cast
4✔
861

862
        # Special cases
863
        if op == "**":
4✔
864
            block = self.builder.add_block()
4✔
865
            t_left, left_sub = self._add_read(block, real_left)
4✔
866
            t_right, right_sub = self._add_read(block, real_right)
4✔
867
            t_out = self.builder.add_access(block, tmp_name)
4✔
868

869
            t_task = self.builder.add_cmath(block, CMathFunction.pow)
4✔
870
            self.builder.add_memlet(block, t_left, "void", t_task, "_in1", left_sub)
4✔
871
            self.builder.add_memlet(block, t_right, "void", t_task, "_in2", right_sub)
4✔
872
            self.builder.add_memlet(block, t_task, "_out", t_out, "void", "")
4✔
873

874
            return tmp_name
4✔
875
        elif op == "%":
4✔
876
            block = self.builder.add_block()
4✔
877
            t_left, left_sub = self._add_read(block, real_left)
4✔
878
            t_right, right_sub = self._add_read(block, real_right)
4✔
879
            t_out = self.builder.add_access(block, tmp_name)
4✔
880

881
            if dtype.primitive_type == PrimitiveType.Int64:
4✔
882
                # Implement ((a % b) + b) % b to match Python's modulo behavior
883

884
                # 1. rem1 = a % b
885
                t_rem1 = self.builder.add_tasklet(
4✔
886
                    block, TaskletCode.int_srem, ["_in1", "_in2"], ["_out"]
887
                )
888
                self.builder.add_memlet(block, t_left, "void", t_rem1, "_in1", left_sub)
4✔
889
                self.builder.add_memlet(
4✔
890
                    block, t_right, "void", t_rem1, "_in2", right_sub
891
                )
892

893
                rem1_name = f"_tmp_{self._get_unique_id()}"
4✔
894
                self.builder.add_container(rem1_name, dtype, False)
4✔
895
                t_rem1_out = self.builder.add_access(block, rem1_name)
4✔
896
                self.builder.add_memlet(block, t_rem1, "_out", t_rem1_out, "void", "")
4✔
897

898
                # 2. add = rem1 + b
899
                t_add = self.builder.add_tasklet(
4✔
900
                    block, TaskletCode.int_add, ["_in1", "_in2"], ["_out"]
901
                )
902
                self.builder.add_memlet(block, t_rem1_out, "void", t_add, "_in1", "")
4✔
903
                self.builder.add_memlet(
4✔
904
                    block, t_right, "void", t_add, "_in2", right_sub
905
                )
906

907
                add_name = f"_tmp_{self._get_unique_id()}"
4✔
908
                self.builder.add_container(add_name, dtype, False)
4✔
909
                t_add_out = self.builder.add_access(block, add_name)
4✔
910
                self.builder.add_memlet(block, t_add, "_out", t_add_out, "void", "")
4✔
911

912
                # 3. res = add % b
913
                t_rem2 = self.builder.add_tasklet(
4✔
914
                    block, TaskletCode.int_srem, ["_in1", "_in2"], ["_out"]
915
                )
916
                self.builder.add_memlet(block, t_add_out, "void", t_rem2, "_in1", "")
4✔
917
                self.builder.add_memlet(
4✔
918
                    block, t_right, "void", t_rem2, "_in2", right_sub
919
                )
920
                self.builder.add_memlet(block, t_rem2, "_out", t_out, "void", "")
4✔
921

922
                return tmp_name
4✔
923
            else:
924
                # Python's floored modulo: a % b = a - floor(a / b) * b
925
                # This differs from fmod which uses trunc instead of floor
926
                # Implement as: fmod(fmod(a, b) + b, b) to handle negative values
927

928
                # 1. rem1 = fmod(a, b)
929
                t_rem1 = self.builder.add_tasklet(
4✔
930
                    block, TaskletCode.fp_rem, ["_in1", "_in2"], ["_out"]
931
                )
932
                self.builder.add_memlet(block, t_left, "void", t_rem1, "_in1", left_sub)
4✔
933
                self.builder.add_memlet(
4✔
934
                    block, t_right, "void", t_rem1, "_in2", right_sub
935
                )
936

937
                rem1_name = f"_tmp_{self._get_unique_id()}"
4✔
938
                self.builder.add_container(rem1_name, dtype, False)
4✔
939
                t_rem1_out = self.builder.add_access(block, rem1_name)
4✔
940
                self.builder.add_memlet(block, t_rem1, "_out", t_rem1_out, "void", "")
4✔
941

942
                # 2. add = rem1 + b
943
                t_add = self.builder.add_tasklet(
4✔
944
                    block, TaskletCode.fp_add, ["_in1", "_in2"], ["_out"]
945
                )
946
                self.builder.add_memlet(block, t_rem1_out, "void", t_add, "_in1", "")
4✔
947
                self.builder.add_memlet(
4✔
948
                    block, t_right, "void", t_add, "_in2", right_sub
949
                )
950

951
                add_name = f"_tmp_{self._get_unique_id()}"
4✔
952
                self.builder.add_container(add_name, dtype, False)
4✔
953
                t_add_out = self.builder.add_access(block, add_name)
4✔
954
                self.builder.add_memlet(block, t_add, "_out", t_add_out, "void", "")
4✔
955

956
                # 3. res = fmod(add, b)
957
                t_rem2 = self.builder.add_tasklet(
4✔
958
                    block, TaskletCode.fp_rem, ["_in1", "_in2"], ["_out"]
959
                )
960
                self.builder.add_memlet(block, t_add_out, "void", t_rem2, "_in1", "")
4✔
961
                self.builder.add_memlet(
4✔
962
                    block, t_right, "void", t_rem2, "_in2", right_sub
963
                )
964
                self.builder.add_memlet(block, t_rem2, "_out", t_out, "void", "")
4✔
965

966
                return tmp_name
4✔
967

968
        tasklet_code = None
4✔
969
        if dtype.primitive_type == PrimitiveType.Int64:
4✔
970
            if op == "+":
4✔
971
                tasklet_code = TaskletCode.int_add
4✔
972
            elif op == "-":
4✔
973
                tasklet_code = TaskletCode.int_sub
4✔
974
            elif op == "*":
4✔
975
                tasklet_code = TaskletCode.int_mul
4✔
976
            elif op == "/":
4✔
977
                tasklet_code = TaskletCode.int_sdiv
×
978
            elif op == "//":
4✔
979
                tasklet_code = TaskletCode.int_sdiv
4✔
980
            elif op == "&":
4✔
NEW
981
                tasklet_code = TaskletCode.int_and
×
982
            elif op == "|":
4✔
983
                tasklet_code = TaskletCode.int_or
4✔
984
            elif op == "^":
4✔
985
                tasklet_code = TaskletCode.int_xor
4✔
NEW
986
            elif op == "<<":
×
NEW
987
                tasklet_code = TaskletCode.int_shl
×
NEW
988
            elif op == ">>":
×
NEW
989
                tasklet_code = TaskletCode.int_lshr
×
990
        else:
991
            if op == "+":
4✔
992
                tasklet_code = TaskletCode.fp_add
4✔
993
            elif op == "-":
4✔
994
                tasklet_code = TaskletCode.fp_sub
4✔
995
            elif op == "*":
4✔
996
                tasklet_code = TaskletCode.fp_mul
4✔
997
            elif op == "/":
4✔
998
                tasklet_code = TaskletCode.fp_div
4✔
999
            elif op == "//":
×
1000
                tasklet_code = TaskletCode.fp_div
×
1001
            else:
1002
                raise NotImplementedError(f"Operation {op} not supported for floats")
×
1003

1004
        block = self.builder.add_block()
4✔
1005
        t_left, left_sub = self._add_read(block, real_left)
4✔
1006
        t_right, right_sub = self._add_read(block, real_right)
4✔
1007
        t_out = self.builder.add_access(block, tmp_name)
4✔
1008

1009
        t_task = self.builder.add_tasklet(
4✔
1010
            block, tasklet_code, ["_in1", "_in2"], ["_out"]
1011
        )
1012

1013
        self.builder.add_memlet(block, t_left, "void", t_task, "_in1", left_sub)
4✔
1014
        self.builder.add_memlet(block, t_right, "void", t_task, "_in2", right_sub)
4✔
1015
        self.builder.add_memlet(block, t_task, "_out", t_out, "void", "")
4✔
1016

1017
        return tmp_name
4✔
1018

1019
    def _add_assign_constant(self, target_name, value_str, dtype):
4✔
1020
        block = self.builder.add_block()
4✔
1021
        t_const = self.builder.add_constant(block, value_str, dtype)
4✔
1022
        t_dst = self.builder.add_access(block, target_name)
4✔
1023
        t_task = self.builder.add_tasklet(block, TaskletCode.assign, ["_in"], ["_out"])
4✔
1024
        self.builder.add_memlet(block, t_const, "void", t_task, "_in", "")
4✔
1025
        self.builder.add_memlet(block, t_task, "_out", t_dst, "void", "")
4✔
1026

1027
    def visit_BoolOp(self, node):
4✔
1028
        op = self.visit(node.op)
4✔
1029
        values = [f"({self.visit(v)} != 0)" for v in node.values]
4✔
1030
        expr_str = f"{f' {op} '.join(values)}"
4✔
1031

1032
        tmp_name = f"_tmp_{self._get_unique_id()}"
4✔
1033
        dtype = Scalar(PrimitiveType.Bool)
4✔
1034
        self.builder.add_container(tmp_name, dtype, False)
4✔
1035

1036
        # Use control flow to assign boolean value
1037
        self.builder.begin_if(expr_str)
4✔
1038
        self._add_assign_constant(tmp_name, "true", dtype)
4✔
1039
        self.builder.begin_else()
4✔
1040
        self._add_assign_constant(tmp_name, "false", dtype)
4✔
1041
        self.builder.end_if()
4✔
1042

1043
        self.symbol_table[tmp_name] = dtype
4✔
1044
        return tmp_name
4✔
1045

1046
    def visit_Compare(self, node):
4✔
1047
        left = self.visit(node.left)
4✔
1048
        if len(node.ops) > 1:
4✔
1049
            raise NotImplementedError("Chained comparisons not supported yet")
×
1050

1051
        op = self.visit(node.ops[0])
4✔
1052
        right = self.visit(node.comparators[0])
4✔
1053

1054
        # Check if this is an array comparison
1055
        left_is_array = left in self.array_info
4✔
1056
        right_is_array = right in self.array_info
4✔
1057

1058
        if left_is_array or right_is_array:
4✔
1059
            # Handle array comparison - return boolean array
1060
            return self._handle_array_compare(
4✔
1061
                left, op, right, left_is_array, right_is_array
1062
            )
1063

1064
        # Scalar comparison
1065
        expr_str = f"{left} {op} {right}"
4✔
1066

1067
        tmp_name = f"_tmp_{self._get_unique_id()}"
4✔
1068
        dtype = Scalar(PrimitiveType.Bool)
4✔
1069
        self.builder.add_container(tmp_name, dtype, False)
4✔
1070

1071
        # Use control flow to assign boolean value
1072
        self.builder.begin_if(expr_str)
4✔
1073
        self.builder.add_transition(tmp_name, "true")
4✔
1074
        self.builder.begin_else()
4✔
1075
        self.builder.add_transition(tmp_name, "false")
4✔
1076
        self.builder.end_if()
4✔
1077

1078
        self.symbol_table[tmp_name] = dtype
4✔
1079
        return tmp_name
4✔
1080

1081
    def visit_UnaryOp(self, node):
4✔
1082
        if (
4✔
1083
            isinstance(node.op, ast.USub)
1084
            and isinstance(node.operand, ast.Constant)
1085
            and isinstance(node.operand.value, (int, float))
1086
        ):
1087
            return f"-{node.operand.value}"
4✔
1088

1089
        op = self.visit(node.op)
4✔
1090
        operand = self.visit(node.operand)
4✔
1091

1092
        # Check if operand is an array - handle as array operation
1093
        if operand in self.array_info and op == "-":
4✔
1094
            return self._handle_array_negate(operand)
4✔
1095

1096
        tmp_name = f"_tmp_{self._get_unique_id()}"
4✔
1097
        dtype = Scalar(PrimitiveType.Double)
4✔
1098
        if operand in self.symbol_table:
4✔
1099
            dtype = self.symbol_table[operand]
4✔
1100
            # If it's a pointer (array), get the element type
1101
            if isinstance(dtype, Pointer) and dtype.has_pointee_type():
4✔
1102
                dtype = dtype.pointee_type
×
1103
        elif self._is_int(operand):
×
1104
            dtype = Scalar(PrimitiveType.Int64)
×
1105
        elif isinstance(node.op, ast.Not):
×
1106
            dtype = Scalar(PrimitiveType.Bool)
×
1107

1108
        self.builder.add_container(tmp_name, dtype, False)
4✔
1109
        self.symbol_table[tmp_name] = dtype
4✔
1110

1111
        block = self.builder.add_block()
4✔
1112
        t_src, src_sub = self._add_read(block, operand)
4✔
1113
        t_dst = self.builder.add_access(block, tmp_name)
4✔
1114

1115
        if isinstance(node.op, ast.Not):
4✔
1116
            t_const = self.builder.add_constant(
4✔
1117
                block, "true", Scalar(PrimitiveType.Bool)
1118
            )
1119
            t_task = self.builder.add_tasklet(
4✔
1120
                block, TaskletCode.int_xor, ["_in1", "_in2"], ["_out"]
1121
            )
1122
            self.builder.add_memlet(block, t_src, "void", t_task, "_in1", src_sub)
4✔
1123
            self.builder.add_memlet(block, t_const, "void", t_task, "_in2", "")
4✔
1124
            self.builder.add_memlet(block, t_task, "_out", t_dst, "void", "")
4✔
1125

1126
        elif op == "-":
4✔
1127
            if dtype.primitive_type == PrimitiveType.Int64:
4✔
1128
                t_const = self.builder.add_constant(block, "0", dtype)
4✔
1129
                t_task = self.builder.add_tasklet(
4✔
1130
                    block, TaskletCode.int_sub, ["_in1", "_in2"], ["_out"]
1131
                )
1132
                self.builder.add_memlet(block, t_const, "void", t_task, "_in1", "")
4✔
1133
                self.builder.add_memlet(block, t_src, "void", t_task, "_in2", src_sub)
4✔
1134
                self.builder.add_memlet(block, t_task, "_out", t_dst, "void", "")
4✔
1135
            else:
1136
                t_task = self.builder.add_tasklet(
4✔
1137
                    block, TaskletCode.fp_neg, ["_in"], ["_out"]
1138
                )
1139
                self.builder.add_memlet(block, t_src, "void", t_task, "_in", src_sub)
4✔
1140
                self.builder.add_memlet(block, t_task, "_out", t_dst, "void", "")
4✔
1141

NEW
1142
        elif op == "~":
×
1143
            # Bitwise NOT: ~x = x XOR -1 (all bits set)
NEW
1144
            t_const = self.builder.add_constant(
×
1145
                block, "-1", Scalar(PrimitiveType.Int64)
1146
            )
NEW
1147
            t_task = self.builder.add_tasklet(
×
1148
                block, TaskletCode.int_xor, ["_in1", "_in2"], ["_out"]
1149
            )
NEW
1150
            self.builder.add_memlet(block, t_src, "void", t_task, "_in1", src_sub)
×
NEW
1151
            self.builder.add_memlet(block, t_const, "void", t_task, "_in2", "")
×
NEW
1152
            self.builder.add_memlet(block, t_task, "_out", t_dst, "void", "")
×
1153

1154
        else:
1155
            t_task = self.builder.add_tasklet(
×
1156
                block, TaskletCode.assign, ["_in"], ["_out"]
1157
            )
1158
            self.builder.add_memlet(block, t_src, "void", t_task, "_in", src_sub)
×
1159
            self.builder.add_memlet(block, t_task, "_out", t_dst, "void", "")
×
1160

1161
        return tmp_name
4✔
1162

1163
    def _handle_array_negate(self, operand):
4✔
1164
        """Handle negation of an array operand (-arr)."""
1165
        shape = self.array_info[operand]["shapes"]
4✔
1166
        dtype = self._get_dtype(operand)
4✔
1167

1168
        # Create output array
1169
        tmp_name = self._create_array_temp(shape, dtype)
4✔
1170

1171
        # Use elementwise binary op: 0 - arr
1172
        # First create a zero constant
1173
        zero_name = f"_tmp_{self._get_unique_id()}"
4✔
1174
        self.builder.add_container(zero_name, dtype, False)
4✔
1175
        self.symbol_table[zero_name] = dtype
4✔
1176

1177
        zero_block = self.builder.add_block()
4✔
1178
        t_const = self.builder.add_constant(
4✔
1179
            zero_block,
1180
            "0.0" if dtype.primitive_type == PrimitiveType.Double else "0",
1181
            dtype,
1182
        )
1183
        t_zero = self.builder.add_access(zero_block, zero_name)
4✔
1184
        t_assign = self.builder.add_tasklet(
4✔
1185
            zero_block, TaskletCode.assign, ["_in"], ["_out"]
1186
        )
1187
        self.builder.add_memlet(zero_block, t_const, "void", t_assign, "_in", "")
4✔
1188
        self.builder.add_memlet(zero_block, t_assign, "_out", t_zero, "void", "")
4✔
1189

1190
        # Now subtract: tmp = 0 - operand (broadcast scalar subtraction)
1191
        self.builder.add_elementwise_op("sub", zero_name, operand, tmp_name, shape)
4✔
1192

1193
        return tmp_name
4✔
1194

1195
    def _handle_array_compare(self, left, op, right, left_is_array, right_is_array):
4✔
1196
        """Handle elementwise comparison of arrays, returning a boolean array.
1197

1198
        Supports: arr > 0, arr < scalar, arr1 > arr2, etc.
1199
        """
1200
        # Determine shape from the array operand
1201
        if left_is_array:
4✔
1202
            shape = self.array_info[left]["shapes"]
4✔
1203
            arr_name = left
4✔
1204
        else:
1205
            shape = self.array_info[right]["shapes"]
×
1206
            arr_name = right
×
1207

1208
        # Determine if we need integer or floating point comparison
1209
        # based on the array element type
1210
        use_int_cmp = False
4✔
1211
        arr_dtype = self._get_dtype(arr_name)
4✔
1212
        if arr_dtype.primitive_type in (PrimitiveType.Int32, PrimitiveType.Int64):
4✔
1213
            use_int_cmp = True
×
1214

1215
        # Create output boolean array
1216
        dtype = Scalar(PrimitiveType.Bool)
4✔
1217
        tmp_name = self._create_array_temp(shape, dtype)
4✔
1218

1219
        # Map comparison operators to tasklet codes
1220
        if use_int_cmp:
4✔
1221
            cmp_ops = {
×
1222
                ">": TaskletCode.int_sgt,
1223
                ">=": TaskletCode.int_sge,
1224
                "<": TaskletCode.int_slt,
1225
                "<=": TaskletCode.int_sle,
1226
                "==": TaskletCode.int_eq,
1227
                "!=": TaskletCode.int_ne,
1228
            }
1229
        else:
1230
            # Floating point ordered comparisons
1231
            cmp_ops = {
4✔
1232
                ">": TaskletCode.fp_ogt,
1233
                ">=": TaskletCode.fp_oge,
1234
                "<": TaskletCode.fp_olt,
1235
                "<=": TaskletCode.fp_ole,
1236
                "==": TaskletCode.fp_oeq,
1237
                "!=": TaskletCode.fp_one,
1238
            }
1239

1240
        if op not in cmp_ops:
4✔
1241
            raise NotImplementedError(
×
1242
                f"Comparison operator {op} not supported for arrays"
1243
            )
1244

1245
        tasklet_code = cmp_ops[op]
4✔
1246

1247
        # For scalar operand, we may need to convert integer to float
1248
        # Create a float constant if needed
1249
        scalar_name = None
4✔
1250
        if not left_is_array:
4✔
1251
            scalar_name = left
×
1252
        elif not right_is_array:
4✔
1253
            scalar_name = right
4✔
1254

1255
        if scalar_name is not None and not use_int_cmp:
4✔
1256
            # Check if scalar is an integer literal and convert to float
1257
            if self._is_int(scalar_name):
4✔
1258
                # Create a float constant
1259
                float_name = f"_tmp_{self._get_unique_id()}"
4✔
1260
                self.builder.add_container(
4✔
1261
                    float_name, Scalar(PrimitiveType.Double), False
1262
                )
1263
                self.symbol_table[float_name] = Scalar(PrimitiveType.Double)
4✔
1264

1265
                block_conv = self.builder.add_block()
4✔
1266
                t_const = self.builder.add_constant(
4✔
1267
                    block_conv, f"{scalar_name}.0", Scalar(PrimitiveType.Double)
1268
                )
1269
                t_float = self.builder.add_access(block_conv, float_name)
4✔
1270
                t_assign = self.builder.add_tasklet(
4✔
1271
                    block_conv, TaskletCode.assign, ["_in"], ["_out"]
1272
                )
1273
                self.builder.add_memlet(
4✔
1274
                    block_conv, t_const, "void", t_assign, "_in", ""
1275
                )
1276
                self.builder.add_memlet(
4✔
1277
                    block_conv, t_assign, "_out", t_float, "void", ""
1278
                )
1279

1280
                # Replace the scalar name with the converted float
1281
                if not left_is_array:
4✔
1282
                    left = float_name
×
1283
                else:
1284
                    right = float_name
4✔
1285

1286
        # Generate nested loops
1287
        loop_vars = []
4✔
1288
        for i, dim in enumerate(shape):
4✔
1289
            loop_var = f"_cmp_i{i}_{self._get_unique_id()}"
4✔
1290
            if not self.builder.exists(loop_var):
4✔
1291
                self.builder.add_container(loop_var, Scalar(PrimitiveType.Int64), False)
4✔
1292
                self.symbol_table[loop_var] = Scalar(PrimitiveType.Int64)
4✔
1293
            loop_vars.append(loop_var)
4✔
1294
            self.builder.begin_for(loop_var, "0", str(dim), "1")
4✔
1295

1296
        # Compute linear index for array access
1297
        linear_idx = self._compute_linear_index(loop_vars, shape, tmp_name, len(shape))
4✔
1298

1299
        # Create comparison block
1300
        block = self.builder.add_block()
4✔
1301

1302
        # Read left operand
1303
        if left_is_array:
4✔
1304
            t_left = self.builder.add_access(block, left)
4✔
1305
            left_sub = linear_idx
4✔
1306
        else:
1307
            t_left, left_sub = self._add_read(block, left)
×
1308

1309
        # Read right operand
1310
        if right_is_array:
4✔
1311
            t_right = self.builder.add_access(block, right)
×
1312
            right_sub = linear_idx
×
1313
        else:
1314
            t_right, right_sub = self._add_read(block, right)
4✔
1315

1316
        # Output access
1317
        t_out = self.builder.add_access(block, tmp_name)
4✔
1318

1319
        # Create tasklet for comparison
1320
        t_task = self.builder.add_tasklet(
4✔
1321
            block, tasklet_code, ["_in1", "_in2"], ["_out"]
1322
        )
1323

1324
        # Connect memlets
1325
        self.builder.add_memlet(block, t_left, "void", t_task, "_in1", left_sub)
4✔
1326
        self.builder.add_memlet(block, t_right, "void", t_task, "_in2", right_sub)
4✔
1327
        self.builder.add_memlet(block, t_task, "_out", t_out, "void", linear_idx)
4✔
1328

1329
        # Close loops
1330
        for _ in loop_vars:
4✔
1331
            self.builder.end_for()
4✔
1332

1333
        return tmp_name
4✔
1334

1335
    def _parse_array_arg(self, node, simple_visitor):
4✔
1336
        if isinstance(node, ast.Name):
×
1337
            if node.id in self.array_info:
×
1338
                return node.id, [], self.array_info[node.id]["shapes"]
×
1339
        elif isinstance(node, ast.Subscript):
×
1340
            if isinstance(node.value, ast.Name) and node.value.id in self.array_info:
×
1341
                name = node.value.id
×
1342
                ndim = self.array_info[name]["ndim"]
×
1343

1344
                indices = []
×
1345
                if isinstance(node.slice, ast.Tuple):
×
1346
                    indices = list(node.slice.elts)
×
1347
                else:
1348
                    indices = [node.slice]
×
1349

1350
                while len(indices) < ndim:
×
1351
                    indices.append(ast.Slice(lower=None, upper=None, step=None))
×
1352

1353
                start_indices = []
×
1354
                slice_shape = []
×
1355

1356
                for i, idx in enumerate(indices):
×
1357
                    if isinstance(idx, ast.Slice):
×
1358
                        start = "0"
×
1359
                        if idx.lower:
×
1360
                            start = simple_visitor.visit(idx.lower)
×
1361
                        start_indices.append(start)
×
1362

1363
                        shapes = self.array_info[name]["shapes"]
×
1364
                        dim_size = (
×
1365
                            shapes[i] if i < len(shapes) else f"_{name}_shape_{i}"
1366
                        )
1367
                        stop = dim_size
×
1368
                        if idx.upper:
×
1369
                            stop = simple_visitor.visit(idx.upper)
×
1370

1371
                        size = f"({stop} - {start})"
×
1372
                        slice_shape.append(size)
×
1373
                    else:
1374
                        val = simple_visitor.visit(idx)
×
1375
                        start_indices.append(val)
×
1376

1377
                shapes = self.array_info[name]["shapes"]
×
1378
                linear_index = ""
×
1379
                for i in range(ndim):
×
1380
                    term = start_indices[i]
×
1381
                    for j in range(i + 1, ndim):
×
1382
                        shape_val = shapes[j] if j < len(shapes) else None
×
1383
                        shape_sym = (
×
1384
                            shape_val if shape_val is not None else f"_{name}_shape_{j}"
1385
                        )
1386
                        term = f"({term} * {shape_sym})"
×
1387

1388
                    if i == 0:
×
1389
                        linear_index = term
×
1390
                    else:
1391
                        linear_index = f"({linear_index} + {term})"
×
1392

1393
                return name, [linear_index], slice_shape
×
1394

1395
        return None, None, None
×
1396

1397
    def visit_Attribute(self, node):
4✔
1398
        if node.attr == "shape":
4✔
1399
            if isinstance(node.value, ast.Name) and node.value.id in self.array_info:
4✔
1400
                return f"_shape_proxy_{node.value.id}"
4✔
1401

1402
        if isinstance(node.value, ast.Name) and node.value.id == "math":
4✔
1403
            val = ""
4✔
1404
            if node.attr == "pi":
4✔
1405
                val = "M_PI"
4✔
1406
            elif node.attr == "e":
4✔
1407
                val = "M_E"
4✔
1408

1409
            if val:
4✔
1410
                tmp_name = f"_tmp_{self._get_unique_id()}"
4✔
1411
                dtype = Scalar(PrimitiveType.Double)
4✔
1412
                self.builder.add_container(tmp_name, dtype, False)
4✔
1413
                self.symbol_table[tmp_name] = dtype
4✔
1414
                self._add_assign_constant(tmp_name, val, dtype)
4✔
1415
                return tmp_name
4✔
1416

1417
        # Handle class member access (e.g., obj.x, obj.y)
1418
        if isinstance(node.value, ast.Name):
4✔
1419
            obj_name = node.value.id
4✔
1420
            attr_name = node.attr
4✔
1421

1422
            # Check if the object is a class instance (has a Structure type)
1423
            if obj_name in self.symbol_table:
4✔
1424
                obj_type = self.symbol_table[obj_name]
4✔
1425
                if isinstance(obj_type, Pointer) and obj_type.has_pointee_type():
4✔
1426
                    pointee_type = obj_type.pointee_type
4✔
1427
                    if isinstance(pointee_type, Structure):
4✔
1428
                        struct_name = pointee_type.name
4✔
1429

1430
                        # Look up member index and type from structure info
1431
                        if (
4✔
1432
                            struct_name in self.structure_member_info
1433
                            and attr_name in self.structure_member_info[struct_name]
1434
                        ):
1435
                            member_index, member_type = self.structure_member_info[
4✔
1436
                                struct_name
1437
                            ][attr_name]
1438
                        else:
1439
                            # This should not happen if structure was registered properly
1440
                            raise RuntimeError(
×
1441
                                f"Member '{attr_name}' not found in structure '{struct_name}'. "
1442
                                f"Available members: {list(self.structure_member_info.get(struct_name, {}).keys())}"
1443
                            )
1444

1445
                        # Generate a tasklet to access the member
1446
                        tmp_name = f"_tmp_{self._get_unique_id()}"
4✔
1447

1448
                        self.builder.add_container(tmp_name, member_type, False)
4✔
1449
                        self.symbol_table[tmp_name] = member_type
4✔
1450

1451
                        # Create a tasklet that reads the member
1452
                        block = self.builder.add_block()
4✔
1453
                        obj_access = self.builder.add_access(block, obj_name)
4✔
1454
                        tmp_access = self.builder.add_access(block, tmp_name)
4✔
1455

1456
                        # Use tasklet to pass through the value
1457
                        # The actual member selection is done via the memlet subset
1458
                        tasklet = self.builder.add_tasklet(
4✔
1459
                            block, TaskletCode.assign, ["_in"], ["_out"]
1460
                        )
1461

1462
                        # Use member index in the subset to select the correct member
1463
                        subset = "0," + str(member_index)
4✔
1464
                        self.builder.add_memlet(
4✔
1465
                            block, obj_access, "", tasklet, "_in", subset
1466
                        )
1467
                        self.builder.add_memlet(block, tasklet, "_out", tmp_access, "")
4✔
1468

1469
                        return tmp_name
4✔
1470

1471
        raise NotImplementedError(f"Attribute access {node.attr} not supported")
×
1472

1473
    def _handle_expression_slicing(self, node, value_str, indices_nodes, shapes, ndim):
4✔
1474
        """Handle slicing in expressions (e.g., arr[1:, :, k+1]).
1475

1476
        Creates a temporary array, generates loops to copy sliced data,
1477
        and returns the temporary array name.
1478
        """
1479
        if not self.builder:
4✔
1480
            raise ValueError("Builder required for expression slicing")
×
1481

1482
        # Determine element type from source array
1483
        dtype = Scalar(PrimitiveType.Double)
4✔
1484
        if value_str in self.symbol_table:
4✔
1485
            t = self.symbol_table[value_str]
4✔
1486
            if isinstance(t, Pointer) and t.has_pointee_type():
4✔
1487
                dtype = t.pointee_type
4✔
1488

1489
        # Analyze each dimension: is it a slice or an index?
1490
        # For slices, compute the resulting shape dimension
1491
        # For indices, that dimension is collapsed
1492
        result_shapes = []  # Shape of the resulting array (for SDFG)
4✔
1493
        result_shapes_runtime = []  # Shape expressions for runtime evaluation
4✔
1494
        slice_info = []  # List of (dim_idx, start_str, stop_str, step_str) for slices
4✔
1495
        index_info = []  # List of (dim_idx, index_str) for point indices
4✔
1496

1497
        for i, idx in enumerate(indices_nodes):
4✔
1498
            shape_val = shapes[i] if i < len(shapes) else f"_{value_str}_shape_{i}"
4✔
1499

1500
            if isinstance(idx, ast.Slice):
4✔
1501
                # Parse slice bounds - check for indirect access patterns
1502
                start_str = "0"
4✔
1503
                start_str_runtime = "0"  # For runtime shape evaluation
4✔
1504
                if idx.lower is not None:
4✔
1505
                    # Check if lower bound contains indirect array access
1506
                    if self._contains_indirect_access(idx.lower):
4✔
1507
                        start_str, start_str_runtime = (
4✔
1508
                            self._materialize_indirect_access(
1509
                                idx.lower, return_original_expr=True
1510
                            )
1511
                        )
1512
                    else:
1513
                        start_str = self.visit(idx.lower)
4✔
1514
                        start_str_runtime = start_str
4✔
1515
                    # Handle negative indices
1516
                    if isinstance(start_str, str) and (
4✔
1517
                        start_str.startswith("-") or start_str.startswith("(-")
1518
                    ):
1519
                        start_str = f"({shape_val} + {start_str})"
×
1520
                        start_str_runtime = f"({shape_val} + {start_str_runtime})"
×
1521

1522
                stop_str = str(shape_val)
4✔
1523
                stop_str_runtime = str(shape_val)
4✔
1524
                if idx.upper is not None:
4✔
1525
                    # Check if upper bound contains indirect array access
1526
                    if self._contains_indirect_access(idx.upper):
4✔
1527
                        stop_str, stop_str_runtime = self._materialize_indirect_access(
4✔
1528
                            idx.upper, return_original_expr=True
1529
                        )
1530
                    else:
1531
                        stop_str = self.visit(idx.upper)
4✔
1532
                        stop_str_runtime = stop_str
4✔
1533
                    # Handle negative indices
1534
                    if isinstance(stop_str, str) and (
4✔
1535
                        stop_str.startswith("-") or stop_str.startswith("(-")
1536
                    ):
1537
                        stop_str = f"({shape_val} + {stop_str})"
4✔
1538
                        stop_str_runtime = f"({shape_val} + {stop_str_runtime})"
4✔
1539

1540
                step_str = "1"
4✔
1541
                if idx.step is not None:
4✔
1542
                    step_str = self.visit(idx.step)
×
1543

1544
                # Compute the size of this dimension in the result
1545
                dim_size = f"({stop_str} - {start_str})"
4✔
1546
                dim_size_runtime = f"({stop_str_runtime} - {start_str_runtime})"
4✔
1547
                result_shapes.append(dim_size)
4✔
1548
                result_shapes_runtime.append(dim_size_runtime)
4✔
1549
                slice_info.append((i, start_str, stop_str, step_str))
4✔
1550
            else:
1551
                # Point index - dimension is collapsed
1552
                # Check for indirect array access in the index
1553
                if self._contains_indirect_access(idx):
4✔
1554
                    index_str = self._materialize_indirect_access(idx)
×
1555
                else:
1556
                    index_str = self.visit(idx)
4✔
1557
                # Handle negative indices
1558
                if isinstance(index_str, str) and (
4✔
1559
                    index_str.startswith("-") or index_str.startswith("(-")
1560
                ):
1561
                    index_str = f"({shape_val} + {index_str})"
×
1562
                index_info.append((i, index_str))
4✔
1563

1564
        # Create temporary array for the result
1565
        tmp_name = self._get_temp_name("_slice_tmp_")
4✔
1566
        result_ndim = len(result_shapes)
4✔
1567

1568
        if result_ndim == 0:
4✔
1569
            # All dimensions indexed - result is a scalar
1570
            self.builder.add_container(tmp_name, dtype, False)
×
1571
            self.symbol_table[tmp_name] = dtype
×
1572
        else:
1573
            # Result is an array - use _create_array_temp to handle allocation
1574
            # Calculate size for malloc - use SDFG symbolic shapes
1575
            size_str = "1"
4✔
1576
            for dim in result_shapes:
4✔
1577
                size_str = f"({size_str} * {dim})"
4✔
1578

1579
            element_size = self.builder.get_sizeof(dtype)
4✔
1580
            total_size = f"({size_str} * {element_size})"
4✔
1581

1582
            # Create pointer
1583
            ptr_type = Pointer(dtype)
4✔
1584
            self.builder.add_container(tmp_name, ptr_type, False)
4✔
1585
            self.symbol_table[tmp_name] = ptr_type
4✔
1586
            # Store both SDFG shapes (for compilation) and runtime shapes (for evaluation)
1587
            # The "shapes" field uses SDFG symbolic variables for malloc sizing
1588
            # The "shapes_runtime" field uses original expressions for Python runtime evaluation
1589
            self.array_info[tmp_name] = {
4✔
1590
                "ndim": result_ndim,
1591
                "shapes": result_shapes,  # Uses materialized variables for SDFG
1592
                "shapes_runtime": result_shapes_runtime,  # Uses original expressions for runtime
1593
            }
1594

1595
            # Malloc for the temporary array
1596
            debug_info = DebugInfo()
4✔
1597
            block_alloc = self.builder.add_block(debug_info)
4✔
1598
            t_malloc = self.builder.add_malloc(block_alloc, total_size)
4✔
1599
            t_ptr = self.builder.add_access(block_alloc, tmp_name, debug_info)
4✔
1600
            self.builder.add_memlet(
4✔
1601
                block_alloc, t_malloc, "_ret", t_ptr, "void", "", ptr_type, debug_info
1602
            )
1603

1604
        # Generate loops to copy the sliced data
1605
        loop_vars = []
4✔
1606
        debug_info = DebugInfo()
4✔
1607

1608
        for dim_idx, (orig_dim, start_str, stop_str, step_str) in enumerate(slice_info):
4✔
1609
            loop_var = f"_slice_loop_{dim_idx}_{self._get_unique_id()}"
4✔
1610
            loop_vars.append((loop_var, orig_dim, start_str, step_str))
4✔
1611

1612
            if not self.builder.exists(loop_var):
4✔
1613
                self.builder.add_container(loop_var, Scalar(PrimitiveType.Int64), False)
4✔
1614
                self.symbol_table[loop_var] = Scalar(PrimitiveType.Int64)
4✔
1615

1616
            # Loop from 0 to (stop - start)
1617
            count_str = f"({stop_str} - {start_str})"
4✔
1618
            self.builder.begin_for(loop_var, "0", count_str, "1", debug_info)
4✔
1619

1620
        # Build source and destination indices
1621
        src_indices = [""] * ndim
4✔
1622
        dst_indices = []
4✔
1623

1624
        # Fill in point indices for source
1625
        for orig_dim, index_str in index_info:
4✔
1626
            src_indices[orig_dim] = index_str
4✔
1627

1628
        # Fill in slice indices for source and build destination indices
1629
        for loop_var, orig_dim, start_str, step_str in loop_vars:
4✔
1630
            if step_str == "1":
4✔
1631
                src_indices[orig_dim] = f"({start_str} + {loop_var})"
4✔
1632
            else:
1633
                src_indices[orig_dim] = f"({start_str} + {loop_var} * {step_str})"
×
1634
            dst_indices.append(loop_var)
4✔
1635

1636
        # Compute linear indices
1637
        src_linear = self._compute_linear_index(src_indices, shapes, value_str, ndim)
4✔
1638
        if result_ndim > 0:
4✔
1639
            dst_linear = self._compute_linear_index(
4✔
1640
                dst_indices, result_shapes, tmp_name, result_ndim
1641
            )
1642
        else:
1643
            dst_linear = "0"
×
1644

1645
        # Create the copy block
1646
        block = self.builder.add_block(debug_info)
4✔
1647
        t_src = self.builder.add_access(block, value_str, debug_info)
4✔
1648
        t_dst = self.builder.add_access(block, tmp_name, debug_info)
4✔
1649
        t_task = self.builder.add_tasklet(
4✔
1650
            block, TaskletCode.assign, ["_in"], ["_out"], debug_info
1651
        )
1652

1653
        self.builder.add_memlet(
4✔
1654
            block, t_src, "void", t_task, "_in", src_linear, None, debug_info
1655
        )
1656
        self.builder.add_memlet(
4✔
1657
            block, t_task, "_out", t_dst, "void", dst_linear, None, debug_info
1658
        )
1659

1660
        # Close all loops
1661
        for _ in loop_vars:
4✔
1662
            self.builder.end_for()
4✔
1663

1664
        return tmp_name
4✔
1665

1666
    def _compute_linear_index(self, indices, shapes, array_name, ndim):
4✔
1667
        """Compute linear index from multi-dimensional indices."""
1668
        if ndim == 0:
4✔
1669
            return "0"
×
1670

1671
        linear_index = ""
4✔
1672
        for i in range(ndim):
4✔
1673
            term = str(indices[i])
4✔
1674
            for j in range(i + 1, ndim):
4✔
1675
                shape_val = shapes[j] if j < len(shapes) else f"_{array_name}_shape_{j}"
4✔
1676
                term = f"(({term}) * {shape_val})"
4✔
1677

1678
            if i == 0:
4✔
1679
                linear_index = term
4✔
1680
            else:
1681
                linear_index = f"({linear_index} + {term})"
4✔
1682

1683
        return linear_index
4✔
1684

1685
    def _is_array_index(self, node):
4✔
1686
        """Check if a node represents an array that could be used as an index (gather).
1687

1688
        Returns True if the node is a Name referring to an array in array_info.
1689
        """
1690
        if isinstance(node, ast.Name):
4✔
1691
            return node.id in self.array_info
4✔
1692
        return False
4✔
1693

1694
    def _handle_gather(self, value_str, index_node, debug_info=None):
4✔
1695
        """Handle gather operation: x[indices] where indices is an array.
1696

1697
        Creates a temporary array and generates a loop to gather elements
1698
        from the source array using the index array.
1699

1700
        This is the canonical SDFG pattern for gather operations:
1701
        - Create a loop over the index array
1702
        - Load the index value using a tasklet+memlets
1703
        - Use that index in the memlet subset for the source array
1704
        """
1705
        if debug_info is None:
4✔
1706
            debug_info = DebugInfo()
4✔
1707

1708
        # Get the index array name
1709
        if isinstance(index_node, ast.Name):
4✔
1710
            idx_array_name = index_node.id
4✔
1711
        else:
1712
            # Visit the index to get its name (handles slices like cols)
1713
            idx_array_name = self.visit(index_node)
×
1714

1715
        if idx_array_name not in self.array_info:
4✔
1716
            raise ValueError(f"Gather index must be an array, got {idx_array_name}")
×
1717

1718
        # Get shapes
1719
        idx_shapes = self.array_info[idx_array_name].get("shapes", [])
4✔
1720
        src_ndim = self.array_info[value_str]["ndim"]
4✔
1721
        idx_ndim = self.array_info[idx_array_name]["ndim"]
4✔
1722

1723
        if idx_ndim != 1:
4✔
1724
            raise NotImplementedError("Only 1D index arrays supported for gather")
×
1725

1726
        # Result array has same shape as index array
1727
        result_shape = idx_shapes[0] if idx_shapes else f"_{idx_array_name}_shape_0"
4✔
1728

1729
        # Determine element type from source array
1730
        dtype = Scalar(PrimitiveType.Double)
4✔
1731
        if value_str in self.symbol_table:
4✔
1732
            t = self.symbol_table[value_str]
4✔
1733
            if isinstance(t, Pointer) and t.has_pointee_type():
4✔
1734
                dtype = t.pointee_type
4✔
1735

1736
        # Determine index type from index array
1737
        idx_dtype = Scalar(PrimitiveType.Int64)
4✔
1738
        if idx_array_name in self.symbol_table:
4✔
1739
            t = self.symbol_table[idx_array_name]
4✔
1740
            if isinstance(t, Pointer) and t.has_pointee_type():
4✔
1741
                idx_dtype = t.pointee_type
4✔
1742

1743
        # Create result array
1744
        tmp_name = self._get_temp_name("_gather_")
4✔
1745

1746
        # Calculate size for malloc
1747
        element_size = self.builder.get_sizeof(dtype)
4✔
1748
        total_size = f"({result_shape} * {element_size})"
4✔
1749

1750
        # Create pointer for result
1751
        ptr_type = Pointer(dtype)
4✔
1752
        self.builder.add_container(tmp_name, ptr_type, False)
4✔
1753
        self.symbol_table[tmp_name] = ptr_type
4✔
1754
        self.array_info[tmp_name] = {"ndim": 1, "shapes": [result_shape]}
4✔
1755

1756
        # Malloc for the result array
1757
        block_alloc = self.builder.add_block(debug_info)
4✔
1758
        t_malloc = self.builder.add_malloc(block_alloc, total_size)
4✔
1759
        t_ptr = self.builder.add_access(block_alloc, tmp_name, debug_info)
4✔
1760
        self.builder.add_memlet(
4✔
1761
            block_alloc, t_malloc, "_ret", t_ptr, "void", "", ptr_type, debug_info
1762
        )
1763

1764
        # Create loop variable
1765
        loop_var = f"_gather_i_{self._get_unique_id()}"
4✔
1766
        self.builder.add_container(loop_var, Scalar(PrimitiveType.Int64), False)
4✔
1767
        self.symbol_table[loop_var] = Scalar(PrimitiveType.Int64)
4✔
1768

1769
        # Create variable to hold the loaded index
1770
        idx_var = f"_gather_idx_{self._get_unique_id()}"
4✔
1771
        self.builder.add_container(idx_var, idx_dtype, False)
4✔
1772
        self.symbol_table[idx_var] = idx_dtype
4✔
1773

1774
        # Begin loop
1775
        self.builder.begin_for(loop_var, "0", str(result_shape), "1", debug_info)
4✔
1776

1777
        # Block 1: Load the index from index array using tasklet+memlets
1778
        block_load_idx = self.builder.add_block(debug_info)
4✔
1779
        idx_arr_access = self.builder.add_access(
4✔
1780
            block_load_idx, idx_array_name, debug_info
1781
        )
1782
        idx_var_access = self.builder.add_access(block_load_idx, idx_var, debug_info)
4✔
1783
        tasklet_load = self.builder.add_tasklet(
4✔
1784
            block_load_idx, TaskletCode.assign, ["_in"], ["_out"], debug_info
1785
        )
1786
        self.builder.add_memlet(
4✔
1787
            block_load_idx,
1788
            idx_arr_access,
1789
            "void",
1790
            tasklet_load,
1791
            "_in",
1792
            loop_var,
1793
            None,
1794
            debug_info,
1795
        )
1796
        self.builder.add_memlet(
4✔
1797
            block_load_idx,
1798
            tasklet_load,
1799
            "_out",
1800
            idx_var_access,
1801
            "void",
1802
            "",
1803
            None,
1804
            debug_info,
1805
        )
1806

1807
        # Block 2: Use the loaded index to gather from source array
1808
        block_gather = self.builder.add_block(debug_info)
4✔
1809
        src_access = self.builder.add_access(block_gather, value_str, debug_info)
4✔
1810
        dst_access = self.builder.add_access(block_gather, tmp_name, debug_info)
4✔
1811
        tasklet_gather = self.builder.add_tasklet(
4✔
1812
            block_gather, TaskletCode.assign, ["_in"], ["_out"], debug_info
1813
        )
1814

1815
        # Use the symbolic variable name (idx_var) in the memlet subset - this is key!
1816
        self.builder.add_memlet(
4✔
1817
            block_gather,
1818
            src_access,
1819
            "void",
1820
            tasklet_gather,
1821
            "_in",
1822
            idx_var,
1823
            None,
1824
            debug_info,
1825
        )
1826
        self.builder.add_memlet(
4✔
1827
            block_gather,
1828
            tasklet_gather,
1829
            "_out",
1830
            dst_access,
1831
            "void",
1832
            loop_var,
1833
            None,
1834
            debug_info,
1835
        )
1836

1837
        # End loop
1838
        self.builder.end_for()
4✔
1839

1840
        return tmp_name
4✔
1841

1842
    def visit_Subscript(self, node):
4✔
1843
        value_str = self.visit(node.value)
4✔
1844

1845
        if value_str.startswith("_shape_proxy_"):
4✔
1846
            array_name = value_str[len("_shape_proxy_") :]
4✔
1847
            if isinstance(node.slice, ast.Constant):
4✔
1848
                idx = node.slice.value
4✔
1849
            elif isinstance(node.slice, ast.Index):
×
1850
                idx = node.slice.value.value
×
1851
            else:
1852
                try:
×
1853
                    idx = int(self.visit(node.slice))
×
1854
                except:
×
1855
                    raise NotImplementedError(
×
1856
                        "Dynamic shape indexing not fully supported yet"
1857
                    )
1858

1859
            if (
4✔
1860
                array_name in self.array_info
1861
                and "shapes" in self.array_info[array_name]
1862
            ):
1863
                return self.array_info[array_name]["shapes"][idx]
4✔
1864

1865
            return f"_{array_name}_shape_{idx}"
×
1866

1867
        if value_str in self.array_info:
4✔
1868
            ndim = self.array_info[value_str]["ndim"]
4✔
1869
            shapes = self.array_info[value_str].get("shapes", [])
4✔
1870

1871
            indices = []
4✔
1872
            if isinstance(node.slice, ast.Tuple):
4✔
1873
                indices_nodes = node.slice.elts
4✔
1874
            else:
1875
                indices_nodes = [node.slice]
4✔
1876

1877
            # Check if all indices are full slices (e.g., path[:] or path[:, :])
1878
            # In this case, return just the array name since it's the full array
1879
            all_full_slices = True
4✔
1880
            for idx in indices_nodes:
4✔
1881
                if isinstance(idx, ast.Slice):
4✔
1882
                    # A full slice has no lower, upper bounds or only None
1883
                    if idx.lower is not None or idx.upper is not None:
4✔
1884
                        all_full_slices = False
4✔
1885
                        break
4✔
1886
                else:
1887
                    all_full_slices = False
4✔
1888
                    break
4✔
1889

1890
            # path[:] on an nD array returns the full array
1891
            # So if we have a single full slice, it covers all dimensions
1892
            if all_full_slices:
4✔
1893
                # This is path[:] or path[:,:] - return the array name
1894
                return value_str
4✔
1895

1896
            # Check if there are any slices in the indices
1897
            has_slices = any(isinstance(idx, ast.Slice) for idx in indices_nodes)
4✔
1898
            if has_slices:
4✔
1899
                # Handle mixed slicing (e.g., arr[1:, :, k] or arr[:-1, :, k+1])
1900
                return self._handle_expression_slicing(
4✔
1901
                    node, value_str, indices_nodes, shapes, ndim
1902
                )
1903

1904
            # Check for gather operation: x[indices_array] where indices_array is an array
1905
            # This happens when we have a 1D source array and a 1D index array
1906
            if len(indices_nodes) == 1 and self._is_array_index(indices_nodes[0]):
4✔
1907
                if self.builder:
4✔
1908
                    return self._handle_gather(value_str, indices_nodes[0])
4✔
1909

1910
            if isinstance(node.slice, ast.Tuple):
4✔
1911
                indices = [self.visit(elt) for elt in node.slice.elts]
4✔
1912
            else:
1913
                indices = [self.visit(node.slice)]
4✔
1914

1915
            if len(indices) != ndim:
4✔
1916
                raise ValueError(
×
1917
                    f"Array {value_str} has {ndim} dimensions, but accessed with {len(indices)} indices"
1918
                )
1919

1920
            # Normalize negative indices
1921
            normalized_indices = []
4✔
1922
            for i, idx_str in enumerate(indices):
4✔
1923
                shape_val = shapes[i] if i < len(shapes) else f"_{value_str}_shape_{i}"
4✔
1924
                # Check if index is negative (starts with "-" or "(-")
1925
                if isinstance(idx_str, str) and (
4✔
1926
                    idx_str.startswith("-") or idx_str.startswith("(-")
1927
                ):
1928
                    # Normalize: size + negative_index
1929
                    normalized_indices.append(f"({shape_val} + {idx_str})")
×
1930
                else:
1931
                    normalized_indices.append(idx_str)
4✔
1932

1933
            linear_index = ""
4✔
1934
            for i in range(ndim):
4✔
1935
                term = normalized_indices[i]
4✔
1936
                for j in range(i + 1, ndim):
4✔
1937
                    shape_val = shapes[j] if j < len(shapes) else None
4✔
1938
                    shape_sym = (
4✔
1939
                        shape_val
1940
                        if shape_val is not None
1941
                        else f"_{value_str}_shape_{j}"
1942
                    )
1943
                    term = f"(({term}) * {shape_sym})"
4✔
1944

1945
                if i == 0:
4✔
1946
                    linear_index = term
4✔
1947
                else:
1948
                    linear_index = f"({linear_index} + {term})"
4✔
1949

1950
            access_str = f"{value_str}({linear_index})"
4✔
1951

1952
            if self.builder and isinstance(node.ctx, ast.Load):
4✔
1953
                dtype = Scalar(PrimitiveType.Double)
4✔
1954
                if value_str in self.symbol_table:
4✔
1955
                    t = self.symbol_table[value_str]
4✔
1956
                    if type(t).__name__ == "Array" and hasattr(t, "element_type"):
4✔
1957
                        et = t.element_type
×
1958
                        if callable(et):
×
1959
                            et = et()
×
1960
                        dtype = et
×
1961
                    elif type(t).__name__ == "Pointer" and hasattr(t, "pointee_type"):
4✔
1962
                        et = t.pointee_type
4✔
1963
                        if callable(et):
4✔
1964
                            et = et()
×
1965
                        dtype = et
4✔
1966

1967
                tmp_name = f"_tmp_{self._get_unique_id()}"
4✔
1968
                self.builder.add_container(tmp_name, dtype, False)
4✔
1969

1970
                block = self.builder.add_block()
4✔
1971
                t_src = self.builder.add_access(block, value_str)
4✔
1972
                t_dst = self.builder.add_access(block, tmp_name)
4✔
1973
                t_task = self.builder.add_tasklet(
4✔
1974
                    block, TaskletCode.assign, ["_in"], ["_out"]
1975
                )
1976

1977
                self.builder.add_memlet(
4✔
1978
                    block, t_src, "void", t_task, "_in", linear_index
1979
                )
1980
                self.builder.add_memlet(block, t_task, "_out", t_dst, "void", "")
4✔
1981

1982
                self.symbol_table[tmp_name] = dtype
4✔
1983
                return tmp_name
4✔
1984

1985
            return access_str
4✔
1986

1987
        slice_val = self.visit(node.slice)
×
1988
        access_str = f"{value_str}({slice_val})"
×
1989

1990
        if (
×
1991
            self.builder
1992
            and isinstance(node.ctx, ast.Load)
1993
            and value_str in self.array_info
1994
        ):
1995
            tmp_name = f"_tmp_{self._get_unique_id()}"
×
1996
            self.builder.add_container(tmp_name, Scalar(PrimitiveType.Double), False)
×
1997
            self.builder.add_assignment(tmp_name, access_str)
×
1998
            self.symbol_table[tmp_name] = Scalar(PrimitiveType.Double)
×
1999
            return tmp_name
×
2000

2001
        return access_str
×
2002

2003
    def visit_Add(self, node):
4✔
2004
        return "+"
4✔
2005

2006
    def visit_Sub(self, node):
4✔
2007
        return "-"
4✔
2008

2009
    def visit_Mult(self, node):
4✔
2010
        return "*"
4✔
2011

2012
    def visit_Div(self, node):
4✔
2013
        return "/"
4✔
2014

2015
    def visit_FloorDiv(self, node):
4✔
2016
        return "//"
4✔
2017

2018
    def visit_Mod(self, node):
4✔
2019
        return "%"
4✔
2020

2021
    def visit_Pow(self, node):
4✔
2022
        return "**"
4✔
2023

2024
    def visit_Eq(self, node):
4✔
2025
        return "=="
×
2026

2027
    def visit_NotEq(self, node):
4✔
2028
        return "!="
×
2029

2030
    def visit_Lt(self, node):
4✔
2031
        return "<"
4✔
2032

2033
    def visit_LtE(self, node):
4✔
2034
        return "<="
×
2035

2036
    def visit_Gt(self, node):
4✔
2037
        return ">"
4✔
2038

2039
    def visit_GtE(self, node):
4✔
2040
        return ">="
×
2041

2042
    def visit_And(self, node):
4✔
2043
        return "&"
4✔
2044

2045
    def visit_Or(self, node):
4✔
2046
        return "|"
4✔
2047

2048
    def visit_BitAnd(self, node):
4✔
2049
        return "&"
×
2050

2051
    def visit_BitOr(self, node):
4✔
2052
        return "|"
4✔
2053

2054
    def visit_BitXor(self, node):
4✔
2055
        return "^"
4✔
2056

2057
    def visit_LShift(self, node):
4✔
NEW
2058
        return "<<"
×
2059

2060
    def visit_RShift(self, node):
4✔
NEW
2061
        return ">>"
×
2062

2063
    def visit_Not(self, node):
4✔
2064
        return "!"
4✔
2065

2066
    def visit_USub(self, node):
4✔
2067
        return "-"
4✔
2068

2069
    def visit_UAdd(self, node):
4✔
2070
        return "+"
×
2071

2072
    def visit_Invert(self, node):
4✔
2073
        return "~"
×
2074

2075
    def _get_dtype(self, name):
4✔
2076
        if name in self.symbol_table:
4✔
2077
            t = self.symbol_table[name]
4✔
2078
            if isinstance(t, Scalar):
4✔
2079
                return t
4✔
2080

2081
            if hasattr(t, "pointee_type"):
4✔
2082
                et = t.pointee_type
4✔
2083
                if callable(et):
4✔
2084
                    et = et()
×
2085
                if isinstance(et, Scalar):
4✔
2086
                    return et
4✔
2087

2088
            if hasattr(t, "element_type"):
×
2089
                et = t.element_type
×
2090
                if callable(et):
×
2091
                    et = et()
×
2092
                if isinstance(et, Scalar):
×
2093
                    return et
×
2094

2095
        if self._is_int(name):
4✔
2096
            return Scalar(PrimitiveType.Int64)
4✔
2097

2098
        return Scalar(PrimitiveType.Double)
4✔
2099

2100
    def _promote_dtypes(self, dtype_left, dtype_right):
4✔
2101
        """Promote two dtypes following NumPy rules: float > int, wider > narrower."""
2102
        # Priority order: Double > Float > Int64 > Int32
2103
        priority = {
4✔
2104
            PrimitiveType.Double: 4,
2105
            PrimitiveType.Float: 3,
2106
            PrimitiveType.Int64: 2,
2107
            PrimitiveType.Int32: 1,
2108
        }
2109
        left_prio = priority.get(dtype_left.primitive_type, 0)
4✔
2110
        right_prio = priority.get(dtype_right.primitive_type, 0)
4✔
2111
        if left_prio >= right_prio:
4✔
2112
            return dtype_left
4✔
2113
        else:
2114
            return dtype_right
4✔
2115

2116
    def _create_array_temp(
4✔
2117
        self, shape, dtype, zero_init=False, ones_init=False, shapes_runtime=None
2118
    ):
2119
        tmp_name = f"_tmp_{self._get_unique_id()}"
4✔
2120

2121
        # Handle 0-dimensional arrays as scalars
2122
        if not shape or (len(shape) == 0):
4✔
2123
            # 0-D array is just a scalar
2124
            self.builder.add_container(tmp_name, dtype, False)
4✔
2125
            self.symbol_table[tmp_name] = dtype
4✔
2126
            self.array_info[tmp_name] = {"ndim": 0, "shapes": []}
4✔
2127

2128
            if zero_init:
4✔
2129
                self.builder.add_assignment(
×
2130
                    tmp_name,
2131
                    "0.0" if dtype.primitive_type == PrimitiveType.Double else "0",
2132
                )
2133
            elif ones_init:
4✔
2134
                self.builder.add_assignment(
×
2135
                    tmp_name,
2136
                    "1.0" if dtype.primitive_type == PrimitiveType.Double else "1",
2137
                )
2138

2139
            return tmp_name
4✔
2140

2141
        # Calculate size
2142
        size_str = "1"
4✔
2143
        for dim in shape:
4✔
2144
            size_str = f"({size_str} * {dim})"
4✔
2145

2146
        element_size = self.builder.get_sizeof(dtype)
4✔
2147
        total_size = f"({size_str} * {element_size})"
4✔
2148

2149
        # Create pointer
2150
        ptr_type = Pointer(dtype)
4✔
2151
        self.builder.add_container(tmp_name, ptr_type, False)
4✔
2152
        self.symbol_table[tmp_name] = ptr_type
4✔
2153
        array_info_entry = {"ndim": len(shape), "shapes": shape}
4✔
2154
        if shapes_runtime is not None:
4✔
2155
            array_info_entry["shapes_runtime"] = shapes_runtime
4✔
2156
        self.array_info[tmp_name] = array_info_entry
4✔
2157

2158
        # Malloc
2159
        block1 = self.builder.add_block()
4✔
2160
        t_malloc = self.builder.add_malloc(block1, total_size)
4✔
2161
        t_ptr1 = self.builder.add_access(block1, tmp_name)
4✔
2162
        self.builder.add_memlet(block1, t_malloc, "_ret", t_ptr1, "void", "", ptr_type)
4✔
2163

2164
        if zero_init:
4✔
2165
            block2 = self.builder.add_block()
4✔
2166
            t_memset = self.builder.add_memset(block2, "0", total_size)
4✔
2167
            t_ptr2 = self.builder.add_access(block2, tmp_name)
4✔
2168
            self.builder.add_memlet(
4✔
2169
                block2, t_memset, "_ptr", t_ptr2, "void", "", ptr_type
2170
            )
2171
        elif ones_init:
4✔
2172
            # Initialize array with ones using a loop
2173
            loop_var = f"_i_{self._get_unique_id()}"
4✔
2174
            if not self.builder.exists(loop_var):
4✔
2175
                self.builder.add_container(loop_var, Scalar(PrimitiveType.Int64), False)
4✔
2176
                self.symbol_table[loop_var] = Scalar(PrimitiveType.Int64)
4✔
2177

2178
            self.builder.begin_for(loop_var, "0", size_str, "1")
4✔
2179

2180
            # Determine the value to set based on dtype
2181
            val = "1.0"
4✔
2182
            if dtype.primitive_type in [
4✔
2183
                PrimitiveType.Int64,
2184
                PrimitiveType.Int32,
2185
                PrimitiveType.Int8,
2186
                PrimitiveType.Int16,
2187
                PrimitiveType.UInt64,
2188
                PrimitiveType.UInt32,
2189
                PrimitiveType.UInt8,
2190
                PrimitiveType.UInt16,
2191
            ]:
2192
                val = "1"
4✔
2193

2194
            block_assign = self.builder.add_block()
4✔
2195
            t_const = self.builder.add_constant(block_assign, val, dtype)
4✔
2196
            t_arr = self.builder.add_access(block_assign, tmp_name)
4✔
2197

2198
            t_task = self.builder.add_tasklet(
4✔
2199
                block_assign, TaskletCode.assign, ["_in"], ["_out"]
2200
            )
2201
            self.builder.add_memlet(
4✔
2202
                block_assign, t_const, "void", t_task, "_in", "", dtype
2203
            )
2204
            self.builder.add_memlet(
4✔
2205
                block_assign, t_task, "_out", t_arr, "void", loop_var
2206
            )
2207

2208
            self.builder.end_for()
4✔
2209

2210
        return tmp_name
4✔
2211

2212
    def _handle_array_unary_op(self, op_type, operand):
4✔
2213
        # Determine output shape
2214
        shape = []
4✔
2215
        if operand in self.array_info:
4✔
2216
            shape = self.array_info[operand]["shapes"]
4✔
2217

2218
        # Determine dtype
2219
        dtype = self._get_dtype(operand)
4✔
2220

2221
        # For 0-D arrays (scalars), use an intrinsic (CMathNode) instead of library node
2222
        if not shape or len(shape) == 0:
4✔
2223
            tmp_name = self._create_array_temp(shape, dtype)
4✔
2224

2225
            # Map op_type to C function names
2226
            func_map = {
4✔
2227
                "sqrt": CMathFunction.sqrt,
2228
                "abs": CMathFunction.fabs,
2229
                "absolute": CMathFunction.fabs,
2230
                "exp": CMathFunction.exp,
2231
                "tanh": CMathFunction.tanh,
2232
            }
2233

2234
            block = self.builder.add_block()
4✔
2235
            t_src = self.builder.add_access(block, operand)
4✔
2236
            t_dst = self.builder.add_access(block, tmp_name)
4✔
2237
            t_task = self.builder.add_cmath(block, func_map[op_type])
4✔
2238

2239
            # CMathNode uses _in1, _in2, etc for inputs and _out for output
2240
            self.builder.add_memlet(block, t_src, "void", t_task, "_in1", "", dtype)
4✔
2241
            self.builder.add_memlet(block, t_task, "_out", t_dst, "void", "", dtype)
4✔
2242

2243
            return tmp_name
4✔
2244

2245
        tmp_name = self._create_array_temp(shape, dtype)
4✔
2246

2247
        # Add operation
2248
        self.builder.add_elementwise_unary_op(op_type, operand, tmp_name, shape)
4✔
2249

2250
        return tmp_name
4✔
2251

2252
    def _handle_array_binary_op(self, op_type, left, right):
4✔
2253
        # Determine output shape (handle broadcasting by picking the larger shape)
2254
        left_shape = []
4✔
2255
        right_shape = []
4✔
2256
        if left in self.array_info:
4✔
2257
            left_shape = self.array_info[left]["shapes"]
4✔
2258
        if right in self.array_info:
4✔
2259
            right_shape = self.array_info[right]["shapes"]
4✔
2260
        # Pick the shape with more dimensions for broadcasting
2261
        shape = left_shape if len(left_shape) >= len(right_shape) else right_shape
4✔
2262

2263
        # Determine dtype with promotion (float > int, wider > narrower)
2264
        dtype_left = self._get_dtype(left)
4✔
2265
        dtype_right = self._get_dtype(right)
4✔
2266

2267
        # Promote dtypes: Double > Float > Int64 > Int32
2268
        dtype = self._promote_dtypes(dtype_left, dtype_right)
4✔
2269

2270
        # Cast scalar operands to the promoted dtype if needed
2271
        real_left = left
4✔
2272
        real_right = right
4✔
2273

2274
        # Helper to check if operand is a scalar (not an array)
2275
        left_is_scalar = left not in self.array_info
4✔
2276
        right_is_scalar = right not in self.array_info
4✔
2277

2278
        # Cast left operand if needed (scalar int to float)
2279
        if left_is_scalar and dtype_left.primitive_type != dtype.primitive_type:
4✔
2280
            left_cast = f"_tmp_{self._get_unique_id()}"
4✔
2281
            self.builder.add_container(left_cast, dtype, False)
4✔
2282
            self.symbol_table[left_cast] = dtype
4✔
2283

2284
            c_block = self.builder.add_block()
4✔
2285
            t_src, src_sub = self._add_read(c_block, left)
4✔
2286
            t_dst = self.builder.add_access(c_block, left_cast)
4✔
2287
            t_task = self.builder.add_tasklet(
4✔
2288
                c_block, TaskletCode.assign, ["_in"], ["_out"]
2289
            )
2290
            self.builder.add_memlet(c_block, t_src, "void", t_task, "_in", src_sub)
4✔
2291
            self.builder.add_memlet(c_block, t_task, "_out", t_dst, "void", "")
4✔
2292

2293
            real_left = left_cast
4✔
2294

2295
        # Cast right operand if needed (scalar int to float)
2296
        if right_is_scalar and dtype_right.primitive_type != dtype.primitive_type:
4✔
2297
            right_cast = f"_tmp_{self._get_unique_id()}"
4✔
2298
            self.builder.add_container(right_cast, dtype, False)
4✔
2299
            self.symbol_table[right_cast] = dtype
4✔
2300

2301
            c_block = self.builder.add_block()
4✔
2302
            t_src, src_sub = self._add_read(c_block, right)
4✔
2303
            t_dst = self.builder.add_access(c_block, right_cast)
4✔
2304
            t_task = self.builder.add_tasklet(
4✔
2305
                c_block, TaskletCode.assign, ["_in"], ["_out"]
2306
            )
2307
            self.builder.add_memlet(c_block, t_src, "void", t_task, "_in", src_sub)
4✔
2308
            self.builder.add_memlet(c_block, t_task, "_out", t_dst, "void", "")
4✔
2309

2310
            real_right = right_cast
4✔
2311

2312
        tmp_name = self._create_array_temp(shape, dtype)
4✔
2313

2314
        # Add operation with promoted dtype for implicit casting
2315
        self.builder.add_elementwise_op(op_type, real_left, real_right, tmp_name, shape)
4✔
2316

2317
        return tmp_name
4✔
2318

2319
    def _shape_to_runtime_expr(self, shape_node):
4✔
2320
        """Convert a shape expression AST node to a runtime-evaluable string.
2321

2322
        This converts the AST to a string expression that can be evaluated
2323
        at runtime using only input arrays and shape symbols (_s0, _s1, etc.).
2324
        It does NOT visit the node (which would create SDFG variables).
2325
        """
2326
        if isinstance(shape_node, ast.Constant):
4✔
2327
            return str(shape_node.value)
4✔
2328
        elif isinstance(shape_node, ast.Name):
4✔
2329
            return shape_node.id
4✔
2330
        elif isinstance(shape_node, ast.BinOp):
4✔
2331
            left = self._shape_to_runtime_expr(shape_node.left)
4✔
2332
            right = self._shape_to_runtime_expr(shape_node.right)
4✔
2333
            op = self.visit(shape_node.op)
4✔
2334
            return f"({left} {op} {right})"
4✔
2335
        elif isinstance(shape_node, ast.UnaryOp):
4✔
2336
            operand = self._shape_to_runtime_expr(shape_node.operand)
×
2337
            if isinstance(shape_node.op, ast.USub):
×
2338
                return f"(-{operand})"
×
2339
            elif isinstance(shape_node.op, ast.UAdd):
×
2340
                return operand
×
2341
            else:
2342
                # Fall back to visit for other unary ops
2343
                return self.visit(shape_node)
×
2344
        elif isinstance(shape_node, ast.Subscript):
4✔
2345
            # Handle arr.shape[0] -> arr.shape[0] for runtime eval
2346
            # or _shape_proxy_arr[0] -> _s<idx>
2347
            val = shape_node.value
4✔
2348
            if isinstance(val, ast.Attribute) and val.attr == "shape":
4✔
2349
                # arr.shape[0] -> use the shape symbol
2350
                if isinstance(val.value, ast.Name):
4✔
2351
                    arr_name = val.value.id
4✔
2352
                    if isinstance(shape_node.slice, ast.Constant):
4✔
2353
                        idx = shape_node.slice.value
4✔
2354
                        # Get the shape symbol for this array dimension
2355
                        if arr_name in self.array_info:
4✔
2356
                            shapes = self.array_info[arr_name].get("shapes", [])
4✔
2357
                            if idx < len(shapes):
4✔
2358
                                return shapes[idx]
4✔
2359
                        return f"{arr_name}.shape[{idx}]"
×
2360
            # Fall back to visit
2361
            return self.visit(shape_node)
×
2362
        elif isinstance(shape_node, ast.Tuple):
×
2363
            return [self._shape_to_runtime_expr(elt) for elt in shape_node.elts]
×
2364
        elif isinstance(shape_node, ast.List):
×
2365
            return [self._shape_to_runtime_expr(elt) for elt in shape_node.elts]
×
2366
        else:
2367
            # Fall back to visit for complex expressions
2368
            return self.visit(shape_node)
×
2369

2370
    def _handle_numpy_alloc(self, node, func_name):
4✔
2371
        # Parse shape
2372
        shape_arg = node.args[0]
4✔
2373
        dims = []
4✔
2374
        dims_runtime = []  # Runtime-evaluable shape expressions
4✔
2375
        if isinstance(shape_arg, ast.Tuple):
4✔
2376
            dims = [self.visit(elt) for elt in shape_arg.elts]
4✔
2377
            dims_runtime = [self._shape_to_runtime_expr(elt) for elt in shape_arg.elts]
4✔
2378
        elif isinstance(shape_arg, ast.List):
4✔
2379
            dims = [self.visit(elt) for elt in shape_arg.elts]
×
2380
            dims_runtime = [self._shape_to_runtime_expr(elt) for elt in shape_arg.elts]
×
2381
        else:
2382
            val = self.visit(shape_arg)
4✔
2383
            runtime_val = self._shape_to_runtime_expr(shape_arg)
4✔
2384
            if val.startswith("_shape_proxy_"):
4✔
2385
                array_name = val[len("_shape_proxy_") :]
×
2386
                if array_name in self.array_info:
×
2387
                    dims = self.array_info[array_name]["shapes"]
×
2388
                    dims_runtime = self.array_info[array_name].get(
×
2389
                        "shapes_runtime", dims
2390
                    )
2391
                else:
2392
                    dims = [val]
×
2393
                    dims_runtime = [runtime_val]
×
2394
            else:
2395
                dims = [val]
4✔
2396
                dims_runtime = [runtime_val]
4✔
2397

2398
        # Parse dtype
2399
        dtype_arg = None
4✔
2400
        if len(node.args) > 1:
4✔
2401
            dtype_arg = node.args[1]
×
2402

2403
        for kw in node.keywords:
4✔
2404
            if kw.arg == "dtype":
4✔
2405
                dtype_arg = kw.value
4✔
2406
                break
4✔
2407

2408
        element_type = self._map_numpy_dtype(dtype_arg)
4✔
2409

2410
        return self._create_array_temp(
4✔
2411
            dims,
2412
            element_type,
2413
            zero_init=(func_name == "zeros"),
2414
            ones_init=(func_name == "ones"),
2415
            shapes_runtime=dims_runtime,
2416
        )
2417

2418
    def _handle_numpy_empty_like(self, node, func_name):
4✔
2419
        prototype_arg = node.args[0]
4✔
2420
        prototype_name = self.visit(prototype_arg)
4✔
2421

2422
        # Parse shape from prototype
2423
        dims = []
4✔
2424
        if prototype_name in self.array_info:
4✔
2425
            dims = self.array_info[prototype_name]["shapes"]
4✔
2426

2427
        # Parse dtype
2428
        dtype_arg = None
4✔
2429
        if len(node.args) > 1:
4✔
2430
            dtype_arg = node.args[1]
×
2431

2432
        for kw in node.keywords:
4✔
2433
            if kw.arg == "dtype":
4✔
2434
                dtype_arg = kw.value
4✔
2435
                break
4✔
2436

2437
        element_type = None
4✔
2438
        if dtype_arg:
4✔
2439
            element_type = self._map_numpy_dtype(dtype_arg)
4✔
2440
        else:
2441
            if prototype_name in self.symbol_table:
4✔
2442
                sym_type = self.symbol_table[prototype_name]
4✔
2443
                if isinstance(sym_type, Pointer) and sym_type.has_pointee_type():
4✔
2444
                    element_type = sym_type.pointee_type
4✔
2445

2446
        if element_type is None:
4✔
2447
            element_type = Scalar(PrimitiveType.Double)
×
2448

2449
        return self._create_array_temp(
4✔
2450
            dims,
2451
            element_type,
2452
            zero_init=False,
2453
            ones_init=False,
2454
        )
2455

2456
    def _handle_numpy_zeros_like(self, node, func_name):
4✔
2457
        prototype_arg = node.args[0]
4✔
2458
        prototype_name = self.visit(prototype_arg)
4✔
2459

2460
        # Parse shape from prototype
2461
        dims = []
4✔
2462
        if prototype_name in self.array_info:
4✔
2463
            dims = self.array_info[prototype_name]["shapes"]
4✔
2464

2465
        # Parse dtype
2466
        dtype_arg = None
4✔
2467
        if len(node.args) > 1:
4✔
2468
            dtype_arg = node.args[1]
×
2469

2470
        for kw in node.keywords:
4✔
2471
            if kw.arg == "dtype":
4✔
2472
                dtype_arg = kw.value
4✔
2473
                break
4✔
2474

2475
        element_type = None
4✔
2476
        if dtype_arg:
4✔
2477
            element_type = self._map_numpy_dtype(dtype_arg)
4✔
2478
        else:
2479
            if prototype_name in self.symbol_table:
4✔
2480
                sym_type = self.symbol_table[prototype_name]
4✔
2481
                if isinstance(sym_type, Pointer) and sym_type.has_pointee_type():
4✔
2482
                    element_type = sym_type.pointee_type
4✔
2483

2484
        if element_type is None:
4✔
2485
            element_type = Scalar(PrimitiveType.Double)
×
2486

2487
        return self._create_array_temp(
4✔
2488
            dims,
2489
            element_type,
2490
            zero_init=True,
2491
            ones_init=False,
2492
        )
2493

2494
    def _handle_numpy_eye(self, node, func_name):
4✔
2495
        # Parse N
2496
        N_arg = node.args[0]
4✔
2497
        N_str = self.visit(N_arg)
4✔
2498

2499
        # Parse M
2500
        M_str = N_str
4✔
2501
        if len(node.args) > 1:
4✔
2502
            M_str = self.visit(node.args[1])
×
2503

2504
        # Parse k
2505
        k_str = "0"
4✔
2506
        if len(node.args) > 2:
4✔
2507
            k_str = self.visit(node.args[2])
×
2508

2509
        # Check keywords for M, k, dtype
2510
        dtype_arg = None
4✔
2511
        for kw in node.keywords:
4✔
2512
            if kw.arg == "M":
4✔
2513
                M_str = self.visit(kw.value)
4✔
2514
                if M_str == "None":
4✔
2515
                    M_str = N_str
4✔
2516
            elif kw.arg == "k":
4✔
2517
                k_str = self.visit(kw.value)
4✔
2518
            elif kw.arg == "dtype":
4✔
2519
                dtype_arg = kw.value
4✔
2520

2521
        element_type = self._map_numpy_dtype(dtype_arg)
4✔
2522

2523
        ptr_name = self._create_array_temp([N_str, M_str], element_type, zero_init=True)
4✔
2524

2525
        # Loop to set diagonal
2526
        loop_var = f"_i_{self._get_unique_id()}"
4✔
2527
        if not self.builder.exists(loop_var):
4✔
2528
            self.builder.add_container(loop_var, Scalar(PrimitiveType.Int64), False)
4✔
2529
            self.symbol_table[loop_var] = Scalar(PrimitiveType.Int64)
4✔
2530

2531
        self.builder.begin_for(loop_var, "0", N_str, "1")
4✔
2532

2533
        # Condition: 0 <= i + k < M
2534
        cond = f"(({loop_var} + {k_str}) >= 0) & (({loop_var} + {k_str}) < {M_str})"
4✔
2535
        self.builder.begin_if(cond)
4✔
2536

2537
        # Assignment: A[i, i+k] = 1
2538
        val = "1.0"
4✔
2539
        if element_type.primitive_type in [
4✔
2540
            PrimitiveType.Int64,
2541
            PrimitiveType.Int32,
2542
            PrimitiveType.Int8,
2543
            PrimitiveType.Int16,
2544
            PrimitiveType.UInt64,
2545
            PrimitiveType.UInt32,
2546
            PrimitiveType.UInt8,
2547
            PrimitiveType.UInt16,
2548
        ]:
2549
            val = "1"
×
2550

2551
        block_assign = self.builder.add_block()
4✔
2552
        t_const = self.builder.add_constant(block_assign, val, element_type)
4✔
2553
        t_arr = self.builder.add_access(block_assign, ptr_name)
4✔
2554
        flat_index = f"(({loop_var}) * ({M_str}) + ({loop_var}) + ({k_str}))"
4✔
2555
        subset = flat_index
4✔
2556

2557
        t_task = self.builder.add_tasklet(
4✔
2558
            block_assign, TaskletCode.assign, ["_in"], ["_out"]
2559
        )
2560
        self.builder.add_memlet(
4✔
2561
            block_assign, t_const, "void", t_task, "_in", "", element_type
2562
        )
2563
        self.builder.add_memlet(block_assign, t_task, "_out", t_arr, "void", subset)
4✔
2564

2565
        self.builder.end_if()
4✔
2566
        self.builder.end_for()
4✔
2567

2568
        return ptr_name
4✔
2569

2570
    def _handle_numpy_binary_op(self, node, func_name):
4✔
2571
        args = [self.visit(arg) for arg in node.args]
4✔
2572
        if len(args) != 2:
4✔
2573
            raise NotImplementedError(
×
2574
                f"Numpy function {func_name} requires 2 arguments"
2575
            )
2576

2577
        op_map = {
4✔
2578
            "add": "add",
2579
            "subtract": "sub",
2580
            "multiply": "mul",
2581
            "divide": "div",
2582
            "power": "pow",
2583
            "minimum": "min",
2584
            "maximum": "max",
2585
        }
2586
        return self._handle_array_binary_op(op_map[func_name], args[0], args[1])
4✔
2587

2588
    def _handle_numpy_where(self, node, func_name):
4✔
2589
        """Handle np.where(condition, x, y) - elementwise ternary selection.
2590

2591
        Returns an array where elements are taken from x where condition is True,
2592
        and from y where condition is False.
2593
        """
2594
        if len(node.args) != 3:
4✔
2595
            raise NotImplementedError("np.where requires 3 arguments (condition, x, y)")
×
2596

2597
        # Visit all arguments
2598
        cond_name = self.visit(node.args[0])
4✔
2599
        x_name = self.visit(node.args[1])
4✔
2600
        y_name = self.visit(node.args[2])
4✔
2601

2602
        # Determine output shape from the array arguments
2603
        # Priority: condition > y > x (since x might be scalar 0)
2604
        shape = []
4✔
2605
        dtype = Scalar(PrimitiveType.Double)
4✔
2606

2607
        # Check condition shape
2608
        if cond_name in self.array_info:
4✔
2609
            shape = self.array_info[cond_name]["shapes"]
4✔
2610

2611
        # If condition is scalar, check y
2612
        if not shape and y_name in self.array_info:
4✔
2613
            shape = self.array_info[y_name]["shapes"]
×
2614

2615
        # If y is scalar, check x
2616
        if not shape and x_name in self.array_info:
4✔
2617
            shape = self.array_info[x_name]["shapes"]
×
2618

2619
        if not shape:
4✔
2620
            raise NotImplementedError("np.where requires at least one array argument")
×
2621

2622
        # Determine dtype from y (since x might be scalar 0)
2623
        if y_name in self.symbol_table:
4✔
2624
            y_type = self.symbol_table[y_name]
4✔
2625
            if isinstance(y_type, Pointer) and y_type.has_pointee_type():
4✔
2626
                dtype = y_type.pointee_type
4✔
2627
            elif isinstance(y_type, Scalar):
×
2628
                dtype = y_type
×
2629

2630
        # Create output array
2631
        tmp_name = self._create_array_temp(shape, dtype)
4✔
2632

2633
        # Generate nested loops for the shape
2634
        loop_vars = []
4✔
2635
        for i, dim in enumerate(shape):
4✔
2636
            loop_var = f"_where_i{i}_{self._get_unique_id()}"
4✔
2637
            if not self.builder.exists(loop_var):
4✔
2638
                self.builder.add_container(loop_var, Scalar(PrimitiveType.Int64), False)
4✔
2639
                self.symbol_table[loop_var] = Scalar(PrimitiveType.Int64)
4✔
2640
            loop_vars.append(loop_var)
4✔
2641
            self.builder.begin_for(loop_var, "0", str(dim), "1")
4✔
2642

2643
        # Compute linear index
2644
        linear_idx = self._compute_linear_index(loop_vars, shape, tmp_name, len(shape))
4✔
2645

2646
        # Read condition value
2647
        cond_tmp = f"_where_cond_{self._get_unique_id()}"
4✔
2648
        self.builder.add_container(cond_tmp, Scalar(PrimitiveType.Bool), False)
4✔
2649
        self.symbol_table[cond_tmp] = Scalar(PrimitiveType.Bool)
4✔
2650

2651
        block_cond = self.builder.add_block()
4✔
2652
        if cond_name in self.array_info:
4✔
2653
            t_cond_arr = self.builder.add_access(block_cond, cond_name)
4✔
2654
            t_cond_out = self.builder.add_access(block_cond, cond_tmp)
4✔
2655
            t_cond_task = self.builder.add_tasklet(
4✔
2656
                block_cond, TaskletCode.assign, ["_in"], ["_out"]
2657
            )
2658
            self.builder.add_memlet(
4✔
2659
                block_cond, t_cond_arr, "void", t_cond_task, "_in", linear_idx
2660
            )
2661
            self.builder.add_memlet(
4✔
2662
                block_cond, t_cond_task, "_out", t_cond_out, "void", ""
2663
            )
2664
        else:
2665
            # Scalar condition - just use it directly
2666
            t_cond_src, cond_sub = self._add_read(block_cond, cond_name)
×
2667
            t_cond_out = self.builder.add_access(block_cond, cond_tmp)
×
2668
            t_cond_task = self.builder.add_tasklet(
×
2669
                block_cond, TaskletCode.assign, ["_in"], ["_out"]
2670
            )
2671
            self.builder.add_memlet(
×
2672
                block_cond, t_cond_src, "void", t_cond_task, "_in", cond_sub
2673
            )
2674
            self.builder.add_memlet(
×
2675
                block_cond, t_cond_task, "_out", t_cond_out, "void", ""
2676
            )
2677

2678
        # If-else based on condition
2679
        self.builder.begin_if(f"{cond_tmp} == true")
4✔
2680

2681
        # True branch: assign x
2682
        block_true = self.builder.add_block()
4✔
2683
        t_out_true = self.builder.add_access(block_true, tmp_name)
4✔
2684
        if x_name in self.array_info:
4✔
2685
            # x is an array
2686
            t_x = self.builder.add_access(block_true, x_name)
4✔
2687
            t_task_true = self.builder.add_tasklet(
4✔
2688
                block_true, TaskletCode.assign, ["_in"], ["_out"]
2689
            )
2690
            self.builder.add_memlet(
4✔
2691
                block_true, t_x, "void", t_task_true, "_in", linear_idx
2692
            )
2693
        else:
2694
            # x is a scalar
2695
            t_x, x_sub = self._add_read(block_true, x_name)
4✔
2696
            t_task_true = self.builder.add_tasklet(
4✔
2697
                block_true, TaskletCode.assign, ["_in"], ["_out"]
2698
            )
2699
            self.builder.add_memlet(block_true, t_x, "void", t_task_true, "_in", x_sub)
4✔
2700
        self.builder.add_memlet(
4✔
2701
            block_true, t_task_true, "_out", t_out_true, "void", linear_idx
2702
        )
2703

2704
        self.builder.begin_else()
4✔
2705

2706
        # False branch: assign y
2707
        block_false = self.builder.add_block()
4✔
2708
        t_out_false = self.builder.add_access(block_false, tmp_name)
4✔
2709
        if y_name in self.array_info:
4✔
2710
            # y is an array
2711
            t_y = self.builder.add_access(block_false, y_name)
4✔
2712
            t_task_false = self.builder.add_tasklet(
4✔
2713
                block_false, TaskletCode.assign, ["_in"], ["_out"]
2714
            )
2715
            self.builder.add_memlet(
4✔
2716
                block_false, t_y, "void", t_task_false, "_in", linear_idx
2717
            )
2718
        else:
2719
            # y is a scalar
2720
            t_y, y_sub = self._add_read(block_false, y_name)
4✔
2721
            t_task_false = self.builder.add_tasklet(
4✔
2722
                block_false, TaskletCode.assign, ["_in"], ["_out"]
2723
            )
2724
            self.builder.add_memlet(
4✔
2725
                block_false, t_y, "void", t_task_false, "_in", y_sub
2726
            )
2727
        self.builder.add_memlet(
4✔
2728
            block_false, t_task_false, "_out", t_out_false, "void", linear_idx
2729
        )
2730

2731
        self.builder.end_if()
4✔
2732

2733
        # Close all loops
2734
        for _ in loop_vars:
4✔
2735
            self.builder.end_for()
4✔
2736

2737
        return tmp_name
4✔
2738

2739
    def _handle_numpy_clip(self, node, func_name):
4✔
2740
        """Handle np.clip(a, a_min, a_max) - elementwise clipping.
2741

2742
        Clips array values to be within [a_min, a_max].
2743
        Implemented as: min(max(a, a_min), a_max)
2744

2745
        This uses the existing min/max elementwise operations.
2746
        """
2747
        if len(node.args) != 3:
4✔
2748
            raise NotImplementedError("np.clip requires 3 arguments (a, a_min, a_max)")
×
2749

2750
        # Visit the array argument
2751
        arr_name = self.visit(node.args[0])
4✔
2752
        # Visit the bound arguments (scalars or arrays)
2753
        a_min = self.visit(node.args[1])
4✔
2754
        a_max = self.visit(node.args[2])
4✔
2755

2756
        # First: tmp1 = max(arr, a_min) - ensures values are at least a_min
2757
        tmp1 = self._handle_array_binary_op("max", arr_name, a_min)
4✔
2758

2759
        # Second: result = min(tmp1, a_max) - ensures values are at most a_max
2760
        result = self._handle_array_binary_op("min", tmp1, a_max)
4✔
2761

2762
        return result
4✔
2763

2764
    def _handle_numpy_matmul_op(self, left_node, right_node):
4✔
2765
        return self._handle_matmul_helper(left_node, right_node)
4✔
2766

2767
    def _handle_numpy_matmul(self, node, func_name):
4✔
2768
        if len(node.args) != 2:
4✔
2769
            raise NotImplementedError("matmul/dot requires 2 arguments")
×
2770
        return self._handle_matmul_helper(node.args[0], node.args[1])
4✔
2771

2772
    def _handle_numpy_outer(self, node, func_name):
4✔
2773
        if len(node.args) != 2:
4✔
2774
            raise NotImplementedError("outer requires 2 arguments")
×
2775

2776
        arg0 = node.args[0]
4✔
2777
        arg1 = node.args[1]
4✔
2778

2779
        if not self.la_handler:
4✔
2780
            raise RuntimeError("LinearAlgebraHandler not initialized")
×
2781

2782
        res_a = self.la_handler.parse_arg(arg0)
4✔
2783
        res_b = self.la_handler.parse_arg(arg1)
4✔
2784

2785
        # Resolve standard names if parse_arg failed (likely complex expression)
2786
        if not res_a[0]:
4✔
2787
            left_name = self.visit(arg0)
×
2788
            arg0 = ast.Name(id=left_name)
×
2789
            res_a = self.la_handler.parse_arg(arg0)
×
2790

2791
        if not res_b[0]:
4✔
2792
            right_name = self.visit(arg1)
×
2793
            arg1 = ast.Name(id=right_name)
×
2794
            res_b = self.la_handler.parse_arg(arg1)
×
2795

2796
        name_a, subset_a, shape_a, indices_a = res_a
4✔
2797
        name_b, subset_b, shape_b, indices_b = res_b
4✔
2798

2799
        if not name_a or not name_b:
4✔
2800
            raise NotImplementedError("Could not resolve outer operands")
×
2801

2802
        def get_flattened_size_expr(name, indices, shapes):
4✔
2803
            # Simplified: if slice, we use parse_arg's returned `shapes` (which are dim sizes of the slice)
2804
            # And multiply them.
2805
            size_expr = "1"
4✔
2806
            for s in shapes:
4✔
2807
                if size_expr == "1":
4✔
2808
                    size_expr = str(s)
4✔
2809
                else:
2810
                    size_expr = f"({size_expr} * {str(s)})"
×
2811
            return size_expr
4✔
2812

2813
        m_expr = get_flattened_size_expr(name_a, indices_a, shape_a)
4✔
2814
        n_expr = get_flattened_size_expr(name_b, indices_b, shape_b)
4✔
2815

2816
        # Create temporary container
2817
        # Since outer usually promotes types or uses standard types, we default to double for now.
2818
        dtype = Scalar(PrimitiveType.Double)
4✔
2819

2820
        # Use helper to create array temp which handles symbol table and array info
2821
        tmp_name = self._create_array_temp([m_expr, n_expr], dtype)
4✔
2822

2823
        new_call_node = ast.Call(
4✔
2824
            func=node.func, args=[arg0, arg1], keywords=node.keywords
2825
        )
2826

2827
        self.la_handler.handle_outer(tmp_name, new_call_node)
4✔
2828

2829
        return tmp_name
4✔
2830

2831
    def _handle_ufunc_outer(self, node, ufunc_name):
4✔
2832
        """Handle np.add.outer, np.subtract.outer, np.multiply.outer, etc.
2833

2834
        These compute the outer operation for the given ufunc:
2835
        - np.add.outer(a, b) -> a[:, np.newaxis] + b (outer sum)
2836
        - np.subtract.outer(a, b) -> a[:, np.newaxis] - b (outer difference)
2837
        - np.multiply.outer(a, b) -> a[:, np.newaxis] * b (same as np.outer)
2838
        """
2839
        if len(node.args) != 2:
4✔
2840
            raise NotImplementedError(f"{ufunc_name}.outer requires 2 arguments")
×
2841

2842
        # For np.multiply.outer, use the existing GEMM-based outer handler
2843
        if ufunc_name == "multiply":
4✔
2844
            return self._handle_numpy_outer(node, "outer")
4✔
2845

2846
        # Map ufunc names to operation names and tasklet opcodes
2847
        op_map = {
4✔
2848
            "add": ("add", TaskletCode.fp_add, TaskletCode.int_add),
2849
            "subtract": ("sub", TaskletCode.fp_sub, TaskletCode.int_sub),
2850
            "divide": ("div", TaskletCode.fp_div, TaskletCode.int_sdiv),
2851
            "minimum": ("min", CMathFunction.fmin, TaskletCode.int_smin),
2852
            "maximum": ("max", CMathFunction.fmax, TaskletCode.int_smax),
2853
        }
2854

2855
        if ufunc_name not in op_map:
4✔
2856
            raise NotImplementedError(f"{ufunc_name}.outer not supported")
×
2857

2858
        op_name, fp_opcode, int_opcode = op_map[ufunc_name]
4✔
2859

2860
        # Use la_handler.parse_arg to properly handle sliced arrays
2861
        if not self.la_handler:
4✔
2862
            raise RuntimeError("LinearAlgebraHandler not initialized")
×
2863

2864
        arg0 = node.args[0]
4✔
2865
        arg1 = node.args[1]
4✔
2866

2867
        res_a = self.la_handler.parse_arg(arg0)
4✔
2868
        res_b = self.la_handler.parse_arg(arg1)
4✔
2869

2870
        # If parse_arg fails for complex expressions, try visiting and re-parsing
2871
        if not res_a[0]:
4✔
2872
            left_name = self.visit(arg0)
×
2873
            arg0 = ast.Name(id=left_name)
×
2874
            res_a = self.la_handler.parse_arg(arg0)
×
2875

2876
        if not res_b[0]:
4✔
2877
            right_name = self.visit(arg1)
×
2878
            arg1 = ast.Name(id=right_name)
×
2879
            res_b = self.la_handler.parse_arg(arg1)
×
2880

2881
        name_a, subset_a, shape_a, indices_a = res_a
4✔
2882
        name_b, subset_b, shape_b, indices_b = res_b
4✔
2883

2884
        if not name_a or not name_b:
4✔
2885
            raise NotImplementedError("Could not resolve ufunc outer operands")
×
2886

2887
        # Compute flattened sizes - outer treats inputs as 1D
2888
        def get_flattened_size_expr(shapes):
4✔
2889
            if not shapes:
4✔
2890
                return "1"
×
2891
            size_expr = str(shapes[0])
4✔
2892
            for s in shapes[1:]:
4✔
2893
                size_expr = f"({size_expr} * {str(s)})"
×
2894
            return size_expr
4✔
2895

2896
        m_expr = get_flattened_size_expr(shape_a)
4✔
2897
        n_expr = get_flattened_size_expr(shape_b)
4✔
2898

2899
        # Determine output dtype - infer from inputs or default to double
2900
        dtype_left = self._get_dtype(name_a)
4✔
2901
        dtype_right = self._get_dtype(name_b)
4✔
2902
        dtype = self._promote_dtypes(dtype_left, dtype_right)
4✔
2903

2904
        # Determine if we're working with integers
2905
        is_int = dtype.primitive_type in [
4✔
2906
            PrimitiveType.Int64,
2907
            PrimitiveType.Int32,
2908
            PrimitiveType.Int8,
2909
            PrimitiveType.Int16,
2910
            PrimitiveType.UInt64,
2911
            PrimitiveType.UInt32,
2912
            PrimitiveType.UInt8,
2913
            PrimitiveType.UInt16,
2914
        ]
2915

2916
        # Create output array with shape (M, N)
2917
        tmp_name = self._create_array_temp([m_expr, n_expr], dtype)
4✔
2918

2919
        # Generate unique loop variable names
2920
        i_var = self._get_temp_name("_outer_i_")
4✔
2921
        j_var = self._get_temp_name("_outer_j_")
4✔
2922

2923
        # Ensure loop variables exist
2924
        if not self.builder.exists(i_var):
4✔
2925
            self.builder.add_container(i_var, Scalar(PrimitiveType.Int64), False)
4✔
2926
            self.symbol_table[i_var] = Scalar(PrimitiveType.Int64)
4✔
2927
        if not self.builder.exists(j_var):
4✔
2928
            self.builder.add_container(j_var, Scalar(PrimitiveType.Int64), False)
4✔
2929
            self.symbol_table[j_var] = Scalar(PrimitiveType.Int64)
4✔
2930

2931
        # Helper function to compute the linear index for a sliced array access
2932
        def compute_linear_index(name, subset, indices, loop_var):
4✔
2933
            """
2934
            Compute linear index for accessing element loop_var of a sliced array.
2935

2936
            For array A with shape (N, M):
2937
            - A[:, k] (column k): linear_index = loop_var * M + k
2938
            - A[k, :] (row k): linear_index = k * M + loop_var
2939
            - A[:] (1D array): linear_index = loop_var
2940

2941
            The indices list contains AST nodes showing which dims are sliced vs fixed.
2942
            subset contains start indices for each dimension.
2943
            """
2944
            if not indices:
4✔
2945
                # Simple 1D array, no slicing
2946
                return loop_var
4✔
2947

2948
            info = self.array_info.get(name, {})
4✔
2949
            shapes = info.get("shapes", [])
4✔
2950
            ndim = info.get("ndim", len(shapes))
4✔
2951

2952
            if ndim == 0:
4✔
2953
                return loop_var
×
2954

2955
            # Compute strides (row-major order)
2956
            strides = []
4✔
2957
            current_stride = "1"
4✔
2958
            for i in range(ndim - 1, -1, -1):
4✔
2959
                strides.insert(0, current_stride)
4✔
2960
                if i > 0:
4✔
2961
                    dim_size = shapes[i] if i < len(shapes) else f"_{name}_shape_{i}"
4✔
2962
                    if current_stride == "1":
4✔
2963
                        current_stride = str(dim_size)
4✔
2964
                    else:
2965
                        current_stride = f"({current_stride} * {dim_size})"
×
2966

2967
            # Build linear index from subset and indices info
2968
            terms = []
4✔
2969
            loop_var_used = False
4✔
2970

2971
            for i, idx in enumerate(indices):
4✔
2972
                stride = strides[i] if i < len(strides) else "1"
4✔
2973
                start = subset[i] if i < len(subset) else "0"
4✔
2974

2975
                if isinstance(idx, ast.Slice):
4✔
2976
                    # This dimension is sliced - use loop_var
2977
                    if stride == "1":
4✔
2978
                        term = f"({start} + {loop_var})"
4✔
2979
                    else:
2980
                        term = f"(({start} + {loop_var}) * {stride})"
4✔
2981
                    loop_var_used = True
4✔
2982
                else:
2983
                    # This dimension has a fixed index
2984
                    if stride == "1":
4✔
2985
                        term = start
4✔
2986
                    else:
2987
                        term = f"({start} * {stride})"
4✔
2988

2989
                terms.append(term)
4✔
2990

2991
            # Sum all terms
2992
            if not terms:
4✔
2993
                return loop_var
×
2994

2995
            result = terms[0]
4✔
2996
            for t in terms[1:]:
4✔
2997
                result = f"({result} + {t})"
4✔
2998

2999
            return result
4✔
3000

3001
        # Create nested for loops: for i in range(M): for j in range(N): C[i,j] = A[i] op B[j]
3002
        self.builder.begin_for(i_var, "0", m_expr, "1")
4✔
3003
        self.builder.begin_for(j_var, "0", n_expr, "1")
4✔
3004

3005
        # Create the assignment block: C[i, j] = A[i] op B[j]
3006
        block = self.builder.add_block()
4✔
3007

3008
        # Add access nodes
3009
        t_a = self.builder.add_access(block, name_a)
4✔
3010
        t_b = self.builder.add_access(block, name_b)
4✔
3011
        t_c = self.builder.add_access(block, tmp_name)
4✔
3012

3013
        # Determine tasklet type based on operation
3014
        if ufunc_name in ["minimum", "maximum"]:
4✔
3015
            # Use intrinsic for min/max
3016
            if is_int:
4✔
3017
                t_task = self.builder.add_tasklet(
4✔
3018
                    block, int_opcode, ["_in1", "_in2"], ["_out"]
3019
                )
3020
            else:
3021
                t_task = self.builder.add_cmath(block, fp_opcode)
4✔
3022
        else:
3023
            # Use regular tasklet for arithmetic ops
3024
            tasklet_code = int_opcode if is_int else fp_opcode
4✔
3025
            t_task = self.builder.add_tasklet(
4✔
3026
                block, tasklet_code, ["_in1", "_in2"], ["_out"]
3027
            )
3028

3029
        # Compute the linear index for A[i]
3030
        a_index = compute_linear_index(name_a, subset_a, indices_a, i_var)
4✔
3031

3032
        # Compute the linear index for B[j]
3033
        b_index = compute_linear_index(name_b, subset_b, indices_b, j_var)
4✔
3034

3035
        # Connect A[i + offset_a] -> tasklet
3036
        self.builder.add_memlet(block, t_a, "void", t_task, "_in1", a_index)
4✔
3037

3038
        # Connect B[j + offset_b] -> tasklet
3039
        self.builder.add_memlet(block, t_b, "void", t_task, "_in2", b_index)
4✔
3040

3041
        # Connect tasklet -> C[i * N + j] (linear index for 2D output)
3042
        flat_index = f"(({i_var}) * ({n_expr}) + ({j_var}))"
4✔
3043
        self.builder.add_memlet(block, t_task, "_out", t_c, "void", flat_index)
4✔
3044

3045
        self.builder.end_for()  # end j loop
4✔
3046
        self.builder.end_for()  # end i loop
4✔
3047

3048
        return tmp_name
4✔
3049

3050
    def _op_symbol(self, op_name):
4✔
3051
        """Convert operation name to symbol."""
3052
        symbols = {
×
3053
            "add": "+",
3054
            "sub": "-",
3055
            "mul": "*",
3056
            "div": "/",
3057
            "min": "min",  # Will need special handling
3058
            "max": "max",  # Will need special handling
3059
        }
3060
        return symbols.get(op_name, op_name)
×
3061

3062
    def _handle_matmul_helper(self, left_node, right_node):
4✔
3063
        if not self.la_handler:
4✔
3064
            raise RuntimeError("LinearAlgebraHandler not initialized")
×
3065

3066
        res_a = self.la_handler.parse_arg(left_node)
4✔
3067
        res_b = self.la_handler.parse_arg(right_node)
4✔
3068

3069
        if not res_a[0]:
4✔
3070
            left_name = self.visit(left_node)
×
3071
            left_node = ast.Name(id=left_name)
×
3072
            res_a = self.la_handler.parse_arg(left_node)
×
3073

3074
        if not res_b[0]:
4✔
3075
            right_name = self.visit(right_node)
4✔
3076
            right_node = ast.Name(id=right_name)
4✔
3077
            res_b = self.la_handler.parse_arg(right_node)
4✔
3078

3079
        name_a, subset_a, shape_a, indices_a = res_a
4✔
3080
        name_b, subset_b, shape_b, indices_b = res_b
4✔
3081

3082
        if not name_a or not name_b:
4✔
3083
            raise NotImplementedError("Could not resolve matmul operands")
×
3084

3085
        real_shape_a = shape_a
4✔
3086
        real_shape_b = shape_b
4✔
3087

3088
        ndim_a = len(real_shape_a)
4✔
3089
        ndim_b = len(real_shape_b)
4✔
3090

3091
        output_shape = []
4✔
3092
        is_scalar = False
4✔
3093

3094
        if ndim_a == 1 and ndim_b == 1:
4✔
3095
            is_scalar = True
4✔
3096
            output_shape = []
4✔
3097
        elif ndim_a == 2 and ndim_b == 2:
4✔
3098
            output_shape = [real_shape_a[0], real_shape_b[1]]
4✔
3099
        elif ndim_a == 2 and ndim_b == 1:
4✔
3100
            output_shape = [real_shape_a[0]]
4✔
3101
        elif ndim_a == 1 and ndim_b == 2:
4✔
3102
            output_shape = [real_shape_b[1]]
×
3103
        elif ndim_a > 2 or ndim_b > 2:
4✔
3104
            if ndim_a == ndim_b:
4✔
3105
                output_shape = list(real_shape_a[:-2]) + [
4✔
3106
                    real_shape_a[-2],
3107
                    real_shape_b[-1],
3108
                ]
3109
            else:
3110
                raise NotImplementedError(
×
3111
                    "Broadcasting with different ranks not fully supported yet"
3112
                )
3113
        else:
3114
            raise NotImplementedError(
×
3115
                f"Matmul with ranks {ndim_a} and {ndim_b} not supported"
3116
            )
3117

3118
        dtype = Scalar(PrimitiveType.Double)
4✔
3119

3120
        if is_scalar:
4✔
3121
            tmp_name = f"_tmp_{self._get_unique_id()}"
4✔
3122
            self.builder.add_container(tmp_name, dtype, False)
4✔
3123
            self.symbol_table[tmp_name] = dtype
4✔
3124
        else:
3125
            tmp_name = self._create_array_temp(output_shape, dtype)
4✔
3126

3127
        if ndim_a > 2 or ndim_b > 2:
4✔
3128
            # Generate loops for broadcasting
3129
            batch_dims = ndim_a - 2
4✔
3130
            loop_vars = []
4✔
3131

3132
            for i in range(batch_dims):
4✔
3133
                loop_var = f"_i{self._get_unique_id()}"
4✔
3134
                self.builder.add_container(loop_var, Scalar(PrimitiveType.Int64), False)
4✔
3135
                loop_vars.append(loop_var)
4✔
3136
                dim_size = real_shape_a[i]
4✔
3137
                self.builder.begin_for(loop_var, "0", str(dim_size), "1")
4✔
3138

3139
            def make_slice(name, indices):
4✔
3140
                elts = []
4✔
3141
                for idx in indices:
4✔
3142
                    if idx == ":":
4✔
3143
                        elts.append(ast.Slice())
4✔
3144
                    else:
3145
                        elts.append(ast.Name(id=idx))
4✔
3146

3147
                return ast.Subscript(
4✔
3148
                    value=ast.Name(id=name), slice=ast.Tuple(elts=elts), ctx=ast.Load()
3149
                )
3150

3151
            indices = loop_vars + [":", ":"]
4✔
3152
            slice_a = make_slice(name_a, indices)
4✔
3153
            slice_b = make_slice(name_b, indices)
4✔
3154
            slice_c = make_slice(tmp_name, indices)
4✔
3155

3156
            self.la_handler.handle_gemm(
4✔
3157
                slice_c, ast.BinOp(left=slice_a, op=ast.MatMult(), right=slice_b)
3158
            )
3159

3160
            for _ in range(batch_dims):
4✔
3161
                self.builder.end_for()
4✔
3162
        else:
3163
            if is_scalar:
4✔
3164
                self.la_handler.handle_dot(
4✔
3165
                    tmp_name,
3166
                    ast.BinOp(left=left_node, op=ast.MatMult(), right=right_node),
3167
                )
3168
            else:
3169
                self.la_handler.handle_gemm(
4✔
3170
                    tmp_name,
3171
                    ast.BinOp(left=left_node, op=ast.MatMult(), right=right_node),
3172
                )
3173

3174
        return tmp_name
4✔
3175

3176
    def _handle_numpy_unary_op(self, node, func_name):
4✔
3177
        args = [self.visit(arg) for arg in node.args]
4✔
3178
        if len(args) != 1:
4✔
3179
            raise NotImplementedError(f"Numpy function {func_name} requires 1 argument")
×
3180

3181
        op_name = func_name
4✔
3182
        if op_name == "absolute":
4✔
3183
            op_name = "abs"
×
3184

3185
        return self._handle_array_unary_op(op_name, args[0])
4✔
3186

3187
    def _handle_numpy_reduce(self, node, func_name):
4✔
3188
        args = node.args
4✔
3189
        keywords = {kw.arg: kw.value for kw in node.keywords}
4✔
3190

3191
        array_node = args[0]
4✔
3192
        array_name = self.visit(array_node)
4✔
3193

3194
        if array_name not in self.array_info:
4✔
3195
            raise ValueError(f"Reduction input must be an array, got {array_name}")
×
3196

3197
        input_shape = self.array_info[array_name]["shapes"]
4✔
3198
        ndim = len(input_shape)
4✔
3199

3200
        axis = None
4✔
3201
        if len(args) > 1:
4✔
3202
            axis = args[1]
×
3203
        elif "axis" in keywords:
4✔
3204
            axis = keywords["axis"]
4✔
3205

3206
        keepdims = False
4✔
3207
        if "keepdims" in keywords:
4✔
3208
            keepdims_node = keywords["keepdims"]
4✔
3209
            if isinstance(keepdims_node, ast.Constant):
4✔
3210
                keepdims = bool(keepdims_node.value)
4✔
3211

3212
        axes = []
4✔
3213
        if axis is None:
4✔
3214
            axes = list(range(ndim))
4✔
3215
        elif isinstance(axis, ast.Constant):  # Single axis
4✔
3216
            val = axis.value
4✔
3217
            if val < 0:
4✔
3218
                val += ndim
×
3219
            axes = [val]
4✔
3220
        elif isinstance(axis, ast.Tuple):  # Multiple axes
×
3221
            for elt in axis.elts:
×
3222
                if isinstance(elt, ast.Constant):
×
3223
                    val = elt.value
×
3224
                    if val < 0:
×
3225
                        val += ndim
×
3226
                    axes.append(val)
×
3227
        elif (
×
3228
            isinstance(axis, ast.UnaryOp)
3229
            and isinstance(axis.op, ast.USub)
3230
            and isinstance(axis.operand, ast.Constant)
3231
        ):
3232
            val = -axis.operand.value
×
3233
            if val < 0:
×
3234
                val += ndim
×
3235
            axes = [val]
×
3236
        else:
3237
            # Try to evaluate simple expression
3238
            try:
×
3239
                val = int(self.visit(axis))
×
3240
                if val < 0:
×
3241
                    val += ndim
×
3242
                axes = [val]
×
3243
            except:
×
3244
                raise NotImplementedError("Dynamic axis not supported")
×
3245

3246
        # Calculate output shape
3247
        output_shape = []
4✔
3248
        for i in range(ndim):
4✔
3249
            if i in axes:
4✔
3250
                if keepdims:
4✔
3251
                    output_shape.append("1")
4✔
3252
            else:
3253
                output_shape.append(input_shape[i])
4✔
3254

3255
        dtype = self._get_dtype(array_name)
4✔
3256

3257
        if not output_shape:
4✔
3258
            tmp_name = f"_tmp_{self._get_unique_id()}"
4✔
3259
            self.builder.add_container(tmp_name, dtype, False)
4✔
3260
            self.symbol_table[tmp_name] = dtype
4✔
3261
            self.array_info[tmp_name] = {"ndim": 0, "shapes": []}
4✔
3262
        else:
3263
            tmp_name = self._create_array_temp(output_shape, dtype)
4✔
3264

3265
        self.builder.add_reduce_op(
4✔
3266
            func_name, array_name, tmp_name, input_shape, axes, keepdims
3267
        )
3268

3269
        return tmp_name
4✔
3270

3271
    def _handle_numpy_astype(self, node, array_name):
4✔
3272
        """Handle numpy array.astype(dtype) method calls."""
3273
        if len(node.args) < 1:
4✔
3274
            raise ValueError("astype requires at least one argument (dtype)")
×
3275

3276
        dtype_arg = node.args[0]
4✔
3277
        target_dtype = self._map_numpy_dtype(dtype_arg)
4✔
3278

3279
        # Get input array shape
3280
        if array_name not in self.array_info:
4✔
3281
            raise ValueError(f"Array {array_name} not found in array_info")
×
3282

3283
        input_shape = self.array_info[array_name]["shapes"]
4✔
3284

3285
        # Create output array with target dtype
3286
        tmp_name = self._create_array_temp(input_shape, target_dtype)
4✔
3287

3288
        # Add cast operation
3289
        self.builder.add_cast_op(
4✔
3290
            array_name, tmp_name, input_shape, target_dtype.primitive_type
3291
        )
3292

3293
        return tmp_name
4✔
3294

3295
    def _handle_numpy_copy(self, node, array_name):
4✔
3296
        """Handle numpy array.copy() method calls using memcpy."""
3297
        if array_name not in self.array_info:
4✔
3298
            raise ValueError(f"Array {array_name} not found in array_info")
×
3299

3300
        input_shape = self.array_info[array_name]["shapes"]
4✔
3301

3302
        # Get element type from array
3303
        element_type = Scalar(PrimitiveType.Double)  # Default
4✔
3304
        if array_name in self.symbol_table:
4✔
3305
            sym_type = self.symbol_table[array_name]
4✔
3306
            if isinstance(sym_type, Pointer) and sym_type.has_pointee_type():
4✔
3307
                element_type = sym_type.pointee_type
4✔
3308

3309
        # Create output array with same dtype
3310
        tmp_name = self._create_array_temp(input_shape, element_type)
4✔
3311

3312
        # Calculate total number of bytes to copy
3313
        # count = total_elements * sizeof(element_type)
3314
        total_elements = " * ".join([f"({s})" for s in input_shape])
4✔
3315
        element_size = self.builder.get_sizeof(element_type)
4✔
3316
        count_expr = f"({total_elements}) * ({element_size})"
4✔
3317

3318
        # Get pointer type for memlets
3319
        ptr_type = Pointer(element_type)
4✔
3320

3321
        # Add memcpy operation
3322
        block = self.builder.add_block()
4✔
3323
        t_src = self.builder.add_access(block, array_name)
4✔
3324
        t_dst = self.builder.add_access(block, tmp_name)
4✔
3325
        t_memcpy = self.builder.add_memcpy(block, count_expr)
4✔
3326

3327
        # Connect source and destination
3328
        self.builder.add_memlet(block, t_src, "void", t_memcpy, "_src", "", ptr_type)
4✔
3329
        self.builder.add_memlet(block, t_memcpy, "_dst", t_dst, "void", "", ptr_type)
4✔
3330

3331
        return tmp_name
4✔
3332

3333
    def _handle_scipy_softmax(self, node, func_name):
4✔
3334
        args = node.args
4✔
3335
        keywords = {kw.arg: kw.value for kw in node.keywords}
4✔
3336

3337
        array_node = args[0]
4✔
3338
        array_name = self.visit(array_node)
4✔
3339

3340
        if array_name not in self.array_info:
4✔
3341
            raise ValueError(f"Softmax input must be an array, got {array_name}")
×
3342

3343
        input_shape = self.array_info[array_name]["shapes"]
4✔
3344
        ndim = len(input_shape)
4✔
3345

3346
        axis = None
4✔
3347
        if len(args) > 1:
4✔
3348
            axis = args[1]
×
3349
        elif "axis" in keywords:
4✔
3350
            axis = keywords["axis"]
4✔
3351

3352
        axes = []
4✔
3353
        if axis is None:
4✔
3354
            axes = list(range(ndim))
4✔
3355
        elif isinstance(axis, ast.Constant):  # Single axis
4✔
3356
            val = axis.value
4✔
3357
            if val < 0:
4✔
3358
                val += ndim
×
3359
            axes = [val]
4✔
3360
        elif isinstance(axis, ast.Tuple):  # Multiple axes
×
3361
            for elt in axis.elts:
×
3362
                if isinstance(elt, ast.Constant):
×
3363
                    val = elt.value
×
3364
                    if val < 0:
×
3365
                        val += ndim
×
3366
                    axes.append(val)
×
3367
        elif (
×
3368
            isinstance(axis, ast.UnaryOp)
3369
            and isinstance(axis.op, ast.USub)
3370
            and isinstance(axis.operand, ast.Constant)
3371
        ):
3372
            val = -axis.operand.value
×
3373
            if val < 0:
×
3374
                val += ndim
×
3375
            axes = [val]
×
3376
        else:
3377
            # Try to evaluate simple expression
3378
            try:
×
3379
                val = int(self.visit(axis))
×
3380
                if val < 0:
×
3381
                    val += ndim
×
3382
                axes = [val]
×
3383
            except:
×
3384
                raise NotImplementedError("Dynamic axis not supported")
×
3385

3386
        # Create output array
3387
        # Assume double for now, or infer from input
3388
        dtype = Scalar(PrimitiveType.Double)  # TODO: infer
4✔
3389

3390
        tmp_name = self._create_array_temp(input_shape, dtype)
4✔
3391

3392
        self.builder.add_reduce_op(
4✔
3393
            func_name, array_name, tmp_name, input_shape, axes, False
3394
        )
3395

3396
        return tmp_name
4✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc