• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

daisytuner / docc / 23514597273

24 Mar 2026 10:08PM UTC coverage: 64.344% (+0.05%) from 64.295%
23514597273

Pull #611

github

web-flow
Merge 89f7b18f0 into e56781552
Pull Request #611: Updates rules to handle casts between numpy arrays and scalars

79 of 85 new or added lines in 3 files covered. (92.94%)

2 existing lines in 1 file now uncovered.

26715 of 41519 relevant lines covered (64.34%)

405.72 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

78.51
/python/docc/python/functions/numpy.py
1
import ast
4✔
2
from docc.sdfg import (
4✔
3
    Scalar,
4
    PrimitiveType,
5
    Pointer,
6
    DebugInfo,
7
    TaskletCode,
8
    CMathFunction,
9
    Tensor,
10
)
11
from docc.python.types import (
4✔
12
    element_type_from_ast_node,
13
    promote_element_types,
14
    numpy_promote_types,
15
)
16
from docc.python.ast_utils import get_debug_info
4✔
17
from docc.python.memory import ManagedMemoryHandler
4✔
18

19

20
class NumPyHandler:
4✔
21
    """
22
    Unified handler for NumPy operations including:
23
    - Array creation (empty, zeros, ones, eye, etc.)
24
    - Elementwise operations (add, subtract, multiply, etc.)
25
    - Linear algebra (matmul, dot, outer, gemm)
26
    - Array manipulation (transpose)
27
    - Reductions (sum, max, min, mean, std)
28
    """
29

30
    def __init__(self, expression_visitor):
4✔
31
        self._ev = expression_visitor
4✔
32
        self._unique_counter = 0
4✔
33
        self.function_handlers = {
4✔
34
            "empty": self._handle_numpy_alloc,
35
            "empty_like": self._handle_numpy_empty_like,
36
            "zeros": self._handle_numpy_alloc,
37
            "zeros_like": self._handle_numpy_zeros_like,
38
            "ones": self._handle_numpy_alloc,
39
            "ndarray": self._handle_numpy_alloc,
40
            "eye": self._handle_numpy_eye,
41
            "add": self._handle_numpy_binary_op,
42
            "subtract": self._handle_numpy_binary_op,
43
            "multiply": self._handle_numpy_binary_op,
44
            "divide": self._handle_numpy_binary_op,
45
            "power": self._handle_numpy_binary_op,
46
            "exp": self._handle_numpy_unary_op,
47
            "abs": self._handle_numpy_unary_op,
48
            "absolute": self._handle_numpy_unary_op,
49
            "sqrt": self._handle_numpy_unary_op,
50
            "tanh": self._handle_numpy_unary_op,
51
            "sum": self._handle_numpy_reduce,
52
            "max": self._handle_numpy_reduce,
53
            "min": self._handle_numpy_reduce,
54
            "mean": self._handle_numpy_reduce,
55
            "std": self._handle_numpy_reduce,
56
            "matmul": self._handle_numpy_matmul,
57
            "dot": self._handle_numpy_matmul,
58
            "matvec": self._handle_numpy_matmul,
59
            "outer": self._handle_numpy_outer,
60
            "minimum": self._handle_numpy_binary_op,
61
            "maximum": self._handle_numpy_binary_op,
62
            "where": self._handle_numpy_where,
63
            "clip": self._handle_numpy_clip,
64
            "transpose": self._handle_numpy_transpose,
65
            "flip": self._handle_numpy_flip,
66
            "fliplr": self._handle_numpy_fliplr,
67
            "flipud": self._handle_numpy_flipud,
68
            "reshape": self._handle_numpy_reshape,
69
        }
70

71
    # Expose parent properties for convenience
72
    @property
4✔
73
    def tensor_table(self):
4✔
74
        return self._ev.tensor_table
4✔
75

76
    @property
4✔
77
    def builder(self):
4✔
78
        return self._ev.builder
4✔
79

80
    @property
4✔
81
    def container_table(self):
4✔
82
        return self._ev.container_table
4✔
83

84
    @property
4✔
85
    def globals_dict(self):
4✔
86
        return self._ev.globals_dict
×
87

88
    @property
4✔
89
    def shapes_runtime_info(self):
4✔
90
        return self._ev.shapes_runtime_info
4✔
91

92
    @property
4✔
93
    def memory_handler(self):
4✔
94
        """Access the memory handler owned by the parser."""
95
        return self._ev.memory_handler
4✔
96

97
    def _get_unique_id(self):
4✔
98
        return self._ev._get_unique_id()
4✔
99

100
    def _add_read(self, block, expr_str, debug_info=None):
4✔
101
        return self._ev._add_read(block, expr_str, debug_info)
4✔
102

103
    def _is_int(self, operand):
4✔
104
        return self._ev._is_int(operand)
4✔
105

106
    def visit(self, node):
4✔
107
        return self._ev.visit(node)
4✔
108

109
    # ========== Linear Algebra Helper Methods (from LinearAlgebraHandler) ==========
110

111
    def parse_arg(self, node):
4✔
112
        """Parse an array argument, returning (name, start_indices, slice_shape, indices).
113

114
        Returns None for 0-d arrays since they are scalars, not valid array operands
115
        for linear algebra operations.
116
        """
117
        if isinstance(node, ast.Name):
4✔
118
            if node.id in self.tensor_table:
4✔
119
                shape = self.tensor_table[node.id].shape
4✔
120
                # Reject 0-d arrays (scalars) - not valid for linalg ops
121
                if len(shape) == 0:
4✔
NEW
122
                    return None, None, None, None
×
123
                return node.id, [], shape, []
4✔
124
        elif isinstance(node, ast.Subscript):
4✔
125
            if isinstance(node.value, ast.Name) and node.value.id in self.tensor_table:
4✔
126
                name = node.value.id
4✔
127
                indices = []
4✔
128
                if isinstance(node.slice, ast.Tuple):
4✔
129
                    indices = node.slice.elts
4✔
130
                else:
131
                    indices = [node.slice]
4✔
132

133
                start_indices = []
4✔
134
                slice_shape = []
4✔
135

136
                for i, idx in enumerate(indices):
4✔
137
                    if isinstance(idx, ast.Slice):
4✔
138
                        start = "0"
4✔
139
                        if idx.lower:
4✔
140
                            start = self._ev.visit(idx.lower)
4✔
141
                        start_indices.append(start)
4✔
142

143
                        shapes = self.tensor_table[name].shape
4✔
144
                        dim_size = (
4✔
145
                            shapes[i] if i < len(shapes) else f"_{name}_shape_{i}"
146
                        )
147
                        stop = dim_size
4✔
148
                        if idx.upper:
4✔
149
                            stop = self._ev.visit(idx.upper)
4✔
150

151
                        size = f"({stop} - {start})"
4✔
152
                        slice_shape.append(size)
4✔
153
                    else:
154
                        if isinstance(idx, ast.Name) and idx.id in self.tensor_table:
4✔
155
                            # This is an array index (gather operation)
156
                            return None, None, None, None
×
157
                        val = self._ev.visit(idx)
4✔
158
                        start_indices.append(val)
4✔
159

160
                return name, start_indices, slice_shape, indices
4✔
161

162
        return None, None, None, None
4✔
163

164
    def flatten_subset(self, name, start_indices):
4✔
165
        """Convert multi-dimensional start indices to a flattened linear offset."""
166
        if not start_indices:
4✔
167
            return []
4✔
168
        info = self.tensor_table[name]
4✔
169
        shapes = info.shape
4✔
170
        ndim = len(info.shape)
4✔
171

172
        if len(start_indices) != ndim:
4✔
173
            return start_indices
4✔
174

175
        strides = []
4✔
176
        current_stride = "1"
4✔
177
        strides.append(current_stride)
4✔
178
        for i in range(ndim - 1, 0, -1):
4✔
179
            dim_size = shapes[i]
4✔
180
            if current_stride == "1":
4✔
181
                current_stride = str(dim_size)
4✔
182
            else:
183
                current_stride = f"({current_stride} * {dim_size})"
4✔
184
            strides.append(current_stride)
4✔
185
        strides = list(reversed(strides))
4✔
186

187
        offset = "0"
4✔
188
        for i in range(ndim):
4✔
189
            idx = start_indices[i]
4✔
190
            stride = strides[i]
4✔
191
            term = f"({idx} * {stride})" if stride != "1" else idx
4✔
192
            if offset == "0":
4✔
193
                offset = term
4✔
194
            else:
195
                offset = f"({offset} + {term})"
4✔
196

197
        return [offset]
4✔
198

199
    def is_gemm(self, node):
4✔
200
        """Check if a node represents a GEMM operation (matrix multiplication)."""
201
        if isinstance(node, ast.BinOp) and isinstance(node.op, ast.MatMult):
4✔
202
            return True
4✔
203
        if isinstance(node, ast.Call):
4✔
204
            if isinstance(node.func, ast.Attribute) and node.func.attr == "dot":
4✔
205
                return True
×
206
            if isinstance(node.func, ast.Name) and node.func.id == "dot":
4✔
207
                return True
×
208
            if isinstance(node.func, ast.Attribute) and node.func.attr == "matmul":
4✔
209
                return True
×
210
            if isinstance(node.func, ast.Name) and node.func.id == "matmul":
4✔
211
                return True
×
212
        if isinstance(node, ast.BinOp) and isinstance(node.op, ast.Add):
4✔
213
            return self.is_gemm(node.left) or self.is_gemm(node.right)
4✔
214
        if isinstance(node, ast.BinOp) and isinstance(node.op, ast.Mult):
4✔
215
            return self.is_gemm(node.left) or self.is_gemm(node.right)
4✔
216
        return False
4✔
217

218
    def _is_stride_1(self, name, indices):
4✔
219
        """Check if the sliced dimension has stride 1 (contiguous access)."""
220
        if name not in self.tensor_table:
4✔
221
            return True
×
222
        info = self.tensor_table[name]
4✔
223
        ndim = len(info.shape)
4✔
224

225
        if not indices:
4✔
226
            return True
4✔
227

228
        sliced_dim = -1
×
229
        for i, idx in enumerate(indices):
×
230
            if isinstance(idx, ast.Slice):
×
231
                sliced_dim = i
×
232
                break
×
233

234
        if sliced_dim == -1:
×
235
            if len(indices) < ndim:
×
236
                sliced_dim = ndim - 1
×
237
            else:
238
                return True
×
239

240
        return sliced_dim == ndim - 1
×
241

242
    def _is_target(self, node, target_name):
4✔
243
        """Check if node refers to the target."""
244
        if isinstance(target_name, ast.AST):
4✔
245
            return self._ev.visit(node) == self._ev.visit(target_name)
4✔
246

247
        if isinstance(node, ast.Name) and node.id == target_name:
4✔
248
            return True
×
249
        if isinstance(node, ast.Subscript):
4✔
250
            if isinstance(node.value, ast.Name) and node.value.id == target_name:
4✔
251
                return True
4✔
252
        return False
4✔
253

254
    def _is_dot_call(self, node):
4✔
255
        """Check if node is a dot product call."""
256
        if isinstance(node, ast.Call):
4✔
257
            if isinstance(node.func, ast.Attribute) and node.func.attr == "dot":
×
258
                return True
×
259
            if isinstance(node.func, ast.Name) and node.func.id == "dot":
×
260
                return True
×
261
        if isinstance(node, ast.BinOp) and isinstance(node.op, ast.MatMult):
4✔
262
            return True
4✔
263
        return False
4✔
264

265
    def handle_gemm(self, target, value_node):
4✔
266
        """Handle GEMM (General Matrix Multiply) operations: C = alpha * A @ B + beta * C."""
267
        target_name = None
4✔
268
        target_subset = []
4✔
269

270
        if isinstance(target, str):
4✔
271
            target_name = target
4✔
272
        elif isinstance(target, ast.Name):
4✔
273
            target_name = target.id
4✔
274
        elif isinstance(target, ast.Subscript):
4✔
275
            if isinstance(target.value, ast.Name):
4✔
276
                res = self.parse_arg(target)
4✔
277
                if res[0]:
4✔
278
                    target_name = res[0]
4✔
279
                    target_subset = self.flatten_subset(target_name, res[1])
4✔
280
                else:
281
                    target_name = target.value.id
×
282

283
        if not target_name or target_name not in self.tensor_table:
4✔
284
            return False
4✔
285

286
        alpha = "1.0"
4✔
287
        beta = "0.0"
4✔
288
        A = None
4✔
289
        B = None
4✔
290

291
        def extract_factor(node):
4✔
292
            if isinstance(node, ast.BinOp) and isinstance(node.op, ast.Mult):
4✔
293
                if self.is_gemm(node.left):
×
294
                    return node.left, self._ev.visit(node.right)
×
295
                if self.is_gemm(node.right):
×
296
                    return node.right, self._ev.visit(node.left)
×
297

298
                res = self.parse_arg(node.left)
×
299
                if res[0]:
×
300
                    return node.left, self._ev.visit(node.right)
×
301
                res = self.parse_arg(node.right)
×
302
                if res[0]:
×
303
                    return node.right, self._ev.visit(node.left)
×
304
            return node, "1.0"
4✔
305

306
        def parse_term(node):
4✔
307
            if isinstance(node, ast.BinOp) and isinstance(node.op, ast.MatMult):
4✔
308
                l, l_f = extract_factor(node.left)
4✔
309
                r, r_f = extract_factor(node.right)
4✔
310
                f = "1.0"
4✔
311
                if l_f != "1.0":
4✔
312
                    f = l_f
×
313
                if r_f != "1.0":
4✔
314
                    if f == "1.0":
×
315
                        f = r_f
×
316
                    else:
317
                        f = f"({f} * {r_f})"
×
318
                return l, r, f
4✔
319

320
            if isinstance(node, ast.Call):
×
321
                is_gemm_call = False
×
322
                if isinstance(node.func, ast.Attribute) and node.func.attr in [
×
323
                    "dot",
324
                    "matmul",
325
                ]:
326
                    is_gemm_call = True
×
327
                if isinstance(node.func, ast.Name) and node.func.id in [
×
328
                    "dot",
329
                    "matmul",
330
                ]:
331
                    is_gemm_call = True
×
332

333
                if is_gemm_call and len(node.args) == 2:
×
334
                    return node.args[0], node.args[1], "1.0"
×
335

336
            if isinstance(node, ast.BinOp) and isinstance(node.op, ast.Mult):
×
337
                l, r, a = parse_term(node.left)
×
338
                if l:
×
339
                    return l, r, self._ev.visit(node.right)
×
340
                l, r, a = parse_term(node.right)
×
341
                if l:
×
342
                    return l, r, self._ev.visit(node.left)
×
343

344
            return None, None, None
×
345

346
        if isinstance(value_node, ast.BinOp) and isinstance(value_node.op, ast.Add):
4✔
347
            l, r, a = parse_term(value_node.left)
×
348
            if l:
×
349
                A = l
×
350
                B = r
×
351
                alpha = a
×
352
                if isinstance(value_node.right, ast.BinOp) and isinstance(
×
353
                    value_node.right.op, ast.Mult
354
                ):
355
                    if self._is_target(value_node.right.left, target_name):
×
356
                        beta = self._ev.visit(value_node.right.right)
×
357
                    elif self._is_target(value_node.right.right, target_name):
×
358
                        beta = self._ev.visit(value_node.right.left)
×
359
                elif self._is_target(value_node.right, target_name):
×
360
                    beta = "1.0"
×
361
            else:
362
                l, r, a = parse_term(value_node.right)
×
363
                if l:
×
364
                    A = l
×
365
                    B = r
×
366
                    alpha = a
×
367
                    if isinstance(value_node.left, ast.BinOp) and isinstance(
×
368
                        value_node.left.op, ast.Mult
369
                    ):
370
                        if self._is_target(value_node.left.left, target_name):
×
371
                            beta = self._ev.visit(value_node.left.right)
×
372
                        elif self._is_target(value_node.left.right, target_name):
×
373
                            beta = self._ev.visit(value_node.left.left)
×
374
                    elif self._is_target(value_node.left, target_name):
×
375
                        beta = "1.0"
×
376
        else:
377
            l, r, a = parse_term(value_node)
4✔
378
            if l:
4✔
379
                A = l
4✔
380
                B = r
4✔
381
                alpha = a
4✔
382

383
        if A is None or B is None:
4✔
384
            return False
×
385

386
        def get_name_and_trans(node):
4✔
387
            if isinstance(node, ast.Attribute) and node.attr == "T":
4✔
388
                return node.value, True
×
389
            return node, False
4✔
390

391
        A_node, trans_a = get_name_and_trans(A)
4✔
392
        B_node, trans_b = get_name_and_trans(B)
4✔
393

394
        if self.is_gemm(A_node):
4✔
395
            tmp_name = self._ev.visit(A_node)
×
396
            A_node = ast.Name(id=tmp_name)
×
397

398
        if self.is_gemm(B_node):
4✔
399
            tmp_name = self._ev.visit(B_node)
×
400
            B_node = ast.Name(id=tmp_name)
×
401

402
        res_a = self.parse_arg(A_node)
4✔
403
        res_b = self.parse_arg(B_node)
4✔
404

405
        if not res_a[0] or not res_b[0]:
4✔
406
            return False
×
407

408
        A_name, subset_a, shape_a, indices_a = res_a
4✔
409
        B_name, subset_b, shape_b, indices_b = res_b
4✔
410

411
        flat_subset_a = self.flatten_subset(A_name, subset_a)
4✔
412
        flat_subset_b = self.flatten_subset(B_name, subset_b)
4✔
413

414
        def get_ndim(name):
4✔
415
            if name not in self.tensor_table:
4✔
416
                return 1
×
417
            return len(self.tensor_table[name].shape)
4✔
418

419
        if len(shape_a) == 2:
4✔
420
            if not trans_a:
4✔
421
                m = shape_a[0]
4✔
422
                k = shape_a[1]
4✔
423
            else:
424
                m = shape_a[1]
×
425
                k = shape_a[0]
×
426
        else:
427
            m = "1"
×
428
            k = shape_a[0]
×
429
            if self._is_stride_1(A_name, indices_a):
×
430
                if get_ndim(A_name) == 1:
×
431
                    trans_a = True
×
432
                else:
433
                    trans_a = False
×
434
            else:
435
                trans_a = True
×
436

437
        if len(shape_b) == 2:
4✔
438
            if not trans_b:
4✔
439
                n = shape_b[1]
4✔
440
            else:
441
                n = shape_b[0]
×
442
        else:
443
            n = "1"
4✔
444
            if self._is_stride_1(B_name, indices_b):
4✔
445
                if get_ndim(B_name) == 1:
4✔
446
                    trans_b = False
4✔
447
                else:
448
                    trans_b = True
×
449
            else:
450
                trans_b = False
×
451

452
        def get_ld(name):
4✔
453
            if name not in self.tensor_table:
4✔
454
                return ""
×
455
            shapes = self.tensor_table[name].shape
4✔
456
            if len(shapes) >= 2:
4✔
457
                return str(shapes[1])
4✔
458
            return "1"
4✔
459

460
        lda = get_ld(A_name)
4✔
461
        ldb = get_ld(B_name)
4✔
462

463
        ldc = ""
4✔
464
        if target_name:
4✔
465
            if get_ndim(target_name) == 1 and m == "1":
4✔
466
                ldc = n
×
467
            else:
468
                ldc = get_ld(target_name)
4✔
469

470
        self.builder.add_gemm(
4✔
471
            A_name,
472
            B_name,
473
            target_name,
474
            alpha,
475
            beta,
476
            m,
477
            n,
478
            k,
479
            trans_a,
480
            trans_b,
481
            flat_subset_a,
482
            flat_subset_b,
483
            target_subset,
484
            lda,
485
            ldb,
486
            ldc,
487
        )
488
        return True
4✔
489

490
    def handle_dot(self, target, value_node):
4✔
491
        """Handle dot product operations for 1D vectors."""
492
        dot_node = None
4✔
493
        is_accumulate = False
4✔
494

495
        if self._is_dot_call(value_node):
4✔
496
            dot_node = value_node
4✔
497
        elif isinstance(value_node, ast.BinOp) and isinstance(value_node.op, ast.Add):
4✔
498
            if self._is_dot_call(value_node.left):
4✔
499
                dot_node = value_node.left
4✔
500
                if self._is_target(value_node.right, target):
4✔
501
                    is_accumulate = True
×
502
            elif self._is_dot_call(value_node.right):
×
503
                dot_node = value_node.right
×
504
                if self._is_target(value_node.left, target):
×
505
                    is_accumulate = True
×
506

507
        if not dot_node:
4✔
508
            return False
×
509

510
        arg0 = None
4✔
511
        arg1 = None
4✔
512

513
        if isinstance(dot_node, ast.Call):
4✔
514
            args = dot_node.args
×
515
            if len(args) != 2:
×
516
                return False
×
517
            arg0 = args[0]
×
518
            arg1 = args[1]
×
519
        elif isinstance(dot_node, ast.BinOp) and isinstance(dot_node.op, ast.MatMult):
4✔
520
            arg0 = dot_node.left
4✔
521
            arg1 = dot_node.right
4✔
522

523
        res_a = self.parse_arg(arg0)
4✔
524
        res_b = self.parse_arg(arg1)
4✔
525

526
        if not res_a[0] or not res_b[0]:
4✔
527
            return False
×
528

529
        name_a, subset_a, shape_a, indices_a = res_a
4✔
530
        name_b, subset_b, shape_b, indices_b = res_b
4✔
531

532
        if len(shape_a) != 1 or len(shape_b) != 1:
4✔
533
            return False
4✔
534

535
        n = shape_a[0]
4✔
536

537
        def get_stride(name, indices):
4✔
538
            if not indices:
4✔
539
                return "1"
4✔
540
            info = self.tensor_table[name]
4✔
541
            shapes = info.shape
4✔
542
            ndim = len(info.shape)
4✔
543

544
            sliced_dim = -1
4✔
545
            for i, idx in enumerate(indices):
4✔
546
                if isinstance(idx, ast.Slice):
4✔
547
                    sliced_dim = i
4✔
548
                    break
4✔
549

550
            if sliced_dim == -1:
4✔
551
                return "1"
×
552

553
            stride = "1"
4✔
554
            for i in range(sliced_dim + 1, ndim):
4✔
555
                dim_size = shapes[i] if i < len(shapes) else f"_{name}_shape_{i}"
×
556
                if stride == "1":
×
557
                    stride = str(dim_size)
×
558
                else:
559
                    stride = f"({stride} * {dim_size})"
×
560
            return stride
4✔
561

562
        incx = get_stride(name_a, indices_a)
4✔
563
        incy = get_stride(name_b, indices_b)
4✔
564

565
        flat_subset_a = self.flatten_subset(name_a, subset_a)
4✔
566
        flat_subset_b = self.flatten_subset(name_b, subset_b)
4✔
567

568
        tmp_res = f"_dot_res_{self._get_unique_id()}"
4✔
569
        self.builder.add_container(tmp_res, Scalar(PrimitiveType.Double), False)
4✔
570
        block = self.builder.add_block()
4✔
571
        constant = self.builder.add_constant(block, "0.0", Scalar(PrimitiveType.Double))
4✔
572
        tasklet = self.builder.add_tasklet(block, TaskletCode.assign, ["_in"], ["_out"])
4✔
573
        self.builder.add_memlet(
4✔
574
            block, constant, "", tasklet, "_in", "", Scalar(PrimitiveType.Double)
575
        )
576
        access = self.builder.add_access(block, tmp_res)
4✔
577
        self.builder.add_memlet(
4✔
578
            block, tasklet, "_out", access, "", "", Scalar(PrimitiveType.Double)
579
        )
580

581
        self.container_table[tmp_res] = Scalar(PrimitiveType.Double)
4✔
582

583
        self.builder.add_dot(
4✔
584
            name_a, name_b, tmp_res, n, incx, incy, flat_subset_a, flat_subset_b
585
        )
586

587
        target_str = target if isinstance(target, str) else self._ev.visit(target)
4✔
588

589
        if not self.builder.exists(target_str):
4✔
590
            self.builder.add_container(target_str, Scalar(PrimitiveType.Double), False)
×
591
            self.container_table[target_str] = Scalar(PrimitiveType.Double)
×
592

593
        if is_accumulate:
4✔
594
            self.builder.add_assignment(target_str, f"{target_str} + {tmp_res}")
×
595
        else:
596
            self.builder.add_assignment(target_str, tmp_res)
4✔
597

598
        return True
4✔
599

600
    def is_outer(self, node):
4✔
601
        """Check if a node represents an outer product operation."""
602
        if isinstance(node, ast.Call):
4✔
603
            if isinstance(node.func, ast.Attribute) and node.func.attr == "outer":
4✔
604
                return True
4✔
605
            if isinstance(node.func, ast.Name) and node.func.id == "outer":
4✔
606
                return True
×
607
        if isinstance(node, ast.BinOp) and isinstance(node.op, ast.Add):
4✔
608
            return self.is_outer(node.left) or self.is_outer(node.right)
4✔
609
        return False
4✔
610

611
    def handle_outer(self, target, value_node):
4✔
612
        """Handle outer product operations."""
613
        target_name = None
4✔
614
        target_subset = []
4✔
615

616
        if isinstance(target, str):
4✔
617
            target_name = target
4✔
618
        elif isinstance(target, ast.Name):
4✔
619
            target_name = target.id
4✔
620
        elif isinstance(target, ast.Subscript):
4✔
621
            res = self.parse_arg(target)
4✔
622
            if res[0]:
4✔
623
                target_name = res[0]
4✔
624
                target_subset = self.flatten_subset(target_name, res[1])
4✔
625
            else:
626
                if isinstance(target.value, ast.Name):
×
627
                    target_name = target.value.id
×
628

629
        if not target_name:
4✔
630
            return False
×
631

632
        outer_calls = []
4✔
633
        target_found = False
4✔
634
        terms = []
4✔
635

636
        def collect_terms(node):
4✔
637
            if isinstance(node, ast.BinOp) and isinstance(node.op, ast.Add):
4✔
638
                collect_terms(node.left)
4✔
639
                collect_terms(node.right)
4✔
640
            else:
641
                terms.append(node)
4✔
642

643
        collect_terms(value_node)
4✔
644

645
        for term in terms:
4✔
646
            if self._is_target(term, target_name):
4✔
647
                target_found = True
4✔
648
            elif isinstance(term, ast.Call) and (
4✔
649
                (isinstance(term.func, ast.Attribute) and term.func.attr == "outer")
650
                or (isinstance(term.func, ast.Name) and term.func.id == "outer")
651
            ):
652
                if len(term.args) != 2:
4✔
653
                    return False
×
654
                outer_calls.append(term)
4✔
655
            else:
656
                return False
×
657

658
        if not outer_calls:
4✔
659
            return False
×
660

661
        parsed_outers = []
4✔
662
        for outer_node in outer_calls:
4✔
663
            arg0 = outer_node.args[0]
4✔
664
            arg1 = outer_node.args[1]
4✔
665

666
            res_a = self.parse_arg(arg0)
4✔
667
            res_b = self.parse_arg(arg1)
4✔
668

669
            if not res_a[0] or not res_b[0]:
4✔
670
                return False
×
671

672
            parsed_outers.append((res_a, res_b))
4✔
673

674
        alpha = "1.0"
4✔
675
        beta = "1.0" if target_found else "0.0"
4✔
676

677
        def get_flattened_size(name, indices, shapes):
4✔
678
            size_expr = "1"
4✔
679
            for s in shapes:
4✔
680
                if size_expr == "1":
4✔
681
                    size_expr = str(s)
4✔
682
                else:
683
                    size_expr = f"({size_expr} * {str(s)})"
×
684
            return size_expr
4✔
685

686
        def get_ld_2d(name):
4✔
687
            if name in self.tensor_table:
4✔
688
                shapes = self.tensor_table[name].shape
4✔
689
                if len(shapes) >= 2:
4✔
690
                    return str(shapes[1])
4✔
691
            return "1"
4✔
692

693
        ldc = get_ld_2d(target_name)
4✔
694

695
        for res_a, res_b in parsed_outers:
4✔
696
            name_a, subset_a, shape_a, indices_a = res_a
4✔
697
            name_b, subset_b, shape_b, indices_b = res_b
4✔
698

699
            m = get_flattened_size(name_a, indices_a, shape_a)
4✔
700
            n = get_flattened_size(name_b, indices_b, shape_b)
4✔
701
            k = "1"
4✔
702

703
            trans_a = False
4✔
704
            trans_b = True
4✔
705

706
            flat_subset_a = self.flatten_subset(name_a, subset_a)
4✔
707
            flat_subset_b = self.flatten_subset(name_b, subset_b)
4✔
708

709
            lda = "1"
4✔
710
            ldb = "1"
4✔
711

712
            self.builder.add_gemm(
4✔
713
                name_a,
714
                name_b,
715
                target_name,
716
                alpha,
717
                beta,
718
                m,
719
                n,
720
                k,
721
                trans_a,
722
                trans_b,
723
                flat_subset_a,
724
                flat_subset_b,
725
                target_subset,
726
                lda,
727
                ldb,
728
                ldc,
729
            )
730
            beta = "1.0"
4✔
731

732
        return True
4✔
733

734
    # ========== Transpose Operations ==========
735

736
    def _parse_perm(self, node):
4✔
737
        """Parse a permutation list or tuple from an AST node."""
738
        if isinstance(node, (ast.List, ast.Tuple)):
4✔
739
            res = []
4✔
740
            for elt in node.elts:
4✔
741
                val = self._ev.visit(elt)
4✔
742
                res.append(int(val))
4✔
743
            return res
4✔
744
        return []
×
745

746
    def is_transpose(self, node):
4✔
747
        """Check if a node represents a transpose operation."""
748
        # Case 1: np.transpose(arr, ...)
749
        if isinstance(node, ast.Call):
4✔
750
            if isinstance(node.func, ast.Attribute) and node.func.attr == "transpose":
4✔
751
                return True
×
752
            if isinstance(node.func, ast.Name) and node.func.id == "transpose":
4✔
753
                return True
×
754

755
        # Case 2: arr.T
756
        if isinstance(node, ast.Attribute) and node.attr == "T":
4✔
757
            return True
4✔
758

759
        return False
4✔
760

761
    def handle_transpose(self, target, value_node):
4✔
762
        """Handle transpose operations including .T and np.transpose()."""
763
        if not self.is_transpose(value_node):
4✔
764
            return False
×
765

766
        input_node = None
4✔
767
        perm = []
4✔
768

769
        if isinstance(value_node, ast.Attribute) and value_node.attr == "T":
4✔
770
            input_node = value_node.value
4✔
771
            perm = []  # Empty means reverse
4✔
772

773
        elif isinstance(value_node, ast.Call):
×
774
            args = value_node.args
×
775
            keywords = value_node.keywords
×
776

777
            is_numpy_func = False
×
778
            if isinstance(value_node.func, ast.Attribute):
×
779
                caller = ""
×
780
                if isinstance(value_node.func.value, ast.Name):
×
781
                    caller = value_node.func.value.id
×
782
                if caller in ["np", "numpy"]:
×
783
                    is_numpy_func = True
×
784
            elif isinstance(value_node.func, ast.Name):
×
785
                is_numpy_func = True
×
786

787
            if is_numpy_func:
×
788
                if len(args) < 1:
×
789
                    return False
×
790
                input_node = args[0]
×
791
                if len(args) > 1:
×
792
                    perm = self._parse_perm(args[1])
×
793
                for kw in keywords:
×
794
                    if kw.arg == "axes":
×
795
                        perm = self._parse_perm(kw.value)
×
796
            else:
797
                if isinstance(value_node.func, ast.Attribute):
×
798
                    input_node = value_node.func.value
×
799
                else:
800
                    return False
×
801
                if len(args) > 0:
×
802
                    perm = self._parse_perm(args[0])
×
803
                for kw in keywords:
×
804
                    if kw.arg == "axes":
×
805
                        perm = self._parse_perm(kw.value)
×
806

807
        input_name = self._ev.visit(input_node)
4✔
808
        if input_name not in self.tensor_table:
4✔
809
            return False
×
810

811
        in_info = self.tensor_table[input_name]
4✔
812
        in_shape = in_info.shape
4✔
813
        in_strings = [str(s) for s in in_shape]
4✔
814

815
        if not perm:
4✔
816
            perm = list(range(len(in_shape)))[::-1]
4✔
817

818
        out_shape = [in_strings[p] for p in perm]
4✔
819

820
        # Get input strides and check if input is contiguous
821
        in_strides = (
4✔
822
            in_info.strides if hasattr(in_info, "strides") and in_info.strides else None
823
        )
824
        if in_strides is None:
4✔
825
            in_strides = self._compute_strides(in_shape, "C")
×
826

827
        if self._is_contiguous(in_shape, in_strides):
4✔
828
            # For contiguous inputs, output strides are permuted input strides
829
            out_strides = [in_strides[p] for p in perm]
4✔
830
        else:
831
            # For non-contiguous inputs, output is C-order for the new shape
832
            out_strides = self._compute_strides(out_shape, "C")
×
833

834
        target_name = ""
4✔
835
        if isinstance(target, ast.Name):
4✔
836
            target_name = target.id
4✔
837
        elif isinstance(target, str):
×
838
            target_name = target
×
839

840
        dtype = Scalar(PrimitiveType.Double)
4✔
841
        if input_name in self.container_table:
4✔
842
            input_type = self.container_table[input_name]
4✔
843
            if isinstance(input_type, Pointer):
4✔
844
                dtype = input_type.pointee_type
4✔
845
            else:
846
                dtype = input_type
×
847

848
        ptr_type = Pointer(dtype)
4✔
849

850
        # Create target container if it doesn't exist
851
        if not self.builder.exists(target_name):
4✔
852
            self.builder.add_container(target_name, ptr_type, False)
4✔
853
            self.container_table[target_name] = ptr_type
4✔
854
        self.tensor_table[target_name] = Tensor(dtype, out_shape, out_strides)
4✔
855

856
        # Create reference memlet to alias the source array (view, not copy)
857
        block = self.builder.add_block()
4✔
858
        t_src = self.builder.add_access(block, input_name)
4✔
859
        t_dst = self.builder.add_access(block, target_name)
4✔
860
        self.builder.add_reference_memlet(block, t_src, t_dst, "0", ptr_type)
4✔
861

862
        return True
4✔
863

864
    def handle_transpose_expr(self, node):
4✔
865
        """Handle .T attribute access in expressions, returning a temp array name."""
866
        if not isinstance(node, ast.Attribute) or node.attr != "T":
4✔
867
            return None
×
868

869
        input_name = self._ev.visit(node.value)
4✔
870
        if input_name not in self.tensor_table:
4✔
871
            return None
×
872

873
        in_info = self.tensor_table[input_name]
4✔
874
        in_shape = in_info.shape
4✔
875
        perm = list(range(len(in_shape)))[::-1]
4✔
876

877
        return self._create_transpose_view(input_name, perm)
4✔
878

879
    def _handle_numpy_transpose(self, node, func_name):
4✔
880
        """Handle np.transpose(arr, axes=...) function call."""
881
        if len(node.args) < 1:
4✔
882
            raise ValueError("np.transpose requires at least one argument")
×
883

884
        input_node = node.args[0]
4✔
885
        input_name = self.visit(input_node)
4✔
886

887
        if input_name not in self.tensor_table:
4✔
888
            raise ValueError(f"Array {input_name} not found in tensor_table")
×
889

890
        in_info = self.tensor_table[input_name]
4✔
891
        in_shape = in_info.shape
4✔
892

893
        perm = []
4✔
894
        if len(node.args) > 1:
4✔
895
            perm = self._parse_perm(node.args[1])
×
896
        for kw in node.keywords:
4✔
897
            if kw.arg == "axes":
4✔
898
                perm = self._parse_perm(kw.value)
4✔
899

900
        if not perm:
4✔
901
            perm = list(range(len(in_shape)))[::-1]
4✔
902

903
        return self._create_transpose_view(input_name, perm)
4✔
904

905
    def _create_transpose_view(self, input_name, perm):
4✔
906
        in_info = self.tensor_table[input_name]
4✔
907
        in_shape = in_info.shape
4✔
908
        in_strings = [str(s) for s in in_shape]
4✔
909

910
        # Compute output shape by permuting
911
        out_shape = [in_strings[p] for p in perm]
4✔
912

913
        # Get input strides and check if input is contiguous
914
        in_strides = (
4✔
915
            in_info.strides if hasattr(in_info, "strides") and in_info.strides else None
916
        )
917
        if in_strides is None:
4✔
918
            in_strides = self._compute_strides(in_shape, "C")
×
919

920
        # Always permute input strides (works for both contiguous and view inputs)
921
        out_strides = [in_strides[p] for p in perm]
4✔
922

923
        # Inherit offset from input tensor (for chained views like flip->transpose)
924
        in_offset = getattr(in_info, "offset", "0") or "0"
4✔
925

926
        # Create new pointer container
927
        tmp_name = f"_tmp_{self._get_unique_id()}"
4✔
928
        ptr_type = Pointer(in_info.element_type)
4✔
929
        self.builder.add_container(tmp_name, ptr_type, False)
4✔
930
        self.container_table[tmp_name] = ptr_type
4✔
931

932
        # Register tensor with permuted shape, strides, and inherited offset
933
        self.tensor_table[tmp_name] = Tensor(
4✔
934
            in_info.element_type, out_shape, out_strides, in_offset
935
        )
936

937
        # Create reference memlet to alias the source array
938
        block = self.builder.add_block()
4✔
939
        t_src = self.builder.add_access(block, input_name)
4✔
940
        t_dst = self.builder.add_access(block, tmp_name)
4✔
941
        self.builder.add_reference_memlet(block, t_src, t_dst, "0", ptr_type)
4✔
942

943
        return tmp_name
4✔
944

945
    def _handle_numpy_flip(self, node, func_name):
4✔
946
        """Handle np.flip(arr, axis=None) - flip array along specified axis.
947

948
        Uses negative strides and offset to create a view without copying.
949
        """
950
        if len(node.args) < 1:
4✔
951
            raise ValueError("np.flip requires at least one argument")
×
952

953
        input_name = self.visit(node.args[0])
4✔
954
        if input_name not in self.tensor_table:
4✔
955
            raise ValueError(f"Array {input_name} not found in tensor_table")
×
956

957
        in_info = self.tensor_table[input_name]
4✔
958
        in_shape = in_info.shape
4✔
959
        ndim = len(in_shape)
4✔
960

961
        # Parse axis argument
962
        axis = None
4✔
963
        if len(node.args) > 1:
4✔
964
            axis_node = node.args[1]
×
965
            if isinstance(axis_node, ast.Constant):
×
966
                axis = axis_node.value
×
967
            elif isinstance(axis_node, ast.UnaryOp) and isinstance(
×
968
                axis_node.op, ast.USub
969
            ):
970
                if isinstance(axis_node.operand, ast.Constant):
×
971
                    axis = -axis_node.operand.value
×
972
        for kw in node.keywords:
4✔
973
            if kw.arg == "axis":
4✔
974
                if isinstance(kw.value, ast.Constant):
4✔
975
                    axis = kw.value.value
4✔
976
                elif isinstance(kw.value, ast.UnaryOp) and isinstance(
4✔
977
                    kw.value.op, ast.USub
978
                ):
979
                    if isinstance(kw.value.operand, ast.Constant):
4✔
980
                        axis = -kw.value.operand.value
4✔
981

982
        # Determine which axes to flip
983
        if axis is None:
4✔
984
            # Flip all axes
985
            axes_to_flip = list(range(ndim))
4✔
986
        else:
987
            if axis < 0:
4✔
988
                axis = ndim + axis
4✔
989
            axes_to_flip = [axis]
4✔
990

991
        return self._create_flip_view(input_name, axes_to_flip)
4✔
992

993
    def _handle_numpy_fliplr(self, node, func_name):
4✔
994
        """Handle np.fliplr(arr) - flip array left-right (axis=1)."""
995
        if len(node.args) < 1:
4✔
996
            raise ValueError("np.fliplr requires one argument")
×
997

998
        input_name = self.visit(node.args[0])
4✔
999
        if input_name not in self.tensor_table:
4✔
1000
            raise ValueError(f"Array {input_name} not found in tensor_table")
×
1001

1002
        in_info = self.tensor_table[input_name]
4✔
1003
        if len(in_info.shape) < 2:
4✔
1004
            raise ValueError("np.fliplr requires array with ndim >= 2")
×
1005

1006
        return self._create_flip_view(input_name, [1])
4✔
1007

1008
    def _handle_numpy_flipud(self, node, func_name):
4✔
1009
        """Handle np.flipud(arr) - flip array up-down (axis=0)."""
1010
        if len(node.args) < 1:
4✔
1011
            raise ValueError("np.flipud requires one argument")
×
1012

1013
        input_name = self.visit(node.args[0])
4✔
1014
        if input_name not in self.tensor_table:
4✔
1015
            raise ValueError(f"Array {input_name} not found in tensor_table")
×
1016

1017
        return self._create_flip_view(input_name, [0])
4✔
1018

1019
    def _create_flip_view(self, input_name, axes_to_flip):
4✔
1020
        """Create a flipped view of an array using Tensor.flip().
1021

1022
        Uses the Tensor type's flip() method which computes the correct
1023
        negative strides and offset adjustment.
1024
        """
1025
        in_tensor = self.tensor_table[input_name]
4✔
1026

1027
        # Apply flip for each axis
1028
        flipped_tensor = in_tensor
4✔
1029
        for axis in axes_to_flip:
4✔
1030
            flipped_tensor = flipped_tensor.flip(axis)
4✔
1031

1032
        # Create new pointer container pointing to same data
1033
        tmp_name = f"_tmp_{self._get_unique_id()}"
4✔
1034
        ptr_type = Pointer(in_tensor.element_type)
4✔
1035
        self.builder.add_container(tmp_name, ptr_type, False)
4✔
1036
        self.container_table[tmp_name] = ptr_type
4✔
1037

1038
        # Store the flipped tensor with its offset in tensor_table
1039
        self.tensor_table[tmp_name] = flipped_tensor
4✔
1040

1041
        # Create reference memlet (offset is handled by tensor's offset property)
1042
        block = self.builder.add_block()
4✔
1043
        t_src = self.builder.add_access(block, input_name)
4✔
1044
        t_dst = self.builder.add_access(block, tmp_name)
4✔
1045
        self.builder.add_reference_memlet(block, t_src, t_dst, "0", ptr_type)
4✔
1046

1047
        return tmp_name
4✔
1048

1049
    def _handle_numpy_reshape(self, node, func_name):
4✔
1050
        """Handle np.reshape(arr, newshape) - reshape array without copying.
1051

1052
        Only works for contiguous arrays; creates a view with new shape/strides.
1053
        """
1054
        if len(node.args) < 2:
4✔
1055
            raise ValueError("np.reshape requires array and new shape")
×
1056

1057
        input_name = self.visit(node.args[0])
4✔
1058
        if input_name not in self.tensor_table:
4✔
1059
            raise ValueError(f"Array {input_name} not found in tensor_table")
×
1060

1061
        in_info = self.tensor_table[input_name]
4✔
1062
        in_shape = in_info.shape
4✔
1063

1064
        # Parse new shape
1065
        new_shape = self._parse_shape(node.args[1])
4✔
1066

1067
        # Get input strides
1068
        in_strides = (
4✔
1069
            in_info.strides if hasattr(in_info, "strides") and in_info.strides else None
1070
        )
1071
        if in_strides is None:
4✔
1072
            in_strides = self._compute_strides(in_shape, "C")
×
1073

1074
        # Check if input is contiguous (C or F order)
1075
        c_contig = self._is_contiguous(in_shape, in_strides)
4✔
1076
        f_contig = self._is_contiguous_f(in_shape, in_strides)
4✔
1077

1078
        if c_contig:
4✔
1079
            out_strides = self._compute_strides(new_shape, "C")
4✔
1080
        elif f_contig:
×
1081
            out_strides = self._compute_strides(new_shape, "F")
×
1082
        else:
1083
            # Non-contiguous array cannot be reshaped without copy
1084
            raise NotImplementedError(
×
1085
                "np.reshape on non-contiguous array not supported (would require copy)"
1086
            )
1087

1088
        # Create new pointer container
1089
        tmp_name = f"_tmp_{self._get_unique_id()}"
4✔
1090
        ptr_type = Pointer(in_info.element_type)
4✔
1091
        self.builder.add_container(tmp_name, ptr_type, False)
4✔
1092
        self.container_table[tmp_name] = ptr_type
4✔
1093

1094
        # Register tensor with new shape and computed strides
1095
        self.tensor_table[tmp_name] = Tensor(
4✔
1096
            in_info.element_type, new_shape, out_strides
1097
        )
1098

1099
        # Create reference memlet to alias the source array (view, no copy)
1100
        block = self.builder.add_block()
4✔
1101
        t_src = self.builder.add_access(block, input_name)
4✔
1102
        t_dst = self.builder.add_access(block, tmp_name)
4✔
1103
        self.builder.add_reference_memlet(block, t_src, t_dst, "0", ptr_type)
4✔
1104

1105
        return tmp_name
4✔
1106

1107
    def _parse_shape(self, shape_node):
4✔
1108
        """Parse a shape argument (tuple, list, or single int)."""
1109
        if isinstance(shape_node, ast.Tuple) or isinstance(shape_node, ast.List):
4✔
1110
            result = []
4✔
1111
            for elt in shape_node.elts:
4✔
1112
                if isinstance(elt, ast.Constant):
4✔
1113
                    result.append(str(elt.value))
4✔
1114
                elif isinstance(elt, ast.Name):
×
1115
                    result.append(elt.id)
×
1116
                elif isinstance(elt, ast.UnaryOp) and isinstance(elt.op, ast.USub):
×
1117
                    if isinstance(elt.operand, ast.Constant):
×
1118
                        result.append(str(-elt.operand.value))
×
1119
                else:
1120
                    result.append(self._shape_to_runtime_expr(elt))
×
1121
            return result
4✔
1122
        elif isinstance(shape_node, ast.Constant):
×
1123
            return [str(shape_node.value)]
×
1124
        elif isinstance(shape_node, ast.Name):
×
1125
            # Could be a variable holding a shape tuple - not supported yet
1126
            raise NotImplementedError("Shape variable not supported, use literal tuple")
×
1127
        else:
1128
            raise ValueError(f"Cannot parse shape: {ast.dump(shape_node)}")
×
1129

1130
    def _is_contiguous_f(self, shape, strides):
4✔
1131
        """Check if array is F-order contiguous."""
1132
        if not shape or not strides:
4✔
1133
            return True
×
1134
        f_strides = self._compute_strides(shape, "F")
4✔
1135
        return [str(s) for s in strides] == [str(s) for s in f_strides]
4✔
1136

1137
    def handle_numpy_call(self, node, func_name):
4✔
1138
        if func_name in self.function_handlers:
4✔
1139
            return self.function_handlers[func_name](node, func_name)
4✔
1140
        raise NotImplementedError(f"NumPy function {func_name} not supported")
×
1141

1142
    def has_handler(self, func_name):
4✔
1143
        return func_name in self.function_handlers
4✔
1144

1145
    def handle_array_unary_op(self, op_type, operand):
4✔
1146
        dtype = self._ev._element_type(operand)
4✔
1147
        if operand in self.tensor_table:
4✔
1148
            tensor = self.tensor_table[operand]
4✔
1149
        else:
1150
            tensor = Tensor(dtype, [])
4✔
1151

1152
        if len(tensor.shape) == 0:
4✔
1153
            tmp_name = self._create_array_temp([], dtype)
4✔
1154

1155
            func_map = {
4✔
1156
                "sqrt": CMathFunction.sqrt,
1157
                "abs": CMathFunction.fabs,
1158
                "absolute": CMathFunction.fabs,
1159
                "exp": CMathFunction.exp,
1160
                "tanh": CMathFunction.tanh,
1161
            }
1162

1163
            block = self.builder.add_block()
4✔
1164
            t_src = self.builder.add_access(block, operand)
4✔
1165
            t_dst = self.builder.add_access(block, tmp_name)
4✔
1166
            t_task = self.builder.add_cmath(
4✔
1167
                block, func_map[op_type], dtype.primitive_type
1168
            )
1169

1170
            self.builder.add_memlet(block, t_src, "void", t_task, "_in1", "", dtype)
4✔
1171
            self.builder.add_memlet(block, t_task, "_out", t_dst, "void", "", dtype)
4✔
1172

1173
            return tmp_name
4✔
1174

1175
        output_strides = self._get_contiguous_output_strides(
4✔
1176
            tensor.shape, tensor.strides
1177
        )
1178
        tmp_name = self._create_array_temp(tensor.shape, dtype, strides=output_strides)
4✔
1179
        tmp_tensor = self.tensor_table[tmp_name]
4✔
1180
        self.builder.add_elementwise_unary_op(
4✔
1181
            op_type, operand, tensor, tmp_name, tmp_tensor
1182
        )
1183

1184
        return tmp_name
4✔
1185

1186
    def handle_array_binary_op(self, op_type, left, right):
4✔
1187
        # Determine if operands are arrays or scalars
1188
        # NumPy 0-d arrays (shape=[]) ARE arrays for promotion purposes
1189
        # Only literals and Python scalars (not in tensor_table) are treated as scalars
1190
        left_is_array = left in self.tensor_table
4✔
1191
        right_is_array = right in self.tensor_table
4✔
1192

1193
        dtype_left = self._ev._element_type(left)
4✔
1194
        dtype_right = self._ev._element_type(right)
4✔
1195

1196
        # Use NumPy promotion rules: scalars adapt to arrays
1197
        dtype = numpy_promote_types(
4✔
1198
            dtype_left, left_is_array, dtype_right, right_is_array
1199
        )
1200

1201
        # Cast operands to result type if needed
1202
        real_left = self._cast_to_type(left, dtype)
4✔
1203
        real_right = self._cast_to_type(right, dtype)
4✔
1204

1205
        # Get tensor info for the (possibly casted) operands
1206
        if real_left in self.tensor_table:
4✔
1207
            left_tensor = self.tensor_table[real_left]
4✔
1208
        else:
1209
            left_tensor = Tensor(dtype, [])
4✔
1210

1211
        if real_right in self.tensor_table:
4✔
1212
            right_tensor = self.tensor_table[real_right]
4✔
1213
        else:
1214
            right_tensor = Tensor(dtype, [])
4✔
1215

1216
        left_shape = left_tensor.shape
4✔
1217
        right_shape = right_tensor.shape
4✔
1218

1219
        # Compute broadcast output shape
1220
        output_shape = self._compute_broadcast_shape(left_shape, right_shape)
4✔
1221

1222
        # Check if broadcasting is needed
1223
        left_needs_broadcast = (
4✔
1224
            self._needs_broadcast(left_shape, output_shape) if left_shape else False
1225
        )
1226
        right_needs_broadcast = (
4✔
1227
            self._needs_broadcast(right_shape, output_shape) if right_shape else False
1228
        )
1229

1230
        real_left_tensor = left_tensor
4✔
1231
        real_right_tensor = right_tensor
4✔
1232

1233
        # Broadcast left operand if needed (stride-based, no copy)
1234
        if left_needs_broadcast:
4✔
1235
            left_strides = left_tensor.strides if left_tensor.strides else []
×
1236
            broadcast_strides = self._compute_broadcast_strides(
×
1237
                left_shape, left_strides, output_shape
1238
            )
1239
            # Create a new tensor view with broadcast shape and strides
1240
            # Preserve the offset from the original tensor (important for views like flip)
1241
            left_offset = left_tensor.offset if left_tensor.offset else "0"
×
1242
            real_left_tensor = Tensor(
×
1243
                dtype, output_shape, broadcast_strides, left_offset
1244
            )
1245

1246
        # Broadcast right operand if needed (stride-based, no copy)
1247
        if right_needs_broadcast:
4✔
1248
            right_strides = right_tensor.strides if right_tensor.strides else []
4✔
1249
            broadcast_strides = self._compute_broadcast_strides(
4✔
1250
                right_shape, right_strides, output_shape
1251
            )
1252
            # Create a new tensor view with broadcast shape and strides
1253
            # Preserve the offset from the original tensor (important for views like flip)
1254
            right_offset = right_tensor.offset if right_tensor.offset else "0"
4✔
1255
            real_right_tensor = Tensor(
4✔
1256
                dtype, output_shape, broadcast_strides, right_offset
1257
            )
1258

1259
        # Create output array with broadcast shape
1260
        # Preserve F-order if both inputs are F-order and no broadcasting needed
1261
        if not left_needs_broadcast and not right_needs_broadcast:
4✔
1262
            # Use left tensor strides to determine output order
1263
            output_strides = self._get_contiguous_output_strides(
4✔
1264
                output_shape, left_tensor.strides
1265
            )
1266
        else:
1267
            output_strides = self._compute_strides(output_shape, "C")
4✔
1268
        tmp_name = self._create_array_temp(output_shape, dtype, strides=output_strides)
4✔
1269
        tmp_tensor = self.tensor_table[tmp_name]
4✔
1270

1271
        self.builder.add_elementwise_op(
4✔
1272
            op_type,
1273
            real_left,
1274
            real_left_tensor,
1275
            real_right,
1276
            real_right_tensor,
1277
            tmp_name,
1278
            tmp_tensor,
1279
        )
1280

1281
        return tmp_name
4✔
1282

1283
    def handle_array_negate(self, operand):
4✔
1284
        operand_tensor = self.tensor_table[operand]
4✔
1285
        dtype = self._ev._element_type(operand)
4✔
1286

1287
        output_strides = self._get_contiguous_output_strides(
4✔
1288
            operand_tensor.shape, operand_tensor.strides
1289
        )
1290
        tmp_name = self._create_array_temp(
4✔
1291
            operand_tensor.shape, dtype, strides=output_strides
1292
        )
1293
        tmp_tensor = self.tensor_table[tmp_name]
4✔
1294

1295
        zero_name = f"_tmp_{self._get_unique_id()}"
4✔
1296
        self.builder.add_container(zero_name, dtype, False)
4✔
1297
        self.container_table[zero_name] = dtype
4✔
1298
        self.tensor_table[zero_name] = Tensor(dtype, [])
4✔
1299

1300
        zero_block = self.builder.add_block()
4✔
1301
        t_const = self.builder.add_constant(
4✔
1302
            zero_block,
1303
            "0.0" if dtype.primitive_type == PrimitiveType.Double else "0",
1304
            dtype,
1305
        )
1306
        t_zero = self.builder.add_access(zero_block, zero_name)
4✔
1307
        t_assign = self.builder.add_tasklet(
4✔
1308
            zero_block, TaskletCode.assign, ["_in"], ["_out"]
1309
        )
1310
        self.builder.add_memlet(zero_block, t_const, "void", t_assign, "_in", "")
4✔
1311
        self.builder.add_memlet(zero_block, t_assign, "_out", t_zero, "void", "")
4✔
1312

1313
        zero_tensor = self.tensor_table[zero_name]
4✔
1314
        self.builder.add_elementwise_op(
4✔
1315
            "sub", zero_name, zero_tensor, operand, operand_tensor, tmp_name, tmp_tensor
1316
        )
1317

1318
        return tmp_name
4✔
1319

1320
    def handle_array_compare(self, left, op, right, left_is_array, right_is_array):
4✔
1321
        """Handle elementwise comparison of arrays, returning a boolean array."""
1322
        if left_is_array:
4✔
1323
            shape = self.tensor_table[left].shape
4✔
1324
            arr_name = left
4✔
1325
        else:
1326
            shape = self.tensor_table[right].shape
×
1327
            arr_name = right
×
1328

1329
        use_int_cmp = False
4✔
1330
        arr_dtype = self._ev._element_type(arr_name)
4✔
1331
        if arr_dtype.primitive_type in (PrimitiveType.Int32, PrimitiveType.Int64):
4✔
1332
            use_int_cmp = True
×
1333

1334
        dtype = Scalar(PrimitiveType.Bool)
4✔
1335
        tmp_name = self._create_array_temp(shape, dtype)
4✔
1336

1337
        if use_int_cmp:
4✔
1338
            cmp_ops = {
×
1339
                ">": TaskletCode.int_sgt,
1340
                ">=": TaskletCode.int_sge,
1341
                "<": TaskletCode.int_slt,
1342
                "<=": TaskletCode.int_sle,
1343
                "==": TaskletCode.int_eq,
1344
                "!=": TaskletCode.int_ne,
1345
            }
1346
        else:
1347
            cmp_ops = {
4✔
1348
                ">": TaskletCode.fp_ogt,
1349
                ">=": TaskletCode.fp_oge,
1350
                "<": TaskletCode.fp_olt,
1351
                "<=": TaskletCode.fp_ole,
1352
                "==": TaskletCode.fp_oeq,
1353
                "!=": TaskletCode.fp_one,
1354
            }
1355

1356
        if op not in cmp_ops:
4✔
1357
            raise NotImplementedError(
×
1358
                f"Comparison operator {op} not supported for arrays"
1359
            )
1360

1361
        tasklet_code = cmp_ops[op]
4✔
1362

1363
        scalar_name = None
4✔
1364
        if not left_is_array:
4✔
1365
            scalar_name = left
×
1366
        elif not right_is_array:
4✔
1367
            scalar_name = right
4✔
1368

1369
        if scalar_name is not None and not use_int_cmp:
4✔
1370
            if self._is_int(scalar_name):
4✔
1371
                float_name = f"_tmp_{self._get_unique_id()}"
4✔
1372
                self.builder.add_container(
4✔
1373
                    float_name, Scalar(PrimitiveType.Double), False
1374
                )
1375
                self.container_table[float_name] = Scalar(PrimitiveType.Double)
4✔
1376

1377
                block_conv = self.builder.add_block()
4✔
1378
                t_const = self.builder.add_constant(
4✔
1379
                    block_conv, f"{scalar_name}.0", Scalar(PrimitiveType.Double)
1380
                )
1381
                t_float = self.builder.add_access(block_conv, float_name)
4✔
1382
                t_assign = self.builder.add_tasklet(
4✔
1383
                    block_conv, TaskletCode.assign, ["_in"], ["_out"]
1384
                )
1385
                self.builder.add_memlet(
4✔
1386
                    block_conv, t_const, "void", t_assign, "_in", ""
1387
                )
1388
                self.builder.add_memlet(
4✔
1389
                    block_conv, t_assign, "_out", t_float, "void", ""
1390
                )
1391

1392
                if not left_is_array:
4✔
1393
                    left = float_name
×
1394
                else:
1395
                    right = float_name
4✔
1396

1397
        # Get tensor info for array operands
1398
        left_tensor = self.tensor_table.get(left) if left_is_array else None
4✔
1399
        right_tensor = self.tensor_table.get(right) if right_is_array else None
4✔
1400
        tmp_tensor = self.tensor_table[tmp_name]
4✔
1401

1402
        loop_vars = []
4✔
1403
        for i, dim in enumerate(shape):
4✔
1404
            loop_var = f"_cmp_i{i}_{self._get_unique_id()}"
4✔
1405
            if not self.builder.exists(loop_var):
4✔
1406
                self.builder.add_container(loop_var, Scalar(PrimitiveType.Int64), False)
4✔
1407
                self.container_table[loop_var] = Scalar(PrimitiveType.Int64)
4✔
1408
            loop_vars.append(loop_var)
4✔
1409
            self.builder.begin_for(loop_var, "0", str(dim), "1")
4✔
1410

1411
        # Multi-dimensional subset - TensorToPointerConversion handles strides/offset
1412
        multi_dim_subset = ",".join(loop_vars)
4✔
1413

1414
        block = self.builder.add_block()
4✔
1415

1416
        if left_is_array:
4✔
1417
            t_left = self.builder.add_access(block, left)
4✔
1418
            left_sub = multi_dim_subset
4✔
1419
        else:
1420
            t_left, left_sub = self._add_read(block, left)
×
1421

1422
        if right_is_array:
4✔
1423
            t_right = self.builder.add_access(block, right)
×
1424
            right_sub = multi_dim_subset
×
1425
        else:
1426
            t_right, right_sub = self._add_read(block, right)
4✔
1427

1428
        t_out = self.builder.add_access(block, tmp_name)
4✔
1429

1430
        t_task = self.builder.add_tasklet(
4✔
1431
            block, tasklet_code, ["_in1", "_in2"], ["_out"]
1432
        )
1433

1434
        # Pass tensor type so TensorToPointerConversion uses correct strides/offset
1435
        if left_is_array and left_tensor:
4✔
1436
            self.builder.add_memlet(
4✔
1437
                block, t_left, "void", t_task, "_in1", left_sub, left_tensor
1438
            )
1439
        else:
1440
            self.builder.add_memlet(block, t_left, "void", t_task, "_in1", left_sub)
×
1441

1442
        if right_is_array and right_tensor:
4✔
1443
            self.builder.add_memlet(
×
1444
                block, t_right, "void", t_task, "_in2", right_sub, right_tensor
1445
            )
1446
        else:
1447
            self.builder.add_memlet(block, t_right, "void", t_task, "_in2", right_sub)
4✔
1448

1449
        self.builder.add_memlet(
4✔
1450
            block, t_task, "_out", t_out, "void", multi_dim_subset, tmp_tensor
1451
        )
1452

1453
        for _ in loop_vars:
4✔
1454
            self.builder.end_for()
4✔
1455

1456
        return tmp_name
4✔
1457

1458
    # ========== NumPy Function Handlers ==========
1459

1460
    def _handle_numpy_alloc(self, node, func_name):
4✔
1461
        """Handle np.empty, np.zeros, np.ones, np.ndarray."""
1462
        shape_arg = node.args[0]
4✔
1463
        dims = []
4✔
1464
        dims_runtime = []
4✔
1465
        if isinstance(shape_arg, ast.Tuple):
4✔
1466
            dims = [self.visit(elt) for elt in shape_arg.elts]
4✔
1467
            dims_runtime = [self._shape_to_runtime_expr(elt) for elt in shape_arg.elts]
4✔
1468
        elif isinstance(shape_arg, ast.List):
4✔
1469
            dims = [self.visit(elt) for elt in shape_arg.elts]
×
1470
            dims_runtime = [self._shape_to_runtime_expr(elt) for elt in shape_arg.elts]
×
1471
        else:
1472
            val = self.visit(shape_arg)
4✔
1473
            runtime_val = self._shape_to_runtime_expr(shape_arg)
4✔
1474
            if val.startswith("_shape_proxy_"):
4✔
1475
                array_name = val[len("_shape_proxy_") :]
×
1476
                if array_name in self.tensor_table:
×
1477
                    info = self.tensor_table[array_name]
×
1478
                    dims = info.shape
×
1479
                    dims_runtime = self.shapes_runtime_info.get(array_name, dims)
×
1480
                else:
1481
                    dims = [val]
×
1482
                    dims_runtime = [runtime_val]
×
1483
            else:
1484
                dims = [val]
4✔
1485
                dims_runtime = [runtime_val]
4✔
1486

1487
        dtype_arg = None
4✔
1488
        order = "C"  # Default to C-order (row-major)
4✔
1489
        explicit_strides = None
4✔
1490
        if len(node.args) > 1:
4✔
1491
            dtype_arg = node.args[1]
×
1492

1493
        for kw in node.keywords:
4✔
1494
            if kw.arg == "dtype":
4✔
1495
                dtype_arg = kw.value
4✔
1496
            elif kw.arg == "order":
4✔
1497
                if isinstance(kw.value, ast.Constant):
4✔
1498
                    order = kw.value.value
4✔
1499
            elif kw.arg == "strides":
4✔
1500
                # Parse explicit strides tuple/list
1501
                if isinstance(kw.value, (ast.Tuple, ast.List)):
4✔
1502
                    explicit_strides = [
4✔
1503
                        self._shape_to_runtime_expr(elt) for elt in kw.value.elts
1504
                    ]
1505

1506
        element_type = element_type_from_ast_node(dtype_arg, self.container_table)
4✔
1507

1508
        # Use explicit strides if provided, otherwise compute from order
1509
        if explicit_strides is not None:
4✔
1510
            # Convert byte strides to element strides by dividing by element size
1511
            element_size = self.builder.get_sizeof(element_type)
4✔
1512
            strides = [f"(({s}) / {element_size})" for s in explicit_strides]
4✔
1513
        else:
1514
            strides = self._compute_strides(dims, order)
4✔
1515

1516
        return self._create_array_temp(
4✔
1517
            dims,
1518
            element_type,
1519
            zero_init=(func_name == "zeros"),
1520
            ones_init=(func_name == "ones"),
1521
            shapes_runtime=dims_runtime,
1522
            strides=strides,
1523
        )
1524

1525
    def _handle_numpy_empty_like(self, node, func_name):
4✔
1526
        """Handle np.empty_like."""
1527
        prototype_arg = node.args[0]
4✔
1528
        prototype_name = self.visit(prototype_arg)
4✔
1529

1530
        dims = []
4✔
1531
        if prototype_name in self.tensor_table:
4✔
1532
            dims = self.tensor_table[prototype_name].shape
4✔
1533

1534
        dtype_arg = None
4✔
1535
        order = "C"  # Default to C-order
4✔
1536
        if len(node.args) > 1:
4✔
1537
            dtype_arg = node.args[1]
×
1538

1539
        for kw in node.keywords:
4✔
1540
            if kw.arg == "dtype":
4✔
1541
                dtype_arg = kw.value
4✔
1542
            elif kw.arg == "order":
4✔
1543
                if isinstance(kw.value, ast.Constant):
4✔
1544
                    order = kw.value.value
4✔
1545

1546
        element_type = None
4✔
1547
        if dtype_arg:
4✔
1548
            element_type = element_type_from_ast_node(dtype_arg, self.container_table)
4✔
1549
        else:
1550
            if prototype_name in self.container_table:
4✔
1551
                sym_type = self.container_table[prototype_name]
4✔
1552
                if isinstance(sym_type, Pointer) and sym_type.has_pointee_type():
4✔
1553
                    element_type = sym_type.pointee_type
4✔
1554

1555
        if element_type is None:
4✔
1556
            element_type = Scalar(PrimitiveType.Double)
×
1557

1558
        strides = self._compute_strides(dims, order)
4✔
1559
        return self._create_array_temp(
4✔
1560
            dims, element_type, zero_init=False, ones_init=False, strides=strides
1561
        )
1562

1563
    def _handle_numpy_zeros_like(self, node, func_name):
4✔
1564
        """Handle np.zeros_like."""
1565
        prototype_arg = node.args[0]
4✔
1566
        prototype_name = self.visit(prototype_arg)
4✔
1567

1568
        dims = []
4✔
1569
        if prototype_name in self.tensor_table:
4✔
1570
            dims = self.tensor_table[prototype_name].shape
4✔
1571

1572
        dtype_arg = None
4✔
1573
        order = "C"  # Default to C-order
4✔
1574
        if len(node.args) > 1:
4✔
1575
            dtype_arg = node.args[1]
×
1576

1577
        for kw in node.keywords:
4✔
1578
            if kw.arg == "dtype":
4✔
1579
                dtype_arg = kw.value
4✔
1580
            elif kw.arg == "order":
4✔
1581
                if isinstance(kw.value, ast.Constant):
4✔
1582
                    order = kw.value.value
4✔
1583

1584
        element_type = None
4✔
1585
        if dtype_arg:
4✔
1586
            element_type = element_type_from_ast_node(dtype_arg, self.container_table)
4✔
1587
        else:
1588
            if prototype_name in self.container_table:
4✔
1589
                sym_type = self.container_table[prototype_name]
4✔
1590
                if isinstance(sym_type, Pointer) and sym_type.has_pointee_type():
4✔
1591
                    element_type = sym_type.pointee_type
4✔
1592

1593
        if element_type is None:
4✔
1594
            element_type = Scalar(PrimitiveType.Double)
×
1595

1596
        strides = self._compute_strides(dims, order)
4✔
1597
        return self._create_array_temp(
4✔
1598
            dims, element_type, zero_init=True, ones_init=False, strides=strides
1599
        )
1600

1601
    def _handle_numpy_eye(self, node, func_name):
4✔
1602
        """Handle np.eye."""
1603
        N_arg = node.args[0]
4✔
1604
        N_str = self.visit(N_arg)
4✔
1605
        N_runtime = self._shape_to_runtime_expr(N_arg)
4✔
1606

1607
        M_str = N_str
4✔
1608
        M_arg = N_arg  # Default M = N
4✔
1609
        if len(node.args) > 1:
4✔
1610
            M_arg = node.args[1]
×
1611
            M_str = self.visit(M_arg)
×
1612

1613
        k_str = "0"
4✔
1614
        if len(node.args) > 2:
4✔
1615
            k_str = self.visit(node.args[2])
×
1616

1617
        dtype_arg = None
4✔
1618
        for kw in node.keywords:
4✔
1619
            if kw.arg == "M":
4✔
1620
                M_arg = kw.value
4✔
1621
                M_str = self.visit(M_arg)
4✔
1622
                if M_str == "None":
4✔
1623
                    M_str = N_str
4✔
1624
                    M_arg = N_arg
4✔
1625
            elif kw.arg == "k":
4✔
1626
                k_str = self.visit(kw.value)
4✔
1627
            elif kw.arg == "dtype":
4✔
1628
                dtype_arg = kw.value
4✔
1629

1630
        M_runtime = self._shape_to_runtime_expr(M_arg)
4✔
1631

1632
        element_type = element_type_from_ast_node(dtype_arg, self.container_table)
4✔
1633

1634
        ptr_name = self._create_array_temp(
4✔
1635
            [N_str, M_str],
1636
            element_type,
1637
            zero_init=True,
1638
            shapes_runtime=[N_runtime, M_runtime],
1639
        )
1640

1641
        loop_var = f"_i_{self._get_unique_id()}"
4✔
1642
        if not self.builder.exists(loop_var):
4✔
1643
            self.builder.add_container(loop_var, Scalar(PrimitiveType.Int64), False)
4✔
1644
            self.container_table[loop_var] = Scalar(PrimitiveType.Int64)
4✔
1645

1646
        self.builder.begin_for(loop_var, "0", N_str, "1")
4✔
1647

1648
        cond = f"(({loop_var} + {k_str}) >= 0) & (({loop_var} + {k_str}) < {M_str})"
4✔
1649
        self.builder.begin_if(cond)
4✔
1650

1651
        val = "1.0"
4✔
1652
        if element_type.primitive_type in [
4✔
1653
            PrimitiveType.Int64,
1654
            PrimitiveType.Int32,
1655
            PrimitiveType.Int8,
1656
            PrimitiveType.Int16,
1657
            PrimitiveType.UInt64,
1658
            PrimitiveType.UInt32,
1659
            PrimitiveType.UInt8,
1660
            PrimitiveType.UInt16,
1661
        ]:
1662
            val = "1"
×
1663

1664
        block_assign = self.builder.add_block()
4✔
1665
        t_const = self.builder.add_constant(block_assign, val, element_type)
4✔
1666
        t_arr = self.builder.add_access(block_assign, ptr_name)
4✔
1667
        flat_index = f"(({loop_var}) * ({M_str}) + ({loop_var}) + ({k_str}))"
4✔
1668
        subset = flat_index
4✔
1669

1670
        t_task = self.builder.add_tasklet(
4✔
1671
            block_assign, TaskletCode.assign, ["_in"], ["_out"]
1672
        )
1673
        self.builder.add_memlet(
4✔
1674
            block_assign, t_const, "void", t_task, "_in", "", element_type
1675
        )
1676
        self.builder.add_memlet(block_assign, t_task, "_out", t_arr, "void", subset)
4✔
1677

1678
        self.builder.end_if()
4✔
1679
        self.builder.end_for()
4✔
1680

1681
        return ptr_name
4✔
1682

1683
    def _handle_numpy_binary_op(self, node, func_name):
4✔
1684
        """Handle np.add, np.subtract, np.multiply, np.divide, etc."""
1685
        args = [self.visit(arg) for arg in node.args]
4✔
1686
        if len(args) != 2:
4✔
1687
            raise NotImplementedError(
×
1688
                f"Numpy function {func_name} requires 2 arguments"
1689
            )
1690

1691
        op_map = {
4✔
1692
            "add": "add",
1693
            "subtract": "sub",
1694
            "multiply": "mul",
1695
            "divide": "div",
1696
            "power": "pow",
1697
            "minimum": "min",
1698
            "maximum": "max",
1699
        }
1700
        return self.handle_array_binary_op(op_map[func_name], args[0], args[1])
4✔
1701

1702
    def _handle_numpy_unary_op(self, node, func_name):
4✔
1703
        """Handle np.exp, np.sqrt, np.abs, etc."""
1704
        args = [self.visit(arg) for arg in node.args]
4✔
1705
        if len(args) != 1:
4✔
1706
            raise NotImplementedError(f"Numpy function {func_name} requires 1 argument")
×
1707

1708
        op_name = func_name
4✔
1709
        if op_name == "absolute":
4✔
1710
            op_name = "abs"
×
1711

1712
        return self.handle_array_unary_op(op_name, args[0])
4✔
1713

1714
    def _handle_numpy_where(self, node, func_name):
4✔
1715
        """Handle np.where(condition, x, y) - elementwise ternary selection."""
1716
        if len(node.args) != 3:
4✔
1717
            raise NotImplementedError("np.where requires 3 arguments (condition, x, y)")
×
1718

1719
        cond_name = self.visit(node.args[0])
4✔
1720
        x_name = self.visit(node.args[1])
4✔
1721
        y_name = self.visit(node.args[2])
4✔
1722

1723
        shape = []
4✔
1724
        dtype = Scalar(PrimitiveType.Double)
4✔
1725

1726
        if cond_name in self.tensor_table:
4✔
1727
            shape = self.tensor_table[cond_name].shape
4✔
1728

1729
        if not shape and y_name in self.tensor_table:
4✔
1730
            shape = self.tensor_table[y_name].shape
×
1731

1732
        if not shape and x_name in self.tensor_table:
4✔
1733
            shape = self.tensor_table[x_name].shape
×
1734

1735
        if not shape:
4✔
1736
            raise NotImplementedError("np.where requires at least one array argument")
×
1737

1738
        if y_name in self.container_table:
4✔
1739
            y_type = self.container_table[y_name]
4✔
1740
            if isinstance(y_type, Pointer) and y_type.has_pointee_type():
4✔
1741
                dtype = y_type.pointee_type
4✔
1742
            elif isinstance(y_type, Scalar):
×
1743
                dtype = y_type
×
1744

1745
        tmp_name = self._create_array_temp(shape, dtype)
4✔
1746
        tmp_tensor = self.tensor_table[tmp_name]
4✔
1747

1748
        loop_vars = []
4✔
1749
        for i, dim in enumerate(shape):
4✔
1750
            loop_var = f"_where_i{i}_{self._get_unique_id()}"
4✔
1751
            if not self.builder.exists(loop_var):
4✔
1752
                self.builder.add_container(loop_var, Scalar(PrimitiveType.Int64), False)
4✔
1753
                self.container_table[loop_var] = Scalar(PrimitiveType.Int64)
4✔
1754
            loop_vars.append(loop_var)
4✔
1755
            self.builder.begin_for(loop_var, "0", str(dim), "1")
4✔
1756
        multi_dim_subset = ",".join(loop_vars)
4✔
1757

1758
        cond_tmp = f"_where_cond_{self._get_unique_id()}"
4✔
1759
        self.builder.add_container(cond_tmp, Scalar(PrimitiveType.Bool), False)
4✔
1760
        self.container_table[cond_tmp] = Scalar(PrimitiveType.Bool)
4✔
1761

1762
        block_cond = self.builder.add_block()
4✔
1763
        if cond_name in self.tensor_table:
4✔
1764
            cond_tensor = self.tensor_table[cond_name]
4✔
1765
            t_cond_arr = self.builder.add_access(block_cond, cond_name)
4✔
1766
            t_cond_out = self.builder.add_access(block_cond, cond_tmp)
4✔
1767
            t_cond_task = self.builder.add_tasklet(
4✔
1768
                block_cond, TaskletCode.assign, ["_in"], ["_out"]
1769
            )
1770
            self.builder.add_memlet(
4✔
1771
                block_cond,
1772
                t_cond_arr,
1773
                "void",
1774
                t_cond_task,
1775
                "_in",
1776
                multi_dim_subset,
1777
                cond_tensor,
1778
            )
1779
            self.builder.add_memlet(
4✔
1780
                block_cond, t_cond_task, "_out", t_cond_out, "void", ""
1781
            )
1782
        else:
1783
            t_cond_src, cond_sub = self._add_read(block_cond, cond_name)
×
1784
            t_cond_out = self.builder.add_access(block_cond, cond_tmp)
×
1785
            t_cond_task = self.builder.add_tasklet(
×
1786
                block_cond, TaskletCode.assign, ["_in"], ["_out"]
1787
            )
1788
            self.builder.add_memlet(
×
1789
                block_cond, t_cond_src, "void", t_cond_task, "_in", cond_sub
1790
            )
1791
            self.builder.add_memlet(
×
1792
                block_cond, t_cond_task, "_out", t_cond_out, "void", ""
1793
            )
1794

1795
        self.builder.begin_if(f"{cond_tmp} == true")
4✔
1796

1797
        block_true = self.builder.add_block()
4✔
1798
        t_out_true = self.builder.add_access(block_true, tmp_name)
4✔
1799
        if x_name in self.tensor_table:
4✔
1800
            x_tensor = self.tensor_table[x_name]
4✔
1801
            t_x = self.builder.add_access(block_true, x_name)
4✔
1802
            t_task_true = self.builder.add_tasklet(
4✔
1803
                block_true, TaskletCode.assign, ["_in"], ["_out"]
1804
            )
1805
            self.builder.add_memlet(
4✔
1806
                block_true, t_x, "void", t_task_true, "_in", multi_dim_subset, x_tensor
1807
            )
1808
        else:
1809
            t_x, x_sub = self._add_read(block_true, x_name)
4✔
1810
            t_task_true = self.builder.add_tasklet(
4✔
1811
                block_true, TaskletCode.assign, ["_in"], ["_out"]
1812
            )
1813
            self.builder.add_memlet(block_true, t_x, "void", t_task_true, "_in", x_sub)
4✔
1814
        self.builder.add_memlet(
4✔
1815
            block_true,
1816
            t_task_true,
1817
            "_out",
1818
            t_out_true,
1819
            "void",
1820
            multi_dim_subset,
1821
            tmp_tensor,
1822
        )
1823

1824
        self.builder.begin_else()
4✔
1825

1826
        # False branch: read from y, write to output
1827
        block_false = self.builder.add_block()
4✔
1828
        t_out_false = self.builder.add_access(block_false, tmp_name)
4✔
1829
        if y_name in self.tensor_table:
4✔
1830
            y_tensor = self.tensor_table[y_name]
4✔
1831
            t_y = self.builder.add_access(block_false, y_name)
4✔
1832
            t_task_false = self.builder.add_tasklet(
4✔
1833
                block_false, TaskletCode.assign, ["_in"], ["_out"]
1834
            )
1835
            self.builder.add_memlet(
4✔
1836
                block_false,
1837
                t_y,
1838
                "void",
1839
                t_task_false,
1840
                "_in",
1841
                multi_dim_subset,
1842
                y_tensor,
1843
            )
1844
        else:
1845
            t_y, y_sub = self._add_read(block_false, y_name)
4✔
1846
            t_task_false = self.builder.add_tasklet(
4✔
1847
                block_false, TaskletCode.assign, ["_in"], ["_out"]
1848
            )
1849
            self.builder.add_memlet(
4✔
1850
                block_false, t_y, "void", t_task_false, "_in", y_sub
1851
            )
1852
        self.builder.add_memlet(
4✔
1853
            block_false,
1854
            t_task_false,
1855
            "_out",
1856
            t_out_false,
1857
            "void",
1858
            multi_dim_subset,
1859
            tmp_tensor,
1860
        )
1861

1862
        self.builder.end_if()
4✔
1863

1864
        for _ in loop_vars:
4✔
1865
            self.builder.end_for()
4✔
1866

1867
        return tmp_name
4✔
1868

1869
    def _handle_numpy_clip(self, node, func_name):
4✔
1870
        """Handle np.clip(a, a_min, a_max) - elementwise clipping."""
1871
        if len(node.args) != 3:
4✔
1872
            raise NotImplementedError("np.clip requires 3 arguments (a, a_min, a_max)")
×
1873

1874
        arr_name = self.visit(node.args[0])
4✔
1875
        a_min = self.visit(node.args[1])
4✔
1876
        a_max = self.visit(node.args[2])
4✔
1877

1878
        tmp1 = self.handle_array_binary_op("max", arr_name, a_min)
4✔
1879
        result = self.handle_array_binary_op("min", tmp1, a_max)
4✔
1880

1881
        return result
4✔
1882

1883
    def _handle_numpy_matmul(self, node, func_name):
4✔
1884
        """Handle np.matmul, np.dot."""
1885
        if len(node.args) != 2:
4✔
1886
            raise NotImplementedError("matmul/dot requires 2 arguments")
×
1887
        return self._handle_matmul_helper(node.args[0], node.args[1])
4✔
1888

1889
    def handle_numpy_matmul_op(self, left_node, right_node):
4✔
1890
        """Handle the @ operator for matrix multiplication."""
1891
        return self._handle_matmul_helper(left_node, right_node)
4✔
1892

1893
    def _handle_matmul_helper(self, left_node, right_node):
4✔
1894
        """Helper for matrix multiplication operations."""
1895
        res_a = self.parse_arg(left_node)
4✔
1896
        res_b = self.parse_arg(right_node)
4✔
1897

1898
        if not res_a[0]:
4✔
1899
            left_name = self.visit(left_node)
4✔
1900
            left_node = ast.Name(id=left_name)
4✔
1901
            res_a = self.parse_arg(left_node)
4✔
1902

1903
        if not res_b[0]:
4✔
1904
            right_name = self.visit(right_node)
×
1905
            right_node = ast.Name(id=right_name)
×
1906
            res_b = self.parse_arg(right_node)
×
1907

1908
        name_a, subset_a, shape_a, indices_a = res_a
4✔
1909
        name_b, subset_b, shape_b, indices_b = res_b
4✔
1910

1911
        if not name_a or not name_b:
4✔
1912
            raise NotImplementedError("Could not resolve matmul operands")
×
1913

1914
        real_shape_a = shape_a
4✔
1915
        real_shape_b = shape_b
4✔
1916

1917
        ndim_a = len(real_shape_a)
4✔
1918
        ndim_b = len(real_shape_b)
4✔
1919

1920
        output_shape = []
4✔
1921
        is_scalar = False
4✔
1922

1923
        if ndim_a == 1 and ndim_b == 1:
4✔
1924
            is_scalar = True
4✔
1925
            output_shape = []
4✔
1926
        elif ndim_a == 2 and ndim_b == 2:
4✔
1927
            output_shape = [real_shape_a[0], real_shape_b[1]]
4✔
1928
        elif ndim_a == 2 and ndim_b == 1:
4✔
1929
            output_shape = [real_shape_a[0]]
4✔
1930
        elif ndim_a == 1 and ndim_b == 2:
4✔
1931
            output_shape = [real_shape_b[1]]
×
1932
        elif ndim_a > 2 or ndim_b > 2:
4✔
1933
            if ndim_a == ndim_b:
4✔
1934
                output_shape = list(real_shape_a[:-2]) + [
4✔
1935
                    real_shape_a[-2],
1936
                    real_shape_b[-1],
1937
                ]
1938
            else:
1939
                raise NotImplementedError(
×
1940
                    "Broadcasting with different ranks not fully supported yet"
1941
                )
1942
        else:
1943
            raise NotImplementedError(
×
1944
                f"Matmul with ranks {ndim_a} and {ndim_b} not supported"
1945
            )
1946

1947
        dtype_a = self._ev._element_type(name_a)
4✔
1948
        dtype_b = self._ev._element_type(name_b)
4✔
1949
        dtype = promote_element_types(dtype_a, dtype_b)
4✔
1950

1951
        if is_scalar:
4✔
1952
            tmp_name = f"_tmp_{self._get_unique_id()}"
4✔
1953
            self.builder.add_container(tmp_name, dtype, False)
4✔
1954
            self.container_table[tmp_name] = dtype
4✔
1955
        else:
1956
            tmp_name = self._create_array_temp(output_shape, dtype)
4✔
1957

1958
        if ndim_a > 2 or ndim_b > 2:
4✔
1959
            batch_dims = ndim_a - 2
4✔
1960
            loop_vars = []
4✔
1961

1962
            for i in range(batch_dims):
4✔
1963
                loop_var = f"_i{self._get_unique_id()}"
4✔
1964
                self.builder.add_container(loop_var, Scalar(PrimitiveType.Int64), False)
4✔
1965
                loop_vars.append(loop_var)
4✔
1966
                dim_size = real_shape_a[i]
4✔
1967
                self.builder.begin_for(loop_var, "0", str(dim_size), "1")
4✔
1968

1969
            def make_slice(name, indices):
4✔
1970
                elts = []
4✔
1971
                for idx in indices:
4✔
1972
                    if idx == ":":
4✔
1973
                        elts.append(ast.Slice())
4✔
1974
                    else:
1975
                        elts.append(ast.Name(id=idx))
4✔
1976

1977
                return ast.Subscript(
4✔
1978
                    value=ast.Name(id=name), slice=ast.Tuple(elts=elts), ctx=ast.Load()
1979
                )
1980

1981
            indices = loop_vars + [":", ":"]
4✔
1982
            slice_a = make_slice(name_a, indices)
4✔
1983
            slice_b = make_slice(name_b, indices)
4✔
1984
            slice_c = make_slice(tmp_name, indices)
4✔
1985

1986
            self.handle_gemm(
4✔
1987
                slice_c, ast.BinOp(left=slice_a, op=ast.MatMult(), right=slice_b)
1988
            )
1989

1990
            for _ in range(batch_dims):
4✔
1991
                self.builder.end_for()
4✔
1992
        else:
1993
            if is_scalar:
4✔
1994
                self.handle_dot(
4✔
1995
                    tmp_name,
1996
                    ast.BinOp(left=left_node, op=ast.MatMult(), right=right_node),
1997
                )
1998
            else:
1999
                self.handle_gemm(
4✔
2000
                    tmp_name,
2001
                    ast.BinOp(left=left_node, op=ast.MatMult(), right=right_node),
2002
                )
2003

2004
        return tmp_name
4✔
2005

2006
    def _handle_numpy_outer(self, node, func_name):
4✔
2007
        """Handle np.outer."""
2008
        if len(node.args) != 2:
4✔
2009
            raise NotImplementedError("outer requires 2 arguments")
×
2010

2011
        arg0 = node.args[0]
4✔
2012
        arg1 = node.args[1]
4✔
2013

2014
        res_a = self.parse_arg(arg0)
4✔
2015
        res_b = self.parse_arg(arg1)
4✔
2016

2017
        if not res_a[0]:
4✔
2018
            left_name = self.visit(arg0)
×
2019
            arg0 = ast.Name(id=left_name)
×
2020
            res_a = self.parse_arg(arg0)
×
2021

2022
        if not res_b[0]:
4✔
2023
            right_name = self.visit(arg1)
×
2024
            arg1 = ast.Name(id=right_name)
×
2025
            res_b = self.parse_arg(arg1)
×
2026

2027
        name_a, subset_a, shape_a, indices_a = res_a
4✔
2028
        name_b, subset_b, shape_b, indices_b = res_b
4✔
2029

2030
        if not name_a or not name_b:
4✔
2031
            raise NotImplementedError("Could not resolve outer operands")
×
2032

2033
        def get_flattened_size_expr(name, indices, shapes):
4✔
2034
            size_expr = "1"
4✔
2035
            for s in shapes:
4✔
2036
                if size_expr == "1":
4✔
2037
                    size_expr = str(s)
4✔
2038
                else:
2039
                    size_expr = f"({size_expr} * {str(s)})"
×
2040
            return size_expr
4✔
2041

2042
        m_expr = get_flattened_size_expr(name_a, indices_a, shape_a)
4✔
2043
        n_expr = get_flattened_size_expr(name_b, indices_b, shape_b)
4✔
2044

2045
        dtype_a = self._ev._element_type(name_a)
4✔
2046
        dtype_b = self._ev._element_type(name_b)
4✔
2047
        dtype = promote_element_types(dtype_a, dtype_b)
4✔
2048

2049
        tmp_name = self._create_array_temp([m_expr, n_expr], dtype)
4✔
2050

2051
        new_call_node = ast.Call(
4✔
2052
            func=node.func, args=[arg0, arg1], keywords=node.keywords
2053
        )
2054

2055
        self.handle_outer(tmp_name, new_call_node)
4✔
2056

2057
        return tmp_name
4✔
2058

2059
    def handle_ufunc_outer(self, node, ufunc_name):
4✔
2060
        """Handle np.add.outer, np.subtract.outer, np.multiply.outer, etc."""
2061
        if len(node.args) != 2:
4✔
2062
            raise NotImplementedError(f"{ufunc_name}.outer requires 2 arguments")
×
2063

2064
        if ufunc_name == "multiply":
4✔
2065
            return self._handle_numpy_outer(node, "outer")
4✔
2066

2067
        op_map = {
4✔
2068
            "add": ("add", TaskletCode.fp_add, TaskletCode.int_add),
2069
            "subtract": ("sub", TaskletCode.fp_sub, TaskletCode.int_sub),
2070
            "divide": ("div", TaskletCode.fp_div, TaskletCode.int_sdiv),
2071
            "minimum": ("min", CMathFunction.fmin, TaskletCode.int_smin),
2072
            "maximum": ("max", CMathFunction.fmax, TaskletCode.int_smax),
2073
        }
2074

2075
        if ufunc_name not in op_map:
4✔
2076
            raise NotImplementedError(f"{ufunc_name}.outer not supported")
×
2077

2078
        op_name, fp_opcode, int_opcode = op_map[ufunc_name]
4✔
2079

2080
        arg0 = node.args[0]
4✔
2081
        arg1 = node.args[1]
4✔
2082

2083
        res_a = self.parse_arg(arg0)
4✔
2084
        res_b = self.parse_arg(arg1)
4✔
2085

2086
        if not res_a[0]:
4✔
2087
            left_name = self.visit(arg0)
×
2088
            arg0 = ast.Name(id=left_name)
×
2089
            res_a = self.parse_arg(arg0)
×
2090

2091
        if not res_b[0]:
4✔
2092
            right_name = self.visit(arg1)
×
2093
            arg1 = ast.Name(id=right_name)
×
2094
            res_b = self.parse_arg(arg1)
×
2095

2096
        name_a, subset_a, shape_a, indices_a = res_a
4✔
2097
        name_b, subset_b, shape_b, indices_b = res_b
4✔
2098

2099
        if not name_a or not name_b:
4✔
2100
            raise NotImplementedError("Could not resolve ufunc outer operands")
×
2101

2102
        def get_flattened_size_expr(shapes):
4✔
2103
            if not shapes:
4✔
2104
                return "1"
×
2105
            size_expr = str(shapes[0])
4✔
2106
            for s in shapes[1:]:
4✔
2107
                size_expr = f"({size_expr} * {str(s)})"
×
2108
            return size_expr
4✔
2109

2110
        m_expr = get_flattened_size_expr(shape_a)
4✔
2111
        n_expr = get_flattened_size_expr(shape_b)
4✔
2112

2113
        dtype_left = self._ev._element_type(name_a)
4✔
2114
        dtype_right = self._ev._element_type(name_b)
4✔
2115
        dtype = promote_element_types(dtype_left, dtype_right)
4✔
2116

2117
        is_int = dtype.primitive_type in [
4✔
2118
            PrimitiveType.Int64,
2119
            PrimitiveType.Int32,
2120
            PrimitiveType.Int8,
2121
            PrimitiveType.Int16,
2122
            PrimitiveType.UInt64,
2123
            PrimitiveType.UInt32,
2124
            PrimitiveType.UInt8,
2125
            PrimitiveType.UInt16,
2126
        ]
2127

2128
        tmp_name = self._create_array_temp([m_expr, n_expr], dtype)
4✔
2129

2130
        i_var = self.builder.find_new_name("_outer_i_")
4✔
2131
        j_var = self.builder.find_new_name("_outer_j_")
4✔
2132

2133
        if not self.builder.exists(i_var):
4✔
2134
            self.builder.add_container(i_var, Scalar(PrimitiveType.Int64), False)
4✔
2135
            self.container_table[i_var] = Scalar(PrimitiveType.Int64)
4✔
2136
        if not self.builder.exists(j_var):
4✔
2137
            self.builder.add_container(j_var, Scalar(PrimitiveType.Int64), False)
4✔
2138
            self.container_table[j_var] = Scalar(PrimitiveType.Int64)
4✔
2139

2140
        def compute_linear_index(name, subset, indices, loop_var):
4✔
2141
            if not indices:
4✔
2142
                return loop_var
4✔
2143

2144
            if name in self.tensor_table:
4✔
2145
                info = self.tensor_table[name]
4✔
2146
                shapes = info.shape
4✔
2147
                ndim = len(shapes)
4✔
2148
            else:
2149
                shapes = []
×
2150
                ndim = 0
×
2151

2152
            if ndim == 0:
4✔
2153
                return loop_var
×
2154

2155
            strides = []
4✔
2156
            current_stride = "1"
4✔
2157
            for i in range(ndim - 1, -1, -1):
4✔
2158
                strides.insert(0, current_stride)
4✔
2159
                if i > 0:
4✔
2160
                    dim_size = shapes[i] if i < len(shapes) else f"_{name}_shape_{i}"
4✔
2161
                    if current_stride == "1":
4✔
2162
                        current_stride = str(dim_size)
4✔
2163
                    else:
2164
                        current_stride = f"({current_stride} * {dim_size})"
×
2165

2166
            terms = []
4✔
2167
            loop_var_used = False
4✔
2168

2169
            for i, idx in enumerate(indices):
4✔
2170
                stride = strides[i] if i < len(strides) else "1"
4✔
2171
                start = subset[i] if i < len(subset) else "0"
4✔
2172

2173
                if isinstance(idx, ast.Slice):
4✔
2174
                    if stride == "1":
4✔
2175
                        term = f"({start} + {loop_var})"
4✔
2176
                    else:
2177
                        term = f"(({start} + {loop_var}) * {stride})"
4✔
2178
                    loop_var_used = True
4✔
2179
                else:
2180
                    if stride == "1":
4✔
2181
                        term = start
4✔
2182
                    else:
2183
                        term = f"({start} * {stride})"
4✔
2184

2185
                terms.append(term)
4✔
2186

2187
            if not terms:
4✔
2188
                return loop_var
×
2189

2190
            result = terms[0]
4✔
2191
            for t in terms[1:]:
4✔
2192
                result = f"({result} + {t})"
4✔
2193

2194
            return result
4✔
2195

2196
        self.builder.begin_for(i_var, "0", m_expr, "1")
4✔
2197
        self.builder.begin_for(j_var, "0", n_expr, "1")
4✔
2198

2199
        block = self.builder.add_block()
4✔
2200

2201
        t_a = self.builder.add_access(block, name_a)
4✔
2202
        t_b = self.builder.add_access(block, name_b)
4✔
2203
        t_c = self.builder.add_access(block, tmp_name)
4✔
2204

2205
        if ufunc_name in ["minimum", "maximum"]:
4✔
2206
            if is_int:
4✔
2207
                t_task = self.builder.add_tasklet(
4✔
2208
                    block, int_opcode, ["_in1", "_in2"], ["_out"]
2209
                )
2210
            else:
2211
                t_task = self.builder.add_cmath(block, fp_opcode, dtype.primitive_type)
4✔
2212
        else:
2213
            tasklet_code = int_opcode if is_int else fp_opcode
4✔
2214
            t_task = self.builder.add_tasklet(
4✔
2215
                block, tasklet_code, ["_in1", "_in2"], ["_out"]
2216
            )
2217

2218
        a_index = compute_linear_index(name_a, subset_a, indices_a, i_var)
4✔
2219
        b_index = compute_linear_index(name_b, subset_b, indices_b, j_var)
4✔
2220

2221
        self.builder.add_memlet(block, t_a, "void", t_task, "_in1", a_index)
4✔
2222
        self.builder.add_memlet(block, t_b, "void", t_task, "_in2", b_index)
4✔
2223

2224
        flat_index = f"(({i_var}) * ({n_expr}) + ({j_var}))"
4✔
2225
        self.builder.add_memlet(block, t_task, "_out", t_c, "void", flat_index)
4✔
2226

2227
        self.builder.end_for()
4✔
2228
        self.builder.end_for()
4✔
2229

2230
        return tmp_name
4✔
2231

2232
    def _handle_numpy_reduce(self, node, func_name):
4✔
2233
        """Handle np.sum, np.max, np.min, np.mean, np.std."""
2234
        args = node.args
4✔
2235
        keywords = {kw.arg: kw.value for kw in node.keywords}
4✔
2236

2237
        array_node = args[0]
4✔
2238
        array_name = self.visit(array_node)
4✔
2239

2240
        if array_name not in self.tensor_table:
4✔
2241
            raise ValueError(f"Reduction input must be an array, got {array_name}")
×
2242

2243
        # For mean and std, we need float64 input and output (NumPy behavior)
2244
        # Cast input to float64 if needed
2245
        if func_name in ("mean", "std"):
4✔
2246
            float64_type = Scalar(PrimitiveType.Double)
4✔
2247
            array_name = self._cast_array(array_name, float64_type)
4✔
2248

2249
        input_tensor = self.tensor_table[array_name]
4✔
2250
        input_shape = input_tensor.shape
4✔
2251
        ndim = len(input_shape)
4✔
2252

2253
        axis = None
4✔
2254
        if len(args) > 1:
4✔
2255
            axis = args[1]
×
2256
        elif "axis" in keywords:
4✔
2257
            axis = keywords["axis"]
4✔
2258

2259
        keepdims = False
4✔
2260
        if "keepdims" in keywords:
4✔
2261
            keepdims_node = keywords["keepdims"]
4✔
2262
            if isinstance(keepdims_node, ast.Constant):
4✔
2263
                keepdims = bool(keepdims_node.value)
4✔
2264

2265
        axes = []
4✔
2266
        if axis is None:
4✔
2267
            axes = list(range(ndim))
4✔
2268
        elif isinstance(axis, ast.Constant):
4✔
2269
            val = axis.value
4✔
2270
            if val < 0:
4✔
2271
                val += ndim
×
2272
            axes = [val]
4✔
2273
        elif isinstance(axis, ast.Tuple):
4✔
2274
            for elt in axis.elts:
×
2275
                if isinstance(elt, ast.Constant):
×
2276
                    val = elt.value
×
2277
                    if val < 0:
×
2278
                        val += ndim
×
2279
                    axes.append(val)
×
2280
        elif (
4✔
2281
            isinstance(axis, ast.UnaryOp)
2282
            and isinstance(axis.op, ast.USub)
2283
            and isinstance(axis.operand, ast.Constant)
2284
        ):
2285
            val = -axis.operand.value
4✔
2286
            if val < 0:
4✔
2287
                val += ndim
4✔
2288
            axes = [val]
4✔
2289
        else:
2290
            try:
×
2291
                val = int(self.visit(axis))
×
2292
                if val < 0:
×
2293
                    val += ndim
×
2294
                axes = [val]
×
2295
            except:
×
2296
                raise NotImplementedError("Dynamic axis not supported")
×
2297

2298
        output_shape = []
4✔
2299
        for i in range(ndim):
4✔
2300
            if i in axes:
4✔
2301
                if keepdims:
4✔
2302
                    output_shape.append("1")
4✔
2303
            else:
2304
                output_shape.append(input_shape[i])
4✔
2305

2306
        dtype = self._ev._element_type(array_name)
4✔
2307

2308
        if not output_shape:
4✔
2309
            tmp_name = f"_tmp_{self._get_unique_id()}"
4✔
2310
            self.builder.add_container(tmp_name, dtype, False)
4✔
2311
            self.container_table[tmp_name] = dtype
4✔
2312
            self.tensor_table[tmp_name] = Tensor(dtype, [])
4✔
2313
        else:
2314
            output_strides = self._compute_strides(output_shape, "C")
4✔
2315
            tmp_name = self._create_array_temp(
4✔
2316
                output_shape, dtype, strides=output_strides
2317
            )
2318

2319
        output_tensor = self.tensor_table[tmp_name]
4✔
2320
        self.builder.add_reduce_op(
4✔
2321
            func_name, array_name, input_tensor, tmp_name, output_tensor, axes, keepdims
2322
        )
2323

2324
        return tmp_name
4✔
2325

2326
    def handle_numpy_astype(self, node, array_name):
4✔
2327
        """Handle numpy array.astype(dtype) method calls."""
2328
        if len(node.args) < 1:
4✔
2329
            raise ValueError("astype requires at least one argument (dtype)")
×
2330

2331
        # Check for copy=False which we don't support (we always copy)
2332
        for kw in node.keywords:
4✔
2333
            if kw.arg == "copy":
4✔
2334
                if isinstance(kw.value, ast.Constant) and kw.value.value is False:
4✔
2335
                    raise NotImplementedError("astype with copy=False is not supported")
4✔
2336

2337
        dtype_arg = node.args[0]
4✔
2338
        target_dtype = element_type_from_ast_node(dtype_arg, self.container_table)
4✔
2339

2340
        if array_name not in self.tensor_table:
4✔
2341
            raise ValueError(f"Array {array_name} not found in tensor_table")
×
2342

2343
        input_tensor = self.tensor_table[array_name]
4✔
2344
        input_shape = input_tensor.shape
4✔
2345
        input_strides = getattr(input_tensor, "strides", None)
4✔
2346

2347
        # Determine output order: preserve F-order if input is F-contiguous
2348
        order = "C"
4✔
2349
        if input_strides and len(input_strides) >= 2 and len(input_shape) >= 2:
4✔
2350
            # F-order: first stride is 1, subsequent strides are products of preceding dims
2351
            f_strides = self._compute_strides(input_shape, "F")
4✔
2352
            if input_strides == f_strides:
4✔
2353
                order = "F"
×
2354

2355
        output_strides = self._compute_strides(input_shape, order)
4✔
2356
        tmp_name = self._create_array_temp(
4✔
2357
            input_shape, target_dtype, strides=output_strides
2358
        )
2359

2360
        output_tensor = self.tensor_table[tmp_name]
4✔
2361
        self.builder.add_cast_op(array_name, input_tensor, tmp_name, output_tensor)
4✔
2362

2363
        return tmp_name
4✔
2364

2365
    def handle_numpy_copy(self, node, array_name):
4✔
2366
        """Handle numpy array.copy() method calls using memcpy."""
2367
        if array_name not in self.tensor_table:
4✔
2368
            raise ValueError(f"Array {array_name} not found in tensor_table")
×
2369

2370
        input_tensor = self.tensor_table[array_name]
4✔
2371
        input_shape = input_tensor.shape
4✔
2372
        input_strides = getattr(input_tensor, "strides", None)
4✔
2373

2374
        element_type = Scalar(PrimitiveType.Double)
4✔
2375
        if array_name in self.container_table:
4✔
2376
            sym_type = self.container_table[array_name]
4✔
2377
            if isinstance(sym_type, Pointer) and sym_type.has_pointee_type():
4✔
2378
                element_type = sym_type.pointee_type
4✔
2379

2380
        # Determine output order: preserve F-order if input is F-contiguous
2381
        order = "C"
4✔
2382
        if input_strides and len(input_strides) >= 2 and len(input_shape) >= 2:
4✔
2383
            f_strides = self._compute_strides(input_shape, "F")
4✔
2384
            if input_strides == f_strides:
4✔
2385
                order = "F"
×
2386

2387
        output_strides = self._compute_strides(input_shape, order)
4✔
2388
        tmp_name = self._create_array_temp(
4✔
2389
            input_shape, element_type, strides=output_strides
2390
        )
2391

2392
        output_tensor = self.tensor_table[tmp_name]
4✔
2393
        # Workaround: "assign-op"
2394
        self.builder.add_cast_op(array_name, input_tensor, tmp_name, output_tensor)
4✔
2395

2396
        return tmp_name
4✔
2397

2398
    def _get_contiguous_output_strides(self, shape, input_strides):
4✔
2399
        """Get contiguous output strides, preserving C or F order if input is contiguous.
2400

2401
        For non-contiguous input strides (e.g., from slices), returns C-order strides.
2402
        This ensures output allocation matches the stride pattern.
2403

2404
        Args:
2405
            shape: Output shape
2406
            input_strides: Strides from input tensor
2407

2408
        Returns:
2409
            List of stride expressions for a contiguous output array
2410
        """
2411
        if not shape or not input_strides:
4✔
2412
            return self._compute_strides(shape, "C")
4✔
2413

2414
        # Preserve order if contiguous, otherwise default to C-order
2415
        c_strides = self._compute_strides(shape, "C")
4✔
2416
        if input_strides == c_strides:
4✔
2417
            return c_strides
4✔
2418
        f_strides = self._compute_strides(shape, "F")
4✔
2419
        if input_strides == f_strides:
4✔
2420
            return f_strides
×
2421
        return c_strides
4✔
2422

2423
    def _compute_strides(self, shape, order="C"):
4✔
2424
        """Compute strides for a given shape and memory order.
2425

2426
        Args:
2427
            shape: List of dimension sizes
2428
            order: "C" for row-major (default), "F" for column-major
2429

2430
        Returns:
2431
            List of stride expressions as strings
2432
        """
2433
        if not shape:
4✔
2434
            return []
4✔
2435

2436
        ndim = len(shape)
4✔
2437
        strides = []
4✔
2438

2439
        if order == "F":
4✔
2440
            # Column-major (Fortran order): stride[i] = product of shape[:i]
2441
            for dim_idx in range(ndim):
4✔
2442
                if dim_idx == 0:
4✔
2443
                    strides.append("1")
4✔
2444
                else:
2445
                    # Wrap each shape in parens to ensure correct precedence
2446
                    prefix_shapes = [f"({s})" for s in shape[:dim_idx]]
4✔
2447
                    if len(prefix_shapes) == 1:
4✔
2448
                        strides.append(prefix_shapes[0])
4✔
2449
                    else:
2450
                        strides.append("(" + " * ".join(prefix_shapes) + ")")
×
2451
        else:
2452
            # Row-major (C order): stride[i] = product of shape[i+1:]
2453
            for dim_idx in range(ndim):
4✔
2454
                if dim_idx == ndim - 1:
4✔
2455
                    strides.append("1")
4✔
2456
                else:
2457
                    # Wrap each shape in parens to ensure correct precedence
2458
                    suffix_shapes = [f"({s})" for s in shape[dim_idx + 1 :]]
4✔
2459
                    if len(suffix_shapes) == 1:
4✔
2460
                        strides.append(suffix_shapes[0])
4✔
2461
                    else:
2462
                        strides.append("(" + " * ".join(suffix_shapes) + ")")
4✔
2463

2464
        return strides
4✔
2465

2466
    def _is_contiguous(self, shape, strides):
4✔
2467
        """Check if strides represent a contiguous (C or F order) layout."""
2468
        if not shape or not strides:
4✔
2469
            return True
×
2470

2471
        def normalize(s):
4✔
2472
            # Normalize stride expression by removing spaces and outer parens
2473
            s = s.replace(" ", "")
4✔
2474
            while s.startswith("(") and s.endswith(")"):
4✔
2475
                # Only strip if balanced parens
2476
                inner = s[1:-1]
4✔
2477
                depth = 0
4✔
2478
                balanced = True
4✔
2479
                for c in inner:
4✔
2480
                    if c == "(":
4✔
2481
                        depth += 1
×
2482
                    elif c == ")":
4✔
2483
                        depth -= 1
×
2484
                        if depth < 0:
×
2485
                            balanced = False
×
2486
                            break
×
2487
                if balanced and depth == 0:
4✔
2488
                    s = inner
4✔
2489
                else:
2490
                    break
×
2491
            return s
4✔
2492

2493
        c_strides = self._compute_strides(shape, "C")
4✔
2494
        if all(
4✔
2495
            normalize(str(a)) == normalize(str(b)) for a, b in zip(strides, c_strides)
2496
        ):
2497
            return True
4✔
2498
        f_strides = self._compute_strides(shape, "F")
×
2499
        return all(
×
2500
            normalize(str(a)) == normalize(str(b)) for a, b in zip(strides, f_strides)
2501
        )
2502

2503
    def _create_array_temp(
4✔
2504
        self,
2505
        shape,
2506
        dtype,
2507
        zero_init=False,
2508
        ones_init=False,
2509
        shapes_runtime=None,
2510
        strides=None,
2511
    ):
2512
        """Create a temporary array."""
2513
        tmp_name = f"_tmp_{self._get_unique_id()}"
4✔
2514

2515
        # Handle 0-dimensional arrays as scalars
2516
        if not shape or (len(shape) == 0):
4✔
2517
            self.builder.add_container(tmp_name, dtype, False)
4✔
2518
            self.container_table[tmp_name] = dtype
4✔
2519
            self.tensor_table[tmp_name] = Tensor(dtype, [])
4✔
2520

2521
            if zero_init:
4✔
2522
                self.builder.add_assignment(
×
2523
                    tmp_name,
2524
                    "0.0" if dtype.primitive_type == PrimitiveType.Double else "0",
2525
                )
2526
            elif ones_init:
4✔
2527
                self.builder.add_assignment(
×
2528
                    tmp_name,
2529
                    "1.0" if dtype.primitive_type == PrimitiveType.Double else "1",
2530
                )
2531

2532
            return tmp_name
4✔
2533

2534
        # Calculate size - wrap each dimension in parentheses to ensure correct
2535
        # parsing when dimensions are expressions like "-2 + _s0"
2536
        size_str = "1"
4✔
2537
        for dim in shape:
4✔
2538
            size_str = f"({size_str} * ({dim}))"
4✔
2539

2540
        element_size = self.builder.get_sizeof(dtype)
4✔
2541
        total_size = f"({size_str} * {element_size})"
4✔
2542

2543
        # Use provided strides or compute C-order strides
2544
        if strides is None:
4✔
2545
            strides = self._compute_strides(shape, "C")
4✔
2546

2547
        # Create pointer
2548
        ptr_type = Pointer(dtype)
4✔
2549
        self.builder.add_container(tmp_name, ptr_type, False)
4✔
2550
        self.container_table[tmp_name] = ptr_type
4✔
2551
        tensor_entry = Tensor(dtype, shape, strides, "0")
4✔
2552
        if shapes_runtime is not None:
4✔
2553
            self.shapes_runtime_info[tmp_name] = shapes_runtime
4✔
2554
        self.tensor_table[tmp_name] = tensor_entry
4✔
2555

2556
        # Try to hoist allocation to function entry
2557
        init_type = (
4✔
2558
            ManagedMemoryHandler.INIT_ZERO
2559
            if zero_init
2560
            else ManagedMemoryHandler.INIT_NONE
2561
        )
2562
        if not ones_init and self.memory_handler.allocate(
4✔
2563
            tmp_name, ptr_type, total_size, init=init_type
2564
        ):
2565
            pass  # Allocation registered for hoisting
4✔
2566
        else:
2567
            # Emit allocation immediately (size depends on loop variables or needs loop init)
2568
            self._emit_malloc(
4✔
2569
                tmp_name, total_size, ptr_type, zero_init, ones_init, size_str, dtype
2570
            )
2571

2572
        return tmp_name
4✔
2573

2574
    def _emit_malloc(
4✔
2575
        self, tmp_name, total_size, ptr_type, zero_init, ones_init, size_str, dtype
2576
    ):
2577
        """Emit malloc and optional initialization for a temporary array."""
2578
        block1 = self.builder.add_block()
4✔
2579
        t_malloc = self.builder.add_malloc(block1, total_size)
4✔
2580
        t_ptr1 = self.builder.add_access(block1, tmp_name)
4✔
2581
        self.builder.add_memlet(block1, t_malloc, "_ret", t_ptr1, "void", "", ptr_type)
4✔
2582

2583
        if zero_init:
4✔
2584
            block2 = self.builder.add_block()
4✔
2585
            t_memset = self.builder.add_memset(block2, "0", total_size)
4✔
2586
            t_ptr2 = self.builder.add_access(block2, tmp_name)
4✔
2587
            self.builder.add_memlet(
4✔
2588
                block2, t_memset, "_ptr", t_ptr2, "void", "", ptr_type
2589
            )
2590
        elif ones_init:
4✔
2591
            loop_var = f"_i_{self._get_unique_id()}"
4✔
2592
            if not self.builder.exists(loop_var):
4✔
2593
                self.builder.add_container(loop_var, Scalar(PrimitiveType.Int64), False)
4✔
2594
                self.container_table[loop_var] = Scalar(PrimitiveType.Int64)
4✔
2595

2596
            self.builder.begin_for(loop_var, "0", size_str, "1")
4✔
2597

2598
            val = "1.0"
4✔
2599
            if dtype.primitive_type in [
4✔
2600
                PrimitiveType.Int64,
2601
                PrimitiveType.Int32,
2602
                PrimitiveType.Int8,
2603
                PrimitiveType.Int16,
2604
                PrimitiveType.UInt64,
2605
                PrimitiveType.UInt32,
2606
                PrimitiveType.UInt8,
2607
                PrimitiveType.UInt16,
2608
            ]:
2609
                val = "1"
4✔
2610

2611
            block_assign = self.builder.add_block()
4✔
2612
            t_const = self.builder.add_constant(block_assign, val, dtype)
4✔
2613
            t_arr = self.builder.add_access(block_assign, tmp_name)
4✔
2614

2615
            t_task = self.builder.add_tasklet(
4✔
2616
                block_assign, TaskletCode.assign, ["_in"], ["_out"]
2617
            )
2618
            self.builder.add_memlet(
4✔
2619
                block_assign, t_const, "void", t_task, "_in", "", dtype
2620
            )
2621
            self.builder.add_memlet(
4✔
2622
                block_assign, t_task, "_out", t_arr, "void", loop_var
2623
            )
2624

2625
            self.builder.end_for()
4✔
2626

2627
    def _compute_linear_index(self, indices, shapes, array_name, ndim):
4✔
2628
        """Compute linear index from multi-dimensional indices.
2629

2630
        Uses strides from tensor_table if available (supporting F-order arrays),
2631
        otherwise falls back to computing strides assuming C-order.
2632
        """
2633
        if ndim == 0:
×
2634
            return "0"
×
2635

2636
        # Try to get strides from tensor_table
2637
        strides = None
×
2638
        if array_name in self.tensor_table:
×
2639
            tensor_info = self.tensor_table[array_name]
×
2640
            if hasattr(tensor_info, "strides") and tensor_info.strides:
×
2641
                strides = tensor_info.strides
×
2642

2643
        if strides and len(strides) == ndim:
×
2644
            # Use explicit strides from tensor_table
2645
            linear_index = ""
×
2646
            for i in range(ndim):
×
2647
                stride = strides[i]
×
2648
                if stride == "1":
×
2649
                    term = str(indices[i])
×
2650
                else:
2651
                    term = f"(({indices[i]}) * ({stride}))"
×
2652

2653
                if i == 0:
×
2654
                    linear_index = term
×
2655
                else:
2656
                    linear_index = f"({linear_index} + {term})"
×
2657
            return linear_index
×
2658
        else:
2659
            # Fall back to C-order (row-major) stride computation
2660
            linear_index = ""
×
2661
            for i in range(ndim):
×
2662
                term = str(indices[i])
×
2663
                for j in range(i + 1, ndim):
×
2664
                    shape_val = (
×
2665
                        shapes[j] if j < len(shapes) else f"_{array_name}_shape_{j}"
2666
                    )
2667
                    term = f"(({term}) * {shape_val})"
×
2668

2669
                if i == 0:
×
2670
                    linear_index = term
×
2671
                else:
2672
                    linear_index = f"({linear_index} + {term})"
×
2673

2674
            return linear_index
×
2675

2676
    def _compute_broadcast_shape(self, shape_a, shape_b):
4✔
2677
        """Compute the broadcast output shape following NumPy broadcasting rules."""
2678
        if not shape_a:
4✔
2679
            return shape_b
4✔
2680
        if not shape_b:
4✔
2681
            return shape_a
4✔
2682

2683
        max_ndim = max(len(shape_a), len(shape_b))
4✔
2684
        padded_a = ["1"] * (max_ndim - len(shape_a)) + [str(s) for s in shape_a]
4✔
2685
        padded_b = ["1"] * (max_ndim - len(shape_b)) + [str(s) for s in shape_b]
4✔
2686

2687
        result = []
4✔
2688
        for a, b in zip(padded_a, padded_b):
4✔
2689
            if a == "1":
4✔
2690
                result.append(b)
4✔
2691
            elif b == "1":
4✔
2692
                result.append(a)
4✔
2693
            elif a == b:
4✔
2694
                result.append(a)
4✔
2695
            else:
2696
                result.append(a)
4✔
2697

2698
        return result
4✔
2699

2700
    def _needs_broadcast(self, input_shape, output_shape):
4✔
2701
        """Check if input shape needs broadcasting to match output shape."""
2702
        if len(input_shape) != len(output_shape):
4✔
2703
            return True
4✔
2704
        for in_dim, out_dim in zip(input_shape, output_shape):
4✔
2705
            if str(in_dim) != str(out_dim):
4✔
2706
                return True
4✔
2707
        return False
4✔
2708

2709
    def _compute_broadcast_strides(self, input_shape, input_strides, output_shape):
4✔
2710
        """Compute strides for broadcasting input to output shape.
2711

2712
        For broadcast dimensions (size 1), stride is set to 0 so the same
2713
        value is repeated. This enables stride-based broadcasting without copying.
2714
        """
2715
        # Pad input shape and strides on the left to match output ndim
2716
        ndim_diff = len(output_shape) - len(input_shape)
4✔
2717
        padded_shape = ["1"] * ndim_diff + [str(s) for s in input_shape]
4✔
2718
        padded_strides = ["0"] * ndim_diff + [str(s) for s in input_strides]
4✔
2719

2720
        broadcast_strides = []
4✔
2721
        for in_dim, in_stride, out_dim in zip(
4✔
2722
            padded_shape, padded_strides, output_shape
2723
        ):
2724
            # Only use stride 0 when input dimension is exactly "1" (broadcast case).
2725
            # For other cases (including symbolic dimensions that may be equal at runtime),
2726
            # keep the original stride.
2727
            if str(in_dim) == "1" and str(out_dim) != "1":
4✔
2728
                # Broadcast dimension: use stride 0
2729
                broadcast_strides.append("0")
4✔
2730
            else:
2731
                # Non-broadcast dimension or potentially equal symbolic dimensions:
2732
                # keep original stride
2733
                broadcast_strides.append(in_stride)
4✔
2734

2735
        return broadcast_strides
4✔
2736

2737
    def _shape_to_runtime_expr(self, shape_node):
4✔
2738
        """Convert a shape expression AST node to a runtime-evaluable string."""
2739
        if isinstance(shape_node, ast.Constant):
4✔
2740
            return str(shape_node.value)
4✔
2741
        elif isinstance(shape_node, ast.Name):
4✔
2742
            return shape_node.id
4✔
2743
        elif isinstance(shape_node, ast.BinOp):
4✔
2744
            left = self._shape_to_runtime_expr(shape_node.left)
4✔
2745
            right = self._shape_to_runtime_expr(shape_node.right)
4✔
2746
            op = self.visit(shape_node.op)
4✔
2747
            return f"({left} {op} {right})"
4✔
2748
        elif isinstance(shape_node, ast.UnaryOp):
4✔
2749
            operand = self._shape_to_runtime_expr(shape_node.operand)
×
2750
            if isinstance(shape_node.op, ast.USub):
×
2751
                return f"(-{operand})"
×
2752
            elif isinstance(shape_node.op, ast.UAdd):
×
2753
                return operand
×
2754
            else:
2755
                return self.visit(shape_node)
×
2756
        elif isinstance(shape_node, ast.Subscript):
4✔
2757
            val = shape_node.value
4✔
2758
            if isinstance(val, ast.Attribute) and val.attr == "shape":
4✔
2759
                if isinstance(val.value, ast.Name):
4✔
2760
                    arr_name = val.value.id
4✔
2761
                    if isinstance(shape_node.slice, ast.Constant):
4✔
2762
                        idx = shape_node.slice.value
4✔
2763
                        if arr_name in self.tensor_table:
4✔
2764
                            shapes = self.tensor_table[arr_name].shape
4✔
2765
                            if idx < len(shapes):
4✔
2766
                                return shapes[idx]
4✔
2767
                        return f"{arr_name}.shape[{idx}]"
×
2768
            return self.visit(shape_node)
×
2769
        elif isinstance(shape_node, ast.Tuple):
×
2770
            return [self._shape_to_runtime_expr(elt) for elt in shape_node.elts]
×
2771
        elif isinstance(shape_node, ast.List):
×
2772
            return [self._shape_to_runtime_expr(elt) for elt in shape_node.elts]
×
2773
        else:
2774
            return self.visit(shape_node)
×
2775

2776
    # ========== Type Casting Helpers ==========
2777

2778
    def _cast_scalar(self, name, target_type):
4✔
2779
        """
2780
        Cast a scalar value to a different type using an assign tasklet.
2781

2782
        The backend detects the specific conversion (fpext, sitofp, etc.)
2783
        from the type mismatch between input and output.
2784

2785
        Args:
2786
            name: Name of the scalar to cast
2787
            target_type: Target element type (Scalar)
2788

2789
        Returns:
2790
            Name of the casted scalar (or original if no cast needed)
2791
        """
2792
        current_type = self._ev._element_type(name)
4✔
2793
        if current_type.primitive_type == target_type.primitive_type:
4✔
2794
            return name
4✔
2795

2796
        cast_name = f"_cast_{self._get_unique_id()}"
4✔
2797
        self.builder.add_container(cast_name, target_type, False)
4✔
2798
        self.container_table[cast_name] = target_type
4✔
2799
        self.tensor_table[cast_name] = Tensor(target_type, [])
4✔
2800

2801
        block = self.builder.add_block()
4✔
2802
        t_src, src_sub = self._add_read(block, name)
4✔
2803
        t_dst = self.builder.add_access(block, cast_name)
4✔
2804
        t_task = self.builder.add_tasklet(block, TaskletCode.assign, ["_in"], ["_out"])
4✔
2805
        self.builder.add_memlet(block, t_src, "void", t_task, "_in", src_sub)
4✔
2806
        self.builder.add_memlet(block, t_task, "_out", t_dst, "void", "")
4✔
2807

2808
        return cast_name
4✔
2809

2810
    def _cast_array(self, name, target_type):
4✔
2811
        """
2812
        Cast an array to a different element type using the CastNode library node.
2813

2814
        This is an elementwise cast operation that creates a new array.
2815
        Reuses the same infrastructure as handle_numpy_astype().
2816

2817
        Args:
2818
            name: Name of the array to cast
2819
            target_type: Target element type (Scalar)
2820

2821
        Returns:
2822
            Name of the casted array (or original if no cast needed)
2823
        """
2824
        current_type = self._ev._element_type(name)
4✔
2825
        if current_type.primitive_type == target_type.primitive_type:
4✔
2826
            return name
4✔
2827

2828
        src_tensor = self.tensor_table[name]
4✔
2829

2830
        # Create output array with same shape but new dtype
2831
        # Preserve strides order (C or F contiguous)
2832
        output_strides = self._get_contiguous_output_strides(
4✔
2833
            src_tensor.shape, src_tensor.strides
2834
        )
2835
        tmp_name = self._create_array_temp(
4✔
2836
            src_tensor.shape, target_type, strides=output_strides
2837
        )
2838
        tmp_tensor = self.tensor_table[tmp_name]
4✔
2839

2840
        # Use existing cast infrastructure (CastNode)
2841
        self.builder.add_cast_op(name, src_tensor, tmp_name, tmp_tensor)
4✔
2842

2843
        return tmp_name
4✔
2844

2845
    def _cast_to_type(self, name, target_type):
4✔
2846
        """
2847
        Cast an operand (scalar or array) to the target type.
2848

2849
        Dispatches to _cast_scalar or _cast_array based on whether
2850
        the operand is in tensor_table (includes 0-d arrays).
2851

2852
        Args:
2853
            name: Name of the operand to cast
2854
            target_type: Target element type (Scalar)
2855

2856
        Returns:
2857
            Name of the casted operand (or original if no cast needed)
2858
        """
2859
        if name in self.tensor_table:
4✔
2860
            # In tensor_table means it's an array (including 0-d arrays)
2861
            return self._cast_array(name, target_type)
4✔
2862
        else:
2863
            # Not in tensor_table means it's a literal or Python scalar
2864
            return self._cast_scalar(name, target_type)
4✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc