• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

daisytuner / docc / 28685979841

03 Jul 2026 10:26PM UTC coverage: 62.417% (+0.3%) from 62.147%
28685979841

Pull #832

github

web-flow
Merge 76677d737 into 3726be1d9
Pull Request #832: activates numpy tests

99 of 112 new or added lines in 2 files covered. (88.39%)

22 existing lines in 2 files now uncovered.

39691 of 63590 relevant lines covered (62.42%)

978.43 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

78.79
/python/docc/python/functions/numpy.py
1
import ast
4✔
2
from docc.sdfg import (
4✔
3
    Scalar,
4
    PrimitiveType,
5
    Pointer,
6
    DebugInfo,
7
    TaskletCode,
8
    CMathFunction,
9
    Tensor,
10
)
11
from docc.python.types import (
4✔
12
    element_type_from_ast_node,
13
    promote_element_types,
14
    numpy_promote_types,
15
)
16
from docc.python.ast_utils import get_debug_info
4✔
17
from docc.python.memory import ManagedMemoryHandler
4✔
18

19

20
class NumPyHandler:
4✔
21
    """
22
    Unified handler for NumPy operations including:
23
    - Array creation (empty, zeros, ones, eye, etc.)
24
    - Elementwise operations (add, subtract, multiply, etc.)
25
    - Linear algebra (matmul, dot, outer, gemm)
26
    - Array manipulation (transpose)
27
    - Reductions (sum, max, min, mean, std)
28
    """
29

30
    def __init__(self, expression_visitor):
4✔
31
        self._ev = expression_visitor
4✔
32
        self._unique_counter = 0
4✔
33
        self.function_handlers = {
4✔
34
            "empty": self._handle_numpy_alloc,
35
            "empty_like": self._handle_numpy_empty_like,
36
            "zeros": self._handle_numpy_alloc,
37
            "zeros_like": self._handle_numpy_zeros_like,
38
            "ones": self._handle_numpy_alloc,
39
            "ndarray": self._handle_numpy_alloc,
40
            "eye": self._handle_numpy_eye,
41
            "add": self._handle_numpy_binary_op,
42
            "subtract": self._handle_numpy_binary_op,
43
            "multiply": self._handle_numpy_binary_op,
44
            "divide": self._handle_numpy_binary_op,
45
            "power": self._handle_numpy_binary_op,
46
            "exp": self._handle_numpy_unary_op,
47
            "abs": self._handle_numpy_unary_op,
48
            "absolute": self._handle_numpy_unary_op,
49
            "sqrt": self._handle_numpy_unary_op,
50
            "tanh": self._handle_numpy_unary_op,
51
            "sum": self._handle_numpy_reduce,
52
            "max": self._handle_numpy_reduce,
53
            "min": self._handle_numpy_reduce,
54
            "mean": self._handle_numpy_reduce,
55
            "std": self._handle_numpy_reduce,
56
            "matmul": self._handle_numpy_matmul,
57
            "dot": self._handle_numpy_matmul,
58
            "matvec": self._handle_numpy_matmul,
59
            "outer": self._handle_numpy_outer,
60
            "minimum": self._handle_numpy_binary_op,
61
            "maximum": self._handle_numpy_binary_op,
62
            "where": self._handle_numpy_where,
63
            "clip": self._handle_numpy_clip,
64
            "transpose": self._handle_numpy_transpose,
65
            "flip": self._handle_numpy_flip,
66
            "fliplr": self._handle_numpy_fliplr,
67
            "flipud": self._handle_numpy_flipud,
68
            "reshape": self._handle_numpy_reshape,
69
            "einsum": self._handle_numpy_einsum,
70
        }
71

72
    # Expose parent properties for convenience
73
    @property
4✔
74
    def tensor_table(self):
4✔
75
        return self._ev.tensor_table
4✔
76

77
    @property
4✔
78
    def builder(self):
4✔
79
        return self._ev.builder
4✔
80

81
    @property
4✔
82
    def container_table(self):
4✔
83
        return self._ev.container_table
4✔
84

85
    @property
4✔
86
    def globals_dict(self):
4✔
87
        return self._ev.globals_dict
×
88

89
    @property
4✔
90
    def shapes_runtime_info(self):
4✔
91
        return self._ev.shapes_runtime_info
4✔
92

93
    @property
4✔
94
    def memory_handler(self):
4✔
95
        """Access the memory handler owned by the parser."""
96
        return self._ev.memory_handler
4✔
97

98
    def _get_unique_id(self):
4✔
99
        return self._ev._get_unique_id()
4✔
100

101
    def _add_read(self, block, expr_str, debug_info=None):
4✔
102
        return self._ev._add_read(block, expr_str, debug_info)
4✔
103

104
    def _is_int(self, operand):
4✔
105
        return self._ev._is_int(operand)
4✔
106

107
    def visit(self, node):
4✔
108
        return self._ev.visit(node)
4✔
109

110
    # ========== Linear Algebra Helper Methods (from LinearAlgebraHandler) ==========
111

112
    def parse_arg(self, node):
4✔
113
        """Parse an array argument, returning (name, start_indices, slice_shape, indices).
114

115
        Returns None for 0-d arrays since they are scalars, not valid array operands
116
        for linear algebra operations.
117
        """
118
        if isinstance(node, ast.Name):
4✔
119
            if node.id in self.tensor_table:
4✔
120
                shape = self.tensor_table[node.id].shape
4✔
121
                # Reject 0-d arrays (scalars) - not valid for linalg ops
122
                if len(shape) == 0:
4✔
123
                    return None, None, None, None
×
124
                return node.id, [], shape, []
4✔
125
        elif isinstance(node, ast.Subscript):
4✔
126
            if isinstance(node.value, ast.Name) and node.value.id in self.tensor_table:
4✔
127
                name = node.value.id
4✔
128
                indices = []
4✔
129
                if isinstance(node.slice, ast.Tuple):
4✔
130
                    indices = node.slice.elts
4✔
131
                else:
132
                    indices = [node.slice]
4✔
133

134
                start_indices = []
4✔
135
                slice_shape = []
4✔
136

137
                for i, idx in enumerate(indices):
4✔
138
                    if isinstance(idx, ast.Slice):
4✔
139
                        start = "0"
4✔
140
                        if idx.lower:
4✔
141
                            start = self._ev.visit(idx.lower)
4✔
142
                        start_indices.append(start)
4✔
143

144
                        shapes = self.tensor_table[name].shape
4✔
145
                        dim_size = (
4✔
146
                            shapes[i] if i < len(shapes) else f"_{name}_shape_{i}"
147
                        )
148
                        stop = dim_size
4✔
149
                        if idx.upper:
4✔
150
                            stop = self._ev.visit(idx.upper)
4✔
151

152
                        size = f"({stop} - {start})"
4✔
153
                        slice_shape.append(size)
4✔
154
                    else:
155
                        if isinstance(idx, ast.Name) and idx.id in self.tensor_table:
4✔
156
                            # This is an array index (gather operation)
157
                            return None, None, None, None
4✔
158
                        val = self._ev.visit(idx)
4✔
159
                        start_indices.append(val)
4✔
160

161
                return name, start_indices, slice_shape, indices
4✔
162

163
        return None, None, None, None
4✔
164

165
    def flatten_subset(self, name, start_indices):
4✔
166
        """Convert multi-dimensional start indices to a flattened linear offset."""
167
        if not start_indices:
4✔
168
            return []
4✔
169
        info = self.tensor_table[name]
4✔
170
        shapes = info.shape
4✔
171
        ndim = len(info.shape)
4✔
172

173
        if len(start_indices) != ndim:
4✔
174
            return start_indices
4✔
175

176
        strides = []
4✔
177
        current_stride = "1"
4✔
178
        strides.append(current_stride)
4✔
179
        for i in range(ndim - 1, 0, -1):
4✔
180
            dim_size = shapes[i]
4✔
181
            if current_stride == "1":
4✔
182
                current_stride = str(dim_size)
4✔
183
            else:
184
                current_stride = f"({current_stride} * {dim_size})"
4✔
185
            strides.append(current_stride)
4✔
186
        strides = list(reversed(strides))
4✔
187

188
        offset = "0"
4✔
189
        for i in range(ndim):
4✔
190
            idx = start_indices[i]
4✔
191
            stride = strides[i]
4✔
192
            term = f"({idx} * {stride})" if stride != "1" else idx
4✔
193
            if offset == "0":
4✔
194
                offset = term
4✔
195
            else:
196
                offset = f"({offset} + {term})"
4✔
197

198
        return [offset]
4✔
199

200
    def is_gemm(self, node):
4✔
201
        """Check if a node represents a GEMM operation (matrix multiplication)."""
202
        if isinstance(node, ast.BinOp) and isinstance(node.op, ast.MatMult):
4✔
203
            return True
4✔
204
        if isinstance(node, ast.Call):
4✔
205
            if isinstance(node.func, ast.Attribute) and node.func.attr == "dot":
4✔
206
                return True
×
207
            if isinstance(node.func, ast.Name) and node.func.id == "dot":
4✔
208
                return True
×
209
            if isinstance(node.func, ast.Attribute) and node.func.attr == "matmul":
4✔
210
                return True
×
211
            if isinstance(node.func, ast.Name) and node.func.id == "matmul":
4✔
212
                return True
×
213
        if isinstance(node, ast.BinOp) and isinstance(node.op, ast.Add):
4✔
214
            return self.is_gemm(node.left) or self.is_gemm(node.right)
4✔
215
        if isinstance(node, ast.BinOp) and isinstance(node.op, ast.Mult):
4✔
216
            return self.is_gemm(node.left) or self.is_gemm(node.right)
4✔
217
        return False
4✔
218

219
    def _is_stride_1(self, name, indices):
4✔
220
        """Check if the sliced dimension has stride 1 (contiguous access)."""
221
        if name not in self.tensor_table:
4✔
222
            return True
×
223
        info = self.tensor_table[name]
4✔
224
        ndim = len(info.shape)
4✔
225

226
        if not indices:
4✔
227
            return True
4✔
228

229
        sliced_dim = -1
×
230
        for i, idx in enumerate(indices):
×
231
            if isinstance(idx, ast.Slice):
×
232
                sliced_dim = i
×
233
                break
×
234

235
        if sliced_dim == -1:
×
236
            if len(indices) < ndim:
×
237
                sliced_dim = ndim - 1
×
238
            else:
239
                return True
×
240

241
        return sliced_dim == ndim - 1
×
242

243
    def _is_target(self, node, target_name):
4✔
244
        """Check if node refers to the target."""
245
        if isinstance(target_name, ast.AST):
4✔
246
            return self._ev.visit(node) == self._ev.visit(target_name)
4✔
247

248
        if isinstance(node, ast.Name) and node.id == target_name:
4✔
249
            return True
×
250
        if isinstance(node, ast.Subscript):
4✔
251
            if isinstance(node.value, ast.Name) and node.value.id == target_name:
4✔
252
                return True
4✔
253
        return False
4✔
254

255
    def _is_dot_call(self, node):
4✔
256
        """Check if node is a dot product call."""
257
        if isinstance(node, ast.Call):
4✔
258
            if isinstance(node.func, ast.Attribute) and node.func.attr == "dot":
×
259
                return True
×
260
            if isinstance(node.func, ast.Name) and node.func.id == "dot":
×
261
                return True
×
262
        if isinstance(node, ast.BinOp) and isinstance(node.op, ast.MatMult):
4✔
263
            return True
4✔
264
        return False
4✔
265

266
    def handle_gemm(self, target, value_node):
4✔
267
        """Handle GEMM (General Matrix Multiply) operations: C = alpha * A @ B + beta * C."""
268
        target_name = None
4✔
269
        target_subset = []
4✔
270

271
        if isinstance(target, str):
4✔
272
            target_name = target
4✔
273
        elif isinstance(target, ast.Name):
4✔
274
            target_name = target.id
4✔
275
        elif isinstance(target, ast.Subscript):
4✔
276
            if isinstance(target.value, ast.Name):
4✔
277
                res = self.parse_arg(target)
4✔
278
                if res[0]:
4✔
279
                    target_name = res[0]
4✔
280
                    target_subset = self.flatten_subset(target_name, res[1])
4✔
281
                else:
282
                    target_name = target.value.id
×
283

284
        if not target_name or target_name not in self.tensor_table:
4✔
285
            return False
4✔
286

287
        alpha = "1.0"
4✔
288
        beta = "0.0"
4✔
289
        A = None
4✔
290
        B = None
4✔
291

292
        def extract_factor(node):
4✔
293
            if isinstance(node, ast.BinOp) and isinstance(node.op, ast.Mult):
4✔
294
                if self.is_gemm(node.left):
×
295
                    return node.left, self._ev.visit(node.right)
×
296
                if self.is_gemm(node.right):
×
297
                    return node.right, self._ev.visit(node.left)
×
298

299
                res = self.parse_arg(node.left)
×
300
                if res[0]:
×
301
                    return node.left, self._ev.visit(node.right)
×
302
                res = self.parse_arg(node.right)
×
303
                if res[0]:
×
304
                    return node.right, self._ev.visit(node.left)
×
305
            return node, "1.0"
4✔
306

307
        def parse_term(node):
4✔
308
            if isinstance(node, ast.BinOp) and isinstance(node.op, ast.MatMult):
4✔
309
                l, l_f = extract_factor(node.left)
4✔
310
                r, r_f = extract_factor(node.right)
4✔
311
                f = "1.0"
4✔
312
                if l_f != "1.0":
4✔
313
                    f = l_f
×
314
                if r_f != "1.0":
4✔
315
                    if f == "1.0":
×
316
                        f = r_f
×
317
                    else:
318
                        f = f"({f} * {r_f})"
×
319
                return l, r, f
4✔
320

321
            if isinstance(node, ast.Call):
×
322
                is_gemm_call = False
×
323
                if isinstance(node.func, ast.Attribute) and node.func.attr in [
×
324
                    "dot",
325
                    "matmul",
326
                ]:
327
                    is_gemm_call = True
×
328
                if isinstance(node.func, ast.Name) and node.func.id in [
×
329
                    "dot",
330
                    "matmul",
331
                ]:
332
                    is_gemm_call = True
×
333

334
                if is_gemm_call and len(node.args) == 2:
×
335
                    return node.args[0], node.args[1], "1.0"
×
336

337
            if isinstance(node, ast.BinOp) and isinstance(node.op, ast.Mult):
×
338
                l, r, a = parse_term(node.left)
×
339
                if l:
×
340
                    return l, r, self._ev.visit(node.right)
×
341
                l, r, a = parse_term(node.right)
×
342
                if l:
×
343
                    return l, r, self._ev.visit(node.left)
×
344

345
            return None, None, None
×
346

347
        if isinstance(value_node, ast.BinOp) and isinstance(value_node.op, ast.Add):
4✔
348
            l, r, a = parse_term(value_node.left)
×
349
            if l:
×
350
                A = l
×
351
                B = r
×
352
                alpha = a
×
353
                if isinstance(value_node.right, ast.BinOp) and isinstance(
×
354
                    value_node.right.op, ast.Mult
355
                ):
356
                    if self._is_target(value_node.right.left, target_name):
×
357
                        beta = self._ev.visit(value_node.right.right)
×
358
                    elif self._is_target(value_node.right.right, target_name):
×
359
                        beta = self._ev.visit(value_node.right.left)
×
360
                elif self._is_target(value_node.right, target_name):
×
361
                    beta = "1.0"
×
362
            else:
363
                l, r, a = parse_term(value_node.right)
×
364
                if l:
×
365
                    A = l
×
366
                    B = r
×
367
                    alpha = a
×
368
                    if isinstance(value_node.left, ast.BinOp) and isinstance(
×
369
                        value_node.left.op, ast.Mult
370
                    ):
371
                        if self._is_target(value_node.left.left, target_name):
×
372
                            beta = self._ev.visit(value_node.left.right)
×
373
                        elif self._is_target(value_node.left.right, target_name):
×
374
                            beta = self._ev.visit(value_node.left.left)
×
375
                    elif self._is_target(value_node.left, target_name):
×
376
                        beta = "1.0"
×
377
        else:
378
            l, r, a = parse_term(value_node)
4✔
379
            if l:
4✔
380
                A = l
4✔
381
                B = r
4✔
382
                alpha = a
4✔
383

384
        if A is None or B is None:
4✔
385
            return False
×
386

387
        def get_name_and_trans(node):
4✔
388
            if isinstance(node, ast.Attribute) and node.attr == "T":
4✔
389
                return node.value, True
×
390
            return node, False
4✔
391

392
        A_node, trans_a = get_name_and_trans(A)
4✔
393
        B_node, trans_b = get_name_and_trans(B)
4✔
394

395
        if self.is_gemm(A_node):
4✔
396
            tmp_name = self._ev.visit(A_node)
×
397
            A_node = ast.Name(id=tmp_name)
×
398

399
        if self.is_gemm(B_node):
4✔
400
            tmp_name = self._ev.visit(B_node)
×
401
            B_node = ast.Name(id=tmp_name)
×
402

403
        res_a = self.parse_arg(A_node)
4✔
404
        res_b = self.parse_arg(B_node)
4✔
405

406
        if not res_a[0] or not res_b[0]:
4✔
407
            return False
4✔
408

409
        A_name, subset_a, shape_a, indices_a = res_a
4✔
410
        B_name, subset_b, shape_b, indices_b = res_b
4✔
411

412
        flat_subset_a = self.flatten_subset(A_name, subset_a)
4✔
413
        flat_subset_b = self.flatten_subset(B_name, subset_b)
4✔
414

415
        def get_ndim(name):
4✔
416
            if name not in self.tensor_table:
4✔
417
                return 1
×
418
            return len(self.tensor_table[name].shape)
4✔
419

420
        if len(shape_a) == 2:
4✔
421
            if not trans_a:
4✔
422
                m = shape_a[0]
4✔
423
                k = shape_a[1]
4✔
424
            else:
425
                m = shape_a[1]
×
426
                k = shape_a[0]
×
427
        else:
428
            m = "1"
×
429
            k = shape_a[0]
×
430
            if self._is_stride_1(A_name, indices_a):
×
431
                if get_ndim(A_name) == 1:
×
432
                    trans_a = True
×
433
                else:
434
                    trans_a = False
×
435
            else:
436
                trans_a = True
×
437

438
        if len(shape_b) == 2:
4✔
439
            if not trans_b:
4✔
440
                n = shape_b[1]
4✔
441
            else:
442
                n = shape_b[0]
×
443
        else:
444
            n = "1"
4✔
445
            if self._is_stride_1(B_name, indices_b):
4✔
446
                if get_ndim(B_name) == 1:
4✔
447
                    trans_b = False
4✔
448
                else:
449
                    trans_b = True
×
450
            else:
451
                trans_b = False
×
452

453
        def get_ld(name):
4✔
454
            if name not in self.tensor_table:
4✔
455
                return ""
×
456
            shapes = self.tensor_table[name].shape
4✔
457
            if len(shapes) >= 2:
4✔
458
                return str(shapes[1])
4✔
459
            return "1"
4✔
460

461
        lda = get_ld(A_name)
4✔
462
        ldb = get_ld(B_name)
4✔
463

464
        ldc = ""
4✔
465
        if target_name:
4✔
466
            if get_ndim(target_name) == 1 and m == "1":
4✔
467
                ldc = n
×
468
            else:
469
                ldc = get_ld(target_name)
4✔
470

471
        self.builder.add_gemm(
4✔
472
            A_name,
473
            B_name,
474
            target_name,
475
            alpha,
476
            beta,
477
            m,
478
            n,
479
            k,
480
            trans_a,
481
            trans_b,
482
            flat_subset_a,
483
            flat_subset_b,
484
            target_subset,
485
            lda,
486
            ldb,
487
            ldc,
488
        )
489
        return True
4✔
490

491
    def handle_dot(self, target, value_node):
4✔
492
        """Handle dot product operations for 1D vectors."""
493
        dot_node = None
4✔
494
        is_accumulate = False
4✔
495

496
        if self._is_dot_call(value_node):
4✔
497
            dot_node = value_node
4✔
498
        elif isinstance(value_node, ast.BinOp) and isinstance(value_node.op, ast.Add):
4✔
499
            if self._is_dot_call(value_node.left):
4✔
500
                dot_node = value_node.left
4✔
501
                if self._is_target(value_node.right, target):
4✔
502
                    is_accumulate = True
×
503
            elif self._is_dot_call(value_node.right):
×
504
                dot_node = value_node.right
×
505
                if self._is_target(value_node.left, target):
×
506
                    is_accumulate = True
×
507

508
        if not dot_node:
4✔
509
            return False
×
510

511
        arg0 = None
4✔
512
        arg1 = None
4✔
513

514
        if isinstance(dot_node, ast.Call):
4✔
515
            args = dot_node.args
×
516
            if len(args) != 2:
×
517
                return False
×
518
            arg0 = args[0]
×
519
            arg1 = args[1]
×
520
        elif isinstance(dot_node, ast.BinOp) and isinstance(dot_node.op, ast.MatMult):
4✔
521
            arg0 = dot_node.left
4✔
522
            arg1 = dot_node.right
4✔
523

524
        res_a = self.parse_arg(arg0)
4✔
525
        res_b = self.parse_arg(arg1)
4✔
526

527
        if not res_a[0] or not res_b[0]:
4✔
528
            return False
4✔
529

530
        name_a, subset_a, shape_a, indices_a = res_a
4✔
531
        name_b, subset_b, shape_b, indices_b = res_b
4✔
532

533
        if len(shape_a) != 1 or len(shape_b) != 1:
4✔
534
            return False
4✔
535

536
        n = shape_a[0]
4✔
537

538
        def get_stride(name, indices):
4✔
539
            if not indices:
4✔
540
                return "1"
4✔
541
            info = self.tensor_table[name]
4✔
542
            shapes = info.shape
4✔
543
            ndim = len(info.shape)
4✔
544

545
            sliced_dim = -1
4✔
546
            for i, idx in enumerate(indices):
4✔
547
                if isinstance(idx, ast.Slice):
4✔
548
                    sliced_dim = i
4✔
549
                    break
4✔
550

551
            if sliced_dim == -1:
4✔
552
                return "1"
×
553

554
            stride = "1"
4✔
555
            for i in range(sliced_dim + 1, ndim):
4✔
NEW
556
                dim_size = shapes[i] if i < len(shapes) else f"_{name}_shape_{i}"
×
NEW
557
                if stride == "1":
×
NEW
558
                    stride = str(dim_size)
×
559
                else:
560
                    stride = f"({stride} * {dim_size})"
×
561
            return stride
4✔
562

563
        incx = get_stride(name_a, indices_a)
4✔
564
        incy = get_stride(name_b, indices_b)
4✔
565

566
        flat_subset_a = self.flatten_subset(name_a, subset_a)
4✔
567
        flat_subset_b = self.flatten_subset(name_b, subset_b)
4✔
568

569
        tmp_res = f"_dot_res_{self._get_unique_id()}"
4✔
570
        self.builder.add_container(tmp_res, Scalar(PrimitiveType.Double), False)
4✔
571
        block = self.builder.add_block()
4✔
572
        constant = self.builder.add_constant(block, "0.0", Scalar(PrimitiveType.Double))
4✔
573
        tasklet = self.builder.add_tasklet(block, TaskletCode.assign, ["_in"], ["_out"])
4✔
574
        self.builder.add_memlet(
4✔
575
            block, constant, "", tasklet, "_in", "", Scalar(PrimitiveType.Double)
576
        )
577
        access = self.builder.add_access(block, tmp_res)
4✔
578
        self.builder.add_memlet(
4✔
579
            block, tasklet, "_out", access, "", "", Scalar(PrimitiveType.Double)
580
        )
581

582
        self.container_table[tmp_res] = Scalar(PrimitiveType.Double)
4✔
583

584
        self.builder.add_dot(
4✔
585
            name_a, name_b, tmp_res, n, incx, incy, flat_subset_a, flat_subset_b
586
        )
587

588
        target_str = target if isinstance(target, str) else self._ev.visit(target)
4✔
589

590
        if not self.builder.exists(target_str):
4✔
591
            self.builder.add_container(target_str, Scalar(PrimitiveType.Double), False)
×
592
            self.container_table[target_str] = Scalar(PrimitiveType.Double)
×
593

594
        if is_accumulate:
4✔
595
            self.builder.add_assignment(target_str, f"{target_str} + {tmp_res}")
×
596
        else:
597
            self.builder.add_assignment(target_str, tmp_res)
4✔
598

599
        return True
4✔
600

601
    def is_outer(self, node):
4✔
602
        """Check if a node represents an outer *product* operation.
603

604
        Only ``np.outer(...)`` and ``np.multiply.outer(...)`` are genuine outer
605
        products that can be lowered to a GEMM. Other ufunc outers such as
606
        ``np.add.outer`` or ``np.subtract.outer`` are element-wise outer
607
        operations and must not take this (multiplication) path.
608
        """
609
        if isinstance(node, ast.Call):
4✔
610
            if isinstance(node.func, ast.Attribute) and node.func.attr == "outer":
4✔
611
                # np.<ufunc>.outer(...): func.value is the Attribute naming the
612
                # ufunc (e.g. "add", "multiply"). Only multiplication is a true
613
                # outer product.
614
                if isinstance(node.func.value, ast.Attribute):
4✔
615
                    return node.func.value.attr == "multiply"
4✔
616
                # np.outer(...): func.value is the module name (Name).
617
                return True
4✔
618
            if isinstance(node.func, ast.Name) and node.func.id == "outer":
4✔
NEW
619
                return True
×
620
        if isinstance(node, ast.BinOp) and isinstance(node.op, ast.Add):
4✔
621
            return self.is_outer(node.left) or self.is_outer(node.right)
4✔
622
        return False
4✔
623

624
    def handle_outer(self, target, value_node):
4✔
625
        """Handle outer product operations."""
626
        target_name = None
4✔
627
        target_subset = []
4✔
628

629
        if isinstance(target, str):
4✔
630
            target_name = target
4✔
631
        elif isinstance(target, ast.Name):
4✔
NEW
UNCOV
632
            target_name = target.id
×
633
        elif isinstance(target, ast.Subscript):
4✔
634
            res = self.parse_arg(target)
4✔
635
            if res[0]:
4✔
636
                target_name = res[0]
4✔
637
                target_subset = self.flatten_subset(target_name, res[1])
4✔
638
            else:
639
                if isinstance(target.value, ast.Name):
×
640
                    target_name = target.value.id
×
641

642
        if not target_name:
4✔
643
            return False
×
644

645
        outer_calls = []
4✔
646
        target_found = False
4✔
647
        terms = []
4✔
648

649
        def collect_terms(node):
4✔
650
            if isinstance(node, ast.BinOp) and isinstance(node.op, ast.Add):
4✔
651
                collect_terms(node.left)
4✔
652
                collect_terms(node.right)
4✔
653
            else:
654
                terms.append(node)
4✔
655

656
        collect_terms(value_node)
4✔
657

658
        for term in terms:
4✔
659
            if self._is_target(term, target_name):
4✔
660
                target_found = True
4✔
661
            elif self.is_outer(term):
4✔
662
                if len(term.args) != 2:
4✔
663
                    return False
×
664
                outer_calls.append(term)
4✔
665
            else:
666
                return False
×
667

668
        if not outer_calls:
4✔
669
            return False
×
670

671
        parsed_outers = []
4✔
672
        for outer_node in outer_calls:
4✔
673
            arg0 = outer_node.args[0]
4✔
674
            arg1 = outer_node.args[1]
4✔
675

676
            res_a = self.parse_arg(arg0)
4✔
677
            res_b = self.parse_arg(arg1)
4✔
678

679
            if not res_a[0] or not res_b[0]:
4✔
680
                return False
×
681

682
            parsed_outers.append((res_a, res_b))
4✔
683

684
        alpha = "1.0"
4✔
685
        beta = "1.0" if target_found else "0.0"
4✔
686

687
        def get_flattened_size(name, indices, shapes):
4✔
688
            size_expr = "1"
4✔
689
            for s in shapes:
4✔
690
                if size_expr == "1":
4✔
691
                    size_expr = str(s)
4✔
692
                else:
693
                    size_expr = f"({size_expr} * {str(s)})"
×
694
            return size_expr
4✔
695

696
        def get_ld_2d(name):
4✔
697
            if name in self.tensor_table:
4✔
698
                shapes = self.tensor_table[name].shape
4✔
699
                if len(shapes) >= 2:
4✔
700
                    return str(shapes[1])
4✔
UNCOV
701
            return "1"
×
702

703
        ldc = get_ld_2d(target_name)
4✔
704

705
        for res_a, res_b in parsed_outers:
4✔
706
            name_a, subset_a, shape_a, indices_a = res_a
4✔
707
            name_b, subset_b, shape_b, indices_b = res_b
4✔
708

709
            m = get_flattened_size(name_a, indices_a, shape_a)
4✔
710
            n = get_flattened_size(name_b, indices_b, shape_b)
4✔
711
            k = "1"
4✔
712

713
            trans_a = False
4✔
714
            trans_b = True
4✔
715

716
            flat_subset_a = self.flatten_subset(name_a, subset_a)
4✔
717
            flat_subset_b = self.flatten_subset(name_b, subset_b)
4✔
718

719
            lda = "1"
4✔
720
            ldb = "1"
4✔
721

722
            self.builder.add_gemm(
4✔
723
                name_a,
724
                name_b,
725
                target_name,
726
                alpha,
727
                beta,
728
                m,
729
                n,
730
                k,
731
                trans_a,
732
                trans_b,
733
                flat_subset_a,
734
                flat_subset_b,
735
                target_subset,
736
                lda,
737
                ldb,
738
                ldc,
739
            )
740
            beta = "1.0"
4✔
741

742
        return True
4✔
743

744
    # ========== Transpose Operations ==========
745

746
    def _parse_perm(self, node):
4✔
747
        """Parse a permutation list or tuple from an AST node."""
748
        if isinstance(node, (ast.List, ast.Tuple)):
4✔
749
            res = []
4✔
750
            for elt in node.elts:
4✔
751
                val = self._ev.visit(elt)
4✔
752
                res.append(int(val))
4✔
753
            return res
4✔
754
        return []
×
755

756
    def is_transpose(self, node):
4✔
757
        """Check if a node represents a transpose operation."""
758
        # Case 1: np.transpose(arr, ...)
759
        if isinstance(node, ast.Call):
4✔
760
            if isinstance(node.func, ast.Attribute) and node.func.attr == "transpose":
4✔
761
                return True
×
762
            if isinstance(node.func, ast.Name) and node.func.id == "transpose":
4✔
763
                return True
×
764

765
        # Case 2: arr.T
766
        if isinstance(node, ast.Attribute) and node.attr == "T":
4✔
767
            return True
4✔
768

769
        return False
4✔
770

771
    def handle_transpose(self, target, value_node):
4✔
772
        """Handle transpose operations including .T and np.transpose()."""
773
        if not self.is_transpose(value_node):
4✔
774
            return False
×
775

776
        input_node = None
4✔
777
        perm = []
4✔
778

779
        if isinstance(value_node, ast.Attribute) and value_node.attr == "T":
4✔
780
            input_node = value_node.value
4✔
781
            perm = []  # Empty means reverse
4✔
782

783
        elif isinstance(value_node, ast.Call):
×
784
            args = value_node.args
×
785
            keywords = value_node.keywords
×
786

787
            is_numpy_func = False
×
788
            if isinstance(value_node.func, ast.Attribute):
×
789
                caller = ""
×
790
                if isinstance(value_node.func.value, ast.Name):
×
791
                    caller = value_node.func.value.id
×
792
                if caller in ["np", "numpy"]:
×
793
                    is_numpy_func = True
×
794
            elif isinstance(value_node.func, ast.Name):
×
795
                is_numpy_func = True
×
796

797
            if is_numpy_func:
×
798
                if len(args) < 1:
×
799
                    return False
×
800
                input_node = args[0]
×
801
                if len(args) > 1:
×
802
                    perm = self._parse_perm(args[1])
×
803
                for kw in keywords:
×
804
                    if kw.arg == "axes":
×
805
                        perm = self._parse_perm(kw.value)
×
806
            else:
807
                if isinstance(value_node.func, ast.Attribute):
×
808
                    input_node = value_node.func.value
×
809
                else:
810
                    return False
×
811
                if len(args) > 0:
×
812
                    perm = self._parse_perm(args[0])
×
813
                for kw in keywords:
×
814
                    if kw.arg == "axes":
×
815
                        perm = self._parse_perm(kw.value)
×
816

817
        input_name = self._ev.visit(input_node)
4✔
818
        if input_name not in self.tensor_table:
4✔
819
            return False
×
820

821
        in_info = self.tensor_table[input_name]
4✔
822
        in_shape = in_info.shape
4✔
823
        in_strings = [str(s) for s in in_shape]
4✔
824

825
        if not perm:
4✔
826
            perm = list(range(len(in_shape)))[::-1]
4✔
827

828
        out_shape = [in_strings[p] for p in perm]
4✔
829

830
        # Get input strides and check if input is contiguous
831
        in_strides = (
4✔
832
            in_info.strides if hasattr(in_info, "strides") and in_info.strides else None
833
        )
834
        if in_strides is None:
4✔
835
            in_strides = self._compute_strides(in_shape, "C")
×
836

837
        if self._is_contiguous(in_shape, in_strides):
4✔
838
            # For contiguous inputs, output strides are permuted input strides
839
            out_strides = [in_strides[p] for p in perm]
4✔
840
        else:
841
            # For non-contiguous inputs, output is C-order for the new shape
842
            out_strides = self._compute_strides(out_shape, "C")
×
843

844
        target_name = ""
4✔
845
        if isinstance(target, ast.Name):
4✔
846
            target_name = target.id
4✔
847
        elif isinstance(target, str):
×
848
            target_name = target
×
849

850
        dtype = Scalar(PrimitiveType.Double)
4✔
851
        if input_name in self.container_table:
4✔
852
            input_type = self.container_table[input_name]
4✔
853
            if isinstance(input_type, Pointer):
4✔
854
                dtype = input_type.pointee_type
4✔
855
            else:
856
                dtype = input_type
×
857

858
        ptr_type = Pointer(dtype)
4✔
859

860
        # Create target container if it doesn't exist
861
        if not self.builder.exists(target_name):
4✔
862
            self.builder.add_container(target_name, ptr_type, False)
4✔
863
            self.container_table[target_name] = ptr_type
4✔
864
        self.tensor_table[target_name] = Tensor(dtype, out_shape, out_strides)
4✔
865

866
        # Create reference memlet to alias the source array (view, not copy)
867
        block = self.builder.add_block()
4✔
868
        t_src = self.builder.add_access(block, input_name)
4✔
869
        t_dst = self.builder.add_access(block, target_name)
4✔
870
        self.builder.add_reference_memlet(block, t_src, t_dst, "0", ptr_type)
4✔
871

872
        return True
4✔
873

874
    def handle_transpose_expr(self, node):
4✔
875
        """Handle .T attribute access in expressions, returning a temp array name."""
876
        if not isinstance(node, ast.Attribute) or node.attr != "T":
4✔
877
            return None
×
878

879
        input_name = self._ev.visit(node.value)
4✔
880
        if input_name not in self.tensor_table:
4✔
881
            return None
×
882

883
        in_info = self.tensor_table[input_name]
4✔
884
        in_shape = in_info.shape
4✔
885
        perm = list(range(len(in_shape)))[::-1]
4✔
886

887
        return self._create_transpose_view(input_name, perm)
4✔
888

889
    def _handle_numpy_transpose(self, node, func_name):
4✔
890
        """Handle np.transpose(arr, axes=...) function call."""
891
        if len(node.args) < 1:
4✔
892
            raise ValueError("np.transpose requires at least one argument")
×
893

894
        input_node = node.args[0]
4✔
895
        input_name = self.visit(input_node)
4✔
896

897
        if input_name not in self.tensor_table:
4✔
898
            raise ValueError(f"Array {input_name} not found in tensor_table")
×
899

900
        in_info = self.tensor_table[input_name]
4✔
901
        in_shape = in_info.shape
4✔
902

903
        perm = []
4✔
904
        if len(node.args) > 1:
4✔
905
            perm = self._parse_perm(node.args[1])
×
906
        for kw in node.keywords:
4✔
907
            if kw.arg == "axes":
4✔
908
                perm = self._parse_perm(kw.value)
4✔
909

910
        if not perm:
4✔
911
            perm = list(range(len(in_shape)))[::-1]
4✔
912

913
        return self._create_transpose_view(input_name, perm)
4✔
914

915
    def _create_transpose_view(self, input_name, perm):
4✔
916
        in_info = self.tensor_table[input_name]
4✔
917
        in_shape = in_info.shape
4✔
918
        in_strings = [str(s) for s in in_shape]
4✔
919

920
        # Compute output shape by permuting
921
        out_shape = [in_strings[p] for p in perm]
4✔
922

923
        # Get input strides and check if input is contiguous
924
        in_strides = (
4✔
925
            in_info.strides if hasattr(in_info, "strides") and in_info.strides else None
926
        )
927
        if in_strides is None:
4✔
928
            in_strides = self._compute_strides(in_shape, "C")
×
929

930
        # Always permute input strides (works for both contiguous and view inputs)
931
        out_strides = [in_strides[p] for p in perm]
4✔
932

933
        # Inherit offset from input tensor (for chained views like flip->transpose)
934
        in_offset = getattr(in_info, "offset", "0") or "0"
4✔
935

936
        # Create new pointer container
937
        tmp_name = f"_tmp_{self._get_unique_id()}"
4✔
938
        ptr_type = Pointer(in_info.element_type)
4✔
939
        self.builder.add_container(tmp_name, ptr_type, False)
4✔
940
        self.container_table[tmp_name] = ptr_type
4✔
941

942
        # Register tensor with permuted shape, strides, and inherited offset
943
        self.tensor_table[tmp_name] = Tensor(
4✔
944
            in_info.element_type, out_shape, out_strides, in_offset
945
        )
946

947
        # Create reference memlet to alias the source array
948
        block = self.builder.add_block()
4✔
949
        t_src = self.builder.add_access(block, input_name)
4✔
950
        t_dst = self.builder.add_access(block, tmp_name)
4✔
951
        self.builder.add_reference_memlet(block, t_src, t_dst, "0", ptr_type)
4✔
952

953
        return tmp_name
4✔
954

955
    def _handle_numpy_flip(self, node, func_name):
4✔
956
        """Handle np.flip(arr, axis=None) - flip array along specified axis.
957

958
        Uses negative strides and offset to create a view without copying.
959
        """
960
        if len(node.args) < 1:
4✔
961
            raise ValueError("np.flip requires at least one argument")
×
962

963
        input_name = self.visit(node.args[0])
4✔
964
        if input_name not in self.tensor_table:
4✔
965
            raise ValueError(f"Array {input_name} not found in tensor_table")
×
966

967
        in_info = self.tensor_table[input_name]
4✔
968
        in_shape = in_info.shape
4✔
969
        ndim = len(in_shape)
4✔
970

971
        # Parse axis argument
972
        axis = None
4✔
973
        if len(node.args) > 1:
4✔
974
            axis_node = node.args[1]
×
975
            if isinstance(axis_node, ast.Constant):
×
976
                axis = axis_node.value
×
977
            elif isinstance(axis_node, ast.UnaryOp) and isinstance(
×
978
                axis_node.op, ast.USub
979
            ):
980
                if isinstance(axis_node.operand, ast.Constant):
×
981
                    axis = -axis_node.operand.value
×
982
        for kw in node.keywords:
4✔
983
            if kw.arg == "axis":
4✔
984
                if isinstance(kw.value, ast.Constant):
4✔
985
                    axis = kw.value.value
4✔
986
                elif isinstance(kw.value, ast.UnaryOp) and isinstance(
4✔
987
                    kw.value.op, ast.USub
988
                ):
989
                    if isinstance(kw.value.operand, ast.Constant):
4✔
990
                        axis = -kw.value.operand.value
4✔
991

992
        # Determine which axes to flip
993
        if axis is None:
4✔
994
            # Flip all axes
995
            axes_to_flip = list(range(ndim))
4✔
996
        else:
997
            if axis < 0:
4✔
998
                axis = ndim + axis
4✔
999
            axes_to_flip = [axis]
4✔
1000

1001
        return self._create_flip_view(input_name, axes_to_flip)
4✔
1002

1003
    def _handle_numpy_fliplr(self, node, func_name):
4✔
1004
        """Handle np.fliplr(arr) - flip array left-right (axis=1)."""
1005
        if len(node.args) < 1:
4✔
1006
            raise ValueError("np.fliplr requires one argument")
×
1007

1008
        input_name = self.visit(node.args[0])
4✔
1009
        if input_name not in self.tensor_table:
4✔
1010
            raise ValueError(f"Array {input_name} not found in tensor_table")
×
1011

1012
        in_info = self.tensor_table[input_name]
4✔
1013
        if len(in_info.shape) < 2:
4✔
1014
            raise ValueError("np.fliplr requires array with ndim >= 2")
×
1015

1016
        return self._create_flip_view(input_name, [1])
4✔
1017

1018
    def _handle_numpy_flipud(self, node, func_name):
4✔
1019
        """Handle np.flipud(arr) - flip array up-down (axis=0)."""
1020
        if len(node.args) < 1:
4✔
1021
            raise ValueError("np.flipud requires one argument")
×
1022

1023
        input_name = self.visit(node.args[0])
4✔
1024
        if input_name not in self.tensor_table:
4✔
1025
            raise ValueError(f"Array {input_name} not found in tensor_table")
×
1026

1027
        return self._create_flip_view(input_name, [0])
4✔
1028

1029
    def _create_flip_view(self, input_name, axes_to_flip):
4✔
1030
        """Create a flipped view of an array using Tensor.flip().
1031

1032
        Uses the Tensor type's flip() method which computes the correct
1033
        negative strides and offset adjustment.
1034
        """
1035
        in_tensor = self.tensor_table[input_name]
4✔
1036

1037
        # Apply flip for each axis
1038
        flipped_tensor = in_tensor
4✔
1039
        for axis in axes_to_flip:
4✔
1040
            flipped_tensor = flipped_tensor.flip(axis)
4✔
1041

1042
        # Create new pointer container pointing to same data
1043
        tmp_name = f"_tmp_{self._get_unique_id()}"
4✔
1044
        ptr_type = Pointer(in_tensor.element_type)
4✔
1045
        self.builder.add_container(tmp_name, ptr_type, False)
4✔
1046
        self.container_table[tmp_name] = ptr_type
4✔
1047

1048
        # Store the flipped tensor with its offset in tensor_table
1049
        self.tensor_table[tmp_name] = flipped_tensor
4✔
1050

1051
        # Create reference memlet (offset is handled by tensor's offset property)
1052
        block = self.builder.add_block()
4✔
1053
        t_src = self.builder.add_access(block, input_name)
4✔
1054
        t_dst = self.builder.add_access(block, tmp_name)
4✔
1055
        self.builder.add_reference_memlet(block, t_src, t_dst, "0", ptr_type)
4✔
1056

1057
        return tmp_name
4✔
1058

1059
    def _handle_numpy_reshape(self, node, func_name):
4✔
1060
        """Handle np.reshape(arr, newshape) - reshape array without copying.
1061

1062
        Only works for contiguous arrays; creates a view with new shape/strides.
1063
        """
1064
        if len(node.args) < 2:
4✔
1065
            raise ValueError("np.reshape requires array and new shape")
×
1066

1067
        input_name = self.visit(node.args[0])
4✔
1068
        if input_name not in self.tensor_table:
4✔
1069
            raise ValueError(f"Array {input_name} not found in tensor_table")
×
1070

1071
        in_info = self.tensor_table[input_name]
4✔
1072
        in_shape = in_info.shape
4✔
1073

1074
        # Parse new shape
1075
        new_shape = self._parse_shape(node.args[1])
4✔
1076

1077
        # Get input strides
1078
        in_strides = (
4✔
1079
            in_info.strides if hasattr(in_info, "strides") and in_info.strides else None
1080
        )
1081
        if in_strides is None:
4✔
1082
            in_strides = self._compute_strides(in_shape, "C")
×
1083

1084
        # Check if input is contiguous (C or F order)
1085
        c_contig = self._is_contiguous(in_shape, in_strides)
4✔
1086
        f_contig = self._is_contiguous_f(in_shape, in_strides)
4✔
1087

1088
        if c_contig:
4✔
1089
            out_strides = self._compute_strides(new_shape, "C")
4✔
1090
        elif f_contig:
×
1091
            out_strides = self._compute_strides(new_shape, "F")
×
1092
        else:
1093
            # Non-contiguous array cannot be reshaped without copy
1094
            raise NotImplementedError(
×
1095
                "np.reshape on non-contiguous array not supported (would require copy)"
1096
            )
1097

1098
        # Create new pointer container
1099
        tmp_name = f"_tmp_{self._get_unique_id()}"
4✔
1100
        ptr_type = Pointer(in_info.element_type)
4✔
1101
        self.builder.add_container(tmp_name, ptr_type, False)
4✔
1102
        self.container_table[tmp_name] = ptr_type
4✔
1103

1104
        # Register tensor with new shape and computed strides
1105
        self.tensor_table[tmp_name] = Tensor(
4✔
1106
            in_info.element_type, new_shape, out_strides
1107
        )
1108

1109
        # Create reference memlet to alias the source array (view, no copy)
1110
        block = self.builder.add_block()
4✔
1111
        t_src = self.builder.add_access(block, input_name)
4✔
1112
        t_dst = self.builder.add_access(block, tmp_name)
4✔
1113
        self.builder.add_reference_memlet(block, t_src, t_dst, "0", ptr_type)
4✔
1114

1115
        return tmp_name
4✔
1116

1117
    def _parse_shape(self, shape_node):
4✔
1118
        """Parse a shape argument (tuple, list, or single int)."""
1119
        if isinstance(shape_node, ast.Tuple) or isinstance(shape_node, ast.List):
4✔
1120
            result = []
4✔
1121
            for elt in shape_node.elts:
4✔
1122
                if isinstance(elt, ast.Constant):
4✔
1123
                    result.append(str(elt.value))
4✔
1124
                elif isinstance(elt, ast.Name):
×
1125
                    result.append(elt.id)
×
1126
                elif isinstance(elt, ast.UnaryOp) and isinstance(elt.op, ast.USub):
×
1127
                    if isinstance(elt.operand, ast.Constant):
×
1128
                        result.append(str(-elt.operand.value))
×
1129
                else:
1130
                    result.append(self._shape_to_runtime_expr(elt))
×
1131
            return result
4✔
1132
        elif isinstance(shape_node, ast.Constant):
×
1133
            return [str(shape_node.value)]
×
1134
        elif isinstance(shape_node, ast.Name):
×
1135
            # Could be a variable holding a shape tuple - not supported yet
1136
            raise NotImplementedError("Shape variable not supported, use literal tuple")
×
1137
        else:
1138
            raise ValueError(f"Cannot parse shape: {ast.dump(shape_node)}")
×
1139

1140
    def _is_contiguous_f(self, shape, strides):
4✔
1141
        """Check if array is F-order contiguous."""
1142
        if not shape or not strides:
4✔
1143
            return True
×
1144
        f_strides = self._compute_strides(shape, "F")
4✔
1145
        return [str(s) for s in strides] == [str(s) for s in f_strides]
4✔
1146

1147
    def handle_numpy_call(self, node, func_name):
4✔
1148
        if func_name in self.function_handlers:
4✔
1149
            return self.function_handlers[func_name](node, func_name)
4✔
1150
        raise NotImplementedError(f"NumPy function {func_name} not supported")
×
1151

1152
    def has_handler(self, func_name):
4✔
1153
        return func_name in self.function_handlers
4✔
1154

1155
    def handle_array_unary_op(self, op_type, operand):
4✔
1156
        dtype = self._ev._element_type(operand)
4✔
1157
        if operand in self.tensor_table:
4✔
1158
            tensor = self.tensor_table[operand]
4✔
1159
        else:
1160
            tensor = Tensor(dtype, [])
4✔
1161

1162
        if len(tensor.shape) == 0:
4✔
1163
            tmp_name = self._create_array_temp([], dtype)
4✔
1164

1165
            func_map = {
4✔
1166
                "sqrt": CMathFunction.sqrt,
1167
                "abs": CMathFunction.fabs,
1168
                "absolute": CMathFunction.fabs,
1169
                "exp": CMathFunction.exp,
1170
                "tanh": CMathFunction.tanh,
1171
            }
1172

1173
            block = self.builder.add_block()
4✔
1174
            t_src = self.builder.add_access(block, operand)
4✔
1175
            t_dst = self.builder.add_access(block, tmp_name)
4✔
1176
            t_task = self.builder.add_cmath(
4✔
1177
                block, func_map[op_type], dtype.primitive_type
1178
            )
1179

1180
            self.builder.add_memlet(block, t_src, "void", t_task, "_in1", "", dtype)
4✔
1181
            self.builder.add_memlet(block, t_task, "_out", t_dst, "void", "", dtype)
4✔
1182

1183
            return tmp_name
4✔
1184

1185
        output_strides = self._get_contiguous_output_strides(
4✔
1186
            tensor.shape, tensor.strides
1187
        )
1188
        tmp_name = self._create_array_temp(tensor.shape, dtype, strides=output_strides)
4✔
1189
        tmp_tensor = self.tensor_table[tmp_name]
4✔
1190
        self.builder.add_elementwise_unary_op(
4✔
1191
            op_type, operand, tensor, tmp_name, tmp_tensor
1192
        )
1193

1194
        return tmp_name
4✔
1195

1196
    def handle_array_binary_op(self, op_type, left, right):
4✔
1197
        # Determine if operands are arrays or scalars
1198
        # NumPy 0-d arrays (shape=[]) ARE arrays for promotion purposes
1199
        # Only literals and Python scalars (not in tensor_table) are treated as scalars
1200
        left_is_array = left in self.tensor_table
4✔
1201
        right_is_array = right in self.tensor_table
4✔
1202

1203
        dtype_left = self._ev._element_type(left)
4✔
1204
        dtype_right = self._ev._element_type(right)
4✔
1205

1206
        # Use NumPy promotion rules: scalars adapt to arrays
1207
        dtype = numpy_promote_types(
4✔
1208
            dtype_left, left_is_array, dtype_right, right_is_array
1209
        )
1210

1211
        # Cast operands to result type if needed
1212
        real_left = self._cast_to_type(left, dtype)
4✔
1213
        real_right = self._cast_to_type(right, dtype)
4✔
1214

1215
        # Get tensor info for the (possibly casted) operands
1216
        if real_left in self.tensor_table:
4✔
1217
            left_tensor = self.tensor_table[real_left]
4✔
1218
        else:
1219
            left_tensor = Tensor(dtype, [])
4✔
1220

1221
        if real_right in self.tensor_table:
4✔
1222
            right_tensor = self.tensor_table[real_right]
4✔
1223
        else:
1224
            right_tensor = Tensor(dtype, [])
4✔
1225

1226
        left_shape = left_tensor.shape
4✔
1227
        right_shape = right_tensor.shape
4✔
1228

1229
        # Compute broadcast output shape
1230
        output_shape = self._compute_broadcast_shape(left_shape, right_shape)
4✔
1231

1232
        # Check if broadcasting is needed
1233
        left_needs_broadcast = (
4✔
1234
            self._needs_broadcast(left_shape, output_shape) if left_shape else False
1235
        )
1236
        right_needs_broadcast = (
4✔
1237
            self._needs_broadcast(right_shape, output_shape) if right_shape else False
1238
        )
1239

1240
        real_left_tensor = left_tensor
4✔
1241
        real_right_tensor = right_tensor
4✔
1242

1243
        # Broadcast left operand if needed (stride-based, no copy)
1244
        if left_needs_broadcast:
4✔
1245
            left_strides = left_tensor.strides if left_tensor.strides else []
×
1246
            broadcast_strides = self._compute_broadcast_strides(
×
1247
                left_shape, left_strides, output_shape
1248
            )
1249
            # Create a new tensor view with broadcast shape and strides
1250
            # Preserve the offset from the original tensor (important for views like flip)
1251
            left_offset = left_tensor.offset if left_tensor.offset else "0"
×
1252
            real_left_tensor = Tensor(
×
1253
                dtype, output_shape, broadcast_strides, left_offset
1254
            )
1255

1256
        # Broadcast right operand if needed (stride-based, no copy)
1257
        if right_needs_broadcast:
4✔
1258
            right_strides = right_tensor.strides if right_tensor.strides else []
4✔
1259
            broadcast_strides = self._compute_broadcast_strides(
4✔
1260
                right_shape, right_strides, output_shape
1261
            )
1262
            # Create a new tensor view with broadcast shape and strides
1263
            # Preserve the offset from the original tensor (important for views like flip)
1264
            right_offset = right_tensor.offset if right_tensor.offset else "0"
4✔
1265
            real_right_tensor = Tensor(
4✔
1266
                dtype, output_shape, broadcast_strides, right_offset
1267
            )
1268

1269
        # Create output array with broadcast shape
1270
        # Preserve F-order if both inputs are F-order and no broadcasting needed
1271
        if not left_needs_broadcast and not right_needs_broadcast:
4✔
1272
            # Use left tensor strides to determine output order
1273
            output_strides = self._get_contiguous_output_strides(
4✔
1274
                output_shape, left_tensor.strides
1275
            )
1276
        else:
1277
            output_strides = self._compute_strides(output_shape, "C")
4✔
1278
        tmp_name = self._create_array_temp(output_shape, dtype, strides=output_strides)
4✔
1279
        tmp_tensor = self.tensor_table[tmp_name]
4✔
1280

1281
        self.builder.add_elementwise_op(
4✔
1282
            op_type,
1283
            real_left,
1284
            real_left_tensor,
1285
            real_right,
1286
            real_right_tensor,
1287
            tmp_name,
1288
            tmp_tensor,
1289
        )
1290

1291
        return tmp_name
4✔
1292

1293
    def handle_array_negate(self, operand):
4✔
1294
        operand_tensor = self.tensor_table[operand]
4✔
1295
        dtype = self._ev._element_type(operand)
4✔
1296

1297
        output_strides = self._get_contiguous_output_strides(
4✔
1298
            operand_tensor.shape, operand_tensor.strides
1299
        )
1300
        tmp_name = self._create_array_temp(
4✔
1301
            operand_tensor.shape, dtype, strides=output_strides
1302
        )
1303
        tmp_tensor = self.tensor_table[tmp_name]
4✔
1304

1305
        zero_name = f"_tmp_{self._get_unique_id()}"
4✔
1306
        self.builder.add_container(zero_name, dtype, False)
4✔
1307
        self.container_table[zero_name] = dtype
4✔
1308
        self.tensor_table[zero_name] = Tensor(dtype, [])
4✔
1309

1310
        zero_block = self.builder.add_block()
4✔
1311
        t_const = self.builder.add_constant(
4✔
1312
            zero_block,
1313
            "0.0" if dtype.primitive_type == PrimitiveType.Double else "0",
1314
            dtype,
1315
        )
1316
        t_zero = self.builder.add_access(zero_block, zero_name)
4✔
1317
        t_assign = self.builder.add_tasklet(
4✔
1318
            zero_block, TaskletCode.assign, ["_in"], ["_out"]
1319
        )
1320
        self.builder.add_memlet(zero_block, t_const, "void", t_assign, "_in", "")
4✔
1321
        self.builder.add_memlet(zero_block, t_assign, "_out", t_zero, "void", "")
4✔
1322

1323
        zero_tensor = self.tensor_table[zero_name]
4✔
1324
        self.builder.add_elementwise_op(
4✔
1325
            "sub", zero_name, zero_tensor, operand, operand_tensor, tmp_name, tmp_tensor
1326
        )
1327

1328
        return tmp_name
4✔
1329

1330
    def handle_array_compare(self, left, op, right, left_is_array, right_is_array):
4✔
1331
        """Handle elementwise comparison of arrays, returning a boolean array."""
1332
        if left_is_array:
4✔
1333
            shape = self.tensor_table[left].shape
4✔
1334
            arr_name = left
4✔
1335
        else:
1336
            shape = self.tensor_table[right].shape
×
1337
            arr_name = right
×
1338

1339
        use_int_cmp = False
4✔
1340
        arr_dtype = self._ev._element_type(arr_name)
4✔
1341
        if arr_dtype.primitive_type in (PrimitiveType.Int32, PrimitiveType.Int64):
4✔
1342
            use_int_cmp = True
×
1343

1344
        dtype = Scalar(PrimitiveType.Bool)
4✔
1345
        tmp_name = self._create_array_temp(shape, dtype)
4✔
1346

1347
        if use_int_cmp:
4✔
1348
            cmp_ops = {
×
1349
                ">": TaskletCode.int_sgt,
1350
                ">=": TaskletCode.int_sge,
1351
                "<": TaskletCode.int_slt,
1352
                "<=": TaskletCode.int_sle,
1353
                "==": TaskletCode.int_eq,
1354
                "!=": TaskletCode.int_ne,
1355
            }
1356
        else:
1357
            cmp_ops = {
4✔
1358
                ">": TaskletCode.fp_ogt,
1359
                ">=": TaskletCode.fp_oge,
1360
                "<": TaskletCode.fp_olt,
1361
                "<=": TaskletCode.fp_ole,
1362
                "==": TaskletCode.fp_oeq,
1363
                "!=": TaskletCode.fp_one,
1364
            }
1365

1366
        if op not in cmp_ops:
4✔
1367
            raise NotImplementedError(
×
1368
                f"Comparison operator {op} not supported for arrays"
1369
            )
1370

1371
        tasklet_code = cmp_ops[op]
4✔
1372

1373
        scalar_name = None
4✔
1374
        if not left_is_array:
4✔
1375
            scalar_name = left
×
1376
        elif not right_is_array:
4✔
1377
            scalar_name = right
4✔
1378

1379
        if scalar_name is not None and not use_int_cmp:
4✔
1380
            if self._is_int(scalar_name):
4✔
1381
                float_name = f"_tmp_{self._get_unique_id()}"
4✔
1382
                self.builder.add_container(
4✔
1383
                    float_name, Scalar(PrimitiveType.Double), False
1384
                )
1385
                self.container_table[float_name] = Scalar(PrimitiveType.Double)
4✔
1386

1387
                block_conv = self.builder.add_block()
4✔
1388
                t_const = self.builder.add_constant(
4✔
1389
                    block_conv, f"{scalar_name}.0", Scalar(PrimitiveType.Double)
1390
                )
1391
                t_float = self.builder.add_access(block_conv, float_name)
4✔
1392
                t_assign = self.builder.add_tasklet(
4✔
1393
                    block_conv, TaskletCode.assign, ["_in"], ["_out"]
1394
                )
1395
                self.builder.add_memlet(
4✔
1396
                    block_conv, t_const, "void", t_assign, "_in", ""
1397
                )
1398
                self.builder.add_memlet(
4✔
1399
                    block_conv, t_assign, "_out", t_float, "void", ""
1400
                )
1401

1402
                if not left_is_array:
4✔
1403
                    left = float_name
×
1404
                else:
1405
                    right = float_name
4✔
1406

1407
        # Get tensor info for array operands
1408
        left_tensor = self.tensor_table.get(left) if left_is_array else None
4✔
1409
        right_tensor = self.tensor_table.get(right) if right_is_array else None
4✔
1410
        tmp_tensor = self.tensor_table[tmp_name]
4✔
1411

1412
        loop_vars = []
4✔
1413
        for i, dim in enumerate(shape):
4✔
1414
            loop_var = f"_cmp_i{i}_{self._get_unique_id()}"
4✔
1415
            if not self.builder.exists(loop_var):
4✔
1416
                self.builder.add_container(loop_var, Scalar(PrimitiveType.Int64), False)
4✔
1417
                self.container_table[loop_var] = Scalar(PrimitiveType.Int64)
4✔
1418
            loop_vars.append(loop_var)
4✔
1419
            self.builder.begin_for(loop_var, "0", str(dim), "1")
4✔
1420

1421
        # Multi-dimensional subset - TensorToPointerConversion handles strides/offset
1422
        multi_dim_subset = ",".join(loop_vars)
4✔
1423

1424
        block = self.builder.add_block()
4✔
1425

1426
        if left_is_array:
4✔
1427
            t_left = self.builder.add_access(block, left)
4✔
1428
            left_sub = multi_dim_subset
4✔
1429
        else:
1430
            t_left, left_sub = self._add_read(block, left)
×
1431

1432
        if right_is_array:
4✔
1433
            t_right = self.builder.add_access(block, right)
×
1434
            right_sub = multi_dim_subset
×
1435
        else:
1436
            t_right, right_sub = self._add_read(block, right)
4✔
1437

1438
        t_out = self.builder.add_access(block, tmp_name)
4✔
1439

1440
        t_task = self.builder.add_tasklet(
4✔
1441
            block, tasklet_code, ["_in1", "_in2"], ["_out"]
1442
        )
1443

1444
        # Pass tensor type so TensorToPointerConversion uses correct strides/offset
1445
        if left_is_array and left_tensor:
4✔
1446
            self.builder.add_memlet(
4✔
1447
                block, t_left, "void", t_task, "_in1", left_sub, left_tensor
1448
            )
1449
        else:
1450
            self.builder.add_memlet(block, t_left, "void", t_task, "_in1", left_sub)
×
1451

1452
        if right_is_array and right_tensor:
4✔
1453
            self.builder.add_memlet(
×
1454
                block, t_right, "void", t_task, "_in2", right_sub, right_tensor
1455
            )
1456
        else:
1457
            self.builder.add_memlet(block, t_right, "void", t_task, "_in2", right_sub)
4✔
1458

1459
        self.builder.add_memlet(
4✔
1460
            block, t_task, "_out", t_out, "void", multi_dim_subset, tmp_tensor
1461
        )
1462

1463
        for _ in loop_vars:
4✔
1464
            self.builder.end_for()
4✔
1465

1466
        return tmp_name
4✔
1467

1468
    # ========== NumPy Function Handlers ==========
1469

1470
    def _handle_numpy_alloc(self, node, func_name):
4✔
1471
        """Handle np.empty, np.zeros, np.ones, np.ndarray."""
1472
        shape_arg = node.args[0]
4✔
1473
        dims = []
4✔
1474
        dims_runtime = []
4✔
1475
        if isinstance(shape_arg, ast.Tuple):
4✔
1476
            dims = [self.visit(elt) for elt in shape_arg.elts]
4✔
1477
            dims_runtime = [self._shape_to_runtime_expr(elt) for elt in shape_arg.elts]
4✔
1478
        elif isinstance(shape_arg, ast.List):
4✔
1479
            dims = [self.visit(elt) for elt in shape_arg.elts]
×
1480
            dims_runtime = [self._shape_to_runtime_expr(elt) for elt in shape_arg.elts]
×
1481
        else:
1482
            val = self.visit(shape_arg)
4✔
1483
            runtime_val = self._shape_to_runtime_expr(shape_arg)
4✔
1484
            if val.startswith("_shape_proxy_"):
4✔
1485
                array_name = val[len("_shape_proxy_") :]
×
1486
                if array_name in self.tensor_table:
×
1487
                    info = self.tensor_table[array_name]
×
1488
                    dims = info.shape
×
1489
                    dims_runtime = self.shapes_runtime_info.get(array_name, dims)
×
1490
                else:
1491
                    dims = [val]
×
1492
                    dims_runtime = [runtime_val]
×
1493
            else:
1494
                dims = [val]
4✔
1495
                dims_runtime = [runtime_val]
4✔
1496

1497
        dtype_arg = None
4✔
1498
        order = "C"  # Default to C-order (row-major)
4✔
1499
        explicit_strides = None
4✔
1500
        if len(node.args) > 1:
4✔
1501
            dtype_arg = node.args[1]
×
1502

1503
        for kw in node.keywords:
4✔
1504
            if kw.arg == "dtype":
4✔
1505
                dtype_arg = kw.value
4✔
1506
            elif kw.arg == "order":
4✔
1507
                if isinstance(kw.value, ast.Constant):
4✔
1508
                    order = kw.value.value
4✔
1509
            elif kw.arg == "strides":
4✔
1510
                # Parse explicit strides tuple/list
1511
                if isinstance(kw.value, (ast.Tuple, ast.List)):
4✔
1512
                    explicit_strides = [
4✔
1513
                        self._shape_to_runtime_expr(elt) for elt in kw.value.elts
1514
                    ]
1515

1516
        element_type = element_type_from_ast_node(dtype_arg, self.container_table)
4✔
1517

1518
        # Use explicit strides if provided, otherwise compute from order
1519
        if explicit_strides is not None:
4✔
1520
            # Convert byte strides to element strides by dividing by element size
1521
            element_size = self.builder.get_sizeof(element_type)
4✔
1522
            strides = [f"(({s}) / {element_size})" for s in explicit_strides]
4✔
1523
        else:
1524
            strides = self._compute_strides(dims, order)
4✔
1525

1526
        return self._create_array_temp(
4✔
1527
            dims,
1528
            element_type,
1529
            zero_init=(func_name == "zeros"),
1530
            ones_init=(func_name == "ones"),
1531
            shapes_runtime=dims_runtime,
1532
            strides=strides,
1533
        )
1534

1535
    def _handle_numpy_empty_like(self, node, func_name):
4✔
1536
        """Handle np.empty_like."""
1537
        prototype_arg = node.args[0]
4✔
1538
        prototype_name = self.visit(prototype_arg)
4✔
1539

1540
        dims = []
4✔
1541
        if prototype_name in self.tensor_table:
4✔
1542
            dims = self.tensor_table[prototype_name].shape
4✔
1543

1544
        dtype_arg = None
4✔
1545
        order = "C"  # Default to C-order
4✔
1546
        if len(node.args) > 1:
4✔
1547
            dtype_arg = node.args[1]
×
1548

1549
        for kw in node.keywords:
4✔
1550
            if kw.arg == "dtype":
4✔
1551
                dtype_arg = kw.value
4✔
1552
            elif kw.arg == "order":
4✔
1553
                if isinstance(kw.value, ast.Constant):
4✔
1554
                    order = kw.value.value
4✔
1555

1556
        element_type = None
4✔
1557
        if dtype_arg:
4✔
1558
            element_type = element_type_from_ast_node(dtype_arg, self.container_table)
4✔
1559
        else:
1560
            if prototype_name in self.container_table:
4✔
1561
                sym_type = self.container_table[prototype_name]
4✔
1562
                if isinstance(sym_type, Pointer) and sym_type.has_pointee_type():
4✔
1563
                    element_type = sym_type.pointee_type
4✔
1564

1565
        if element_type is None:
4✔
1566
            element_type = Scalar(PrimitiveType.Double)
×
1567

1568
        strides = self._compute_strides(dims, order)
4✔
1569
        return self._create_array_temp(
4✔
1570
            dims, element_type, zero_init=False, ones_init=False, strides=strides
1571
        )
1572

1573
    def _handle_numpy_zeros_like(self, node, func_name):
4✔
1574
        """Handle np.zeros_like."""
1575
        prototype_arg = node.args[0]
4✔
1576
        prototype_name = self.visit(prototype_arg)
4✔
1577

1578
        dims = []
4✔
1579
        if prototype_name in self.tensor_table:
4✔
1580
            dims = self.tensor_table[prototype_name].shape
4✔
1581

1582
        dtype_arg = None
4✔
1583
        order = "C"  # Default to C-order
4✔
1584
        if len(node.args) > 1:
4✔
1585
            dtype_arg = node.args[1]
×
1586

1587
        for kw in node.keywords:
4✔
1588
            if kw.arg == "dtype":
4✔
1589
                dtype_arg = kw.value
4✔
1590
            elif kw.arg == "order":
4✔
1591
                if isinstance(kw.value, ast.Constant):
4✔
1592
                    order = kw.value.value
4✔
1593

1594
        element_type = None
4✔
1595
        if dtype_arg:
4✔
1596
            element_type = element_type_from_ast_node(dtype_arg, self.container_table)
4✔
1597
        else:
1598
            if prototype_name in self.container_table:
4✔
1599
                sym_type = self.container_table[prototype_name]
4✔
1600
                if isinstance(sym_type, Pointer) and sym_type.has_pointee_type():
4✔
1601
                    element_type = sym_type.pointee_type
4✔
1602

1603
        if element_type is None:
4✔
1604
            element_type = Scalar(PrimitiveType.Double)
×
1605

1606
        strides = self._compute_strides(dims, order)
4✔
1607
        return self._create_array_temp(
4✔
1608
            dims, element_type, zero_init=True, ones_init=False, strides=strides
1609
        )
1610

1611
    def _handle_numpy_eye(self, node, func_name):
4✔
1612
        """Handle np.eye."""
1613
        N_arg = node.args[0]
4✔
1614
        N_str = self.visit(N_arg)
4✔
1615
        N_runtime = self._shape_to_runtime_expr(N_arg)
4✔
1616

1617
        M_str = N_str
4✔
1618
        M_arg = N_arg  # Default M = N
4✔
1619
        if len(node.args) > 1:
4✔
1620
            M_arg = node.args[1]
×
1621
            M_str = self.visit(M_arg)
×
1622

1623
        k_str = "0"
4✔
1624
        if len(node.args) > 2:
4✔
1625
            k_str = self.visit(node.args[2])
×
1626

1627
        dtype_arg = None
4✔
1628
        for kw in node.keywords:
4✔
1629
            if kw.arg == "M":
4✔
1630
                M_arg = kw.value
4✔
1631
                M_str = self.visit(M_arg)
4✔
1632
                if M_str == "None":
4✔
1633
                    M_str = N_str
4✔
1634
                    M_arg = N_arg
4✔
1635
            elif kw.arg == "k":
4✔
1636
                k_str = self.visit(kw.value)
4✔
1637
            elif kw.arg == "dtype":
4✔
1638
                dtype_arg = kw.value
4✔
1639

1640
        M_runtime = self._shape_to_runtime_expr(M_arg)
4✔
1641

1642
        element_type = element_type_from_ast_node(dtype_arg, self.container_table)
4✔
1643

1644
        ptr_name = self._create_array_temp(
4✔
1645
            [N_str, M_str],
1646
            element_type,
1647
            zero_init=True,
1648
            shapes_runtime=[N_runtime, M_runtime],
1649
        )
1650

1651
        loop_var = f"_i_{self._get_unique_id()}"
4✔
1652
        if not self.builder.exists(loop_var):
4✔
1653
            self.builder.add_container(loop_var, Scalar(PrimitiveType.Int64), False)
4✔
1654
            self.container_table[loop_var] = Scalar(PrimitiveType.Int64)
4✔
1655

1656
        self.builder.begin_for(loop_var, "0", N_str, "1")
4✔
1657

1658
        cond = f"(({loop_var} + {k_str}) >= 0) & (({loop_var} + {k_str}) < {M_str})"
4✔
1659
        self.builder.begin_if(cond)
4✔
1660

1661
        val = "1.0"
4✔
1662
        if element_type.primitive_type in [
4✔
1663
            PrimitiveType.Int64,
1664
            PrimitiveType.Int32,
1665
            PrimitiveType.Int8,
1666
            PrimitiveType.Int16,
1667
            PrimitiveType.UInt64,
1668
            PrimitiveType.UInt32,
1669
            PrimitiveType.UInt8,
1670
            PrimitiveType.UInt16,
1671
        ]:
1672
            val = "1"
×
1673

1674
        block_assign = self.builder.add_block()
4✔
1675
        t_const = self.builder.add_constant(block_assign, val, element_type)
4✔
1676
        t_arr = self.builder.add_access(block_assign, ptr_name)
4✔
1677
        flat_index = f"(({loop_var}) * ({M_str}) + ({loop_var}) + ({k_str}))"
4✔
1678
        subset = flat_index
4✔
1679

1680
        t_task = self.builder.add_tasklet(
4✔
1681
            block_assign, TaskletCode.assign, ["_in"], ["_out"]
1682
        )
1683
        self.builder.add_memlet(
4✔
1684
            block_assign, t_const, "void", t_task, "_in", "", element_type
1685
        )
1686
        self.builder.add_memlet(block_assign, t_task, "_out", t_arr, "void", subset)
4✔
1687

1688
        self.builder.end_if()
4✔
1689
        self.builder.end_for()
4✔
1690

1691
        return ptr_name
4✔
1692

1693
    def _handle_numpy_binary_op(self, node, func_name):
4✔
1694
        """Handle np.add, np.subtract, np.multiply, np.divide, etc."""
1695
        args = [self.visit(arg) for arg in node.args]
4✔
1696
        if len(args) != 2:
4✔
1697
            raise NotImplementedError(
×
1698
                f"Numpy function {func_name} requires 2 arguments"
1699
            )
1700

1701
        op_map = {
4✔
1702
            "add": "add",
1703
            "subtract": "sub",
1704
            "multiply": "mul",
1705
            "divide": "div",
1706
            "power": "pow",
1707
            "minimum": "min",
1708
            "maximum": "max",
1709
        }
1710
        return self.handle_array_binary_op(op_map[func_name], args[0], args[1])
4✔
1711

1712
    def _handle_numpy_unary_op(self, node, func_name):
4✔
1713
        """Handle np.exp, np.sqrt, np.abs, etc."""
1714
        args = [self.visit(arg) for arg in node.args]
4✔
1715
        if len(args) != 1:
4✔
1716
            raise NotImplementedError(f"Numpy function {func_name} requires 1 argument")
×
1717

1718
        op_name = func_name
4✔
1719
        if op_name == "absolute":
4✔
1720
            op_name = "abs"
×
1721

1722
        return self.handle_array_unary_op(op_name, args[0])
4✔
1723

1724
    def _handle_numpy_where(self, node, func_name):
4✔
1725
        """Handle np.where(condition, x, y) - elementwise ternary selection."""
1726
        if len(node.args) != 3:
4✔
1727
            raise NotImplementedError("np.where requires 3 arguments (condition, x, y)")
×
1728

1729
        cond_name = self.visit(node.args[0])
4✔
1730
        x_name = self.visit(node.args[1])
4✔
1731
        y_name = self.visit(node.args[2])
4✔
1732

1733
        shape = []
4✔
1734
        dtype = Scalar(PrimitiveType.Double)
4✔
1735

1736
        if cond_name in self.tensor_table:
4✔
1737
            shape = self.tensor_table[cond_name].shape
4✔
1738

1739
        if not shape and y_name in self.tensor_table:
4✔
1740
            shape = self.tensor_table[y_name].shape
×
1741

1742
        if not shape and x_name in self.tensor_table:
4✔
1743
            shape = self.tensor_table[x_name].shape
×
1744

1745
        if not shape:
4✔
1746
            raise NotImplementedError("np.where requires at least one array argument")
×
1747

1748
        if y_name in self.container_table:
4✔
1749
            y_type = self.container_table[y_name]
4✔
1750
            if isinstance(y_type, Pointer) and y_type.has_pointee_type():
4✔
1751
                dtype = y_type.pointee_type
4✔
1752
            elif isinstance(y_type, Scalar):
×
1753
                dtype = y_type
×
1754

1755
        tmp_name = self._create_array_temp(shape, dtype)
4✔
1756
        tmp_tensor = self.tensor_table[tmp_name]
4✔
1757

1758
        loop_vars = []
4✔
1759
        for i, dim in enumerate(shape):
4✔
1760
            loop_var = f"_where_i{i}_{self._get_unique_id()}"
4✔
1761
            if not self.builder.exists(loop_var):
4✔
1762
                self.builder.add_container(loop_var, Scalar(PrimitiveType.Int64), False)
4✔
1763
                self.container_table[loop_var] = Scalar(PrimitiveType.Int64)
4✔
1764
            loop_vars.append(loop_var)
4✔
1765
            self.builder.begin_for(loop_var, "0", str(dim), "1")
4✔
1766
        multi_dim_subset = ",".join(loop_vars)
4✔
1767

1768
        cond_tmp = f"_where_cond_{self._get_unique_id()}"
4✔
1769
        self.builder.add_container(cond_tmp, Scalar(PrimitiveType.Bool), False)
4✔
1770
        self.container_table[cond_tmp] = Scalar(PrimitiveType.Bool)
4✔
1771

1772
        block_cond = self.builder.add_block()
4✔
1773
        if cond_name in self.tensor_table:
4✔
1774
            cond_tensor = self.tensor_table[cond_name]
4✔
1775
            t_cond_arr = self.builder.add_access(block_cond, cond_name)
4✔
1776
            t_cond_out = self.builder.add_access(block_cond, cond_tmp)
4✔
1777
            t_cond_task = self.builder.add_tasklet(
4✔
1778
                block_cond, TaskletCode.assign, ["_in"], ["_out"]
1779
            )
1780
            self.builder.add_memlet(
4✔
1781
                block_cond,
1782
                t_cond_arr,
1783
                "void",
1784
                t_cond_task,
1785
                "_in",
1786
                multi_dim_subset,
1787
                cond_tensor,
1788
            )
1789
            self.builder.add_memlet(
4✔
1790
                block_cond, t_cond_task, "_out", t_cond_out, "void", ""
1791
            )
1792
        else:
1793
            t_cond_src, cond_sub = self._add_read(block_cond, cond_name)
×
1794
            t_cond_out = self.builder.add_access(block_cond, cond_tmp)
×
1795
            t_cond_task = self.builder.add_tasklet(
×
1796
                block_cond, TaskletCode.assign, ["_in"], ["_out"]
1797
            )
1798
            self.builder.add_memlet(
×
1799
                block_cond, t_cond_src, "void", t_cond_task, "_in", cond_sub
1800
            )
1801
            self.builder.add_memlet(
×
1802
                block_cond, t_cond_task, "_out", t_cond_out, "void", ""
1803
            )
1804

1805
        self.builder.begin_if(f"{cond_tmp} == true")
4✔
1806

1807
        block_true = self.builder.add_block()
4✔
1808
        t_out_true = self.builder.add_access(block_true, tmp_name)
4✔
1809
        if x_name in self.tensor_table:
4✔
1810
            x_tensor = self.tensor_table[x_name]
4✔
1811
            t_x = self.builder.add_access(block_true, x_name)
4✔
1812
            t_task_true = self.builder.add_tasklet(
4✔
1813
                block_true, TaskletCode.assign, ["_in"], ["_out"]
1814
            )
1815
            self.builder.add_memlet(
4✔
1816
                block_true, t_x, "void", t_task_true, "_in", multi_dim_subset, x_tensor
1817
            )
1818
        else:
1819
            t_x, x_sub = self._add_read(block_true, x_name)
4✔
1820
            t_task_true = self.builder.add_tasklet(
4✔
1821
                block_true, TaskletCode.assign, ["_in"], ["_out"]
1822
            )
1823
            self.builder.add_memlet(block_true, t_x, "void", t_task_true, "_in", x_sub)
4✔
1824
        self.builder.add_memlet(
4✔
1825
            block_true,
1826
            t_task_true,
1827
            "_out",
1828
            t_out_true,
1829
            "void",
1830
            multi_dim_subset,
1831
            tmp_tensor,
1832
        )
1833

1834
        self.builder.begin_else()
4✔
1835

1836
        # False branch: read from y, write to output
1837
        block_false = self.builder.add_block()
4✔
1838
        t_out_false = self.builder.add_access(block_false, tmp_name)
4✔
1839
        if y_name in self.tensor_table:
4✔
1840
            y_tensor = self.tensor_table[y_name]
4✔
1841
            t_y = self.builder.add_access(block_false, y_name)
4✔
1842
            t_task_false = self.builder.add_tasklet(
4✔
1843
                block_false, TaskletCode.assign, ["_in"], ["_out"]
1844
            )
1845
            self.builder.add_memlet(
4✔
1846
                block_false,
1847
                t_y,
1848
                "void",
1849
                t_task_false,
1850
                "_in",
1851
                multi_dim_subset,
1852
                y_tensor,
1853
            )
1854
        else:
1855
            t_y, y_sub = self._add_read(block_false, y_name)
4✔
1856
            t_task_false = self.builder.add_tasklet(
4✔
1857
                block_false, TaskletCode.assign, ["_in"], ["_out"]
1858
            )
1859
            self.builder.add_memlet(
4✔
1860
                block_false, t_y, "void", t_task_false, "_in", y_sub
1861
            )
1862
        self.builder.add_memlet(
4✔
1863
            block_false,
1864
            t_task_false,
1865
            "_out",
1866
            t_out_false,
1867
            "void",
1868
            multi_dim_subset,
1869
            tmp_tensor,
1870
        )
1871

1872
        self.builder.end_if()
4✔
1873

1874
        for _ in loop_vars:
4✔
1875
            self.builder.end_for()
4✔
1876

1877
        return tmp_name
4✔
1878

1879
    def _handle_numpy_clip(self, node, func_name):
4✔
1880
        """Handle np.clip(a, a_min, a_max) - elementwise clipping."""
1881
        if len(node.args) != 3:
4✔
1882
            raise NotImplementedError("np.clip requires 3 arguments (a, a_min, a_max)")
×
1883

1884
        arr_name = self.visit(node.args[0])
4✔
1885
        a_min = self.visit(node.args[1])
4✔
1886
        a_max = self.visit(node.args[2])
4✔
1887

1888
        tmp1 = self.handle_array_binary_op("max", arr_name, a_min)
4✔
1889
        result = self.handle_array_binary_op("min", tmp1, a_max)
4✔
1890

1891
        return result
4✔
1892

1893
    def _handle_numpy_matmul(self, node, func_name):
4✔
1894
        """Handle np.matmul, np.dot."""
1895
        if len(node.args) != 2:
4✔
1896
            raise NotImplementedError("matmul/dot requires 2 arguments")
×
1897
        return self._handle_matmul_helper(node.args[0], node.args[1])
4✔
1898

1899
    def handle_numpy_matmul_op(self, left_node, right_node):
4✔
1900
        """Handle the @ operator for matrix multiplication."""
1901
        return self._handle_matmul_helper(left_node, right_node)
4✔
1902

1903
    def _handle_matmul_helper(self, left_node, right_node):
4✔
1904
        """Helper for matrix multiplication operations."""
1905
        res_a = self.parse_arg(left_node)
4✔
1906
        res_b = self.parse_arg(right_node)
4✔
1907

1908
        if not res_a[0]:
4✔
1909
            left_name = self.visit(left_node)
4✔
1910
            left_node = ast.Name(id=left_name)
4✔
1911
            res_a = self.parse_arg(left_node)
4✔
1912

1913
        if not res_b[0]:
4✔
1914
            right_name = self.visit(right_node)
4✔
1915
            right_node = ast.Name(id=right_name)
4✔
1916
            res_b = self.parse_arg(right_node)
4✔
1917

1918
        name_a, subset_a, shape_a, indices_a = res_a
4✔
1919
        name_b, subset_b, shape_b, indices_b = res_b
4✔
1920

1921
        if not name_a or not name_b:
4✔
1922
            raise NotImplementedError("Could not resolve matmul operands")
×
1923

1924
        real_shape_a = shape_a
4✔
1925
        real_shape_b = shape_b
4✔
1926

1927
        ndim_a = len(real_shape_a)
4✔
1928
        ndim_b = len(real_shape_b)
4✔
1929

1930
        output_shape = []
4✔
1931
        is_scalar = False
4✔
1932

1933
        if ndim_a == 1 and ndim_b == 1:
4✔
1934
            is_scalar = True
4✔
1935
            output_shape = []
4✔
1936
        elif ndim_a == 2 and ndim_b == 2:
4✔
1937
            output_shape = [real_shape_a[0], real_shape_b[1]]
4✔
1938
        elif ndim_a == 2 and ndim_b == 1:
4✔
1939
            output_shape = [real_shape_a[0]]
4✔
1940
        elif ndim_a == 1 and ndim_b == 2:
4✔
1941
            output_shape = [real_shape_b[1]]
×
1942
        elif ndim_a > 2 or ndim_b > 2:
4✔
1943
            if ndim_a == ndim_b:
4✔
1944
                output_shape = list(real_shape_a[:-2]) + [
4✔
1945
                    real_shape_a[-2],
1946
                    real_shape_b[-1],
1947
                ]
1948
            else:
1949
                raise NotImplementedError(
×
1950
                    "Broadcasting with different ranks not fully supported yet"
1951
                )
1952
        else:
1953
            raise NotImplementedError(
×
1954
                f"Matmul with ranks {ndim_a} and {ndim_b} not supported"
1955
            )
1956

1957
        dtype_a = self._ev._element_type(name_a)
4✔
1958
        dtype_b = self._ev._element_type(name_b)
4✔
1959
        dtype = promote_element_types(dtype_a, dtype_b)
4✔
1960

1961
        if is_scalar:
4✔
1962
            tmp_name = f"_tmp_{self._get_unique_id()}"
4✔
1963
            self.builder.add_container(tmp_name, dtype, False)
4✔
1964
            self.container_table[tmp_name] = dtype
4✔
1965
        else:
1966
            tmp_name = self._create_array_temp(output_shape, dtype)
4✔
1967

1968
        if ndim_a > 2 or ndim_b > 2:
4✔
1969
            batch_dims = ndim_a - 2
4✔
1970
            loop_vars = []
4✔
1971

1972
            for i in range(batch_dims):
4✔
1973
                loop_var = f"_i{self._get_unique_id()}"
4✔
1974
                self.builder.add_container(loop_var, Scalar(PrimitiveType.Int64), False)
4✔
1975
                loop_vars.append(loop_var)
4✔
1976
                dim_size = real_shape_a[i]
4✔
1977
                self.builder.begin_for(loop_var, "0", str(dim_size), "1")
4✔
1978

1979
            def make_slice(name, indices):
4✔
1980
                elts = []
4✔
1981
                for idx in indices:
4✔
1982
                    if idx == ":":
4✔
1983
                        elts.append(ast.Slice())
4✔
1984
                    else:
1985
                        elts.append(ast.Name(id=idx))
4✔
1986

1987
                return ast.Subscript(
4✔
1988
                    value=ast.Name(id=name), slice=ast.Tuple(elts=elts), ctx=ast.Load()
1989
                )
1990

1991
            indices = loop_vars + [":", ":"]
4✔
1992
            slice_a = make_slice(name_a, indices)
4✔
1993
            slice_b = make_slice(name_b, indices)
4✔
1994
            slice_c = make_slice(tmp_name, indices)
4✔
1995

1996
            self.handle_gemm(
4✔
1997
                slice_c, ast.BinOp(left=slice_a, op=ast.MatMult(), right=slice_b)
1998
            )
1999

2000
            for _ in range(batch_dims):
4✔
2001
                self.builder.end_for()
4✔
2002
        else:
2003
            if is_scalar:
4✔
2004
                self.handle_dot(
4✔
2005
                    tmp_name,
2006
                    ast.BinOp(left=left_node, op=ast.MatMult(), right=right_node),
2007
                )
2008
            else:
2009
                self.handle_gemm(
4✔
2010
                    tmp_name,
2011
                    ast.BinOp(left=left_node, op=ast.MatMult(), right=right_node),
2012
                )
2013

2014
        return tmp_name
4✔
2015

2016
    def _handle_numpy_outer(self, node, func_name):
4✔
2017
        """Handle np.outer."""
2018
        if len(node.args) != 2:
4✔
NEW
2019
            raise NotImplementedError("outer requires 2 arguments")
×
2020

2021
        arg0 = node.args[0]
4✔
2022
        arg1 = node.args[1]
4✔
2023

2024
        res_a = self.parse_arg(arg0)
4✔
2025
        res_b = self.parse_arg(arg1)
4✔
2026

2027
        if not res_a[0]:
4✔
2028
            left_name = self.visit(arg0)
×
2029
            arg0 = ast.Name(id=left_name)
×
2030
            res_a = self.parse_arg(arg0)
×
2031

2032
        if not res_b[0]:
4✔
2033
            right_name = self.visit(arg1)
×
2034
            arg1 = ast.Name(id=right_name)
×
2035
            res_b = self.parse_arg(arg1)
×
2036

2037
        name_a, subset_a, shape_a, indices_a = res_a
4✔
2038
        name_b, subset_b, shape_b, indices_b = res_b
4✔
2039

2040
        if not name_a or not name_b:
4✔
2041
            raise NotImplementedError("Could not resolve outer operands")
×
2042

2043
        def get_flattened_size_expr(name, indices, shapes):
4✔
2044
            size_expr = "1"
4✔
2045
            for s in shapes:
4✔
2046
                if size_expr == "1":
4✔
2047
                    size_expr = str(s)
4✔
2048
                else:
2049
                    size_expr = f"({size_expr} * {str(s)})"
×
2050
            return size_expr
4✔
2051

2052
        m_expr = get_flattened_size_expr(name_a, indices_a, shape_a)
4✔
2053
        n_expr = get_flattened_size_expr(name_b, indices_b, shape_b)
4✔
2054

2055
        dtype_a = self._ev._element_type(name_a)
4✔
2056
        dtype_b = self._ev._element_type(name_b)
4✔
2057
        dtype = promote_element_types(dtype_a, dtype_b)
4✔
2058

2059
        tmp_name = self._create_array_temp([m_expr, n_expr], dtype)
4✔
2060

2061
        new_call_node = ast.Call(
4✔
2062
            func=node.func, args=[arg0, arg1], keywords=node.keywords
2063
        )
2064

2065
        self.handle_outer(tmp_name, new_call_node)
4✔
2066

2067
        return tmp_name
4✔
2068

2069
    def handle_ufunc_outer(self, node, ufunc_name):
4✔
2070
        """Handle np.add.outer, np.subtract.outer, np.multiply.outer, etc."""
2071
        if len(node.args) != 2:
4✔
2072
            raise NotImplementedError(f"{ufunc_name}.outer requires 2 arguments")
×
2073

2074
        if ufunc_name == "multiply":
4✔
2075
            return self._handle_numpy_outer(node, "outer")
4✔
2076

2077
        op_map = {
4✔
2078
            "add": ("add", TaskletCode.fp_add, TaskletCode.int_add),
2079
            "subtract": ("sub", TaskletCode.fp_sub, TaskletCode.int_sub),
2080
            "divide": ("div", TaskletCode.fp_div, TaskletCode.int_sdiv),
2081
            "minimum": ("min", CMathFunction.fmin, TaskletCode.int_smin),
2082
            "maximum": ("max", CMathFunction.fmax, TaskletCode.int_smax),
2083
        }
2084

2085
        if ufunc_name not in op_map:
4✔
2086
            raise NotImplementedError(f"{ufunc_name}.outer not supported")
×
2087

2088
        op_name, fp_opcode, int_opcode = op_map[ufunc_name]
4✔
2089

2090
        arg0 = node.args[0]
4✔
2091
        arg1 = node.args[1]
4✔
2092

2093
        res_a = self.parse_arg(arg0)
4✔
2094
        res_b = self.parse_arg(arg1)
4✔
2095

2096
        if not res_a[0]:
4✔
2097
            left_name = self.visit(arg0)
×
2098
            arg0 = ast.Name(id=left_name)
×
2099
            res_a = self.parse_arg(arg0)
×
2100

2101
        if not res_b[0]:
4✔
2102
            right_name = self.visit(arg1)
×
2103
            arg1 = ast.Name(id=right_name)
×
2104
            res_b = self.parse_arg(arg1)
×
2105

2106
        name_a, subset_a, shape_a, indices_a = res_a
4✔
2107
        name_b, subset_b, shape_b, indices_b = res_b
4✔
2108

2109
        if not name_a or not name_b:
4✔
2110
            raise NotImplementedError("Could not resolve ufunc outer operands")
×
2111

2112
        def get_flattened_size_expr(shapes):
4✔
2113
            if not shapes:
4✔
2114
                return "1"
×
2115
            size_expr = str(shapes[0])
4✔
2116
            for s in shapes[1:]:
4✔
2117
                size_expr = f"({size_expr} * {str(s)})"
×
2118
            return size_expr
4✔
2119

2120
        m_expr = get_flattened_size_expr(shape_a)
4✔
2121
        n_expr = get_flattened_size_expr(shape_b)
4✔
2122

2123
        dtype_left = self._ev._element_type(name_a)
4✔
2124
        dtype_right = self._ev._element_type(name_b)
4✔
2125
        dtype = promote_element_types(dtype_left, dtype_right)
4✔
2126

2127
        is_int = dtype.primitive_type in [
4✔
2128
            PrimitiveType.Int64,
2129
            PrimitiveType.Int32,
2130
            PrimitiveType.Int8,
2131
            PrimitiveType.Int16,
2132
            PrimitiveType.UInt64,
2133
            PrimitiveType.UInt32,
2134
            PrimitiveType.UInt8,
2135
            PrimitiveType.UInt16,
2136
        ]
2137

2138
        tmp_name = self._create_array_temp([m_expr, n_expr], dtype)
4✔
2139

2140
        i_var = self.builder.find_new_name("_outer_i_")
4✔
2141
        j_var = self.builder.find_new_name("_outer_j_")
4✔
2142

2143
        if not self.builder.exists(i_var):
4✔
2144
            self.builder.add_container(i_var, Scalar(PrimitiveType.Int64), False)
4✔
2145
            self.container_table[i_var] = Scalar(PrimitiveType.Int64)
4✔
2146
        if not self.builder.exists(j_var):
4✔
2147
            self.builder.add_container(j_var, Scalar(PrimitiveType.Int64), False)
4✔
2148
            self.container_table[j_var] = Scalar(PrimitiveType.Int64)
4✔
2149

2150
        def compute_linear_index(name, subset, indices, loop_var):
4✔
2151
            if not indices:
4✔
2152
                return loop_var
4✔
2153

2154
            if name in self.tensor_table:
4✔
2155
                info = self.tensor_table[name]
4✔
2156
                shapes = info.shape
4✔
2157
                ndim = len(shapes)
4✔
2158
            else:
2159
                shapes = []
×
2160
                ndim = 0
×
2161

2162
            if ndim == 0:
4✔
2163
                return loop_var
×
2164

2165
            strides = []
4✔
2166
            current_stride = "1"
4✔
2167
            for i in range(ndim - 1, -1, -1):
4✔
2168
                strides.insert(0, current_stride)
4✔
2169
                if i > 0:
4✔
2170
                    dim_size = shapes[i] if i < len(shapes) else f"_{name}_shape_{i}"
4✔
2171
                    if current_stride == "1":
4✔
2172
                        current_stride = str(dim_size)
4✔
2173
                    else:
2174
                        current_stride = f"({current_stride} * {dim_size})"
×
2175

2176
            terms = []
4✔
2177
            loop_var_used = False
4✔
2178

2179
            for i, idx in enumerate(indices):
4✔
2180
                stride = strides[i] if i < len(strides) else "1"
4✔
2181
                start = subset[i] if i < len(subset) else "0"
4✔
2182

2183
                if isinstance(idx, ast.Slice):
4✔
2184
                    if stride == "1":
4✔
2185
                        term = f"({start} + {loop_var})"
4✔
2186
                    else:
2187
                        term = f"(({start} + {loop_var}) * {stride})"
4✔
2188
                    loop_var_used = True
4✔
2189
                else:
2190
                    if stride == "1":
4✔
2191
                        term = start
4✔
2192
                    else:
2193
                        term = f"({start} * {stride})"
4✔
2194

2195
                terms.append(term)
4✔
2196

2197
            if not terms:
4✔
2198
                return loop_var
×
2199

2200
            result = terms[0]
4✔
2201
            for t in terms[1:]:
4✔
2202
                result = f"({result} + {t})"
4✔
2203

2204
            return result
4✔
2205

2206
        self.builder.begin_for(i_var, "0", m_expr, "1")
4✔
2207
        self.builder.begin_for(j_var, "0", n_expr, "1")
4✔
2208

2209
        block = self.builder.add_block()
4✔
2210

2211
        t_a = self.builder.add_access(block, name_a)
4✔
2212
        t_b = self.builder.add_access(block, name_b)
4✔
2213
        t_c = self.builder.add_access(block, tmp_name)
4✔
2214

2215
        if ufunc_name in ["minimum", "maximum"]:
4✔
2216
            if is_int:
4✔
2217
                t_task = self.builder.add_tasklet(
4✔
2218
                    block, int_opcode, ["_in1", "_in2"], ["_out"]
2219
                )
2220
            else:
2221
                t_task = self.builder.add_cmath(block, fp_opcode, dtype.primitive_type)
4✔
2222
        else:
2223
            tasklet_code = int_opcode if is_int else fp_opcode
4✔
2224
            t_task = self.builder.add_tasklet(
4✔
2225
                block, tasklet_code, ["_in1", "_in2"], ["_out"]
2226
            )
2227

2228
        a_index = compute_linear_index(name_a, subset_a, indices_a, i_var)
4✔
2229
        b_index = compute_linear_index(name_b, subset_b, indices_b, j_var)
4✔
2230

2231
        self.builder.add_memlet(block, t_a, "void", t_task, "_in1", a_index)
4✔
2232
        self.builder.add_memlet(block, t_b, "void", t_task, "_in2", b_index)
4✔
2233

2234
        flat_index = f"(({i_var}) * ({n_expr}) + ({j_var}))"
4✔
2235
        self.builder.add_memlet(block, t_task, "_out", t_c, "void", flat_index)
4✔
2236

2237
        self.builder.end_for()
4✔
2238
        self.builder.end_for()
4✔
2239

2240
        return tmp_name
4✔
2241

2242
    def _handle_numpy_reduce(self, node, func_name):
4✔
2243
        """Handle np.sum, np.max, np.min, np.mean, np.std."""
2244
        args = node.args
4✔
2245
        keywords = {kw.arg: kw.value for kw in node.keywords}
4✔
2246

2247
        array_node = args[0]
4✔
2248
        array_name = self.visit(array_node)
4✔
2249

2250
        if array_name not in self.tensor_table:
4✔
2251
            raise ValueError(f"Reduction input must be an array, got {array_name}")
×
2252

2253
        # For mean and std, we need float64 input and output (NumPy behavior)
2254
        # Cast input to float64 if needed
2255
        if func_name in ("mean", "std"):
4✔
2256
            float64_type = Scalar(PrimitiveType.Double)
4✔
2257
            array_name = self._cast_array(array_name, float64_type)
4✔
2258

2259
        input_tensor = self.tensor_table[array_name]
4✔
2260
        input_shape = input_tensor.shape
4✔
2261
        ndim = len(input_shape)
4✔
2262

2263
        axis = None
4✔
2264
        if len(args) > 1:
4✔
2265
            axis = args[1]
×
2266
        elif "axis" in keywords:
4✔
2267
            axis = keywords["axis"]
4✔
2268

2269
        keepdims = False
4✔
2270
        if "keepdims" in keywords:
4✔
2271
            keepdims_node = keywords["keepdims"]
4✔
2272
            if isinstance(keepdims_node, ast.Constant):
4✔
2273
                keepdims = bool(keepdims_node.value)
4✔
2274

2275
        axes = []
4✔
2276
        if axis is None:
4✔
2277
            axes = list(range(ndim))
4✔
2278
        elif isinstance(axis, ast.Constant):
4✔
2279
            val = axis.value
4✔
2280
            if val < 0:
4✔
2281
                val += ndim
×
2282
            axes = [val]
4✔
2283
        elif isinstance(axis, ast.Tuple):
4✔
2284
            for elt in axis.elts:
×
2285
                if isinstance(elt, ast.Constant):
×
2286
                    val = elt.value
×
2287
                    if val < 0:
×
2288
                        val += ndim
×
2289
                    axes.append(val)
×
2290
        elif (
4✔
2291
            isinstance(axis, ast.UnaryOp)
2292
            and isinstance(axis.op, ast.USub)
2293
            and isinstance(axis.operand, ast.Constant)
2294
        ):
2295
            val = -axis.operand.value
4✔
2296
            if val < 0:
4✔
2297
                val += ndim
4✔
2298
            axes = [val]
4✔
2299
        else:
2300
            try:
×
2301
                val = int(self.visit(axis))
×
2302
                if val < 0:
×
2303
                    val += ndim
×
2304
                axes = [val]
×
2305
            except:
×
2306
                raise NotImplementedError("Dynamic axis not supported")
×
2307

2308
        output_shape = []
4✔
2309
        for i in range(ndim):
4✔
2310
            if i in axes:
4✔
2311
                if keepdims:
4✔
2312
                    output_shape.append("1")
4✔
2313
            else:
2314
                output_shape.append(input_shape[i])
4✔
2315

2316
        dtype = self._ev._element_type(array_name)
4✔
2317

2318
        if not output_shape:
4✔
2319
            tmp_name = f"_tmp_{self._get_unique_id()}"
4✔
2320
            self.builder.add_container(tmp_name, dtype, False)
4✔
2321
            self.container_table[tmp_name] = dtype
4✔
2322
            self.tensor_table[tmp_name] = Tensor(dtype, [])
4✔
2323
        else:
2324
            output_strides = self._compute_strides(output_shape, "C")
4✔
2325
            tmp_name = self._create_array_temp(
4✔
2326
                output_shape, dtype, strides=output_strides
2327
            )
2328

2329
        output_tensor = self.tensor_table[tmp_name]
4✔
2330
        self.builder.add_reduce_op(
4✔
2331
            func_name, array_name, input_tensor, tmp_name, output_tensor, axes, keepdims
2332
        )
2333

2334
        return tmp_name
4✔
2335

2336
    # ========== Einsum Operations ==========
2337

2338
    def _parse_einsum_subscripts(self, subscripts, operand_shapes):
4✔
2339
        """Parse einsum subscripts string and return parsed components.
2340

2341
        Args:
2342
            subscripts: Einsum notation string, e.g., "ij,jk->ik" or "ij,jk"
2343
            operand_shapes: List of shapes for each operand
2344

2345
        Returns:
2346
            Tuple of (input_subscripts, output_subscripts, index_to_dim)
2347
            - input_subscripts: List of index strings per operand, e.g., ["ij", "jk"]
2348
            - output_subscripts: Output index string, e.g., "ik"
2349
            - index_to_dim: Dict mapping index char to dimension size string
2350
        """
2351
        # Remove whitespace
2352
        subscripts = subscripts.replace(" ", "")
4✔
2353

2354
        # Split into inputs and output
2355
        if "->" in subscripts:
4✔
2356
            input_part, output_subscripts = subscripts.split("->")
4✔
2357
        else:
2358
            input_part = subscripts
×
2359
            output_subscripts = None  # Implicit output
×
2360

2361
        # Split inputs by comma
2362
        input_subscripts = input_part.split(",")
4✔
2363

2364
        if len(input_subscripts) != len(operand_shapes):
4✔
2365
            raise ValueError(
×
2366
                f"Number of operands ({len(operand_shapes)}) does not match "
2367
                f"number of subscripts ({len(input_subscripts)})"
2368
            )
2369

2370
        # Map each index to its dimension size
2371
        index_to_dim = {}
4✔
2372
        for subscript, shape in zip(input_subscripts, operand_shapes):
4✔
2373
            if len(subscript) != len(shape):
4✔
2374
                raise ValueError(
×
2375
                    f"Subscript '{subscript}' has {len(subscript)} indices but "
2376
                    f"operand has {len(shape)} dimensions"
2377
                )
2378
            for idx_char, dim_size in zip(subscript, shape):
4✔
2379
                if idx_char in index_to_dim:
4✔
2380
                    # Validate dimensions match (at least symbolically)
2381
                    existing = index_to_dim[idx_char]
4✔
2382
                    if str(existing) != str(dim_size):
4✔
2383
                        # Could be symbolic - just warn or trust the user
2384
                        pass
×
2385
                else:
2386
                    index_to_dim[idx_char] = dim_size
4✔
2387

2388
        # Compute implicit output if not provided
2389
        if output_subscripts is None:
4✔
2390
            output_subscripts = self._compute_implicit_output(input_subscripts)
×
2391

2392
        return input_subscripts, output_subscripts, index_to_dim
4✔
2393

2394
    def _compute_implicit_output(self, input_subscripts):
4✔
2395
        """Compute implicit output indices (sorted indices appearing exactly once).
2396

2397
        Args:
2398
            input_subscripts: List of index strings, e.g., ["ij", "jk"]
2399

2400
        Returns:
2401
            Output index string with sorted non-contracted indices, e.g., "ik"
2402
        """
2403
        counts = {}
×
2404
        for subscript in input_subscripts:
×
2405
            for idx in subscript:
×
2406
                counts[idx] = counts.get(idx, 0) + 1
×
2407

2408
        # Output = sorted indices with count == 1 (non-contracted)
2409
        return "".join(sorted(idx for idx, cnt in counts.items() if cnt == 1))
×
2410

2411
    def _handle_numpy_einsum(self, node, func_name):
4✔
2412
        """Handle np.einsum(subscripts, *operands) calls.
2413

2414
        Parses the subscripts string to extract index structure, computes output
2415
        shape, and emits an EinsumNode to the IR.
2416
        """
2417
        if len(node.args) < 2:
4✔
2418
            raise ValueError("np.einsum requires at least subscripts and one operand")
×
2419

2420
        # First argument is the subscripts string
2421
        subscripts_arg = node.args[0]
4✔
2422
        if not isinstance(subscripts_arg, ast.Constant) or not isinstance(
4✔
2423
            subscripts_arg.value, str
2424
        ):
2425
            raise NotImplementedError("np.einsum subscripts must be a string literal")
×
2426
        subscripts = subscripts_arg.value
4✔
2427

2428
        # Remaining arguments are operands
2429
        operand_nodes = node.args[1:]
4✔
2430
        operand_names = [self.visit(op) for op in operand_nodes]
4✔
2431

2432
        # Validate all operands are in tensor_table
2433
        for name in operand_names:
4✔
2434
            if name not in self.tensor_table:
4✔
2435
                raise ValueError(f"Einsum operand '{name}' not found in tensor_table")
×
2436

2437
        # Get shapes for all operands
2438
        operand_shapes = [self.tensor_table[name].shape for name in operand_names]
4✔
2439

2440
        # Parse subscripts
2441
        input_subscripts, output_subscripts, index_to_dim = (
4✔
2442
            self._parse_einsum_subscripts(subscripts, operand_shapes)
2443
        )
2444

2445
        # Build dimension specs: (indvar, init, bound) for each unique index
2446
        # Collect all unique indices in order of first appearance
2447
        seen_indices = []
4✔
2448
        for subscript in input_subscripts:
4✔
2449
            for idx in subscript:
4✔
2450
                if idx not in seen_indices:
4✔
2451
                    seen_indices.append(idx)
4✔
2452

2453
        dims = []
4✔
2454
        for idx in seen_indices:
4✔
2455
            dims.append((idx, "0", str(index_to_dim[idx])))
4✔
2456

2457
        # Build output indices (the index variables for output dimensions)
2458
        out_indices = list(output_subscripts)
4✔
2459

2460
        # Build input indices for each operand
2461
        in_indices = [list(subscript) for subscript in input_subscripts]
4✔
2462

2463
        # Compute output shape from output subscripts
2464
        output_shape = [str(index_to_dim[idx]) for idx in output_subscripts]
4✔
2465

2466
        # Determine element type (promote from inputs)
2467
        dtypes = [self._ev._element_type(name) for name in operand_names]
4✔
2468
        dtype = dtypes[0]
4✔
2469
        for dt in dtypes[1:]:
4✔
2470
            dtype = promote_element_types(dtype, dt)
4✔
2471

2472
        # Create output container
2473
        if output_shape:
4✔
2474
            output_strides = self._compute_strides(output_shape, "C")
4✔
2475
            tmp_name = self._create_array_temp(
4✔
2476
                output_shape, dtype, strides=output_strides, zero_init=True
2477
            )
2478
        else:
2479
            # Scalar output
2480
            tmp_name = f"_tmp_{self._get_unique_id()}"
4✔
2481
            self.builder.add_container(tmp_name, dtype, False)
4✔
2482
            self.container_table[tmp_name] = dtype
4✔
2483
            self.tensor_table[tmp_name] = Tensor(dtype, [])
4✔
2484

2485
        # Get tensor types for builder call
2486
        input_types = [self.tensor_table[name] for name in operand_names]
4✔
2487
        output_type = self.tensor_table[tmp_name]
4✔
2488

2489
        # Call builder.add_einsum
2490
        self.builder.add_einsum(
4✔
2491
            operand_names,
2492
            tmp_name,
2493
            dims,
2494
            out_indices,
2495
            in_indices,
2496
            input_types,
2497
            output_type,
2498
        )
2499

2500
        return tmp_name
4✔
2501

2502
    def handle_numpy_astype(self, node, array_name):
4✔
2503
        """Handle numpy array.astype(dtype) method calls."""
2504
        if len(node.args) < 1:
4✔
2505
            raise ValueError("astype requires at least one argument (dtype)")
×
2506

2507
        # Check for copy=False which we don't support (we always copy)
2508
        for kw in node.keywords:
4✔
2509
            if kw.arg == "copy":
4✔
2510
                if isinstance(kw.value, ast.Constant) and kw.value.value is False:
4✔
2511
                    raise NotImplementedError("astype with copy=False is not supported")
4✔
2512

2513
        dtype_arg = node.args[0]
4✔
2514
        target_dtype = element_type_from_ast_node(dtype_arg, self.container_table)
4✔
2515

2516
        if array_name not in self.tensor_table:
4✔
2517
            raise ValueError(f"Array {array_name} not found in tensor_table")
×
2518

2519
        input_tensor = self.tensor_table[array_name]
4✔
2520
        input_shape = input_tensor.shape
4✔
2521
        input_strides = getattr(input_tensor, "strides", None)
4✔
2522

2523
        # Determine output order: preserve F-order if input is F-contiguous
2524
        order = "C"
4✔
2525
        if input_strides and len(input_strides) >= 2 and len(input_shape) >= 2:
4✔
2526
            # F-order: first stride is 1, subsequent strides are products of preceding dims
2527
            f_strides = self._compute_strides(input_shape, "F")
4✔
2528
            if input_strides == f_strides:
4✔
2529
                order = "F"
×
2530

2531
        output_strides = self._compute_strides(input_shape, order)
4✔
2532
        tmp_name = self._create_array_temp(
4✔
2533
            input_shape, target_dtype, strides=output_strides
2534
        )
2535

2536
        output_tensor = self.tensor_table[tmp_name]
4✔
2537
        self.builder.add_cast_op(array_name, input_tensor, tmp_name, output_tensor)
4✔
2538

2539
        return tmp_name
4✔
2540

2541
    def handle_numpy_copy(self, node, array_name):
4✔
2542
        """Handle numpy array.copy() method calls using memcpy."""
2543
        if array_name not in self.tensor_table:
4✔
2544
            raise ValueError(f"Array {array_name} not found in tensor_table")
×
2545

2546
        input_tensor = self.tensor_table[array_name]
4✔
2547
        input_shape = input_tensor.shape
4✔
2548
        input_strides = getattr(input_tensor, "strides", None)
4✔
2549

2550
        element_type = Scalar(PrimitiveType.Double)
4✔
2551
        if array_name in self.container_table:
4✔
2552
            sym_type = self.container_table[array_name]
4✔
2553
            if isinstance(sym_type, Pointer) and sym_type.has_pointee_type():
4✔
2554
                element_type = sym_type.pointee_type
4✔
2555

2556
        # Determine output order: preserve F-order if input is F-contiguous
2557
        order = "C"
4✔
2558
        if input_strides and len(input_strides) >= 2 and len(input_shape) >= 2:
4✔
2559
            f_strides = self._compute_strides(input_shape, "F")
4✔
2560
            if input_strides == f_strides:
4✔
2561
                order = "F"
×
2562

2563
        output_strides = self._compute_strides(input_shape, order)
4✔
2564
        tmp_name = self._create_array_temp(
4✔
2565
            input_shape, element_type, strides=output_strides
2566
        )
2567

2568
        output_tensor = self.tensor_table[tmp_name]
4✔
2569
        # Workaround: "assign-op"
2570
        self.builder.add_cast_op(array_name, input_tensor, tmp_name, output_tensor)
4✔
2571

2572
        return tmp_name
4✔
2573

2574
    def _get_contiguous_output_strides(self, shape, input_strides):
4✔
2575
        """Get contiguous output strides, preserving C or F order if input is contiguous.
2576

2577
        For non-contiguous input strides (e.g., from slices), returns C-order strides.
2578
        This ensures output allocation matches the stride pattern.
2579

2580
        Args:
2581
            shape: Output shape
2582
            input_strides: Strides from input tensor
2583

2584
        Returns:
2585
            List of stride expressions for a contiguous output array
2586
        """
2587
        if not shape or not input_strides:
4✔
2588
            return self._compute_strides(shape, "C")
4✔
2589

2590
        # Preserve order if contiguous, otherwise default to C-order
2591
        c_strides = self._compute_strides(shape, "C")
4✔
2592
        if input_strides == c_strides:
4✔
2593
            return c_strides
4✔
2594
        f_strides = self._compute_strides(shape, "F")
4✔
2595
        if input_strides == f_strides:
4✔
2596
            return f_strides
×
2597
        return c_strides
4✔
2598

2599
    def _compute_strides(self, shape, order="C"):
4✔
2600
        """Compute strides for a given shape and memory order.
2601

2602
        Args:
2603
            shape: List of dimension sizes
2604
            order: "C" for row-major (default), "F" for column-major
2605

2606
        Returns:
2607
            List of stride expressions as strings
2608
        """
2609
        if not shape:
4✔
2610
            return []
4✔
2611

2612
        ndim = len(shape)
4✔
2613
        strides = []
4✔
2614

2615
        if order == "F":
4✔
2616
            # Column-major (Fortran order): stride[i] = product of shape[:i]
2617
            for dim_idx in range(ndim):
4✔
2618
                if dim_idx == 0:
4✔
2619
                    strides.append("1")
4✔
2620
                else:
2621
                    # Wrap each shape in parens to ensure correct precedence
2622
                    prefix_shapes = [f"({s})" for s in shape[:dim_idx]]
4✔
2623
                    if len(prefix_shapes) == 1:
4✔
2624
                        strides.append(prefix_shapes[0])
4✔
2625
                    else:
2626
                        strides.append("(" + " * ".join(prefix_shapes) + ")")
×
2627
        else:
2628
            # Row-major (C order): stride[i] = product of shape[i+1:]
2629
            for dim_idx in range(ndim):
4✔
2630
                if dim_idx == ndim - 1:
4✔
2631
                    strides.append("1")
4✔
2632
                else:
2633
                    # Wrap each shape in parens to ensure correct precedence
2634
                    suffix_shapes = [f"({s})" for s in shape[dim_idx + 1 :]]
4✔
2635
                    if len(suffix_shapes) == 1:
4✔
2636
                        strides.append(suffix_shapes[0])
4✔
2637
                    else:
2638
                        strides.append("(" + " * ".join(suffix_shapes) + ")")
4✔
2639

2640
        return strides
4✔
2641

2642
    def _is_contiguous(self, shape, strides):
4✔
2643
        """Check if strides represent a contiguous (C or F order) layout."""
2644
        if not shape or not strides:
4✔
2645
            return True
×
2646

2647
        def normalize(s):
4✔
2648
            # Normalize stride expression by removing spaces and outer parens
2649
            s = s.replace(" ", "")
4✔
2650
            while s.startswith("(") and s.endswith(")"):
4✔
2651
                # Only strip if balanced parens
2652
                inner = s[1:-1]
4✔
2653
                depth = 0
4✔
2654
                balanced = True
4✔
2655
                for c in inner:
4✔
2656
                    if c == "(":
4✔
2657
                        depth += 1
×
2658
                    elif c == ")":
4✔
2659
                        depth -= 1
×
2660
                        if depth < 0:
×
2661
                            balanced = False
×
2662
                            break
×
2663
                if balanced and depth == 0:
4✔
2664
                    s = inner
4✔
2665
                else:
2666
                    break
×
2667
            return s
4✔
2668

2669
        c_strides = self._compute_strides(shape, "C")
4✔
2670
        if all(
4✔
2671
            normalize(str(a)) == normalize(str(b)) for a, b in zip(strides, c_strides)
2672
        ):
2673
            return True
4✔
NEW
2674
        f_strides = self._compute_strides(shape, "F")
×
NEW
2675
        return all(
×
2676
            normalize(str(a)) == normalize(str(b)) for a, b in zip(strides, f_strides)
2677
        )
2678

2679
    def _create_array_temp(
4✔
2680
        self,
2681
        shape,
2682
        dtype,
2683
        zero_init=False,
2684
        ones_init=False,
2685
        shapes_runtime=None,
2686
        strides=None,
2687
    ):
2688
        """Create a temporary array."""
2689
        tmp_name = f"_tmp_{self._get_unique_id()}"
4✔
2690

2691
        # Handle 0-dimensional arrays as scalars
2692
        if not shape or (len(shape) == 0):
4✔
2693
            self.builder.add_container(tmp_name, dtype, False)
4✔
2694
            self.container_table[tmp_name] = dtype
4✔
2695
            self.tensor_table[tmp_name] = Tensor(dtype, [])
4✔
2696

2697
            if zero_init:
4✔
NEW
2698
                self.builder.add_assignment(
×
2699
                    tmp_name,
2700
                    "0.0" if dtype.primitive_type == PrimitiveType.Double else "0",
2701
                )
2702
            elif ones_init:
4✔
NEW
2703
                self.builder.add_assignment(
×
2704
                    tmp_name,
2705
                    "1.0" if dtype.primitive_type == PrimitiveType.Double else "1",
2706
                )
2707

2708
            return tmp_name
4✔
2709

2710
        # Calculate size - wrap each dimension in parentheses to ensure correct
2711
        # parsing when dimensions are expressions like "-2 + _s0"
2712
        size_str = "1"
4✔
2713
        for dim in shape:
4✔
2714
            size_str = f"({size_str} * ({dim}))"
4✔
2715

2716
        element_size = self.builder.get_sizeof(dtype)
4✔
2717
        total_size = f"({size_str} * {element_size})"
4✔
2718

2719
        # Use provided strides or compute C-order strides
2720
        if strides is None:
4✔
2721
            strides = self._compute_strides(shape, "C")
4✔
2722

2723
        # Create pointer
2724
        ptr_type = Pointer(dtype)
4✔
2725
        self.builder.add_container(tmp_name, ptr_type, False)
4✔
2726
        self.container_table[tmp_name] = ptr_type
4✔
2727
        tensor_entry = Tensor(dtype, shape, strides, "0")
4✔
2728
        if shapes_runtime is not None:
4✔
2729
            self.shapes_runtime_info[tmp_name] = shapes_runtime
4✔
2730
        self.tensor_table[tmp_name] = tensor_entry
4✔
2731

2732
        # Try to hoist allocation to function entry
2733
        init_type = (
4✔
2734
            ManagedMemoryHandler.INIT_ZERO
2735
            if zero_init
2736
            else ManagedMemoryHandler.INIT_NONE
2737
        )
2738
        if not ones_init and self.memory_handler.allocate(
4✔
2739
            tmp_name, ptr_type, total_size, init=init_type
2740
        ):
2741
            pass  # Allocation registered for hoisting
4✔
2742
        else:
2743
            # Emit allocation immediately (size depends on loop variables or needs loop init)
2744
            self._emit_malloc(
4✔
2745
                tmp_name, total_size, ptr_type, zero_init, ones_init, size_str, dtype
2746
            )
2747

2748
        return tmp_name
4✔
2749

2750
    def _emit_malloc(
4✔
2751
        self, tmp_name, total_size, ptr_type, zero_init, ones_init, size_str, dtype
2752
    ):
2753
        """Emit malloc and optional initialization for a temporary array."""
2754
        block1 = self.builder.add_block()
4✔
2755
        t_malloc = self.builder.add_malloc(block1, total_size)
4✔
2756
        t_ptr1 = self.builder.add_access(block1, tmp_name)
4✔
2757
        self.builder.add_memlet(block1, t_malloc, "_ret", t_ptr1, "void", "", ptr_type)
4✔
2758

2759
        if zero_init:
4✔
2760
            block2 = self.builder.add_block()
4✔
2761
            t_memset = self.builder.add_memset(block2, "0", total_size)
4✔
2762
            t_ptr2 = self.builder.add_access(block2, tmp_name)
4✔
2763
            self.builder.add_memlet(
4✔
2764
                block2, t_ptr2, "void", t_memset, "_ptr", "", ptr_type
2765
            )
2766
        elif ones_init:
4✔
2767
            loop_var = f"_i_{self._get_unique_id()}"
4✔
2768
            if not self.builder.exists(loop_var):
4✔
2769
                self.builder.add_container(loop_var, Scalar(PrimitiveType.Int64), False)
4✔
2770
                self.container_table[loop_var] = Scalar(PrimitiveType.Int64)
4✔
2771

2772
            self.builder.begin_for(loop_var, "0", size_str, "1")
4✔
2773

2774
            val = "1.0"
4✔
2775
            if dtype.primitive_type in [
4✔
2776
                PrimitiveType.Int64,
2777
                PrimitiveType.Int32,
2778
                PrimitiveType.Int8,
2779
                PrimitiveType.Int16,
2780
                PrimitiveType.UInt64,
2781
                PrimitiveType.UInt32,
2782
                PrimitiveType.UInt8,
2783
                PrimitiveType.UInt16,
2784
            ]:
2785
                val = "1"
4✔
2786

2787
            block_assign = self.builder.add_block()
4✔
2788
            t_const = self.builder.add_constant(block_assign, val, dtype)
4✔
2789
            t_arr = self.builder.add_access(block_assign, tmp_name)
4✔
2790

2791
            t_task = self.builder.add_tasklet(
4✔
2792
                block_assign, TaskletCode.assign, ["_in"], ["_out"]
2793
            )
2794
            self.builder.add_memlet(
4✔
2795
                block_assign, t_const, "void", t_task, "_in", "", dtype
2796
            )
2797
            self.builder.add_memlet(
4✔
2798
                block_assign, t_task, "_out", t_arr, "void", loop_var
2799
            )
2800

2801
            self.builder.end_for()
4✔
2802

2803
    def _compute_linear_index(self, indices, shapes, array_name, ndim):
4✔
2804
        """Compute linear index from multi-dimensional indices.
2805

2806
        Uses strides from tensor_table if available (supporting F-order arrays),
2807
        otherwise falls back to computing strides assuming C-order.
2808
        """
2809
        if ndim == 0:
×
2810
            return "0"
×
2811

2812
        # Try to get strides from tensor_table
2813
        strides = None
×
2814
        if array_name in self.tensor_table:
×
2815
            tensor_info = self.tensor_table[array_name]
×
2816
            if hasattr(tensor_info, "strides") and tensor_info.strides:
×
2817
                strides = tensor_info.strides
×
2818

2819
        if strides and len(strides) == ndim:
×
2820
            # Use explicit strides from tensor_table
2821
            linear_index = ""
×
2822
            for i in range(ndim):
×
2823
                stride = strides[i]
×
2824
                if stride == "1":
×
2825
                    term = str(indices[i])
×
2826
                else:
2827
                    term = f"(({indices[i]}) * ({stride}))"
×
2828

2829
                if i == 0:
×
2830
                    linear_index = term
×
2831
                else:
2832
                    linear_index = f"({linear_index} + {term})"
×
2833
            return linear_index
×
2834
        else:
2835
            # Fall back to C-order (row-major) stride computation
2836
            linear_index = ""
×
2837
            for i in range(ndim):
×
2838
                term = str(indices[i])
×
2839
                for j in range(i + 1, ndim):
×
2840
                    shape_val = (
×
2841
                        shapes[j] if j < len(shapes) else f"_{array_name}_shape_{j}"
2842
                    )
2843
                    term = f"(({term}) * {shape_val})"
×
2844

2845
                if i == 0:
×
2846
                    linear_index = term
×
2847
                else:
2848
                    linear_index = f"({linear_index} + {term})"
×
2849

2850
            return linear_index
×
2851

2852
    def _compute_broadcast_shape(self, shape_a, shape_b):
4✔
2853
        """Compute the broadcast output shape following NumPy broadcasting rules."""
2854
        if not shape_a:
4✔
2855
            return shape_b
4✔
2856
        if not shape_b:
4✔
2857
            return shape_a
4✔
2858

2859
        max_ndim = max(len(shape_a), len(shape_b))
4✔
2860
        padded_a = ["1"] * (max_ndim - len(shape_a)) + [str(s) for s in shape_a]
4✔
2861
        padded_b = ["1"] * (max_ndim - len(shape_b)) + [str(s) for s in shape_b]
4✔
2862

2863
        result = []
4✔
2864
        for a, b in zip(padded_a, padded_b):
4✔
2865
            if a == "1":
4✔
2866
                result.append(b)
4✔
2867
            elif b == "1":
4✔
2868
                result.append(a)
4✔
2869
            elif a == b:
4✔
2870
                result.append(a)
4✔
2871
            else:
2872
                result.append(a)
4✔
2873

2874
        return result
4✔
2875

2876
    def _needs_broadcast(self, input_shape, output_shape):
4✔
2877
        """Check if input shape needs broadcasting to match output shape."""
2878
        if len(input_shape) != len(output_shape):
4✔
2879
            return True
4✔
2880
        for in_dim, out_dim in zip(input_shape, output_shape):
4✔
2881
            if str(in_dim) != str(out_dim):
4✔
2882
                return True
4✔
2883
        return False
4✔
2884

2885
    def _compute_broadcast_strides(self, input_shape, input_strides, output_shape):
4✔
2886
        """Compute strides for broadcasting input to output shape.
2887

2888
        For broadcast dimensions (size 1), stride is set to 0 so the same
2889
        value is repeated. This enables stride-based broadcasting without copying.
2890
        """
2891
        # Pad input shape and strides on the left to match output ndim
2892
        ndim_diff = len(output_shape) - len(input_shape)
4✔
2893
        padded_shape = ["1"] * ndim_diff + [str(s) for s in input_shape]
4✔
2894
        padded_strides = ["0"] * ndim_diff + [str(s) for s in input_strides]
4✔
2895

2896
        broadcast_strides = []
4✔
2897
        for in_dim, in_stride, out_dim in zip(
4✔
2898
            padded_shape, padded_strides, output_shape
2899
        ):
2900
            # Only use stride 0 when input dimension is exactly "1" (broadcast case).
2901
            # For other cases (including symbolic dimensions that may be equal at runtime),
2902
            # keep the original stride.
2903
            if str(in_dim) == "1" and str(out_dim) != "1":
4✔
2904
                # Broadcast dimension: use stride 0
2905
                broadcast_strides.append("0")
4✔
2906
            else:
2907
                # Non-broadcast dimension or potentially equal symbolic dimensions:
2908
                # keep original stride
2909
                broadcast_strides.append(in_stride)
4✔
2910

2911
        return broadcast_strides
4✔
2912

2913
    def _shape_to_runtime_expr(self, shape_node):
4✔
2914
        """Convert a shape expression AST node to a runtime-evaluable string."""
2915
        if isinstance(shape_node, ast.Constant):
4✔
2916
            return str(shape_node.value)
4✔
2917
        elif isinstance(shape_node, ast.Name):
4✔
2918
            return shape_node.id
4✔
2919
        elif isinstance(shape_node, ast.BinOp):
4✔
2920
            left = self._shape_to_runtime_expr(shape_node.left)
4✔
2921
            right = self._shape_to_runtime_expr(shape_node.right)
4✔
2922
            op = self.visit(shape_node.op)
4✔
2923
            return f"({left} {op} {right})"
4✔
2924
        elif isinstance(shape_node, ast.UnaryOp):
4✔
2925
            operand = self._shape_to_runtime_expr(shape_node.operand)
×
2926
            if isinstance(shape_node.op, ast.USub):
×
2927
                return f"(-{operand})"
×
2928
            elif isinstance(shape_node.op, ast.UAdd):
×
2929
                return operand
×
2930
            else:
2931
                return self.visit(shape_node)
×
2932
        elif isinstance(shape_node, ast.Subscript):
4✔
2933
            val = shape_node.value
4✔
2934
            if isinstance(val, ast.Attribute) and val.attr == "shape":
4✔
2935
                if isinstance(val.value, ast.Name):
4✔
2936
                    arr_name = val.value.id
4✔
2937
                    if isinstance(shape_node.slice, ast.Constant):
4✔
2938
                        idx = shape_node.slice.value
4✔
2939
                        if arr_name in self.tensor_table:
4✔
2940
                            shapes = self.tensor_table[arr_name].shape
4✔
2941
                            if idx < len(shapes):
4✔
2942
                                return shapes[idx]
4✔
2943
                        return f"{arr_name}.shape[{idx}]"
×
2944
            return self.visit(shape_node)
×
2945
        elif isinstance(shape_node, ast.Tuple):
×
2946
            return [self._shape_to_runtime_expr(elt) for elt in shape_node.elts]
×
2947
        elif isinstance(shape_node, ast.List):
×
2948
            return [self._shape_to_runtime_expr(elt) for elt in shape_node.elts]
×
2949
        else:
2950
            return self.visit(shape_node)
×
2951

2952
    # ========== Type Casting Helpers ==========
2953

2954
    def _cast_scalar(self, name, target_type):
4✔
2955
        """
2956
        Cast a scalar value to a different type using an assign tasklet.
2957

2958
        The backend detects the specific conversion (fpext, sitofp, etc.)
2959
        from the type mismatch between input and output.
2960

2961
        Args:
2962
            name: Name of the scalar to cast
2963
            target_type: Target element type (Scalar)
2964

2965
        Returns:
2966
            Name of the casted scalar (or original if no cast needed)
2967
        """
2968
        current_type = self._ev._element_type(name)
4✔
2969
        if current_type.primitive_type == target_type.primitive_type:
4✔
2970
            return name
4✔
2971

2972
        cast_name = f"_cast_{self._get_unique_id()}"
4✔
2973
        self.builder.add_container(cast_name, target_type, False)
4✔
2974
        self.container_table[cast_name] = target_type
4✔
2975
        self.tensor_table[cast_name] = Tensor(target_type, [])
4✔
2976

2977
        block = self.builder.add_block()
4✔
2978
        t_src, src_sub = self._add_read(block, name)
4✔
2979
        t_dst = self.builder.add_access(block, cast_name)
4✔
2980
        t_task = self.builder.add_tasklet(block, TaskletCode.assign, ["_in"], ["_out"])
4✔
2981
        self.builder.add_memlet(block, t_src, "void", t_task, "_in", src_sub)
4✔
2982
        self.builder.add_memlet(block, t_task, "_out", t_dst, "void", "")
4✔
2983

2984
        return cast_name
4✔
2985

2986
    def _cast_array(self, name, target_type):
4✔
2987
        """
2988
        Cast an array to a different element type using the CastNode library node.
2989

2990
        This is an elementwise cast operation that creates a new array.
2991
        Reuses the same infrastructure as handle_numpy_astype().
2992

2993
        Args:
2994
            name: Name of the array to cast
2995
            target_type: Target element type (Scalar)
2996

2997
        Returns:
2998
            Name of the casted array (or original if no cast needed)
2999
        """
3000
        current_type = self._ev._element_type(name)
4✔
3001
        if current_type.primitive_type == target_type.primitive_type:
4✔
3002
            return name
4✔
3003

3004
        src_tensor = self.tensor_table[name]
4✔
3005

3006
        # Create output array with same shape but new dtype
3007
        # Preserve strides order (C or F contiguous)
3008
        output_strides = self._get_contiguous_output_strides(
4✔
3009
            src_tensor.shape, src_tensor.strides
3010
        )
3011
        tmp_name = self._create_array_temp(
4✔
3012
            src_tensor.shape, target_type, strides=output_strides
3013
        )
3014
        tmp_tensor = self.tensor_table[tmp_name]
4✔
3015

3016
        # Use existing cast infrastructure (CastNode)
3017
        self.builder.add_cast_op(name, src_tensor, tmp_name, tmp_tensor)
4✔
3018

3019
        return tmp_name
4✔
3020

3021
    def _cast_to_type(self, name, target_type):
4✔
3022
        """
3023
        Cast an operand (scalar or array) to the target type.
3024

3025
        Dispatches to _cast_scalar or _cast_array based on whether
3026
        the operand is in tensor_table (includes 0-d arrays).
3027

3028
        Args:
3029
            name: Name of the operand to cast
3030
            target_type: Target element type (Scalar)
3031

3032
        Returns:
3033
            Name of the casted operand (or original if no cast needed)
3034
        """
3035
        if name in self.tensor_table:
4✔
3036
            # In tensor_table means it's an array (including 0-d arrays)
3037
            return self._cast_array(name, target_type)
4✔
3038
        else:
3039
            # Not in tensor_table means it's a literal or Python scalar
3040
            return self._cast_scalar(name, target_type)
4✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc