• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

daisytuner / sdfglib / 21469091322

29 Jan 2026 07:12AM UTC coverage: 65.843% (+0.1%) from 65.732%
21469091322

push

github

web-flow
Merge pull request #484 from daisytuner/python-gather

adds support for Python gather operations

190 of 240 new or added lines in 4 files covered. (79.17%)

1 existing line in 1 file now uncovered.

22407 of 34031 relevant lines covered (65.84%)

383.67 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

81.19
/python/docc/expression_visitor.py
1
import ast
4✔
2
import inspect
4✔
3
import textwrap
4✔
4
from ._sdfg import (
4✔
5
    Scalar,
6
    PrimitiveType,
7
    Pointer,
8
    Type,
9
    DebugInfo,
10
    Structure,
11
    TaskletCode,
12
    CMathFunction,
13
)
14

15

16
class ExpressionVisitor(ast.NodeVisitor):
4✔
17
    def __init__(
4✔
18
        self,
19
        array_info=None,
20
        builder=None,
21
        symbol_table=None,
22
        globals_dict=None,
23
        inliner=None,
24
        unique_counter_ref=None,
25
        structure_member_info=None,
26
    ):
27
        self.array_info = array_info if array_info is not None else {}
4✔
28
        self.builder = builder
4✔
29
        self.symbol_table = symbol_table if symbol_table is not None else {}
4✔
30
        self.globals_dict = globals_dict if globals_dict is not None else {}
4✔
31
        self.inliner = inliner
4✔
32
        self._unique_counter_ref = (
4✔
33
            unique_counter_ref if unique_counter_ref is not None else [0]
34
        )
35
        self._access_cache = {}
4✔
36
        self.la_handler = None
4✔
37
        self.structure_member_info = (
4✔
38
            structure_member_info if structure_member_info is not None else {}
39
        )
40
        self._init_numpy_handlers()
4✔
41

42
    def _get_unique_id(self):
4✔
43
        self._unique_counter_ref[0] += 1
4✔
44
        return self._unique_counter_ref[0]
4✔
45

46
    def _get_temp_name(self, prefix="_tmp_"):
4✔
47
        if hasattr(self.builder, "find_new_name"):
4✔
48
            return self.builder.find_new_name(prefix)
×
49
        return f"{prefix}{self._get_unique_id()}"
4✔
50

51
    def _is_indirect_access(self, node):
4✔
52
        """Check if a node represents an indirect array access (e.g., A[B[i]]).
53

54
        Returns True if the node is a subscript where the index itself is a subscript
55
        into an array (indirect access pattern).
56
        """
NEW
57
        if not isinstance(node, ast.Subscript):
×
NEW
58
            return False
×
59
        # Check if value is a subscripted array access
NEW
60
        if isinstance(node.value, ast.Name):
×
NEW
61
            arr_name = node.value.id
×
NEW
62
            if arr_name in self.array_info:
×
63
                # Check if slice/index is itself an array access
NEW
64
                if isinstance(node.slice, ast.Subscript):
×
NEW
65
                    if isinstance(node.slice.value, ast.Name):
×
NEW
66
                        idx_arr_name = node.slice.value.id
×
NEW
67
                        if idx_arr_name in self.array_info:
×
NEW
68
                            return True
×
NEW
69
        return False
×
70

71
    def _contains_indirect_access(self, node):
4✔
72
        """Check if an AST node contains any indirect array access.
73

74
        Used to detect expressions like A_row[i] that would be used as slice bounds.
75
        """
76
        if isinstance(node, ast.Subscript):
4✔
77
            if isinstance(node.value, ast.Name):
4✔
78
                arr_name = node.value.id
4✔
79
                if arr_name in self.array_info:
4✔
80
                    return True
4✔
81
        elif isinstance(node, ast.BinOp):
4✔
82
            return self._contains_indirect_access(
4✔
83
                node.left
84
            ) or self._contains_indirect_access(node.right)
85
        elif isinstance(node, ast.UnaryOp):
4✔
86
            return self._contains_indirect_access(node.operand)
4✔
87
        return False
4✔
88

89
    def _materialize_indirect_access(
4✔
90
        self, node, debug_info=None, return_original_expr=False
91
    ):
92
        """Materialize an array access into a scalar variable using tasklet+memlets.
93

94
        For indirect memory access patterns in SDFGs, we need to:
95
        1. Create a scalar container for the result
96
        2. Create a tasklet that performs the assignment
97
        3. Use memlets to read from the array and write to the scalar
98
        4. Return the scalar name (which can be used as a symbolic expression)
99

100
        This is the canonical SDFG pattern for indirect access.
101

102
        If return_original_expr is True, also returns the original array access
103
        expression using parentheses notation (e.g., "A_row(0)") which is consistent
104
        with SDFG subset notation. The runtime evaluator will convert this to
105
        bracket notation for Python evaluation.
106
        """
107
        if not self.builder:
4✔
108
            # Without builder, just return the expression string
NEW
109
            expr = self.visit(node)
×
NEW
110
            return (expr, expr) if return_original_expr else expr
×
111

112
        if debug_info is None:
4✔
113
            debug_info = DebugInfo()
4✔
114

115
        if not isinstance(node, ast.Subscript):
4✔
NEW
116
            expr = self.visit(node)
×
NEW
117
            return (expr, expr) if return_original_expr else expr
×
118

119
        if not isinstance(node.value, ast.Name):
4✔
NEW
120
            expr = self.visit(node)
×
NEW
121
            return (expr, expr) if return_original_expr else expr
×
122

123
        arr_name = node.value.id
4✔
124
        if arr_name not in self.array_info:
4✔
NEW
125
            expr = self.visit(node)
×
NEW
126
            return (expr, expr) if return_original_expr else expr
×
127

128
        # Determine the element type
129
        dtype = Scalar(PrimitiveType.Int64)  # Default for indices
4✔
130
        if arr_name in self.symbol_table:
4✔
131
            t = self.symbol_table[arr_name]
4✔
132
            if isinstance(t, Pointer) and t.has_pointee_type():
4✔
133
                dtype = t.pointee_type
4✔
134

135
        # Create scalar container for the result
136
        tmp_name = self._get_temp_name("_idx_")
4✔
137
        self.builder.add_container(tmp_name, dtype, False)
4✔
138
        self.symbol_table[tmp_name] = dtype
4✔
139

140
        # Get the index expression
141
        ndim = self.array_info[arr_name]["ndim"]
4✔
142
        shapes = self.array_info[arr_name].get("shapes", [])
4✔
143

144
        # Compute linear index from the subscript
145
        if isinstance(node.slice, ast.Tuple):
4✔
NEW
146
            indices = [self.visit(elt) for elt in node.slice.elts]
×
147
        else:
148
            indices = [self.visit(node.slice)]
4✔
149

150
        # Handle cases where we need recursive materialization
151
        materialized_indices = []
4✔
152
        for i, idx_str in enumerate(indices):
4✔
153
            # Check if the index itself needs materialization (nested indirect)
154
            # This happens when idx_str looks like an array access e.g., "arr(i)"
155
            if "(" in idx_str and idx_str.endswith(")"):
4✔
156
                # This is an array access, it should already be a valid symbolic expression
157
                # or a scalar variable name
NEW
158
                materialized_indices.append(idx_str)
×
159
            else:
160
                materialized_indices.append(idx_str)
4✔
161

162
        # Compute linear index
163
        linear_index = self._compute_linear_index(
4✔
164
            materialized_indices, shapes, arr_name, ndim
165
        )
166

167
        # Create block with tasklet and memlets
168
        block = self.builder.add_block(debug_info)
4✔
169
        t_src = self.builder.add_access(block, arr_name, debug_info)
4✔
170
        t_dst = self.builder.add_access(block, tmp_name, debug_info)
4✔
171
        t_task = self.builder.add_tasklet(
4✔
172
            block, TaskletCode.assign, ["_in"], ["_out"], debug_info
173
        )
174

175
        self.builder.add_memlet(
4✔
176
            block, t_src, "void", t_task, "_in", linear_index, None, debug_info
177
        )
178
        self.builder.add_memlet(
4✔
179
            block, t_task, "_out", t_dst, "void", "", None, debug_info
180
        )
181

182
        if return_original_expr:
4✔
183
            # Return both the materialized variable name and the original array access expression
184
            # Use parentheses notation which is consistent with SDFG subset syntax
185
            original_expr = f"{arr_name}({linear_index})"
4✔
186
            return (tmp_name, original_expr)
4✔
187

NEW
188
        return tmp_name
×
189

190
    def _init_numpy_handlers(self):
4✔
191
        self.numpy_handlers = {
4✔
192
            "empty": self._handle_numpy_alloc,
193
            "empty_like": self._handle_numpy_empty_like,
194
            "zeros": self._handle_numpy_alloc,
195
            "zeros_like": self._handle_numpy_zeros_like,
196
            "ones": self._handle_numpy_alloc,
197
            "ndarray": self._handle_numpy_alloc,  # np.ndarray() constructor
198
            "eye": self._handle_numpy_eye,
199
            "add": self._handle_numpy_binary_op,
200
            "subtract": self._handle_numpy_binary_op,
201
            "multiply": self._handle_numpy_binary_op,
202
            "divide": self._handle_numpy_binary_op,
203
            "power": self._handle_numpy_binary_op,
204
            "exp": self._handle_numpy_unary_op,
205
            "abs": self._handle_numpy_unary_op,
206
            "absolute": self._handle_numpy_unary_op,
207
            "sqrt": self._handle_numpy_unary_op,
208
            "tanh": self._handle_numpy_unary_op,
209
            "sum": self._handle_numpy_reduce,
210
            "max": self._handle_numpy_reduce,
211
            "min": self._handle_numpy_reduce,
212
            "mean": self._handle_numpy_reduce,
213
            "std": self._handle_numpy_reduce,
214
            "matmul": self._handle_numpy_matmul,
215
            "dot": self._handle_numpy_matmul,
216
            "matvec": self._handle_numpy_matmul,
217
            "outer": self._handle_numpy_outer,
218
            "minimum": self._handle_numpy_binary_op,
219
            "maximum": self._handle_numpy_binary_op,
220
            "where": self._handle_numpy_where,
221
        }
222

223
    def generic_visit(self, node):
4✔
224
        return super().generic_visit(node)
×
225

226
    def visit_Constant(self, node):
4✔
227
        if isinstance(node.value, bool):
4✔
228
            return "true" if node.value else "false"
×
229
        return str(node.value)
4✔
230

231
    def visit_Name(self, node):
4✔
232
        name = node.id
4✔
233
        # Check if it's a global constant (not a local variable/array)
234
        if name not in self.symbol_table and self.globals_dict is not None:
4✔
235
            if name in self.globals_dict:
4✔
236
                val = self.globals_dict[name]
4✔
237
                # Only substitute simple numeric constants
238
                if isinstance(val, (int, float)):
4✔
239
                    return str(val)
4✔
240
        return name
4✔
241

242
    def _map_numpy_dtype(self, dtype_node):
4✔
243
        # Default to double
244
        if dtype_node is None:
4✔
245
            return Scalar(PrimitiveType.Double)
×
246

247
        if isinstance(dtype_node, ast.Name):
4✔
248
            if dtype_node.id == "float":
4✔
249
                return Scalar(PrimitiveType.Double)
4✔
250
            if dtype_node.id == "int":
4✔
251
                return Scalar(PrimitiveType.Int64)
4✔
252
            if dtype_node.id == "bool":
×
253
                return Scalar(PrimitiveType.Bool)
×
254

255
        if isinstance(dtype_node, ast.Attribute):
4✔
256
            # Handle array.dtype
257
            if (
4✔
258
                isinstance(dtype_node.value, ast.Name)
259
                and dtype_node.value.id in self.symbol_table
260
                and dtype_node.attr == "dtype"
261
            ):
262
                sym_type = self.symbol_table[dtype_node.value.id]
4✔
263
                if isinstance(sym_type, Pointer) and sym_type.has_pointee_type():
4✔
264
                    return sym_type.pointee_type
4✔
265

266
            if isinstance(dtype_node.value, ast.Name) and dtype_node.value.id in [
4✔
267
                "numpy",
268
                "np",
269
            ]:
270
                if dtype_node.attr == "float64":
4✔
271
                    return Scalar(PrimitiveType.Double)
4✔
272
                if dtype_node.attr == "float32":
4✔
273
                    return Scalar(PrimitiveType.Float)
4✔
274
                if dtype_node.attr == "int64":
4✔
275
                    return Scalar(PrimitiveType.Int64)
4✔
276
                if dtype_node.attr == "int32":
4✔
277
                    return Scalar(PrimitiveType.Int32)
4✔
278
                if dtype_node.attr == "bool_":
×
279
                    return Scalar(PrimitiveType.Bool)
×
280

281
        # Fallback
282
        return Scalar(PrimitiveType.Double)
×
283

284
    def _is_int(self, operand):
4✔
285
        try:
4✔
286
            if operand.lstrip("-").isdigit():
4✔
287
                return True
4✔
288
        except ValueError:
×
289
            pass
×
290

291
        name = operand
4✔
292
        if "(" in operand and operand.endswith(")"):
4✔
293
            name = operand.split("(")[0]
4✔
294

295
        if name in self.symbol_table:
4✔
296
            t = self.symbol_table[name]
4✔
297

298
            def is_int_ptype(pt):
4✔
299
                return pt in [
4✔
300
                    PrimitiveType.Int64,
301
                    PrimitiveType.Int32,
302
                    PrimitiveType.Int8,
303
                    PrimitiveType.Int16,
304
                    PrimitiveType.UInt64,
305
                    PrimitiveType.UInt32,
306
                    PrimitiveType.UInt8,
307
                    PrimitiveType.UInt16,
308
                ]
309

310
            if isinstance(t, Scalar):
4✔
311
                return is_int_ptype(t.primitive_type)
4✔
312

313
            if type(t).__name__ == "Array" and hasattr(t, "element_type"):
4✔
314
                et = t.element_type
×
315
                if callable(et):
×
316
                    et = et()
×
317
                if isinstance(et, Scalar):
×
318
                    return is_int_ptype(et.primitive_type)
×
319

320
            if type(t).__name__ == "Pointer":
4✔
321
                if hasattr(t, "pointee_type"):
4✔
322
                    et = t.pointee_type
4✔
323
                    if callable(et):
4✔
324
                        et = et()
×
325
                    if isinstance(et, Scalar):
4✔
326
                        return is_int_ptype(et.primitive_type)
4✔
327
                # Fallback: check if it has element_type (maybe alias?)
328
                if hasattr(t, "element_type"):
×
329
                    et = t.element_type
×
330
                    if callable(et):
×
331
                        et = et()
×
332
                    if isinstance(et, Scalar):
×
333
                        return is_int_ptype(et.primitive_type)
×
334

335
        return False
4✔
336

337
    def _add_read(self, block, expr_str, debug_info=None):
4✔
338
        # Try to reuse access node
339
        try:
4✔
340
            if (block, expr_str) in self._access_cache:
4✔
341
                return self._access_cache[(block, expr_str)]
4✔
342
        except TypeError:
×
343
            # block might not be hashable
344
            pass
×
345

346
        if debug_info is None:
4✔
347
            debug_info = DebugInfo()
4✔
348

349
        if "(" in expr_str and expr_str.endswith(")"):
4✔
350
            name = expr_str.split("(")[0]
4✔
351
            subset = expr_str[expr_str.find("(") + 1 : -1]
4✔
352
            access = self.builder.add_access(block, name, debug_info)
4✔
353
            try:
4✔
354
                self._access_cache[(block, expr_str)] = (access, subset)
4✔
355
            except TypeError:
×
356
                pass
×
357
            return access, subset
4✔
358

359
        if self.builder.exists(expr_str):
4✔
360
            access = self.builder.add_access(block, expr_str, debug_info)
4✔
361
            # For pointer types representing 0-D arrays, dereference with "0"
362
            subset = ""
4✔
363
            if expr_str in self.symbol_table:
4✔
364
                sym_type = self.symbol_table[expr_str]
4✔
365
                if isinstance(sym_type, Pointer):
4✔
366
                    # Check if it's a 0-D array (scalar wrapped in pointer)
367
                    if expr_str in self.array_info:
×
368
                        ndim = self.array_info[expr_str].get("ndim", 0)
×
369
                        if ndim == 0:
×
370
                            subset = "0"
×
371
                    else:
372
                        # Pointer without array_info is treated as 0-D
373
                        subset = "0"
×
374
            try:
4✔
375
                self._access_cache[(block, expr_str)] = (access, subset)
4✔
376
            except TypeError:
×
377
                pass
×
378
            return access, subset
4✔
379

380
        dtype = Scalar(PrimitiveType.Double)
4✔
381
        if self._is_int(expr_str):
4✔
382
            dtype = Scalar(PrimitiveType.Int64)
4✔
383
        elif expr_str == "true" or expr_str == "false":
4✔
384
            dtype = Scalar(PrimitiveType.Bool)
×
385

386
        const_node = self.builder.add_constant(block, expr_str, dtype, debug_info)
4✔
387
        try:
4✔
388
            self._access_cache[(block, expr_str)] = (const_node, "")
4✔
389
        except TypeError:
×
390
            pass
×
391
        return const_node, ""
4✔
392

393
    def _handle_min_max(self, node, func_name):
4✔
394
        args = [self.visit(arg) for arg in node.args]
4✔
395
        if len(args) != 2:
4✔
396
            raise NotImplementedError(f"{func_name} only supported with 2 arguments")
×
397

398
        # Check types
399
        is_float = False
4✔
400
        arg_types = []
4✔
401

402
        for arg in args:
4✔
403
            name = arg
4✔
404
            if "(" in arg and arg.endswith(")"):
4✔
405
                name = arg.split("(")[0]
×
406

407
            if name in self.symbol_table:
4✔
408
                t = self.symbol_table[name]
4✔
409
                if isinstance(t, Pointer):
4✔
410
                    t = t.base_type
×
411

412
                if t.primitive_type == PrimitiveType.Double:
4✔
413
                    is_float = True
4✔
414
                    arg_types.append(PrimitiveType.Double)
4✔
415
                else:
416
                    arg_types.append(PrimitiveType.Int64)
4✔
417
            elif self._is_int(arg):
×
418
                arg_types.append(PrimitiveType.Int64)
×
419
            else:
420
                # Assume float constant
421
                is_float = True
×
422
                arg_types.append(PrimitiveType.Double)
×
423

424
        dtype = Scalar(PrimitiveType.Double if is_float else PrimitiveType.Int64)
4✔
425

426
        tmp_name = self._get_temp_name("_tmp_")
4✔
427
        self.builder.add_container(tmp_name, dtype, False)
4✔
428
        self.symbol_table[tmp_name] = dtype
4✔
429

430
        if is_float:
4✔
431
            # Cast args if necessary
432
            casted_args = []
4✔
433
            for i, arg in enumerate(args):
4✔
434
                if arg_types[i] != PrimitiveType.Double:
4✔
435
                    # Create temp double
436
                    tmp_cast = self._get_temp_name("_cast_")
4✔
437
                    self.builder.add_container(
4✔
438
                        tmp_cast, Scalar(PrimitiveType.Double), False
439
                    )
440
                    self.symbol_table[tmp_cast] = Scalar(PrimitiveType.Double)
4✔
441

442
                    # Assign int to double (implicit cast)
443
                    self.builder.add_assignment(tmp_cast, arg)
4✔
444
                    casted_args.append(tmp_cast)
4✔
445
                else:
446
                    casted_args.append(arg)
4✔
447

448
            block = self.builder.add_block()
4✔
449
            t_out = self.builder.add_access(block, tmp_name)
4✔
450

451
            intrinsic_name = (
4✔
452
                CMathFunction.fmax if func_name == "max" else CMathFunction.fmin
453
            )
454
            t_task = self.builder.add_cmath(block, intrinsic_name)
4✔
455

456
            for i, arg in enumerate(casted_args):
4✔
457
                t_arg, arg_sub = self._add_read(block, arg)
4✔
458
                self.builder.add_memlet(
4✔
459
                    block, t_arg, "void", t_task, f"_in{i+1}", arg_sub
460
                )
461
        else:
462
            block = self.builder.add_block()
4✔
463
            t_out = self.builder.add_access(block, tmp_name)
4✔
464

465
            # Use int_smax/int_smin tasklet
466
            opcode = None
4✔
467
            if func_name == "max":
4✔
468
                opcode = TaskletCode.int_smax
4✔
469
            else:
470
                opcode = TaskletCode.int_smin
4✔
471
            t_task = self.builder.add_tasklet(block, opcode, ["_in1", "_in2"], ["_out"])
4✔
472

473
            for i, arg in enumerate(args):
4✔
474
                t_arg, arg_sub = self._add_read(block, arg)
4✔
475
                self.builder.add_memlet(
4✔
476
                    block, t_arg, "void", t_task, f"_in{i+1}", arg_sub
477
                )
478

479
        self.builder.add_memlet(block, t_task, "_out", t_out, "void", "")
4✔
480
        return tmp_name
4✔
481

482
    def _handle_python_cast(self, node, func_name):
4✔
483
        """Handle Python type casts: int(), float(), bool()"""
484
        if len(node.args) != 1:
4✔
485
            raise NotImplementedError(f"{func_name}() cast requires exactly 1 argument")
×
486

487
        arg = self.visit(node.args[0])
4✔
488

489
        # Determine target type based on cast function
490
        if func_name == "int":
4✔
491
            target_dtype = Scalar(PrimitiveType.Int64)
4✔
492
        elif func_name == "float":
4✔
493
            target_dtype = Scalar(PrimitiveType.Double)
4✔
494
        elif func_name == "bool":
4✔
495
            target_dtype = Scalar(PrimitiveType.Bool)
4✔
496
        else:
497
            raise NotImplementedError(f"Cast to {func_name} not supported")
×
498

499
        # Determine source type
500
        source_dtype = None
4✔
501
        name = arg
4✔
502
        if "(" in arg and arg.endswith(")"):
4✔
503
            name = arg.split("(")[0]
×
504

505
        if name in self.symbol_table:
4✔
506
            source_dtype = self.symbol_table[name]
4✔
507
            if isinstance(source_dtype, Pointer):
4✔
508
                source_dtype = source_dtype.base_type
×
509
        elif self._is_int(arg):
×
510
            source_dtype = Scalar(PrimitiveType.Int64)
×
511
        elif arg == "true" or arg == "false":
×
512
            source_dtype = Scalar(PrimitiveType.Bool)
×
513
        else:
514
            # Assume float constant
515
            source_dtype = Scalar(PrimitiveType.Double)
×
516

517
        # Create temporary variable for result
518
        tmp_name = self._get_temp_name("_tmp_")
4✔
519
        self.builder.add_container(tmp_name, target_dtype, False)
4✔
520
        self.symbol_table[tmp_name] = target_dtype
4✔
521

522
        # Use tasklet assign opcode for casting (as specified in problem statement)
523
        block = self.builder.add_block()
4✔
524
        t_src, src_sub = self._add_read(block, arg)
4✔
525
        t_dst = self.builder.add_access(block, tmp_name)
4✔
526
        t_task = self.builder.add_tasklet(block, TaskletCode.assign, ["_in"], ["_out"])
4✔
527
        self.builder.add_memlet(block, t_src, "void", t_task, "_in", src_sub)
4✔
528
        self.builder.add_memlet(block, t_task, "_out", t_dst, "void", "")
4✔
529

530
        return tmp_name
4✔
531

532
    def visit_Call(self, node):
4✔
533
        func_name = ""
4✔
534
        module_name = ""
4✔
535
        if isinstance(node.func, ast.Attribute):
4✔
536
            if isinstance(node.func.value, ast.Name):
4✔
537
                if node.func.value.id == "math":
4✔
538
                    module_name = "math"
4✔
539
                    func_name = node.func.attr
4✔
540
                elif node.func.value.id in ["numpy", "np"]:
4✔
541
                    module_name = "numpy"
4✔
542
                    func_name = node.func.attr
4✔
543
                else:
544
                    # Check if it's a method call on an array (e.g., arr.astype(...))
545
                    array_name = node.func.value.id
4✔
546
                    method_name = node.func.attr
4✔
547
                    if array_name in self.array_info and method_name == "astype":
4✔
548
                        return self._handle_numpy_astype(node, array_name)
4✔
549
            elif isinstance(node.func.value, ast.Attribute):
4✔
550
                if (
4✔
551
                    isinstance(node.func.value.value, ast.Name)
552
                    and node.func.value.value.id == "scipy"
553
                    and node.func.value.attr == "special"
554
                ):
555
                    if node.func.attr == "softmax":
4✔
556
                        return self._handle_scipy_softmax(node, "softmax")
4✔
557
                # Handle np.add.outer, np.subtract.outer, np.multiply.outer, etc.
558
                elif (
4✔
559
                    isinstance(node.func.value.value, ast.Name)
560
                    and node.func.value.value.id in ["numpy", "np"]
561
                    and node.func.attr == "outer"
562
                ):
563
                    ufunc_name = node.func.value.attr  # "add", "subtract", etc.
4✔
564
                    return self._handle_ufunc_outer(node, ufunc_name)
4✔
565

566
        elif isinstance(node.func, ast.Name):
4✔
567
            func_name = node.func.id
4✔
568

569
        if module_name == "numpy":
4✔
570
            if func_name in self.numpy_handlers:
4✔
571
                return self.numpy_handlers[func_name](node, func_name)
4✔
572

573
        if func_name in ["max", "min"]:
4✔
574
            return self._handle_min_max(node, func_name)
4✔
575

576
        # Handle Python type casts (int, float, bool)
577
        if func_name in ["int", "float", "bool"]:
4✔
578
            return self._handle_python_cast(node, func_name)
4✔
579

580
        math_funcs = {
4✔
581
            "sin": CMathFunction.sin,
582
            "cos": CMathFunction.cos,
583
            "tan": CMathFunction.tan,
584
            "exp": CMathFunction.exp,
585
            "log": CMathFunction.log,
586
            "sqrt": CMathFunction.sqrt,
587
            "pow": CMathFunction.pow,
588
            "abs": CMathFunction.fabs,
589
            "ceil": CMathFunction.ceil,
590
            "floor": CMathFunction.floor,
591
            "asin": CMathFunction.asin,
592
            "acos": CMathFunction.acos,
593
            "atan": CMathFunction.atan,
594
            "sinh": CMathFunction.sinh,
595
            "cosh": CMathFunction.cosh,
596
            "tanh": CMathFunction.tanh,
597
        }
598

599
        if func_name in math_funcs:
4✔
600
            args = [self.visit(arg) for arg in node.args]
4✔
601

602
            tmp_name = self._get_temp_name("_tmp_")
4✔
603
            dtype = Scalar(PrimitiveType.Double)
4✔
604
            self.builder.add_container(tmp_name, dtype, False)
4✔
605
            self.symbol_table[tmp_name] = dtype
4✔
606

607
            block = self.builder.add_block()
4✔
608
            t_out = self.builder.add_access(block, tmp_name)
4✔
609

610
            t_task = self.builder.add_cmath(block, math_funcs[func_name])
4✔
611

612
            for i, arg in enumerate(args):
4✔
613
                t_arg, arg_sub = self._add_read(block, arg)
4✔
614
                self.builder.add_memlet(
4✔
615
                    block, t_arg, "void", t_task, f"_in{i+1}", arg_sub
616
                )
617

618
            self.builder.add_memlet(block, t_task, "_out", t_out, "void", "")
4✔
619
            return tmp_name
4✔
620

621
        if func_name in self.globals_dict:
4✔
622
            obj = self.globals_dict[func_name]
4✔
623
            if inspect.isfunction(obj):
4✔
624
                return self._handle_inline_call(node, obj)
4✔
625

626
        raise NotImplementedError(f"Function call {func_name} not supported")
×
627

628
    def _handle_inline_call(self, node, func_obj):
4✔
629
        # 1. Parse function source
630
        try:
4✔
631
            source_lines, start_line = inspect.getsourcelines(func_obj)
4✔
632
            source = textwrap.dedent("".join(source_lines))
4✔
633
            tree = ast.parse(source)
4✔
634
            func_def = tree.body[0]
4✔
635
        except Exception as e:
×
636
            raise NotImplementedError(
×
637
                f"Could not parse function {func_obj.__name__}: {e}"
638
            )
639

640
        # 2. Evaluate arguments
641
        arg_vars = [self.visit(arg) for arg in node.args]
4✔
642

643
        if len(arg_vars) != len(func_def.args.args):
4✔
644
            raise NotImplementedError(
×
645
                f"Argument count mismatch for {func_obj.__name__}"
646
            )
647

648
        # 3. Generate unique suffix
649
        suffix = f"_{func_obj.__name__}_{self._get_unique_id()}"
4✔
650
        res_name = f"_res{suffix}"
4✔
651

652
        # Assume Int64 for now as match returns 0/1
653
        dtype = Scalar(PrimitiveType.Int64)
4✔
654
        self.builder.add_container(res_name, dtype, False)
4✔
655
        self.symbol_table[res_name] = dtype
4✔
656

657
        # 4. Rename variables
658
        class VariableRenamer(ast.NodeTransformer):
4✔
659
            def __init__(self, suffix, globals_dict):
4✔
660
                self.suffix = suffix
4✔
661
                self.globals_dict = globals_dict
4✔
662

663
            def visit_Name(self, node):
4✔
664
                if node.id in self.globals_dict:
4✔
665
                    return node
4✔
666
                return ast.Name(id=f"{node.id}{self.suffix}", ctx=node.ctx)
4✔
667

668
            def visit_Return(self, node):
4✔
669
                if node.value:
4✔
670
                    val = self.visit(node.value)
4✔
671
                    return ast.Assign(
4✔
672
                        targets=[ast.Name(id=res_name, ctx=ast.Store())],
673
                        value=val,
674
                    )
675
                return node
×
676

677
        renamer = VariableRenamer(suffix, self.globals_dict)
4✔
678
        new_body = [renamer.visit(stmt) for stmt in func_def.body]
4✔
679

680
        # 5. Assign arguments to parameters
681
        param_assignments = []
4✔
682
        for arg_def, arg_val in zip(func_def.args.args, arg_vars):
4✔
683
            param_name = f"{arg_def.arg}{suffix}"
4✔
684

685
            # Infer type and create container
686
            if arg_val in self.symbol_table:
4✔
687
                self.symbol_table[param_name] = self.symbol_table[arg_val]
4✔
688
                self.builder.add_container(
4✔
689
                    param_name, self.symbol_table[arg_val], False
690
                )
691
                val_node = ast.Name(id=arg_val, ctx=ast.Load())
4✔
692
            elif self._is_int(arg_val):
×
693
                self.symbol_table[param_name] = Scalar(PrimitiveType.Int64)
×
694
                self.builder.add_container(
×
695
                    param_name, Scalar(PrimitiveType.Int64), False
696
                )
697
                val_node = ast.Constant(value=int(arg_val))
×
698
            else:
699
                # Assume float constant
700
                try:
×
701
                    val = float(arg_val)
×
702
                    self.symbol_table[param_name] = Scalar(PrimitiveType.Double)
×
703
                    self.builder.add_container(
×
704
                        param_name, Scalar(PrimitiveType.Double), False
705
                    )
706
                    val_node = ast.Constant(value=val)
×
707
                except ValueError:
×
708
                    # Fallback to Name, might fail later if not in symbol table
709
                    val_node = ast.Name(id=arg_val, ctx=ast.Load())
×
710

711
            assign = ast.Assign(
4✔
712
                targets=[ast.Name(id=param_name, ctx=ast.Store())], value=val_node
713
            )
714
            param_assignments.append(assign)
4✔
715

716
        final_body = param_assignments + new_body
4✔
717

718
        # 6. Visit new body using ASTParser
719
        from .ast_parser import ASTParser
4✔
720

721
        parser = ASTParser(
4✔
722
            self.builder,
723
            self.array_info,
724
            self.symbol_table,
725
            globals_dict=self.globals_dict,
726
            unique_counter_ref=self._unique_counter_ref,
727
        )
728

729
        for stmt in final_body:
4✔
730
            parser.visit(stmt)
4✔
731

732
        return res_name
4✔
733

734
    def visit_BinOp(self, node):
4✔
735
        if isinstance(node.op, ast.MatMult):
4✔
736
            return self._handle_numpy_matmul_op(node.left, node.right)
4✔
737

738
        left = self.visit(node.left)
4✔
739
        op = self.visit(node.op)
4✔
740
        right = self.visit(node.right)
4✔
741

742
        # Check if left or right are arrays
743
        left_is_array = left in self.array_info
4✔
744
        right_is_array = right in self.array_info
4✔
745

746
        if left_is_array or right_is_array:
4✔
747
            op_map = {"+": "add", "-": "sub", "*": "mul", "/": "div", "**": "pow"}
4✔
748
            if op in op_map:
4✔
749
                return self._handle_array_binary_op(op_map[op], left, right)
4✔
750
            else:
751
                raise NotImplementedError(f"Array operation {op} not supported")
×
752

753
        tmp_name = f"_tmp_{self._get_unique_id()}"
4✔
754

755
        dtype = Scalar(PrimitiveType.Double)  # Default
4✔
756

757
        left_is_int = self._is_int(left)
4✔
758
        right_is_int = self._is_int(right)
4✔
759

760
        if left_is_int and right_is_int and op not in ["/", "**"]:
4✔
761
            dtype = Scalar(PrimitiveType.Int64)
4✔
762

763
        self.builder.add_container(tmp_name, dtype, False)
4✔
764
        self.symbol_table[tmp_name] = dtype
4✔
765

766
        real_left = left
4✔
767
        real_right = right
4✔
768

769
        if dtype.primitive_type == PrimitiveType.Double:
4✔
770
            if left_is_int:
4✔
771
                left_cast = f"_tmp_{self._get_unique_id()}"
4✔
772
                self.builder.add_container(
4✔
773
                    left_cast, Scalar(PrimitiveType.Double), False
774
                )
775
                self.symbol_table[left_cast] = Scalar(PrimitiveType.Double)
4✔
776

777
                c_block = self.builder.add_block()
4✔
778
                t_src, src_sub = self._add_read(c_block, left)
4✔
779
                t_dst = self.builder.add_access(c_block, left_cast)
4✔
780
                t_task = self.builder.add_tasklet(
4✔
781
                    c_block, TaskletCode.assign, ["_in"], ["_out"]
782
                )
783
                self.builder.add_memlet(c_block, t_src, "void", t_task, "_in", src_sub)
4✔
784
                self.builder.add_memlet(c_block, t_task, "_out", t_dst, "void", "")
4✔
785

786
                real_left = left_cast
4✔
787

788
            if right_is_int:
4✔
789
                right_cast = f"_tmp_{self._get_unique_id()}"
4✔
790
                self.builder.add_container(
4✔
791
                    right_cast, Scalar(PrimitiveType.Double), False
792
                )
793
                self.symbol_table[right_cast] = Scalar(PrimitiveType.Double)
4✔
794

795
                c_block = self.builder.add_block()
4✔
796
                t_src, src_sub = self._add_read(c_block, right)
4✔
797
                t_dst = self.builder.add_access(c_block, right_cast)
4✔
798
                t_task = self.builder.add_tasklet(
4✔
799
                    c_block, TaskletCode.assign, ["_in"], ["_out"]
800
                )
801
                self.builder.add_memlet(c_block, t_src, "void", t_task, "_in", src_sub)
4✔
802
                self.builder.add_memlet(c_block, t_task, "_out", t_dst, "void", "")
4✔
803

804
                real_right = right_cast
4✔
805

806
        # Special cases
807
        if op == "**":
4✔
808
            block = self.builder.add_block()
4✔
809
            t_left, left_sub = self._add_read(block, real_left)
4✔
810
            t_right, right_sub = self._add_read(block, real_right)
4✔
811
            t_out = self.builder.add_access(block, tmp_name)
4✔
812

813
            t_task = self.builder.add_cmath(block, CMathFunction.pow)
4✔
814
            self.builder.add_memlet(block, t_left, "void", t_task, "_in1", left_sub)
4✔
815
            self.builder.add_memlet(block, t_right, "void", t_task, "_in2", right_sub)
4✔
816
            self.builder.add_memlet(block, t_task, "_out", t_out, "void", "")
4✔
817

818
            return tmp_name
4✔
819
        elif op == "%":
4✔
820
            block = self.builder.add_block()
4✔
821
            t_left, left_sub = self._add_read(block, real_left)
4✔
822
            t_right, right_sub = self._add_read(block, real_right)
4✔
823
            t_out = self.builder.add_access(block, tmp_name)
4✔
824

825
            if dtype.primitive_type == PrimitiveType.Int64:
4✔
826
                # Implement ((a % b) + b) % b to match Python's modulo behavior
827

828
                # 1. rem1 = a % b
829
                t_rem1 = self.builder.add_tasklet(
4✔
830
                    block, TaskletCode.int_srem, ["_in1", "_in2"], ["_out"]
831
                )
832
                self.builder.add_memlet(block, t_left, "void", t_rem1, "_in1", left_sub)
4✔
833
                self.builder.add_memlet(
4✔
834
                    block, t_right, "void", t_rem1, "_in2", right_sub
835
                )
836

837
                rem1_name = f"_tmp_{self._get_unique_id()}"
4✔
838
                self.builder.add_container(rem1_name, dtype, False)
4✔
839
                t_rem1_out = self.builder.add_access(block, rem1_name)
4✔
840
                self.builder.add_memlet(block, t_rem1, "_out", t_rem1_out, "void", "")
4✔
841

842
                # 2. add = rem1 + b
843
                t_add = self.builder.add_tasklet(
4✔
844
                    block, TaskletCode.int_add, ["_in1", "_in2"], ["_out"]
845
                )
846
                self.builder.add_memlet(block, t_rem1_out, "void", t_add, "_in1", "")
4✔
847
                self.builder.add_memlet(
4✔
848
                    block, t_right, "void", t_add, "_in2", right_sub
849
                )
850

851
                add_name = f"_tmp_{self._get_unique_id()}"
4✔
852
                self.builder.add_container(add_name, dtype, False)
4✔
853
                t_add_out = self.builder.add_access(block, add_name)
4✔
854
                self.builder.add_memlet(block, t_add, "_out", t_add_out, "void", "")
4✔
855

856
                # 3. res = add % b
857
                t_rem2 = self.builder.add_tasklet(
4✔
858
                    block, TaskletCode.int_srem, ["_in1", "_in2"], ["_out"]
859
                )
860
                self.builder.add_memlet(block, t_add_out, "void", t_rem2, "_in1", "")
4✔
861
                self.builder.add_memlet(
4✔
862
                    block, t_right, "void", t_rem2, "_in2", right_sub
863
                )
864
                self.builder.add_memlet(block, t_rem2, "_out", t_out, "void", "")
4✔
865

866
                return tmp_name
4✔
867
            else:
868
                t_task = self.builder.add_cmath(block, CMathFunction.fmod)
4✔
869
                self.builder.add_memlet(block, t_left, "void", t_task, "_in1", left_sub)
4✔
870
                self.builder.add_memlet(
4✔
871
                    block, t_right, "void", t_task, "_in2", right_sub
872
                )
873
                self.builder.add_memlet(block, t_task, "_out", t_out, "void", "")
4✔
874
                return tmp_name
4✔
875

876
        tasklet_code = None
4✔
877
        if dtype.primitive_type == PrimitiveType.Int64:
4✔
878
            if op == "+":
4✔
879
                tasklet_code = TaskletCode.int_add
4✔
880
            elif op == "-":
4✔
881
                tasklet_code = TaskletCode.int_sub
4✔
882
            elif op == "*":
4✔
883
                tasklet_code = TaskletCode.int_mul
4✔
884
            elif op == "/":
4✔
885
                tasklet_code = TaskletCode.int_sdiv
×
886
            elif op == "//":
4✔
887
                tasklet_code = TaskletCode.int_sdiv
4✔
888
            elif op == "|":
4✔
889
                tasklet_code = TaskletCode.int_or
4✔
890
            elif op == "^":
4✔
891
                tasklet_code = TaskletCode.int_xor
4✔
892
        else:
893
            if op == "+":
4✔
894
                tasklet_code = TaskletCode.fp_add
4✔
895
            elif op == "-":
4✔
896
                tasklet_code = TaskletCode.fp_sub
4✔
897
            elif op == "*":
4✔
898
                tasklet_code = TaskletCode.fp_mul
4✔
899
            elif op == "/":
4✔
900
                tasklet_code = TaskletCode.fp_div
4✔
901
            elif op == "//":
×
902
                tasklet_code = TaskletCode.fp_div
×
903
            else:
904
                raise NotImplementedError(f"Operation {op} not supported for floats")
×
905

906
        block = self.builder.add_block()
4✔
907
        t_left, left_sub = self._add_read(block, real_left)
4✔
908
        t_right, right_sub = self._add_read(block, real_right)
4✔
909
        t_out = self.builder.add_access(block, tmp_name)
4✔
910

911
        t_task = self.builder.add_tasklet(
4✔
912
            block, tasklet_code, ["_in1", "_in2"], ["_out"]
913
        )
914

915
        self.builder.add_memlet(block, t_left, "void", t_task, "_in1", left_sub)
4✔
916
        self.builder.add_memlet(block, t_right, "void", t_task, "_in2", right_sub)
4✔
917
        self.builder.add_memlet(block, t_task, "_out", t_out, "void", "")
4✔
918

919
        return tmp_name
4✔
920

921
    def _add_assign_constant(self, target_name, value_str, dtype):
4✔
922
        block = self.builder.add_block()
4✔
923
        t_const = self.builder.add_constant(block, value_str, dtype)
4✔
924
        t_dst = self.builder.add_access(block, target_name)
4✔
925
        t_task = self.builder.add_tasklet(block, TaskletCode.assign, ["_in"], ["_out"])
4✔
926
        self.builder.add_memlet(block, t_const, "void", t_task, "_in", "")
4✔
927
        self.builder.add_memlet(block, t_task, "_out", t_dst, "void", "")
4✔
928

929
    def visit_BoolOp(self, node):
4✔
930
        op = self.visit(node.op)
4✔
931
        values = [f"({self.visit(v)} != 0)" for v in node.values]
4✔
932
        expr_str = f"{f' {op} '.join(values)}"
4✔
933

934
        tmp_name = f"_tmp_{self._get_unique_id()}"
4✔
935
        dtype = Scalar(PrimitiveType.Bool)
4✔
936
        self.builder.add_container(tmp_name, dtype, False)
4✔
937

938
        # Use control flow to assign boolean value
939
        self.builder.begin_if(expr_str)
4✔
940
        self._add_assign_constant(tmp_name, "true", dtype)
4✔
941
        self.builder.begin_else()
4✔
942
        self._add_assign_constant(tmp_name, "false", dtype)
4✔
943
        self.builder.end_if()
4✔
944

945
        self.symbol_table[tmp_name] = dtype
4✔
946
        return tmp_name
4✔
947

948
    def visit_Compare(self, node):
4✔
949
        left = self.visit(node.left)
4✔
950
        if len(node.ops) > 1:
4✔
951
            raise NotImplementedError("Chained comparisons not supported yet")
×
952

953
        op = self.visit(node.ops[0])
4✔
954
        right = self.visit(node.comparators[0])
4✔
955

956
        # Check if this is an array comparison
957
        left_is_array = left in self.array_info
4✔
958
        right_is_array = right in self.array_info
4✔
959

960
        if left_is_array or right_is_array:
4✔
961
            # Handle array comparison - return boolean array
962
            return self._handle_array_compare(
4✔
963
                left, op, right, left_is_array, right_is_array
964
            )
965

966
        # Scalar comparison
967
        expr_str = f"{left} {op} {right}"
4✔
968

969
        tmp_name = f"_tmp_{self._get_unique_id()}"
4✔
970
        dtype = Scalar(PrimitiveType.Bool)
4✔
971
        self.builder.add_container(tmp_name, dtype, False)
4✔
972

973
        # Use control flow to assign boolean value
974
        self.builder.begin_if(expr_str)
4✔
975
        self.builder.add_transition(tmp_name, "true")
4✔
976
        self.builder.begin_else()
4✔
977
        self.builder.add_transition(tmp_name, "false")
4✔
978
        self.builder.end_if()
4✔
979

980
        self.symbol_table[tmp_name] = dtype
4✔
981
        return tmp_name
4✔
982

983
    def visit_UnaryOp(self, node):
4✔
984
        if (
4✔
985
            isinstance(node.op, ast.USub)
986
            and isinstance(node.operand, ast.Constant)
987
            and isinstance(node.operand.value, (int, float))
988
        ):
989
            return f"-{node.operand.value}"
4✔
990

991
        op = self.visit(node.op)
4✔
992
        operand = self.visit(node.operand)
4✔
993

994
        # Check if operand is an array - handle as array operation
995
        if operand in self.array_info and op == "-":
4✔
996
            return self._handle_array_negate(operand)
4✔
997

998
        tmp_name = f"_tmp_{self._get_unique_id()}"
4✔
999
        dtype = Scalar(PrimitiveType.Double)
4✔
1000
        if operand in self.symbol_table:
4✔
1001
            dtype = self.symbol_table[operand]
4✔
1002
            # If it's a pointer (array), get the element type
1003
            if isinstance(dtype, Pointer) and dtype.has_pointee_type():
4✔
1004
                dtype = dtype.pointee_type
×
1005
        elif self._is_int(operand):
×
1006
            dtype = Scalar(PrimitiveType.Int64)
×
1007
        elif isinstance(node.op, ast.Not):
×
1008
            dtype = Scalar(PrimitiveType.Bool)
×
1009

1010
        self.builder.add_container(tmp_name, dtype, False)
4✔
1011
        self.symbol_table[tmp_name] = dtype
4✔
1012

1013
        block = self.builder.add_block()
4✔
1014
        t_src, src_sub = self._add_read(block, operand)
4✔
1015
        t_dst = self.builder.add_access(block, tmp_name)
4✔
1016

1017
        if isinstance(node.op, ast.Not):
4✔
1018
            t_const = self.builder.add_constant(
4✔
1019
                block, "true", Scalar(PrimitiveType.Bool)
1020
            )
1021
            t_task = self.builder.add_tasklet(
4✔
1022
                block, TaskletCode.int_xor, ["_in1", "_in2"], ["_out"]
1023
            )
1024
            self.builder.add_memlet(block, t_src, "void", t_task, "_in1", src_sub)
4✔
1025
            self.builder.add_memlet(block, t_const, "void", t_task, "_in2", "")
4✔
1026
            self.builder.add_memlet(block, t_task, "_out", t_dst, "void", "")
4✔
1027

1028
        elif op == "-":
4✔
1029
            if dtype.primitive_type == PrimitiveType.Int64:
4✔
1030
                t_const = self.builder.add_constant(block, "0", dtype)
4✔
1031
                t_task = self.builder.add_tasklet(
4✔
1032
                    block, TaskletCode.int_sub, ["_in1", "_in2"], ["_out"]
1033
                )
1034
                self.builder.add_memlet(block, t_const, "void", t_task, "_in1", "")
4✔
1035
                self.builder.add_memlet(block, t_src, "void", t_task, "_in2", src_sub)
4✔
1036
                self.builder.add_memlet(block, t_task, "_out", t_dst, "void", "")
4✔
1037
            else:
1038
                t_task = self.builder.add_tasklet(
4✔
1039
                    block, TaskletCode.fp_neg, ["_in"], ["_out"]
1040
                )
1041
                self.builder.add_memlet(block, t_src, "void", t_task, "_in", src_sub)
4✔
1042
                self.builder.add_memlet(block, t_task, "_out", t_dst, "void", "")
4✔
1043
        else:
1044
            t_task = self.builder.add_tasklet(
×
1045
                block, TaskletCode.assign, ["_in"], ["_out"]
1046
            )
1047
            self.builder.add_memlet(block, t_src, "void", t_task, "_in", src_sub)
×
1048
            self.builder.add_memlet(block, t_task, "_out", t_dst, "void", "")
×
1049

1050
        return tmp_name
4✔
1051

1052
    def _handle_array_negate(self, operand):
4✔
1053
        """Handle negation of an array operand (-arr)."""
1054
        shape = self.array_info[operand]["shapes"]
4✔
1055
        dtype = self._get_dtype(operand)
4✔
1056

1057
        # Create output array
1058
        tmp_name = self._create_array_temp(shape, dtype)
4✔
1059

1060
        # Use elementwise binary op: 0 - arr
1061
        # First create a zero constant
1062
        zero_name = f"_tmp_{self._get_unique_id()}"
4✔
1063
        self.builder.add_container(zero_name, dtype, False)
4✔
1064
        self.symbol_table[zero_name] = dtype
4✔
1065

1066
        zero_block = self.builder.add_block()
4✔
1067
        t_const = self.builder.add_constant(
4✔
1068
            zero_block,
1069
            "0.0" if dtype.primitive_type == PrimitiveType.Double else "0",
1070
            dtype,
1071
        )
1072
        t_zero = self.builder.add_access(zero_block, zero_name)
4✔
1073
        t_assign = self.builder.add_tasklet(
4✔
1074
            zero_block, TaskletCode.assign, ["_in"], ["_out"]
1075
        )
1076
        self.builder.add_memlet(zero_block, t_const, "void", t_assign, "_in", "")
4✔
1077
        self.builder.add_memlet(zero_block, t_assign, "_out", t_zero, "void", "")
4✔
1078

1079
        # Now subtract: tmp = 0 - operand (broadcast scalar subtraction)
1080
        self.builder.add_elementwise_op("sub", zero_name, operand, tmp_name, shape)
4✔
1081

1082
        return tmp_name
4✔
1083

1084
    def _handle_array_compare(self, left, op, right, left_is_array, right_is_array):
4✔
1085
        """Handle elementwise comparison of arrays, returning a boolean array.
1086

1087
        Supports: arr > 0, arr < scalar, arr1 > arr2, etc.
1088
        """
1089
        # Determine shape from the array operand
1090
        if left_is_array:
4✔
1091
            shape = self.array_info[left]["shapes"]
4✔
1092
            arr_name = left
4✔
1093
        else:
1094
            shape = self.array_info[right]["shapes"]
×
1095
            arr_name = right
×
1096

1097
        # Determine if we need integer or floating point comparison
1098
        # based on the array element type
1099
        use_int_cmp = False
4✔
1100
        arr_dtype = self._get_dtype(arr_name)
4✔
1101
        if arr_dtype.primitive_type in (PrimitiveType.Int32, PrimitiveType.Int64):
4✔
1102
            use_int_cmp = True
×
1103

1104
        # Create output boolean array
1105
        dtype = Scalar(PrimitiveType.Bool)
4✔
1106
        tmp_name = self._create_array_temp(shape, dtype)
4✔
1107

1108
        # Map comparison operators to tasklet codes
1109
        if use_int_cmp:
4✔
1110
            cmp_ops = {
×
1111
                ">": TaskletCode.int_sgt,
1112
                ">=": TaskletCode.int_sge,
1113
                "<": TaskletCode.int_slt,
1114
                "<=": TaskletCode.int_sle,
1115
                "==": TaskletCode.int_eq,
1116
                "!=": TaskletCode.int_ne,
1117
            }
1118
        else:
1119
            # Floating point ordered comparisons
1120
            cmp_ops = {
4✔
1121
                ">": TaskletCode.fp_ogt,
1122
                ">=": TaskletCode.fp_oge,
1123
                "<": TaskletCode.fp_olt,
1124
                "<=": TaskletCode.fp_ole,
1125
                "==": TaskletCode.fp_oeq,
1126
                "!=": TaskletCode.fp_one,
1127
            }
1128

1129
        if op not in cmp_ops:
4✔
1130
            raise NotImplementedError(
×
1131
                f"Comparison operator {op} not supported for arrays"
1132
            )
1133

1134
        tasklet_code = cmp_ops[op]
4✔
1135

1136
        # For scalar operand, we may need to convert integer to float
1137
        # Create a float constant if needed
1138
        scalar_name = None
4✔
1139
        if not left_is_array:
4✔
1140
            scalar_name = left
×
1141
        elif not right_is_array:
4✔
1142
            scalar_name = right
4✔
1143

1144
        if scalar_name is not None and not use_int_cmp:
4✔
1145
            # Check if scalar is an integer literal and convert to float
1146
            if self._is_int(scalar_name):
4✔
1147
                # Create a float constant
1148
                float_name = f"_tmp_{self._get_unique_id()}"
4✔
1149
                self.builder.add_container(
4✔
1150
                    float_name, Scalar(PrimitiveType.Double), False
1151
                )
1152
                self.symbol_table[float_name] = Scalar(PrimitiveType.Double)
4✔
1153

1154
                block_conv = self.builder.add_block()
4✔
1155
                t_const = self.builder.add_constant(
4✔
1156
                    block_conv, f"{scalar_name}.0", Scalar(PrimitiveType.Double)
1157
                )
1158
                t_float = self.builder.add_access(block_conv, float_name)
4✔
1159
                t_assign = self.builder.add_tasklet(
4✔
1160
                    block_conv, TaskletCode.assign, ["_in"], ["_out"]
1161
                )
1162
                self.builder.add_memlet(
4✔
1163
                    block_conv, t_const, "void", t_assign, "_in", ""
1164
                )
1165
                self.builder.add_memlet(
4✔
1166
                    block_conv, t_assign, "_out", t_float, "void", ""
1167
                )
1168

1169
                # Replace the scalar name with the converted float
1170
                if not left_is_array:
4✔
1171
                    left = float_name
×
1172
                else:
1173
                    right = float_name
4✔
1174

1175
        # Generate nested loops
1176
        loop_vars = []
4✔
1177
        for i, dim in enumerate(shape):
4✔
1178
            loop_var = f"_cmp_i{i}_{self._get_unique_id()}"
4✔
1179
            if not self.builder.exists(loop_var):
4✔
1180
                self.builder.add_container(loop_var, Scalar(PrimitiveType.Int64), False)
4✔
1181
                self.symbol_table[loop_var] = Scalar(PrimitiveType.Int64)
4✔
1182
            loop_vars.append(loop_var)
4✔
1183
            self.builder.begin_for(loop_var, "0", str(dim), "1")
4✔
1184

1185
        # Compute linear index for array access
1186
        linear_idx = self._compute_linear_index(loop_vars, shape, tmp_name, len(shape))
4✔
1187

1188
        # Create comparison block
1189
        block = self.builder.add_block()
4✔
1190

1191
        # Read left operand
1192
        if left_is_array:
4✔
1193
            t_left = self.builder.add_access(block, left)
4✔
1194
            left_sub = linear_idx
4✔
1195
        else:
1196
            t_left, left_sub = self._add_read(block, left)
×
1197

1198
        # Read right operand
1199
        if right_is_array:
4✔
1200
            t_right = self.builder.add_access(block, right)
×
1201
            right_sub = linear_idx
×
1202
        else:
1203
            t_right, right_sub = self._add_read(block, right)
4✔
1204

1205
        # Output access
1206
        t_out = self.builder.add_access(block, tmp_name)
4✔
1207

1208
        # Create tasklet for comparison
1209
        t_task = self.builder.add_tasklet(
4✔
1210
            block, tasklet_code, ["_in1", "_in2"], ["_out"]
1211
        )
1212

1213
        # Connect memlets
1214
        self.builder.add_memlet(block, t_left, "void", t_task, "_in1", left_sub)
4✔
1215
        self.builder.add_memlet(block, t_right, "void", t_task, "_in2", right_sub)
4✔
1216
        self.builder.add_memlet(block, t_task, "_out", t_out, "void", linear_idx)
4✔
1217

1218
        # Close loops
1219
        for _ in loop_vars:
4✔
1220
            self.builder.end_for()
4✔
1221

1222
        return tmp_name
4✔
1223

1224
    def _parse_array_arg(self, node, simple_visitor):
4✔
1225
        if isinstance(node, ast.Name):
×
1226
            if node.id in self.array_info:
×
1227
                return node.id, [], self.array_info[node.id]["shapes"]
×
1228
        elif isinstance(node, ast.Subscript):
×
1229
            if isinstance(node.value, ast.Name) and node.value.id in self.array_info:
×
1230
                name = node.value.id
×
1231
                ndim = self.array_info[name]["ndim"]
×
1232

1233
                indices = []
×
1234
                if isinstance(node.slice, ast.Tuple):
×
1235
                    indices = list(node.slice.elts)
×
1236
                else:
1237
                    indices = [node.slice]
×
1238

1239
                while len(indices) < ndim:
×
1240
                    indices.append(ast.Slice(lower=None, upper=None, step=None))
×
1241

1242
                start_indices = []
×
1243
                slice_shape = []
×
1244

1245
                for i, idx in enumerate(indices):
×
1246
                    if isinstance(idx, ast.Slice):
×
1247
                        start = "0"
×
1248
                        if idx.lower:
×
1249
                            start = simple_visitor.visit(idx.lower)
×
1250
                        start_indices.append(start)
×
1251

1252
                        shapes = self.array_info[name]["shapes"]
×
1253
                        dim_size = (
×
1254
                            shapes[i] if i < len(shapes) else f"_{name}_shape_{i}"
1255
                        )
1256
                        stop = dim_size
×
1257
                        if idx.upper:
×
1258
                            stop = simple_visitor.visit(idx.upper)
×
1259

1260
                        size = f"({stop} - {start})"
×
1261
                        slice_shape.append(size)
×
1262
                    else:
1263
                        val = simple_visitor.visit(idx)
×
1264
                        start_indices.append(val)
×
1265

1266
                shapes = self.array_info[name]["shapes"]
×
1267
                linear_index = ""
×
1268
                for i in range(ndim):
×
1269
                    term = start_indices[i]
×
1270
                    for j in range(i + 1, ndim):
×
1271
                        shape_val = shapes[j] if j < len(shapes) else None
×
1272
                        shape_sym = (
×
1273
                            shape_val if shape_val is not None else f"_{name}_shape_{j}"
1274
                        )
1275
                        term = f"({term} * {shape_sym})"
×
1276

1277
                    if i == 0:
×
1278
                        linear_index = term
×
1279
                    else:
1280
                        linear_index = f"({linear_index} + {term})"
×
1281

1282
                return name, [linear_index], slice_shape
×
1283

1284
        return None, None, None
×
1285

1286
    def visit_Attribute(self, node):
4✔
1287
        if node.attr == "shape":
4✔
1288
            if isinstance(node.value, ast.Name) and node.value.id in self.array_info:
4✔
1289
                return f"_shape_proxy_{node.value.id}"
4✔
1290

1291
        if isinstance(node.value, ast.Name) and node.value.id == "math":
4✔
1292
            val = ""
4✔
1293
            if node.attr == "pi":
4✔
1294
                val = "M_PI"
4✔
1295
            elif node.attr == "e":
4✔
1296
                val = "M_E"
4✔
1297

1298
            if val:
4✔
1299
                tmp_name = f"_tmp_{self._get_unique_id()}"
4✔
1300
                dtype = Scalar(PrimitiveType.Double)
4✔
1301
                self.builder.add_container(tmp_name, dtype, False)
4✔
1302
                self.symbol_table[tmp_name] = dtype
4✔
1303
                self._add_assign_constant(tmp_name, val, dtype)
4✔
1304
                return tmp_name
4✔
1305

1306
        # Handle class member access (e.g., obj.x, obj.y)
1307
        if isinstance(node.value, ast.Name):
4✔
1308
            obj_name = node.value.id
4✔
1309
            attr_name = node.attr
4✔
1310

1311
            # Check if the object is a class instance (has a Structure type)
1312
            if obj_name in self.symbol_table:
4✔
1313
                obj_type = self.symbol_table[obj_name]
4✔
1314
                if isinstance(obj_type, Pointer) and obj_type.has_pointee_type():
4✔
1315
                    pointee_type = obj_type.pointee_type
4✔
1316
                    if isinstance(pointee_type, Structure):
4✔
1317
                        struct_name = pointee_type.name
4✔
1318

1319
                        # Look up member index and type from structure info
1320
                        if (
4✔
1321
                            struct_name in self.structure_member_info
1322
                            and attr_name in self.structure_member_info[struct_name]
1323
                        ):
1324
                            member_index, member_type = self.structure_member_info[
4✔
1325
                                struct_name
1326
                            ][attr_name]
1327
                        else:
1328
                            # This should not happen if structure was registered properly
1329
                            raise RuntimeError(
×
1330
                                f"Member '{attr_name}' not found in structure '{struct_name}'. "
1331
                                f"Available members: {list(self.structure_member_info.get(struct_name, {}).keys())}"
1332
                            )
1333

1334
                        # Generate a tasklet to access the member
1335
                        tmp_name = f"_tmp_{self._get_unique_id()}"
4✔
1336

1337
                        self.builder.add_container(tmp_name, member_type, False)
4✔
1338
                        self.symbol_table[tmp_name] = member_type
4✔
1339

1340
                        # Create a tasklet that reads the member
1341
                        block = self.builder.add_block()
4✔
1342
                        obj_access = self.builder.add_access(block, obj_name)
4✔
1343
                        tmp_access = self.builder.add_access(block, tmp_name)
4✔
1344

1345
                        # Use tasklet to pass through the value
1346
                        # The actual member selection is done via the memlet subset
1347
                        tasklet = self.builder.add_tasklet(
4✔
1348
                            block, TaskletCode.assign, ["_in"], ["_out"]
1349
                        )
1350

1351
                        # Use member index in the subset to select the correct member
1352
                        subset = "0," + str(member_index)
4✔
1353
                        self.builder.add_memlet(
4✔
1354
                            block, obj_access, "", tasklet, "_in", subset
1355
                        )
1356
                        self.builder.add_memlet(block, tasklet, "_out", tmp_access, "")
4✔
1357

1358
                        return tmp_name
4✔
1359

1360
        raise NotImplementedError(f"Attribute access {node.attr} not supported")
×
1361

1362
    def _handle_expression_slicing(self, node, value_str, indices_nodes, shapes, ndim):
4✔
1363
        """Handle slicing in expressions (e.g., arr[1:, :, k+1]).
1364

1365
        Creates a temporary array, generates loops to copy sliced data,
1366
        and returns the temporary array name.
1367
        """
1368
        if not self.builder:
4✔
1369
            raise ValueError("Builder required for expression slicing")
×
1370

1371
        # Determine element type from source array
1372
        dtype = Scalar(PrimitiveType.Double)
4✔
1373
        if value_str in self.symbol_table:
4✔
1374
            t = self.symbol_table[value_str]
4✔
1375
            if isinstance(t, Pointer) and t.has_pointee_type():
4✔
1376
                dtype = t.pointee_type
4✔
1377

1378
        # Analyze each dimension: is it a slice or an index?
1379
        # For slices, compute the resulting shape dimension
1380
        # For indices, that dimension is collapsed
1381
        result_shapes = []  # Shape of the resulting array (for SDFG)
4✔
1382
        result_shapes_runtime = []  # Shape expressions for runtime evaluation
4✔
1383
        slice_info = []  # List of (dim_idx, start_str, stop_str, step_str) for slices
4✔
1384
        index_info = []  # List of (dim_idx, index_str) for point indices
4✔
1385

1386
        for i, idx in enumerate(indices_nodes):
4✔
1387
            shape_val = shapes[i] if i < len(shapes) else f"_{value_str}_shape_{i}"
4✔
1388

1389
            if isinstance(idx, ast.Slice):
4✔
1390
                # Parse slice bounds - check for indirect access patterns
1391
                start_str = "0"
4✔
1392
                start_str_runtime = "0"  # For runtime shape evaluation
4✔
1393
                if idx.lower is not None:
4✔
1394
                    # Check if lower bound contains indirect array access
1395
                    if self._contains_indirect_access(idx.lower):
4✔
1396
                        start_str, start_str_runtime = (
4✔
1397
                            self._materialize_indirect_access(
1398
                                idx.lower, return_original_expr=True
1399
                            )
1400
                        )
1401
                    else:
1402
                        start_str = self.visit(idx.lower)
4✔
1403
                        start_str_runtime = start_str
4✔
1404
                    # Handle negative indices
1405
                    if isinstance(start_str, str) and (
4✔
1406
                        start_str.startswith("-") or start_str.startswith("(-")
1407
                    ):
UNCOV
1408
                        start_str = f"({shape_val} + {start_str})"
×
NEW
1409
                        start_str_runtime = f"({shape_val} + {start_str_runtime})"
×
1410

1411
                stop_str = str(shape_val)
4✔
1412
                stop_str_runtime = str(shape_val)
4✔
1413
                if idx.upper is not None:
4✔
1414
                    # Check if upper bound contains indirect array access
1415
                    if self._contains_indirect_access(idx.upper):
4✔
1416
                        stop_str, stop_str_runtime = self._materialize_indirect_access(
4✔
1417
                            idx.upper, return_original_expr=True
1418
                        )
1419
                    else:
1420
                        stop_str = self.visit(idx.upper)
4✔
1421
                        stop_str_runtime = stop_str
4✔
1422
                    # Handle negative indices
1423
                    if isinstance(stop_str, str) and (
4✔
1424
                        stop_str.startswith("-") or stop_str.startswith("(-")
1425
                    ):
1426
                        stop_str = f"({shape_val} + {stop_str})"
4✔
1427
                        stop_str_runtime = f"({shape_val} + {stop_str_runtime})"
4✔
1428

1429
                step_str = "1"
4✔
1430
                if idx.step is not None:
4✔
1431
                    step_str = self.visit(idx.step)
×
1432

1433
                # Compute the size of this dimension in the result
1434
                dim_size = f"({stop_str} - {start_str})"
4✔
1435
                dim_size_runtime = f"({stop_str_runtime} - {start_str_runtime})"
4✔
1436
                result_shapes.append(dim_size)
4✔
1437
                result_shapes_runtime.append(dim_size_runtime)
4✔
1438
                slice_info.append((i, start_str, stop_str, step_str))
4✔
1439
            else:
1440
                # Point index - dimension is collapsed
1441
                # Check for indirect array access in the index
1442
                if self._contains_indirect_access(idx):
4✔
NEW
1443
                    index_str = self._materialize_indirect_access(idx)
×
1444
                else:
1445
                    index_str = self.visit(idx)
4✔
1446
                # Handle negative indices
1447
                if isinstance(index_str, str) and (
4✔
1448
                    index_str.startswith("-") or index_str.startswith("(-")
1449
                ):
1450
                    index_str = f"({shape_val} + {index_str})"
×
1451
                index_info.append((i, index_str))
4✔
1452

1453
        # Create temporary array for the result
1454
        tmp_name = self._get_temp_name("_slice_tmp_")
4✔
1455
        result_ndim = len(result_shapes)
4✔
1456

1457
        if result_ndim == 0:
4✔
1458
            # All dimensions indexed - result is a scalar
1459
            self.builder.add_container(tmp_name, dtype, False)
×
1460
            self.symbol_table[tmp_name] = dtype
×
1461
        else:
1462
            # Result is an array - use _create_array_temp to handle allocation
1463
            # Calculate size for malloc - use SDFG symbolic shapes
1464
            size_str = "1"
4✔
1465
            for dim in result_shapes:
4✔
1466
                size_str = f"({size_str} * {dim})"
4✔
1467

1468
            element_size = self.builder.get_sizeof(dtype)
4✔
1469
            total_size = f"({size_str} * {element_size})"
4✔
1470

1471
            # Create pointer
1472
            ptr_type = Pointer(dtype)
4✔
1473
            self.builder.add_container(tmp_name, ptr_type, False)
4✔
1474
            self.symbol_table[tmp_name] = ptr_type
4✔
1475
            # Store both SDFG shapes (for compilation) and runtime shapes (for evaluation)
1476
            # The "shapes" field uses SDFG symbolic variables for malloc sizing
1477
            # The "shapes_runtime" field uses original expressions for Python runtime evaluation
1478
            self.array_info[tmp_name] = {
4✔
1479
                "ndim": result_ndim,
1480
                "shapes": result_shapes,  # Uses materialized variables for SDFG
1481
                "shapes_runtime": result_shapes_runtime,  # Uses original expressions for runtime
1482
            }
1483

1484
            # Malloc for the temporary array
1485
            debug_info = DebugInfo()
4✔
1486
            block_alloc = self.builder.add_block(debug_info)
4✔
1487
            t_malloc = self.builder.add_malloc(block_alloc, total_size)
4✔
1488
            t_ptr = self.builder.add_access(block_alloc, tmp_name, debug_info)
4✔
1489
            self.builder.add_memlet(
4✔
1490
                block_alloc, t_malloc, "_ret", t_ptr, "void", "", ptr_type, debug_info
1491
            )
1492

1493
        # Generate loops to copy the sliced data
1494
        loop_vars = []
4✔
1495
        debug_info = DebugInfo()
4✔
1496

1497
        for dim_idx, (orig_dim, start_str, stop_str, step_str) in enumerate(slice_info):
4✔
1498
            loop_var = f"_slice_loop_{dim_idx}_{self._get_unique_id()}"
4✔
1499
            loop_vars.append((loop_var, orig_dim, start_str, step_str))
4✔
1500

1501
            if not self.builder.exists(loop_var):
4✔
1502
                self.builder.add_container(loop_var, Scalar(PrimitiveType.Int64), False)
4✔
1503
                self.symbol_table[loop_var] = Scalar(PrimitiveType.Int64)
4✔
1504

1505
            # Loop from 0 to (stop - start)
1506
            count_str = f"({stop_str} - {start_str})"
4✔
1507
            self.builder.begin_for(loop_var, "0", count_str, "1", debug_info)
4✔
1508

1509
        # Build source and destination indices
1510
        src_indices = [""] * ndim
4✔
1511
        dst_indices = []
4✔
1512

1513
        # Fill in point indices for source
1514
        for orig_dim, index_str in index_info:
4✔
1515
            src_indices[orig_dim] = index_str
4✔
1516

1517
        # Fill in slice indices for source and build destination indices
1518
        for loop_var, orig_dim, start_str, step_str in loop_vars:
4✔
1519
            if step_str == "1":
4✔
1520
                src_indices[orig_dim] = f"({start_str} + {loop_var})"
4✔
1521
            else:
1522
                src_indices[orig_dim] = f"({start_str} + {loop_var} * {step_str})"
×
1523
            dst_indices.append(loop_var)
4✔
1524

1525
        # Compute linear indices
1526
        src_linear = self._compute_linear_index(src_indices, shapes, value_str, ndim)
4✔
1527
        if result_ndim > 0:
4✔
1528
            dst_linear = self._compute_linear_index(
4✔
1529
                dst_indices, result_shapes, tmp_name, result_ndim
1530
            )
1531
        else:
1532
            dst_linear = "0"
×
1533

1534
        # Create the copy block
1535
        block = self.builder.add_block(debug_info)
4✔
1536
        t_src = self.builder.add_access(block, value_str, debug_info)
4✔
1537
        t_dst = self.builder.add_access(block, tmp_name, debug_info)
4✔
1538
        t_task = self.builder.add_tasklet(
4✔
1539
            block, TaskletCode.assign, ["_in"], ["_out"], debug_info
1540
        )
1541

1542
        self.builder.add_memlet(
4✔
1543
            block, t_src, "void", t_task, "_in", src_linear, None, debug_info
1544
        )
1545
        self.builder.add_memlet(
4✔
1546
            block, t_task, "_out", t_dst, "void", dst_linear, None, debug_info
1547
        )
1548

1549
        # Close all loops
1550
        for _ in loop_vars:
4✔
1551
            self.builder.end_for()
4✔
1552

1553
        return tmp_name
4✔
1554

1555
    def _compute_linear_index(self, indices, shapes, array_name, ndim):
4✔
1556
        """Compute linear index from multi-dimensional indices."""
1557
        if ndim == 0:
4✔
1558
            return "0"
×
1559

1560
        linear_index = ""
4✔
1561
        for i in range(ndim):
4✔
1562
            term = str(indices[i])
4✔
1563
            for j in range(i + 1, ndim):
4✔
1564
                shape_val = shapes[j] if j < len(shapes) else f"_{array_name}_shape_{j}"
4✔
1565
                term = f"(({term}) * {shape_val})"
4✔
1566

1567
            if i == 0:
4✔
1568
                linear_index = term
4✔
1569
            else:
1570
                linear_index = f"({linear_index} + {term})"
4✔
1571

1572
        return linear_index
4✔
1573

1574
    def _is_array_index(self, node):
4✔
1575
        """Check if a node represents an array that could be used as an index (gather).
1576

1577
        Returns True if the node is a Name referring to an array in array_info.
1578
        """
1579
        if isinstance(node, ast.Name):
4✔
1580
            return node.id in self.array_info
4✔
1581
        return False
4✔
1582

1583
    def _handle_gather(self, value_str, index_node, debug_info=None):
4✔
1584
        """Handle gather operation: x[indices] where indices is an array.
1585

1586
        Creates a temporary array and generates a loop to gather elements
1587
        from the source array using the index array.
1588

1589
        This is the canonical SDFG pattern for gather operations:
1590
        - Create a loop over the index array
1591
        - Load the index value using a tasklet+memlets
1592
        - Use that index in the memlet subset for the source array
1593
        """
1594
        if debug_info is None:
4✔
1595
            debug_info = DebugInfo()
4✔
1596

1597
        # Get the index array name
1598
        if isinstance(index_node, ast.Name):
4✔
1599
            idx_array_name = index_node.id
4✔
1600
        else:
1601
            # Visit the index to get its name (handles slices like cols)
NEW
1602
            idx_array_name = self.visit(index_node)
×
1603

1604
        if idx_array_name not in self.array_info:
4✔
NEW
1605
            raise ValueError(f"Gather index must be an array, got {idx_array_name}")
×
1606

1607
        # Get shapes
1608
        idx_shapes = self.array_info[idx_array_name].get("shapes", [])
4✔
1609
        src_ndim = self.array_info[value_str]["ndim"]
4✔
1610
        idx_ndim = self.array_info[idx_array_name]["ndim"]
4✔
1611

1612
        if idx_ndim != 1:
4✔
NEW
1613
            raise NotImplementedError("Only 1D index arrays supported for gather")
×
1614

1615
        # Result array has same shape as index array
1616
        result_shape = idx_shapes[0] if idx_shapes else f"_{idx_array_name}_shape_0"
4✔
1617

1618
        # Determine element type from source array
1619
        dtype = Scalar(PrimitiveType.Double)
4✔
1620
        if value_str in self.symbol_table:
4✔
1621
            t = self.symbol_table[value_str]
4✔
1622
            if isinstance(t, Pointer) and t.has_pointee_type():
4✔
1623
                dtype = t.pointee_type
4✔
1624

1625
        # Determine index type from index array
1626
        idx_dtype = Scalar(PrimitiveType.Int64)
4✔
1627
        if idx_array_name in self.symbol_table:
4✔
1628
            t = self.symbol_table[idx_array_name]
4✔
1629
            if isinstance(t, Pointer) and t.has_pointee_type():
4✔
1630
                idx_dtype = t.pointee_type
4✔
1631

1632
        # Create result array
1633
        tmp_name = self._get_temp_name("_gather_")
4✔
1634

1635
        # Calculate size for malloc
1636
        element_size = self.builder.get_sizeof(dtype)
4✔
1637
        total_size = f"({result_shape} * {element_size})"
4✔
1638

1639
        # Create pointer for result
1640
        ptr_type = Pointer(dtype)
4✔
1641
        self.builder.add_container(tmp_name, ptr_type, False)
4✔
1642
        self.symbol_table[tmp_name] = ptr_type
4✔
1643
        self.array_info[tmp_name] = {"ndim": 1, "shapes": [result_shape]}
4✔
1644

1645
        # Malloc for the result array
1646
        block_alloc = self.builder.add_block(debug_info)
4✔
1647
        t_malloc = self.builder.add_malloc(block_alloc, total_size)
4✔
1648
        t_ptr = self.builder.add_access(block_alloc, tmp_name, debug_info)
4✔
1649
        self.builder.add_memlet(
4✔
1650
            block_alloc, t_malloc, "_ret", t_ptr, "void", "", ptr_type, debug_info
1651
        )
1652

1653
        # Create loop variable
1654
        loop_var = f"_gather_i_{self._get_unique_id()}"
4✔
1655
        self.builder.add_container(loop_var, Scalar(PrimitiveType.Int64), False)
4✔
1656
        self.symbol_table[loop_var] = Scalar(PrimitiveType.Int64)
4✔
1657

1658
        # Create variable to hold the loaded index
1659
        idx_var = f"_gather_idx_{self._get_unique_id()}"
4✔
1660
        self.builder.add_container(idx_var, idx_dtype, False)
4✔
1661
        self.symbol_table[idx_var] = idx_dtype
4✔
1662

1663
        # Begin loop
1664
        self.builder.begin_for(loop_var, "0", str(result_shape), "1", debug_info)
4✔
1665

1666
        # Block 1: Load the index from index array using tasklet+memlets
1667
        block_load_idx = self.builder.add_block(debug_info)
4✔
1668
        idx_arr_access = self.builder.add_access(
4✔
1669
            block_load_idx, idx_array_name, debug_info
1670
        )
1671
        idx_var_access = self.builder.add_access(block_load_idx, idx_var, debug_info)
4✔
1672
        tasklet_load = self.builder.add_tasklet(
4✔
1673
            block_load_idx, TaskletCode.assign, ["_in"], ["_out"], debug_info
1674
        )
1675
        self.builder.add_memlet(
4✔
1676
            block_load_idx,
1677
            idx_arr_access,
1678
            "void",
1679
            tasklet_load,
1680
            "_in",
1681
            loop_var,
1682
            None,
1683
            debug_info,
1684
        )
1685
        self.builder.add_memlet(
4✔
1686
            block_load_idx,
1687
            tasklet_load,
1688
            "_out",
1689
            idx_var_access,
1690
            "void",
1691
            "",
1692
            None,
1693
            debug_info,
1694
        )
1695

1696
        # Block 2: Use the loaded index to gather from source array
1697
        block_gather = self.builder.add_block(debug_info)
4✔
1698
        src_access = self.builder.add_access(block_gather, value_str, debug_info)
4✔
1699
        dst_access = self.builder.add_access(block_gather, tmp_name, debug_info)
4✔
1700
        tasklet_gather = self.builder.add_tasklet(
4✔
1701
            block_gather, TaskletCode.assign, ["_in"], ["_out"], debug_info
1702
        )
1703

1704
        # Use the symbolic variable name (idx_var) in the memlet subset - this is key!
1705
        self.builder.add_memlet(
4✔
1706
            block_gather,
1707
            src_access,
1708
            "void",
1709
            tasklet_gather,
1710
            "_in",
1711
            idx_var,
1712
            None,
1713
            debug_info,
1714
        )
1715
        self.builder.add_memlet(
4✔
1716
            block_gather,
1717
            tasklet_gather,
1718
            "_out",
1719
            dst_access,
1720
            "void",
1721
            loop_var,
1722
            None,
1723
            debug_info,
1724
        )
1725

1726
        # End loop
1727
        self.builder.end_for()
4✔
1728

1729
        return tmp_name
4✔
1730

1731
    def visit_Subscript(self, node):
4✔
1732
        value_str = self.visit(node.value)
4✔
1733

1734
        if value_str.startswith("_shape_proxy_"):
4✔
1735
            array_name = value_str[len("_shape_proxy_") :]
4✔
1736
            if isinstance(node.slice, ast.Constant):
4✔
1737
                idx = node.slice.value
4✔
1738
            elif isinstance(node.slice, ast.Index):
×
1739
                idx = node.slice.value.value
×
1740
            else:
1741
                try:
×
1742
                    idx = int(self.visit(node.slice))
×
1743
                except:
×
1744
                    raise NotImplementedError(
×
1745
                        "Dynamic shape indexing not fully supported yet"
1746
                    )
1747

1748
            if (
4✔
1749
                array_name in self.array_info
1750
                and "shapes" in self.array_info[array_name]
1751
            ):
1752
                return self.array_info[array_name]["shapes"][idx]
4✔
1753

1754
            return f"_{array_name}_shape_{idx}"
×
1755

1756
        if value_str in self.array_info:
4✔
1757
            ndim = self.array_info[value_str]["ndim"]
4✔
1758
            shapes = self.array_info[value_str].get("shapes", [])
4✔
1759

1760
            indices = []
4✔
1761
            if isinstance(node.slice, ast.Tuple):
4✔
1762
                indices_nodes = node.slice.elts
4✔
1763
            else:
1764
                indices_nodes = [node.slice]
4✔
1765

1766
            # Check if all indices are full slices (e.g., path[:] or path[:, :])
1767
            # In this case, return just the array name since it's the full array
1768
            all_full_slices = True
4✔
1769
            for idx in indices_nodes:
4✔
1770
                if isinstance(idx, ast.Slice):
4✔
1771
                    # A full slice has no lower, upper bounds or only None
1772
                    if idx.lower is not None or idx.upper is not None:
4✔
1773
                        all_full_slices = False
4✔
1774
                        break
4✔
1775
                else:
1776
                    all_full_slices = False
4✔
1777
                    break
4✔
1778

1779
            # path[:] on an nD array returns the full array
1780
            # So if we have a single full slice, it covers all dimensions
1781
            if all_full_slices:
4✔
1782
                # This is path[:] or path[:,:] - return the array name
1783
                return value_str
4✔
1784

1785
            # Check if there are any slices in the indices
1786
            has_slices = any(isinstance(idx, ast.Slice) for idx in indices_nodes)
4✔
1787
            if has_slices:
4✔
1788
                # Handle mixed slicing (e.g., arr[1:, :, k] or arr[:-1, :, k+1])
1789
                return self._handle_expression_slicing(
4✔
1790
                    node, value_str, indices_nodes, shapes, ndim
1791
                )
1792

1793
            # Check for gather operation: x[indices_array] where indices_array is an array
1794
            # This happens when we have a 1D source array and a 1D index array
1795
            if len(indices_nodes) == 1 and self._is_array_index(indices_nodes[0]):
4✔
1796
                if self.builder:
4✔
1797
                    return self._handle_gather(value_str, indices_nodes[0])
4✔
1798

1799
            if isinstance(node.slice, ast.Tuple):
4✔
1800
                indices = [self.visit(elt) for elt in node.slice.elts]
4✔
1801
            else:
1802
                indices = [self.visit(node.slice)]
4✔
1803

1804
            if len(indices) != ndim:
4✔
1805
                raise ValueError(
×
1806
                    f"Array {value_str} has {ndim} dimensions, but accessed with {len(indices)} indices"
1807
                )
1808

1809
            # Normalize negative indices
1810
            normalized_indices = []
4✔
1811
            for i, idx_str in enumerate(indices):
4✔
1812
                shape_val = shapes[i] if i < len(shapes) else f"_{value_str}_shape_{i}"
4✔
1813
                # Check if index is negative (starts with "-" or "(-")
1814
                if isinstance(idx_str, str) and (
4✔
1815
                    idx_str.startswith("-") or idx_str.startswith("(-")
1816
                ):
1817
                    # Normalize: size + negative_index
1818
                    normalized_indices.append(f"({shape_val} + {idx_str})")
×
1819
                else:
1820
                    normalized_indices.append(idx_str)
4✔
1821

1822
            linear_index = ""
4✔
1823
            for i in range(ndim):
4✔
1824
                term = normalized_indices[i]
4✔
1825
                for j in range(i + 1, ndim):
4✔
1826
                    shape_val = shapes[j] if j < len(shapes) else None
4✔
1827
                    shape_sym = (
4✔
1828
                        shape_val
1829
                        if shape_val is not None
1830
                        else f"_{value_str}_shape_{j}"
1831
                    )
1832
                    term = f"(({term}) * {shape_sym})"
4✔
1833

1834
                if i == 0:
4✔
1835
                    linear_index = term
4✔
1836
                else:
1837
                    linear_index = f"({linear_index} + {term})"
4✔
1838

1839
            access_str = f"{value_str}({linear_index})"
4✔
1840

1841
            if self.builder and isinstance(node.ctx, ast.Load):
4✔
1842
                dtype = Scalar(PrimitiveType.Double)
4✔
1843
                if value_str in self.symbol_table:
4✔
1844
                    t = self.symbol_table[value_str]
4✔
1845
                    if type(t).__name__ == "Array" and hasattr(t, "element_type"):
4✔
1846
                        et = t.element_type
×
1847
                        if callable(et):
×
1848
                            et = et()
×
1849
                        dtype = et
×
1850
                    elif type(t).__name__ == "Pointer" and hasattr(t, "pointee_type"):
4✔
1851
                        et = t.pointee_type
4✔
1852
                        if callable(et):
4✔
1853
                            et = et()
×
1854
                        dtype = et
4✔
1855

1856
                tmp_name = f"_tmp_{self._get_unique_id()}"
4✔
1857
                self.builder.add_container(tmp_name, dtype, False)
4✔
1858

1859
                block = self.builder.add_block()
4✔
1860
                t_src = self.builder.add_access(block, value_str)
4✔
1861
                t_dst = self.builder.add_access(block, tmp_name)
4✔
1862
                t_task = self.builder.add_tasklet(
4✔
1863
                    block, TaskletCode.assign, ["_in"], ["_out"]
1864
                )
1865

1866
                self.builder.add_memlet(
4✔
1867
                    block, t_src, "void", t_task, "_in", linear_index
1868
                )
1869
                self.builder.add_memlet(block, t_task, "_out", t_dst, "void", "")
4✔
1870

1871
                self.symbol_table[tmp_name] = dtype
4✔
1872
                return tmp_name
4✔
1873

1874
            return access_str
4✔
1875

1876
        slice_val = self.visit(node.slice)
×
1877
        access_str = f"{value_str}({slice_val})"
×
1878

1879
        if (
×
1880
            self.builder
1881
            and isinstance(node.ctx, ast.Load)
1882
            and value_str in self.array_info
1883
        ):
1884
            tmp_name = f"_tmp_{self._get_unique_id()}"
×
1885
            self.builder.add_container(tmp_name, Scalar(PrimitiveType.Double), False)
×
1886
            self.builder.add_assignment(tmp_name, access_str)
×
1887
            self.symbol_table[tmp_name] = Scalar(PrimitiveType.Double)
×
1888
            return tmp_name
×
1889

1890
        return access_str
×
1891

1892
    def visit_Add(self, node):
4✔
1893
        return "+"
4✔
1894

1895
    def visit_Sub(self, node):
4✔
1896
        return "-"
4✔
1897

1898
    def visit_Mult(self, node):
4✔
1899
        return "*"
4✔
1900

1901
    def visit_Div(self, node):
4✔
1902
        return "/"
4✔
1903

1904
    def visit_FloorDiv(self, node):
4✔
1905
        return "//"
4✔
1906

1907
    def visit_Mod(self, node):
4✔
1908
        return "%"
4✔
1909

1910
    def visit_Pow(self, node):
4✔
1911
        return "**"
4✔
1912

1913
    def visit_Eq(self, node):
4✔
1914
        return "=="
×
1915

1916
    def visit_NotEq(self, node):
4✔
1917
        return "!="
×
1918

1919
    def visit_Lt(self, node):
4✔
1920
        return "<"
4✔
1921

1922
    def visit_LtE(self, node):
4✔
1923
        return "<="
×
1924

1925
    def visit_Gt(self, node):
4✔
1926
        return ">"
4✔
1927

1928
    def visit_GtE(self, node):
4✔
1929
        return ">="
×
1930

1931
    def visit_And(self, node):
4✔
1932
        return "&"
4✔
1933

1934
    def visit_Or(self, node):
4✔
1935
        return "|"
4✔
1936

1937
    def visit_BitAnd(self, node):
4✔
1938
        return "&"
×
1939

1940
    def visit_BitOr(self, node):
4✔
1941
        return "|"
4✔
1942

1943
    def visit_BitXor(self, node):
4✔
1944
        return "^"
4✔
1945

1946
    def visit_Not(self, node):
4✔
1947
        return "!"
4✔
1948

1949
    def visit_USub(self, node):
4✔
1950
        return "-"
4✔
1951

1952
    def visit_UAdd(self, node):
4✔
1953
        return "+"
×
1954

1955
    def visit_Invert(self, node):
4✔
1956
        return "~"
×
1957

1958
    def _get_dtype(self, name):
4✔
1959
        if name in self.symbol_table:
4✔
1960
            t = self.symbol_table[name]
4✔
1961
            if isinstance(t, Scalar):
4✔
1962
                return t
4✔
1963

1964
            if hasattr(t, "pointee_type"):
4✔
1965
                et = t.pointee_type
4✔
1966
                if callable(et):
4✔
1967
                    et = et()
×
1968
                if isinstance(et, Scalar):
4✔
1969
                    return et
4✔
1970

1971
            if hasattr(t, "element_type"):
×
1972
                et = t.element_type
×
1973
                if callable(et):
×
1974
                    et = et()
×
1975
                if isinstance(et, Scalar):
×
1976
                    return et
×
1977

1978
        if self._is_int(name):
4✔
1979
            return Scalar(PrimitiveType.Int64)
×
1980

1981
        return Scalar(PrimitiveType.Double)
4✔
1982

1983
    def _promote_dtypes(self, dtype_left, dtype_right):
4✔
1984
        """Promote two dtypes following NumPy rules: float > int, wider > narrower."""
1985
        # Priority order: Double > Float > Int64 > Int32
1986
        priority = {
4✔
1987
            PrimitiveType.Double: 4,
1988
            PrimitiveType.Float: 3,
1989
            PrimitiveType.Int64: 2,
1990
            PrimitiveType.Int32: 1,
1991
        }
1992
        left_prio = priority.get(dtype_left.primitive_type, 0)
4✔
1993
        right_prio = priority.get(dtype_right.primitive_type, 0)
4✔
1994
        if left_prio >= right_prio:
4✔
1995
            return dtype_left
4✔
1996
        else:
1997
            return dtype_right
4✔
1998

1999
    def _create_array_temp(
4✔
2000
        self, shape, dtype, zero_init=False, ones_init=False, shapes_runtime=None
2001
    ):
2002
        tmp_name = f"_tmp_{self._get_unique_id()}"
4✔
2003

2004
        # Handle 0-dimensional arrays as scalars
2005
        if not shape or (len(shape) == 0):
4✔
2006
            # 0-D array is just a scalar
2007
            self.builder.add_container(tmp_name, dtype, False)
4✔
2008
            self.symbol_table[tmp_name] = dtype
4✔
2009
            self.array_info[tmp_name] = {"ndim": 0, "shapes": []}
4✔
2010

2011
            if zero_init:
4✔
2012
                self.builder.add_assignment(
×
2013
                    tmp_name,
2014
                    "0.0" if dtype.primitive_type == PrimitiveType.Double else "0",
2015
                )
2016
            elif ones_init:
4✔
2017
                self.builder.add_assignment(
×
2018
                    tmp_name,
2019
                    "1.0" if dtype.primitive_type == PrimitiveType.Double else "1",
2020
                )
2021

2022
            return tmp_name
4✔
2023

2024
        # Calculate size
2025
        size_str = "1"
4✔
2026
        for dim in shape:
4✔
2027
            size_str = f"({size_str} * {dim})"
4✔
2028

2029
        element_size = self.builder.get_sizeof(dtype)
4✔
2030
        total_size = f"({size_str} * {element_size})"
4✔
2031

2032
        # Create pointer
2033
        ptr_type = Pointer(dtype)
4✔
2034
        self.builder.add_container(tmp_name, ptr_type, False)
4✔
2035
        self.symbol_table[tmp_name] = ptr_type
4✔
2036
        array_info_entry = {"ndim": len(shape), "shapes": shape}
4✔
2037
        if shapes_runtime is not None:
4✔
2038
            array_info_entry["shapes_runtime"] = shapes_runtime
4✔
2039
        self.array_info[tmp_name] = array_info_entry
4✔
2040

2041
        # Malloc
2042
        block1 = self.builder.add_block()
4✔
2043
        t_malloc = self.builder.add_malloc(block1, total_size)
4✔
2044
        t_ptr1 = self.builder.add_access(block1, tmp_name)
4✔
2045
        self.builder.add_memlet(block1, t_malloc, "_ret", t_ptr1, "void", "", ptr_type)
4✔
2046

2047
        if zero_init:
4✔
2048
            block2 = self.builder.add_block()
4✔
2049
            t_memset = self.builder.add_memset(block2, "0", total_size)
4✔
2050
            t_ptr2 = self.builder.add_access(block2, tmp_name)
4✔
2051
            self.builder.add_memlet(
4✔
2052
                block2, t_memset, "_ptr", t_ptr2, "void", "", ptr_type
2053
            )
2054
        elif ones_init:
4✔
2055
            # Initialize array with ones using a loop
2056
            loop_var = f"_i_{self._get_unique_id()}"
4✔
2057
            if not self.builder.exists(loop_var):
4✔
2058
                self.builder.add_container(loop_var, Scalar(PrimitiveType.Int64), False)
4✔
2059
                self.symbol_table[loop_var] = Scalar(PrimitiveType.Int64)
4✔
2060

2061
            self.builder.begin_for(loop_var, "0", size_str, "1")
4✔
2062

2063
            # Determine the value to set based on dtype
2064
            val = "1.0"
4✔
2065
            if dtype.primitive_type in [
4✔
2066
                PrimitiveType.Int64,
2067
                PrimitiveType.Int32,
2068
                PrimitiveType.Int8,
2069
                PrimitiveType.Int16,
2070
                PrimitiveType.UInt64,
2071
                PrimitiveType.UInt32,
2072
                PrimitiveType.UInt8,
2073
                PrimitiveType.UInt16,
2074
            ]:
2075
                val = "1"
4✔
2076

2077
            block_assign = self.builder.add_block()
4✔
2078
            t_const = self.builder.add_constant(block_assign, val, dtype)
4✔
2079
            t_arr = self.builder.add_access(block_assign, tmp_name)
4✔
2080

2081
            t_task = self.builder.add_tasklet(
4✔
2082
                block_assign, TaskletCode.assign, ["_in"], ["_out"]
2083
            )
2084
            self.builder.add_memlet(
4✔
2085
                block_assign, t_const, "void", t_task, "_in", "", dtype
2086
            )
2087
            self.builder.add_memlet(
4✔
2088
                block_assign, t_task, "_out", t_arr, "void", loop_var
2089
            )
2090

2091
            self.builder.end_for()
4✔
2092

2093
        return tmp_name
4✔
2094

2095
    def _handle_array_unary_op(self, op_type, operand):
4✔
2096
        # Determine output shape
2097
        shape = []
4✔
2098
        if operand in self.array_info:
4✔
2099
            shape = self.array_info[operand]["shapes"]
4✔
2100

2101
        # Determine dtype
2102
        dtype = self._get_dtype(operand)
4✔
2103

2104
        # For 0-D arrays (scalars), use an intrinsic (CMathNode) instead of library node
2105
        if not shape or len(shape) == 0:
4✔
2106
            tmp_name = self._create_array_temp(shape, dtype)
4✔
2107

2108
            # Map op_type to C function names
2109
            func_map = {
4✔
2110
                "sqrt": CMathFunction.sqrt,
2111
                "abs": CMathFunction.fabs,
2112
                "absolute": CMathFunction.fabs,
2113
                "exp": CMathFunction.exp,
2114
                "tanh": CMathFunction.tanh,
2115
            }
2116

2117
            block = self.builder.add_block()
4✔
2118
            t_src = self.builder.add_access(block, operand)
4✔
2119
            t_dst = self.builder.add_access(block, tmp_name)
4✔
2120
            t_task = self.builder.add_cmath(block, func_map[op_type])
4✔
2121

2122
            # CMathNode uses _in1, _in2, etc for inputs and _out for output
2123
            self.builder.add_memlet(block, t_src, "void", t_task, "_in1", "", dtype)
4✔
2124
            self.builder.add_memlet(block, t_task, "_out", t_dst, "void", "", dtype)
4✔
2125

2126
            return tmp_name
4✔
2127

2128
        tmp_name = self._create_array_temp(shape, dtype)
4✔
2129

2130
        # Add operation
2131
        self.builder.add_elementwise_unary_op(op_type, operand, tmp_name, shape)
4✔
2132

2133
        return tmp_name
4✔
2134

2135
    def _handle_array_binary_op(self, op_type, left, right):
4✔
2136
        # Determine output shape (handle broadcasting by picking the larger shape)
2137
        left_shape = []
4✔
2138
        right_shape = []
4✔
2139
        if left in self.array_info:
4✔
2140
            left_shape = self.array_info[left]["shapes"]
4✔
2141
        if right in self.array_info:
4✔
2142
            right_shape = self.array_info[right]["shapes"]
4✔
2143
        # Pick the shape with more dimensions for broadcasting
2144
        shape = left_shape if len(left_shape) >= len(right_shape) else right_shape
4✔
2145

2146
        # Determine dtype with promotion (float > int, wider > narrower)
2147
        dtype_left = self._get_dtype(left)
4✔
2148
        dtype_right = self._get_dtype(right)
4✔
2149

2150
        # Promote dtypes: Double > Float > Int64 > Int32
2151
        dtype = self._promote_dtypes(dtype_left, dtype_right)
4✔
2152

2153
        # Cast scalar operands to the promoted dtype if needed
2154
        real_left = left
4✔
2155
        real_right = right
4✔
2156

2157
        # Helper to check if operand is a scalar (not an array)
2158
        left_is_scalar = left not in self.array_info
4✔
2159
        right_is_scalar = right not in self.array_info
4✔
2160

2161
        # Cast left operand if needed (scalar int to float)
2162
        if left_is_scalar and dtype_left.primitive_type != dtype.primitive_type:
4✔
2163
            left_cast = f"_tmp_{self._get_unique_id()}"
4✔
2164
            self.builder.add_container(left_cast, dtype, False)
4✔
2165
            self.symbol_table[left_cast] = dtype
4✔
2166

2167
            c_block = self.builder.add_block()
4✔
2168
            t_src, src_sub = self._add_read(c_block, left)
4✔
2169
            t_dst = self.builder.add_access(c_block, left_cast)
4✔
2170
            t_task = self.builder.add_tasklet(
4✔
2171
                c_block, TaskletCode.assign, ["_in"], ["_out"]
2172
            )
2173
            self.builder.add_memlet(c_block, t_src, "void", t_task, "_in", src_sub)
4✔
2174
            self.builder.add_memlet(c_block, t_task, "_out", t_dst, "void", "")
4✔
2175

2176
            real_left = left_cast
4✔
2177

2178
        # Cast right operand if needed (scalar int to float)
2179
        if right_is_scalar and dtype_right.primitive_type != dtype.primitive_type:
4✔
2180
            right_cast = f"_tmp_{self._get_unique_id()}"
4✔
2181
            self.builder.add_container(right_cast, dtype, False)
4✔
2182
            self.symbol_table[right_cast] = dtype
4✔
2183

2184
            c_block = self.builder.add_block()
4✔
2185
            t_src, src_sub = self._add_read(c_block, right)
4✔
2186
            t_dst = self.builder.add_access(c_block, right_cast)
4✔
2187
            t_task = self.builder.add_tasklet(
4✔
2188
                c_block, TaskletCode.assign, ["_in"], ["_out"]
2189
            )
2190
            self.builder.add_memlet(c_block, t_src, "void", t_task, "_in", src_sub)
4✔
2191
            self.builder.add_memlet(c_block, t_task, "_out", t_dst, "void", "")
4✔
2192

2193
            real_right = right_cast
4✔
2194

2195
        tmp_name = self._create_array_temp(shape, dtype)
4✔
2196

2197
        # Add operation with promoted dtype for implicit casting
2198
        self.builder.add_elementwise_op(op_type, real_left, real_right, tmp_name, shape)
4✔
2199

2200
        return tmp_name
4✔
2201

2202
    def _shape_to_runtime_expr(self, shape_node):
4✔
2203
        """Convert a shape expression AST node to a runtime-evaluable string.
2204

2205
        This converts the AST to a string expression that can be evaluated
2206
        at runtime using only input arrays and shape symbols (_s0, _s1, etc.).
2207
        It does NOT visit the node (which would create SDFG variables).
2208
        """
2209
        if isinstance(shape_node, ast.Constant):
4✔
2210
            return str(shape_node.value)
4✔
2211
        elif isinstance(shape_node, ast.Name):
4✔
2212
            return shape_node.id
4✔
2213
        elif isinstance(shape_node, ast.BinOp):
4✔
2214
            left = self._shape_to_runtime_expr(shape_node.left)
4✔
2215
            right = self._shape_to_runtime_expr(shape_node.right)
4✔
2216
            op = self.visit(shape_node.op)
4✔
2217
            return f"({left} {op} {right})"
4✔
2218
        elif isinstance(shape_node, ast.UnaryOp):
4✔
NEW
2219
            operand = self._shape_to_runtime_expr(shape_node.operand)
×
NEW
2220
            if isinstance(shape_node.op, ast.USub):
×
NEW
2221
                return f"(-{operand})"
×
NEW
2222
            elif isinstance(shape_node.op, ast.UAdd):
×
NEW
2223
                return operand
×
2224
            else:
2225
                # Fall back to visit for other unary ops
NEW
2226
                return self.visit(shape_node)
×
2227
        elif isinstance(shape_node, ast.Subscript):
4✔
2228
            # Handle arr.shape[0] -> arr.shape[0] for runtime eval
2229
            # or _shape_proxy_arr[0] -> _s<idx>
2230
            val = shape_node.value
4✔
2231
            if isinstance(val, ast.Attribute) and val.attr == "shape":
4✔
2232
                # arr.shape[0] -> use the shape symbol
2233
                if isinstance(val.value, ast.Name):
4✔
2234
                    arr_name = val.value.id
4✔
2235
                    if isinstance(shape_node.slice, ast.Constant):
4✔
2236
                        idx = shape_node.slice.value
4✔
2237
                        # Get the shape symbol for this array dimension
2238
                        if arr_name in self.array_info:
4✔
2239
                            shapes = self.array_info[arr_name].get("shapes", [])
4✔
2240
                            if idx < len(shapes):
4✔
2241
                                return shapes[idx]
4✔
NEW
2242
                        return f"{arr_name}.shape[{idx}]"
×
2243
            # Fall back to visit
NEW
2244
            return self.visit(shape_node)
×
NEW
2245
        elif isinstance(shape_node, ast.Tuple):
×
NEW
2246
            return [self._shape_to_runtime_expr(elt) for elt in shape_node.elts]
×
NEW
2247
        elif isinstance(shape_node, ast.List):
×
NEW
2248
            return [self._shape_to_runtime_expr(elt) for elt in shape_node.elts]
×
2249
        else:
2250
            # Fall back to visit for complex expressions
NEW
2251
            return self.visit(shape_node)
×
2252

2253
    def _handle_numpy_alloc(self, node, func_name):
4✔
2254
        # Parse shape
2255
        shape_arg = node.args[0]
4✔
2256
        dims = []
4✔
2257
        dims_runtime = []  # Runtime-evaluable shape expressions
4✔
2258
        if isinstance(shape_arg, ast.Tuple):
4✔
2259
            dims = [self.visit(elt) for elt in shape_arg.elts]
4✔
2260
            dims_runtime = [self._shape_to_runtime_expr(elt) for elt in shape_arg.elts]
4✔
2261
        elif isinstance(shape_arg, ast.List):
4✔
2262
            dims = [self.visit(elt) for elt in shape_arg.elts]
×
NEW
2263
            dims_runtime = [self._shape_to_runtime_expr(elt) for elt in shape_arg.elts]
×
2264
        else:
2265
            val = self.visit(shape_arg)
4✔
2266
            runtime_val = self._shape_to_runtime_expr(shape_arg)
4✔
2267
            if val.startswith("_shape_proxy_"):
4✔
2268
                array_name = val[len("_shape_proxy_") :]
×
2269
                if array_name in self.array_info:
×
2270
                    dims = self.array_info[array_name]["shapes"]
×
NEW
2271
                    dims_runtime = self.array_info[array_name].get(
×
2272
                        "shapes_runtime", dims
2273
                    )
2274
                else:
2275
                    dims = [val]
×
NEW
2276
                    dims_runtime = [runtime_val]
×
2277
            else:
2278
                dims = [val]
4✔
2279
                dims_runtime = [runtime_val]
4✔
2280

2281
        # Parse dtype
2282
        dtype_arg = None
4✔
2283
        if len(node.args) > 1:
4✔
2284
            dtype_arg = node.args[1]
×
2285

2286
        for kw in node.keywords:
4✔
2287
            if kw.arg == "dtype":
4✔
2288
                dtype_arg = kw.value
4✔
2289
                break
4✔
2290

2291
        element_type = self._map_numpy_dtype(dtype_arg)
4✔
2292

2293
        return self._create_array_temp(
4✔
2294
            dims,
2295
            element_type,
2296
            zero_init=(func_name == "zeros"),
2297
            ones_init=(func_name == "ones"),
2298
            shapes_runtime=dims_runtime,
2299
        )
2300

2301
    def _handle_numpy_empty_like(self, node, func_name):
4✔
2302
        prototype_arg = node.args[0]
4✔
2303
        prototype_name = self.visit(prototype_arg)
4✔
2304

2305
        # Parse shape from prototype
2306
        dims = []
4✔
2307
        if prototype_name in self.array_info:
4✔
2308
            dims = self.array_info[prototype_name]["shapes"]
4✔
2309

2310
        # Parse dtype
2311
        dtype_arg = None
4✔
2312
        if len(node.args) > 1:
4✔
2313
            dtype_arg = node.args[1]
×
2314

2315
        for kw in node.keywords:
4✔
2316
            if kw.arg == "dtype":
4✔
2317
                dtype_arg = kw.value
4✔
2318
                break
4✔
2319

2320
        element_type = None
4✔
2321
        if dtype_arg:
4✔
2322
            element_type = self._map_numpy_dtype(dtype_arg)
4✔
2323
        else:
2324
            if prototype_name in self.symbol_table:
4✔
2325
                sym_type = self.symbol_table[prototype_name]
4✔
2326
                if isinstance(sym_type, Pointer) and sym_type.has_pointee_type():
4✔
2327
                    element_type = sym_type.pointee_type
4✔
2328

2329
        if element_type is None:
4✔
2330
            element_type = Scalar(PrimitiveType.Double)
×
2331

2332
        return self._create_array_temp(
4✔
2333
            dims,
2334
            element_type,
2335
            zero_init=False,
2336
            ones_init=False,
2337
        )
2338

2339
    def _handle_numpy_zeros_like(self, node, func_name):
4✔
2340
        prototype_arg = node.args[0]
4✔
2341
        prototype_name = self.visit(prototype_arg)
4✔
2342

2343
        # Parse shape from prototype
2344
        dims = []
4✔
2345
        if prototype_name in self.array_info:
4✔
2346
            dims = self.array_info[prototype_name]["shapes"]
4✔
2347

2348
        # Parse dtype
2349
        dtype_arg = None
4✔
2350
        if len(node.args) > 1:
4✔
2351
            dtype_arg = node.args[1]
×
2352

2353
        for kw in node.keywords:
4✔
2354
            if kw.arg == "dtype":
4✔
2355
                dtype_arg = kw.value
4✔
2356
                break
4✔
2357

2358
        element_type = None
4✔
2359
        if dtype_arg:
4✔
2360
            element_type = self._map_numpy_dtype(dtype_arg)
4✔
2361
        else:
2362
            if prototype_name in self.symbol_table:
4✔
2363
                sym_type = self.symbol_table[prototype_name]
4✔
2364
                if isinstance(sym_type, Pointer) and sym_type.has_pointee_type():
4✔
2365
                    element_type = sym_type.pointee_type
4✔
2366

2367
        if element_type is None:
4✔
2368
            element_type = Scalar(PrimitiveType.Double)
×
2369

2370
        return self._create_array_temp(
4✔
2371
            dims,
2372
            element_type,
2373
            zero_init=True,
2374
            ones_init=False,
2375
        )
2376

2377
    def _handle_numpy_eye(self, node, func_name):
4✔
2378
        # Parse N
2379
        N_arg = node.args[0]
4✔
2380
        N_str = self.visit(N_arg)
4✔
2381

2382
        # Parse M
2383
        M_str = N_str
4✔
2384
        if len(node.args) > 1:
4✔
2385
            M_str = self.visit(node.args[1])
×
2386

2387
        # Parse k
2388
        k_str = "0"
4✔
2389
        if len(node.args) > 2:
4✔
2390
            k_str = self.visit(node.args[2])
×
2391

2392
        # Check keywords for M, k, dtype
2393
        dtype_arg = None
4✔
2394
        for kw in node.keywords:
4✔
2395
            if kw.arg == "M":
4✔
2396
                M_str = self.visit(kw.value)
4✔
2397
                if M_str == "None":
4✔
2398
                    M_str = N_str
4✔
2399
            elif kw.arg == "k":
4✔
2400
                k_str = self.visit(kw.value)
4✔
2401
            elif kw.arg == "dtype":
4✔
2402
                dtype_arg = kw.value
4✔
2403

2404
        element_type = self._map_numpy_dtype(dtype_arg)
4✔
2405

2406
        ptr_name = self._create_array_temp([N_str, M_str], element_type, zero_init=True)
4✔
2407

2408
        # Loop to set diagonal
2409
        loop_var = f"_i_{self._get_unique_id()}"
4✔
2410
        if not self.builder.exists(loop_var):
4✔
2411
            self.builder.add_container(loop_var, Scalar(PrimitiveType.Int64), False)
4✔
2412
            self.symbol_table[loop_var] = Scalar(PrimitiveType.Int64)
4✔
2413

2414
        self.builder.begin_for(loop_var, "0", N_str, "1")
4✔
2415

2416
        # Condition: 0 <= i + k < M
2417
        cond = f"(({loop_var} + {k_str}) >= 0) & (({loop_var} + {k_str}) < {M_str})"
4✔
2418
        self.builder.begin_if(cond)
4✔
2419

2420
        # Assignment: A[i, i+k] = 1
2421
        val = "1.0"
4✔
2422
        if element_type.primitive_type in [
4✔
2423
            PrimitiveType.Int64,
2424
            PrimitiveType.Int32,
2425
            PrimitiveType.Int8,
2426
            PrimitiveType.Int16,
2427
            PrimitiveType.UInt64,
2428
            PrimitiveType.UInt32,
2429
            PrimitiveType.UInt8,
2430
            PrimitiveType.UInt16,
2431
        ]:
2432
            val = "1"
×
2433

2434
        block_assign = self.builder.add_block()
4✔
2435
        t_const = self.builder.add_constant(block_assign, val, element_type)
4✔
2436
        t_arr = self.builder.add_access(block_assign, ptr_name)
4✔
2437
        flat_index = f"(({loop_var}) * ({M_str}) + ({loop_var}) + ({k_str}))"
4✔
2438
        subset = flat_index
4✔
2439

2440
        t_task = self.builder.add_tasklet(
4✔
2441
            block_assign, TaskletCode.assign, ["_in"], ["_out"]
2442
        )
2443
        self.builder.add_memlet(
4✔
2444
            block_assign, t_const, "void", t_task, "_in", "", element_type
2445
        )
2446
        self.builder.add_memlet(block_assign, t_task, "_out", t_arr, "void", subset)
4✔
2447

2448
        self.builder.end_if()
4✔
2449
        self.builder.end_for()
4✔
2450

2451
        return ptr_name
4✔
2452

2453
    def _handle_numpy_binary_op(self, node, func_name):
4✔
2454
        args = [self.visit(arg) for arg in node.args]
4✔
2455
        if len(args) != 2:
4✔
2456
            raise NotImplementedError(
×
2457
                f"Numpy function {func_name} requires 2 arguments"
2458
            )
2459

2460
        op_map = {
4✔
2461
            "add": "add",
2462
            "subtract": "sub",
2463
            "multiply": "mul",
2464
            "divide": "div",
2465
            "power": "pow",
2466
            "minimum": "min",
2467
            "maximum": "max",
2468
        }
2469
        return self._handle_array_binary_op(op_map[func_name], args[0], args[1])
4✔
2470

2471
    def _handle_numpy_where(self, node, func_name):
4✔
2472
        """Handle np.where(condition, x, y) - elementwise ternary selection.
2473

2474
        Returns an array where elements are taken from x where condition is True,
2475
        and from y where condition is False.
2476
        """
2477
        if len(node.args) != 3:
4✔
2478
            raise NotImplementedError("np.where requires 3 arguments (condition, x, y)")
×
2479

2480
        # Visit all arguments
2481
        cond_name = self.visit(node.args[0])
4✔
2482
        x_name = self.visit(node.args[1])
4✔
2483
        y_name = self.visit(node.args[2])
4✔
2484

2485
        # Determine output shape from the array arguments
2486
        # Priority: condition > y > x (since x might be scalar 0)
2487
        shape = []
4✔
2488
        dtype = Scalar(PrimitiveType.Double)
4✔
2489

2490
        # Check condition shape
2491
        if cond_name in self.array_info:
4✔
2492
            shape = self.array_info[cond_name]["shapes"]
4✔
2493

2494
        # If condition is scalar, check y
2495
        if not shape and y_name in self.array_info:
4✔
2496
            shape = self.array_info[y_name]["shapes"]
×
2497

2498
        # If y is scalar, check x
2499
        if not shape and x_name in self.array_info:
4✔
2500
            shape = self.array_info[x_name]["shapes"]
×
2501

2502
        if not shape:
4✔
2503
            raise NotImplementedError("np.where requires at least one array argument")
×
2504

2505
        # Determine dtype from y (since x might be scalar 0)
2506
        if y_name in self.symbol_table:
4✔
2507
            y_type = self.symbol_table[y_name]
4✔
2508
            if isinstance(y_type, Pointer) and y_type.has_pointee_type():
4✔
2509
                dtype = y_type.pointee_type
4✔
2510
            elif isinstance(y_type, Scalar):
×
2511
                dtype = y_type
×
2512

2513
        # Create output array
2514
        tmp_name = self._create_array_temp(shape, dtype)
4✔
2515

2516
        # Generate nested loops for the shape
2517
        loop_vars = []
4✔
2518
        for i, dim in enumerate(shape):
4✔
2519
            loop_var = f"_where_i{i}_{self._get_unique_id()}"
4✔
2520
            if not self.builder.exists(loop_var):
4✔
2521
                self.builder.add_container(loop_var, Scalar(PrimitiveType.Int64), False)
4✔
2522
                self.symbol_table[loop_var] = Scalar(PrimitiveType.Int64)
4✔
2523
            loop_vars.append(loop_var)
4✔
2524
            self.builder.begin_for(loop_var, "0", str(dim), "1")
4✔
2525

2526
        # Compute linear index
2527
        linear_idx = self._compute_linear_index(loop_vars, shape, tmp_name, len(shape))
4✔
2528

2529
        # Read condition value
2530
        cond_tmp = f"_where_cond_{self._get_unique_id()}"
4✔
2531
        self.builder.add_container(cond_tmp, Scalar(PrimitiveType.Bool), False)
4✔
2532
        self.symbol_table[cond_tmp] = Scalar(PrimitiveType.Bool)
4✔
2533

2534
        block_cond = self.builder.add_block()
4✔
2535
        if cond_name in self.array_info:
4✔
2536
            t_cond_arr = self.builder.add_access(block_cond, cond_name)
4✔
2537
            t_cond_out = self.builder.add_access(block_cond, cond_tmp)
4✔
2538
            t_cond_task = self.builder.add_tasklet(
4✔
2539
                block_cond, TaskletCode.assign, ["_in"], ["_out"]
2540
            )
2541
            self.builder.add_memlet(
4✔
2542
                block_cond, t_cond_arr, "void", t_cond_task, "_in", linear_idx
2543
            )
2544
            self.builder.add_memlet(
4✔
2545
                block_cond, t_cond_task, "_out", t_cond_out, "void", ""
2546
            )
2547
        else:
2548
            # Scalar condition - just use it directly
2549
            t_cond_src, cond_sub = self._add_read(block_cond, cond_name)
×
2550
            t_cond_out = self.builder.add_access(block_cond, cond_tmp)
×
2551
            t_cond_task = self.builder.add_tasklet(
×
2552
                block_cond, TaskletCode.assign, ["_in"], ["_out"]
2553
            )
2554
            self.builder.add_memlet(
×
2555
                block_cond, t_cond_src, "void", t_cond_task, "_in", cond_sub
2556
            )
2557
            self.builder.add_memlet(
×
2558
                block_cond, t_cond_task, "_out", t_cond_out, "void", ""
2559
            )
2560

2561
        # If-else based on condition
2562
        self.builder.begin_if(f"{cond_tmp} == true")
4✔
2563

2564
        # True branch: assign x
2565
        block_true = self.builder.add_block()
4✔
2566
        t_out_true = self.builder.add_access(block_true, tmp_name)
4✔
2567
        if x_name in self.array_info:
4✔
2568
            # x is an array
2569
            t_x = self.builder.add_access(block_true, x_name)
4✔
2570
            t_task_true = self.builder.add_tasklet(
4✔
2571
                block_true, TaskletCode.assign, ["_in"], ["_out"]
2572
            )
2573
            self.builder.add_memlet(
4✔
2574
                block_true, t_x, "void", t_task_true, "_in", linear_idx
2575
            )
2576
        else:
2577
            # x is a scalar
2578
            t_x, x_sub = self._add_read(block_true, x_name)
4✔
2579
            t_task_true = self.builder.add_tasklet(
4✔
2580
                block_true, TaskletCode.assign, ["_in"], ["_out"]
2581
            )
2582
            self.builder.add_memlet(block_true, t_x, "void", t_task_true, "_in", x_sub)
4✔
2583
        self.builder.add_memlet(
4✔
2584
            block_true, t_task_true, "_out", t_out_true, "void", linear_idx
2585
        )
2586

2587
        self.builder.begin_else()
4✔
2588

2589
        # False branch: assign y
2590
        block_false = self.builder.add_block()
4✔
2591
        t_out_false = self.builder.add_access(block_false, tmp_name)
4✔
2592
        if y_name in self.array_info:
4✔
2593
            # y is an array
2594
            t_y = self.builder.add_access(block_false, y_name)
4✔
2595
            t_task_false = self.builder.add_tasklet(
4✔
2596
                block_false, TaskletCode.assign, ["_in"], ["_out"]
2597
            )
2598
            self.builder.add_memlet(
4✔
2599
                block_false, t_y, "void", t_task_false, "_in", linear_idx
2600
            )
2601
        else:
2602
            # y is a scalar
2603
            t_y, y_sub = self._add_read(block_false, y_name)
4✔
2604
            t_task_false = self.builder.add_tasklet(
4✔
2605
                block_false, TaskletCode.assign, ["_in"], ["_out"]
2606
            )
2607
            self.builder.add_memlet(
4✔
2608
                block_false, t_y, "void", t_task_false, "_in", y_sub
2609
            )
2610
        self.builder.add_memlet(
4✔
2611
            block_false, t_task_false, "_out", t_out_false, "void", linear_idx
2612
        )
2613

2614
        self.builder.end_if()
4✔
2615

2616
        # Close all loops
2617
        for _ in loop_vars:
4✔
2618
            self.builder.end_for()
4✔
2619

2620
        return tmp_name
4✔
2621

2622
    def _handle_numpy_matmul_op(self, left_node, right_node):
4✔
2623
        return self._handle_matmul_helper(left_node, right_node)
4✔
2624

2625
    def _handle_numpy_matmul(self, node, func_name):
4✔
2626
        if len(node.args) != 2:
4✔
2627
            raise NotImplementedError("matmul/dot requires 2 arguments")
×
2628
        return self._handle_matmul_helper(node.args[0], node.args[1])
4✔
2629

2630
    def _handle_numpy_outer(self, node, func_name):
4✔
2631
        if len(node.args) != 2:
4✔
2632
            raise NotImplementedError("outer requires 2 arguments")
×
2633

2634
        arg0 = node.args[0]
4✔
2635
        arg1 = node.args[1]
4✔
2636

2637
        if not self.la_handler:
4✔
2638
            raise RuntimeError("LinearAlgebraHandler not initialized")
×
2639

2640
        res_a = self.la_handler.parse_arg(arg0)
4✔
2641
        res_b = self.la_handler.parse_arg(arg1)
4✔
2642

2643
        # Resolve standard names if parse_arg failed (likely complex expression)
2644
        if not res_a[0]:
4✔
2645
            left_name = self.visit(arg0)
×
2646
            arg0 = ast.Name(id=left_name)
×
2647
            res_a = self.la_handler.parse_arg(arg0)
×
2648

2649
        if not res_b[0]:
4✔
2650
            right_name = self.visit(arg1)
×
2651
            arg1 = ast.Name(id=right_name)
×
2652
            res_b = self.la_handler.parse_arg(arg1)
×
2653

2654
        name_a, subset_a, shape_a, indices_a = res_a
4✔
2655
        name_b, subset_b, shape_b, indices_b = res_b
4✔
2656

2657
        if not name_a or not name_b:
4✔
2658
            raise NotImplementedError("Could not resolve outer operands")
×
2659

2660
        def get_flattened_size_expr(name, indices, shapes):
4✔
2661
            # Simplified: if slice, we use parse_arg's returned `shapes` (which are dim sizes of the slice)
2662
            # And multiply them.
2663
            size_expr = "1"
4✔
2664
            for s in shapes:
4✔
2665
                if size_expr == "1":
4✔
2666
                    size_expr = str(s)
4✔
2667
                else:
2668
                    size_expr = f"({size_expr} * {str(s)})"
×
2669
            return size_expr
4✔
2670

2671
        m_expr = get_flattened_size_expr(name_a, indices_a, shape_a)
4✔
2672
        n_expr = get_flattened_size_expr(name_b, indices_b, shape_b)
4✔
2673

2674
        # Create temporary container
2675
        # Since outer usually promotes types or uses standard types, we default to double for now.
2676
        dtype = Scalar(PrimitiveType.Double)
4✔
2677

2678
        # Use helper to create array temp which handles symbol table and array info
2679
        tmp_name = self._create_array_temp([m_expr, n_expr], dtype)
4✔
2680

2681
        new_call_node = ast.Call(
4✔
2682
            func=node.func, args=[arg0, arg1], keywords=node.keywords
2683
        )
2684

2685
        self.la_handler.handle_outer(tmp_name, new_call_node)
4✔
2686

2687
        return tmp_name
4✔
2688

2689
    def _handle_ufunc_outer(self, node, ufunc_name):
4✔
2690
        """Handle np.add.outer, np.subtract.outer, np.multiply.outer, etc.
2691

2692
        These compute the outer operation for the given ufunc:
2693
        - np.add.outer(a, b) -> a[:, np.newaxis] + b (outer sum)
2694
        - np.subtract.outer(a, b) -> a[:, np.newaxis] - b (outer difference)
2695
        - np.multiply.outer(a, b) -> a[:, np.newaxis] * b (same as np.outer)
2696
        """
2697
        if len(node.args) != 2:
4✔
2698
            raise NotImplementedError(f"{ufunc_name}.outer requires 2 arguments")
×
2699

2700
        # For np.multiply.outer, use the existing GEMM-based outer handler
2701
        if ufunc_name == "multiply":
4✔
2702
            return self._handle_numpy_outer(node, "outer")
4✔
2703

2704
        # Map ufunc names to operation names and tasklet opcodes
2705
        op_map = {
4✔
2706
            "add": ("add", TaskletCode.fp_add, TaskletCode.int_add),
2707
            "subtract": ("sub", TaskletCode.fp_sub, TaskletCode.int_sub),
2708
            "divide": ("div", TaskletCode.fp_div, TaskletCode.int_sdiv),
2709
            "minimum": ("min", CMathFunction.fmin, TaskletCode.int_smin),
2710
            "maximum": ("max", CMathFunction.fmax, TaskletCode.int_smax),
2711
        }
2712

2713
        if ufunc_name not in op_map:
4✔
2714
            raise NotImplementedError(f"{ufunc_name}.outer not supported")
×
2715

2716
        op_name, fp_opcode, int_opcode = op_map[ufunc_name]
4✔
2717

2718
        # Use la_handler.parse_arg to properly handle sliced arrays
2719
        if not self.la_handler:
4✔
2720
            raise RuntimeError("LinearAlgebraHandler not initialized")
×
2721

2722
        arg0 = node.args[0]
4✔
2723
        arg1 = node.args[1]
4✔
2724

2725
        res_a = self.la_handler.parse_arg(arg0)
4✔
2726
        res_b = self.la_handler.parse_arg(arg1)
4✔
2727

2728
        # If parse_arg fails for complex expressions, try visiting and re-parsing
2729
        if not res_a[0]:
4✔
2730
            left_name = self.visit(arg0)
×
2731
            arg0 = ast.Name(id=left_name)
×
2732
            res_a = self.la_handler.parse_arg(arg0)
×
2733

2734
        if not res_b[0]:
4✔
2735
            right_name = self.visit(arg1)
×
2736
            arg1 = ast.Name(id=right_name)
×
2737
            res_b = self.la_handler.parse_arg(arg1)
×
2738

2739
        name_a, subset_a, shape_a, indices_a = res_a
4✔
2740
        name_b, subset_b, shape_b, indices_b = res_b
4✔
2741

2742
        if not name_a or not name_b:
4✔
2743
            raise NotImplementedError("Could not resolve ufunc outer operands")
×
2744

2745
        # Compute flattened sizes - outer treats inputs as 1D
2746
        def get_flattened_size_expr(shapes):
4✔
2747
            if not shapes:
4✔
2748
                return "1"
×
2749
            size_expr = str(shapes[0])
4✔
2750
            for s in shapes[1:]:
4✔
2751
                size_expr = f"({size_expr} * {str(s)})"
×
2752
            return size_expr
4✔
2753

2754
        m_expr = get_flattened_size_expr(shape_a)
4✔
2755
        n_expr = get_flattened_size_expr(shape_b)
4✔
2756

2757
        # Determine output dtype - infer from inputs or default to double
2758
        dtype_left = self._get_dtype(name_a)
4✔
2759
        dtype_right = self._get_dtype(name_b)
4✔
2760
        dtype = self._promote_dtypes(dtype_left, dtype_right)
4✔
2761

2762
        # Determine if we're working with integers
2763
        is_int = dtype.primitive_type in [
4✔
2764
            PrimitiveType.Int64,
2765
            PrimitiveType.Int32,
2766
            PrimitiveType.Int8,
2767
            PrimitiveType.Int16,
2768
            PrimitiveType.UInt64,
2769
            PrimitiveType.UInt32,
2770
            PrimitiveType.UInt8,
2771
            PrimitiveType.UInt16,
2772
        ]
2773

2774
        # Create output array with shape (M, N)
2775
        tmp_name = self._create_array_temp([m_expr, n_expr], dtype)
4✔
2776

2777
        # Generate unique loop variable names
2778
        i_var = self._get_temp_name("_outer_i_")
4✔
2779
        j_var = self._get_temp_name("_outer_j_")
4✔
2780

2781
        # Ensure loop variables exist
2782
        if not self.builder.exists(i_var):
4✔
2783
            self.builder.add_container(i_var, Scalar(PrimitiveType.Int64), False)
4✔
2784
            self.symbol_table[i_var] = Scalar(PrimitiveType.Int64)
4✔
2785
        if not self.builder.exists(j_var):
4✔
2786
            self.builder.add_container(j_var, Scalar(PrimitiveType.Int64), False)
4✔
2787
            self.symbol_table[j_var] = Scalar(PrimitiveType.Int64)
4✔
2788

2789
        # Helper function to compute the linear index for a sliced array access
2790
        def compute_linear_index(name, subset, indices, loop_var):
4✔
2791
            """
2792
            Compute linear index for accessing element loop_var of a sliced array.
2793

2794
            For array A with shape (N, M):
2795
            - A[:, k] (column k): linear_index = loop_var * M + k
2796
            - A[k, :] (row k): linear_index = k * M + loop_var
2797
            - A[:] (1D array): linear_index = loop_var
2798

2799
            The indices list contains AST nodes showing which dims are sliced vs fixed.
2800
            subset contains start indices for each dimension.
2801
            """
2802
            if not indices:
4✔
2803
                # Simple 1D array, no slicing
2804
                return loop_var
4✔
2805

2806
            info = self.array_info.get(name, {})
4✔
2807
            shapes = info.get("shapes", [])
4✔
2808
            ndim = info.get("ndim", len(shapes))
4✔
2809

2810
            if ndim == 0:
4✔
2811
                return loop_var
×
2812

2813
            # Compute strides (row-major order)
2814
            strides = []
4✔
2815
            current_stride = "1"
4✔
2816
            for i in range(ndim - 1, -1, -1):
4✔
2817
                strides.insert(0, current_stride)
4✔
2818
                if i > 0:
4✔
2819
                    dim_size = shapes[i] if i < len(shapes) else f"_{name}_shape_{i}"
4✔
2820
                    if current_stride == "1":
4✔
2821
                        current_stride = str(dim_size)
4✔
2822
                    else:
2823
                        current_stride = f"({current_stride} * {dim_size})"
×
2824

2825
            # Build linear index from subset and indices info
2826
            terms = []
4✔
2827
            loop_var_used = False
4✔
2828

2829
            for i, idx in enumerate(indices):
4✔
2830
                stride = strides[i] if i < len(strides) else "1"
4✔
2831
                start = subset[i] if i < len(subset) else "0"
4✔
2832

2833
                if isinstance(idx, ast.Slice):
4✔
2834
                    # This dimension is sliced - use loop_var
2835
                    if stride == "1":
4✔
2836
                        term = f"({start} + {loop_var})"
4✔
2837
                    else:
2838
                        term = f"(({start} + {loop_var}) * {stride})"
4✔
2839
                    loop_var_used = True
4✔
2840
                else:
2841
                    # This dimension has a fixed index
2842
                    if stride == "1":
4✔
2843
                        term = start
4✔
2844
                    else:
2845
                        term = f"({start} * {stride})"
4✔
2846

2847
                terms.append(term)
4✔
2848

2849
            # Sum all terms
2850
            if not terms:
4✔
2851
                return loop_var
×
2852

2853
            result = terms[0]
4✔
2854
            for t in terms[1:]:
4✔
2855
                result = f"({result} + {t})"
4✔
2856

2857
            return result
4✔
2858

2859
        # Create nested for loops: for i in range(M): for j in range(N): C[i,j] = A[i] op B[j]
2860
        self.builder.begin_for(i_var, "0", m_expr, "1")
4✔
2861
        self.builder.begin_for(j_var, "0", n_expr, "1")
4✔
2862

2863
        # Create the assignment block: C[i, j] = A[i] op B[j]
2864
        block = self.builder.add_block()
4✔
2865

2866
        # Add access nodes
2867
        t_a = self.builder.add_access(block, name_a)
4✔
2868
        t_b = self.builder.add_access(block, name_b)
4✔
2869
        t_c = self.builder.add_access(block, tmp_name)
4✔
2870

2871
        # Determine tasklet type based on operation
2872
        if ufunc_name in ["minimum", "maximum"]:
4✔
2873
            # Use intrinsic for min/max
2874
            if is_int:
4✔
2875
                t_task = self.builder.add_tasklet(
4✔
2876
                    block, int_opcode, ["_in1", "_in2"], ["_out"]
2877
                )
2878
            else:
2879
                t_task = self.builder.add_cmath(block, fp_opcode)
4✔
2880
        else:
2881
            # Use regular tasklet for arithmetic ops
2882
            tasklet_code = int_opcode if is_int else fp_opcode
4✔
2883
            t_task = self.builder.add_tasklet(
4✔
2884
                block, tasklet_code, ["_in1", "_in2"], ["_out"]
2885
            )
2886

2887
        # Compute the linear index for A[i]
2888
        a_index = compute_linear_index(name_a, subset_a, indices_a, i_var)
4✔
2889

2890
        # Compute the linear index for B[j]
2891
        b_index = compute_linear_index(name_b, subset_b, indices_b, j_var)
4✔
2892

2893
        # Connect A[i + offset_a] -> tasklet
2894
        self.builder.add_memlet(block, t_a, "void", t_task, "_in1", a_index)
4✔
2895

2896
        # Connect B[j + offset_b] -> tasklet
2897
        self.builder.add_memlet(block, t_b, "void", t_task, "_in2", b_index)
4✔
2898

2899
        # Connect tasklet -> C[i * N + j] (linear index for 2D output)
2900
        flat_index = f"(({i_var}) * ({n_expr}) + ({j_var}))"
4✔
2901
        self.builder.add_memlet(block, t_task, "_out", t_c, "void", flat_index)
4✔
2902

2903
        self.builder.end_for()  # end j loop
4✔
2904
        self.builder.end_for()  # end i loop
4✔
2905

2906
        return tmp_name
4✔
2907

2908
    def _op_symbol(self, op_name):
4✔
2909
        """Convert operation name to symbol."""
2910
        symbols = {
×
2911
            "add": "+",
2912
            "sub": "-",
2913
            "mul": "*",
2914
            "div": "/",
2915
            "min": "min",  # Will need special handling
2916
            "max": "max",  # Will need special handling
2917
        }
2918
        return symbols.get(op_name, op_name)
×
2919

2920
    def _handle_matmul_helper(self, left_node, right_node):
4✔
2921
        if not self.la_handler:
4✔
2922
            raise RuntimeError("LinearAlgebraHandler not initialized")
×
2923

2924
        res_a = self.la_handler.parse_arg(left_node)
4✔
2925
        res_b = self.la_handler.parse_arg(right_node)
4✔
2926

2927
        if not res_a[0]:
4✔
2928
            left_name = self.visit(left_node)
×
2929
            left_node = ast.Name(id=left_name)
×
2930
            res_a = self.la_handler.parse_arg(left_node)
×
2931

2932
        if not res_b[0]:
4✔
2933
            right_name = self.visit(right_node)
4✔
2934
            right_node = ast.Name(id=right_name)
4✔
2935
            res_b = self.la_handler.parse_arg(right_node)
4✔
2936

2937
        name_a, subset_a, shape_a, indices_a = res_a
4✔
2938
        name_b, subset_b, shape_b, indices_b = res_b
4✔
2939

2940
        if not name_a or not name_b:
4✔
2941
            raise NotImplementedError("Could not resolve matmul operands")
×
2942

2943
        real_shape_a = shape_a
4✔
2944
        real_shape_b = shape_b
4✔
2945

2946
        ndim_a = len(real_shape_a)
4✔
2947
        ndim_b = len(real_shape_b)
4✔
2948

2949
        output_shape = []
4✔
2950
        is_scalar = False
4✔
2951

2952
        if ndim_a == 1 and ndim_b == 1:
4✔
2953
            is_scalar = True
4✔
2954
            output_shape = []
4✔
2955
        elif ndim_a == 2 and ndim_b == 2:
4✔
2956
            output_shape = [real_shape_a[0], real_shape_b[1]]
4✔
2957
        elif ndim_a == 2 and ndim_b == 1:
4✔
2958
            output_shape = [real_shape_a[0]]
4✔
2959
        elif ndim_a == 1 and ndim_b == 2:
4✔
2960
            output_shape = [real_shape_b[1]]
×
2961
        elif ndim_a > 2 or ndim_b > 2:
4✔
2962
            if ndim_a == ndim_b:
4✔
2963
                output_shape = list(real_shape_a[:-2]) + [
4✔
2964
                    real_shape_a[-2],
2965
                    real_shape_b[-1],
2966
                ]
2967
            else:
2968
                raise NotImplementedError(
×
2969
                    "Broadcasting with different ranks not fully supported yet"
2970
                )
2971
        else:
2972
            raise NotImplementedError(
×
2973
                f"Matmul with ranks {ndim_a} and {ndim_b} not supported"
2974
            )
2975

2976
        dtype = Scalar(PrimitiveType.Double)
4✔
2977

2978
        if is_scalar:
4✔
2979
            tmp_name = f"_tmp_{self._get_unique_id()}"
4✔
2980
            self.builder.add_container(tmp_name, dtype, False)
4✔
2981
            self.symbol_table[tmp_name] = dtype
4✔
2982
        else:
2983
            tmp_name = self._create_array_temp(output_shape, dtype)
4✔
2984

2985
        if ndim_a > 2 or ndim_b > 2:
4✔
2986
            # Generate loops for broadcasting
2987
            batch_dims = ndim_a - 2
4✔
2988
            loop_vars = []
4✔
2989

2990
            for i in range(batch_dims):
4✔
2991
                loop_var = f"_i{self._get_unique_id()}"
4✔
2992
                self.builder.add_container(loop_var, Scalar(PrimitiveType.Int64), False)
4✔
2993
                loop_vars.append(loop_var)
4✔
2994
                dim_size = real_shape_a[i]
4✔
2995
                self.builder.begin_for(loop_var, "0", str(dim_size), "1")
4✔
2996

2997
            def make_slice(name, indices):
4✔
2998
                elts = []
4✔
2999
                for idx in indices:
4✔
3000
                    if idx == ":":
4✔
3001
                        elts.append(ast.Slice())
4✔
3002
                    else:
3003
                        elts.append(ast.Name(id=idx))
4✔
3004

3005
                return ast.Subscript(
4✔
3006
                    value=ast.Name(id=name), slice=ast.Tuple(elts=elts), ctx=ast.Load()
3007
                )
3008

3009
            indices = loop_vars + [":", ":"]
4✔
3010
            slice_a = make_slice(name_a, indices)
4✔
3011
            slice_b = make_slice(name_b, indices)
4✔
3012
            slice_c = make_slice(tmp_name, indices)
4✔
3013

3014
            self.la_handler.handle_gemm(
4✔
3015
                slice_c, ast.BinOp(left=slice_a, op=ast.MatMult(), right=slice_b)
3016
            )
3017

3018
            for _ in range(batch_dims):
4✔
3019
                self.builder.end_for()
4✔
3020
        else:
3021
            if is_scalar:
4✔
3022
                self.la_handler.handle_dot(
4✔
3023
                    tmp_name,
3024
                    ast.BinOp(left=left_node, op=ast.MatMult(), right=right_node),
3025
                )
3026
            else:
3027
                self.la_handler.handle_gemm(
4✔
3028
                    tmp_name,
3029
                    ast.BinOp(left=left_node, op=ast.MatMult(), right=right_node),
3030
                )
3031

3032
        return tmp_name
4✔
3033

3034
    def _handle_numpy_unary_op(self, node, func_name):
4✔
3035
        args = [self.visit(arg) for arg in node.args]
4✔
3036
        if len(args) != 1:
4✔
3037
            raise NotImplementedError(f"Numpy function {func_name} requires 1 argument")
×
3038

3039
        op_name = func_name
4✔
3040
        if op_name == "absolute":
4✔
3041
            op_name = "abs"
×
3042

3043
        return self._handle_array_unary_op(op_name, args[0])
4✔
3044

3045
    def _handle_numpy_reduce(self, node, func_name):
4✔
3046
        args = node.args
4✔
3047
        keywords = {kw.arg: kw.value for kw in node.keywords}
4✔
3048

3049
        array_node = args[0]
4✔
3050
        array_name = self.visit(array_node)
4✔
3051

3052
        if array_name not in self.array_info:
4✔
3053
            raise ValueError(f"Reduction input must be an array, got {array_name}")
×
3054

3055
        input_shape = self.array_info[array_name]["shapes"]
4✔
3056
        ndim = len(input_shape)
4✔
3057

3058
        axis = None
4✔
3059
        if len(args) > 1:
4✔
3060
            axis = args[1]
×
3061
        elif "axis" in keywords:
4✔
3062
            axis = keywords["axis"]
4✔
3063

3064
        keepdims = False
4✔
3065
        if "keepdims" in keywords:
4✔
3066
            keepdims_node = keywords["keepdims"]
4✔
3067
            if isinstance(keepdims_node, ast.Constant):
4✔
3068
                keepdims = bool(keepdims_node.value)
4✔
3069

3070
        axes = []
4✔
3071
        if axis is None:
4✔
3072
            axes = list(range(ndim))
4✔
3073
        elif isinstance(axis, ast.Constant):  # Single axis
4✔
3074
            val = axis.value
4✔
3075
            if val < 0:
4✔
3076
                val += ndim
×
3077
            axes = [val]
4✔
3078
        elif isinstance(axis, ast.Tuple):  # Multiple axes
×
3079
            for elt in axis.elts:
×
3080
                if isinstance(elt, ast.Constant):
×
3081
                    val = elt.value
×
3082
                    if val < 0:
×
3083
                        val += ndim
×
3084
                    axes.append(val)
×
3085
        elif (
×
3086
            isinstance(axis, ast.UnaryOp)
3087
            and isinstance(axis.op, ast.USub)
3088
            and isinstance(axis.operand, ast.Constant)
3089
        ):
3090
            val = -axis.operand.value
×
3091
            if val < 0:
×
3092
                val += ndim
×
3093
            axes = [val]
×
3094
        else:
3095
            # Try to evaluate simple expression
3096
            try:
×
3097
                val = int(self.visit(axis))
×
3098
                if val < 0:
×
3099
                    val += ndim
×
3100
                axes = [val]
×
3101
            except:
×
3102
                raise NotImplementedError("Dynamic axis not supported")
×
3103

3104
        # Calculate output shape
3105
        output_shape = []
4✔
3106
        for i in range(ndim):
4✔
3107
            if i in axes:
4✔
3108
                if keepdims:
4✔
3109
                    output_shape.append("1")
4✔
3110
            else:
3111
                output_shape.append(input_shape[i])
4✔
3112

3113
        dtype = self._get_dtype(array_name)
4✔
3114

3115
        if not output_shape:
4✔
3116
            tmp_name = f"_tmp_{self._get_unique_id()}"
4✔
3117
            self.builder.add_container(tmp_name, dtype, False)
4✔
3118
            self.symbol_table[tmp_name] = dtype
4✔
3119
            self.array_info[tmp_name] = {"ndim": 0, "shapes": []}
4✔
3120
        else:
3121
            tmp_name = self._create_array_temp(output_shape, dtype)
4✔
3122

3123
        self.builder.add_reduce_op(
4✔
3124
            func_name, array_name, tmp_name, input_shape, axes, keepdims
3125
        )
3126

3127
        return tmp_name
4✔
3128

3129
    def _handle_numpy_astype(self, node, array_name):
4✔
3130
        """Handle numpy array.astype(dtype) method calls."""
3131
        if len(node.args) < 1:
4✔
3132
            raise ValueError("astype requires at least one argument (dtype)")
×
3133

3134
        dtype_arg = node.args[0]
4✔
3135
        target_dtype = self._map_numpy_dtype(dtype_arg)
4✔
3136

3137
        # Get input array shape
3138
        if array_name not in self.array_info:
4✔
3139
            raise ValueError(f"Array {array_name} not found in array_info")
×
3140

3141
        input_shape = self.array_info[array_name]["shapes"]
4✔
3142

3143
        # Create output array with target dtype
3144
        tmp_name = self._create_array_temp(input_shape, target_dtype)
4✔
3145

3146
        # Add cast operation
3147
        self.builder.add_cast_op(
4✔
3148
            array_name, tmp_name, input_shape, target_dtype.primitive_type
3149
        )
3150

3151
        return tmp_name
4✔
3152

3153
    def _handle_scipy_softmax(self, node, func_name):
4✔
3154
        args = node.args
4✔
3155
        keywords = {kw.arg: kw.value for kw in node.keywords}
4✔
3156

3157
        array_node = args[0]
4✔
3158
        array_name = self.visit(array_node)
4✔
3159

3160
        if array_name not in self.array_info:
4✔
3161
            raise ValueError(f"Softmax input must be an array, got {array_name}")
×
3162

3163
        input_shape = self.array_info[array_name]["shapes"]
4✔
3164
        ndim = len(input_shape)
4✔
3165

3166
        axis = None
4✔
3167
        if len(args) > 1:
4✔
3168
            axis = args[1]
×
3169
        elif "axis" in keywords:
4✔
3170
            axis = keywords["axis"]
4✔
3171

3172
        axes = []
4✔
3173
        if axis is None:
4✔
3174
            axes = list(range(ndim))
4✔
3175
        elif isinstance(axis, ast.Constant):  # Single axis
4✔
3176
            val = axis.value
4✔
3177
            if val < 0:
4✔
3178
                val += ndim
×
3179
            axes = [val]
4✔
3180
        elif isinstance(axis, ast.Tuple):  # Multiple axes
×
3181
            for elt in axis.elts:
×
3182
                if isinstance(elt, ast.Constant):
×
3183
                    val = elt.value
×
3184
                    if val < 0:
×
3185
                        val += ndim
×
3186
                    axes.append(val)
×
3187
        elif (
×
3188
            isinstance(axis, ast.UnaryOp)
3189
            and isinstance(axis.op, ast.USub)
3190
            and isinstance(axis.operand, ast.Constant)
3191
        ):
3192
            val = -axis.operand.value
×
3193
            if val < 0:
×
3194
                val += ndim
×
3195
            axes = [val]
×
3196
        else:
3197
            # Try to evaluate simple expression
3198
            try:
×
3199
                val = int(self.visit(axis))
×
3200
                if val < 0:
×
3201
                    val += ndim
×
3202
                axes = [val]
×
3203
            except:
×
3204
                raise NotImplementedError("Dynamic axis not supported")
×
3205

3206
        # Create output array
3207
        # Assume double for now, or infer from input
3208
        dtype = Scalar(PrimitiveType.Double)  # TODO: infer
4✔
3209

3210
        tmp_name = self._create_array_temp(input_shape, dtype)
4✔
3211

3212
        self.builder.add_reduce_op(
4✔
3213
            func_name, array_name, tmp_name, input_shape, axes, False
3214
        )
3215

3216
        return tmp_name
4✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc