• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

daisytuner / docc / 22941657111

11 Mar 2026 04:56AM UTC coverage: 64.596% (-0.03%) from 64.621%
22941657111

Pull #566

github

web-flow
Merge 8761bd9e4 into af8bb4c54
Pull Request #566: Instrument python code

24682 of 38210 relevant lines covered (64.6%)

375.92 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

78.03
/python/docc/python/functions/numpy.py
1
import ast
4✔
2
from docc.sdfg import (
4✔
3
    Scalar,
4
    PrimitiveType,
5
    Pointer,
6
    DebugInfo,
7
    TaskletCode,
8
    CMathFunction,
9
    Tensor,
10
)
11
from docc.python.types import (
4✔
12
    element_type_from_ast_node,
13
    promote_element_types,
14
)
15
from docc.python.ast_utils import get_debug_info
4✔
16
from docc.python.memory import ManagedMemoryHandler
4✔
17

18

19
class NumPyHandler:
4✔
20
    """
21
    Unified handler for NumPy operations including:
22
    - Array creation (empty, zeros, ones, eye, etc.)
23
    - Elementwise operations (add, subtract, multiply, etc.)
24
    - Linear algebra (matmul, dot, outer, gemm)
25
    - Array manipulation (transpose)
26
    - Reductions (sum, max, min, mean, std)
27
    """
28

29
    def __init__(self, expression_visitor):
4✔
30
        self._ev = expression_visitor
4✔
31
        self._unique_counter = 0
4✔
32
        self.function_handlers = {
4✔
33
            "empty": self._handle_numpy_alloc,
34
            "empty_like": self._handle_numpy_empty_like,
35
            "zeros": self._handle_numpy_alloc,
36
            "zeros_like": self._handle_numpy_zeros_like,
37
            "ones": self._handle_numpy_alloc,
38
            "ndarray": self._handle_numpy_alloc,
39
            "eye": self._handle_numpy_eye,
40
            "add": self._handle_numpy_binary_op,
41
            "subtract": self._handle_numpy_binary_op,
42
            "multiply": self._handle_numpy_binary_op,
43
            "divide": self._handle_numpy_binary_op,
44
            "power": self._handle_numpy_binary_op,
45
            "exp": self._handle_numpy_unary_op,
46
            "abs": self._handle_numpy_unary_op,
47
            "absolute": self._handle_numpy_unary_op,
48
            "sqrt": self._handle_numpy_unary_op,
49
            "tanh": self._handle_numpy_unary_op,
50
            "sum": self._handle_numpy_reduce,
51
            "max": self._handle_numpy_reduce,
52
            "min": self._handle_numpy_reduce,
53
            "mean": self._handle_numpy_reduce,
54
            "std": self._handle_numpy_reduce,
55
            "matmul": self._handle_numpy_matmul,
56
            "dot": self._handle_numpy_matmul,
57
            "matvec": self._handle_numpy_matmul,
58
            "outer": self._handle_numpy_outer,
59
            "minimum": self._handle_numpy_binary_op,
60
            "maximum": self._handle_numpy_binary_op,
61
            "where": self._handle_numpy_where,
62
            "clip": self._handle_numpy_clip,
63
            "transpose": self._handle_numpy_transpose,
64
            "flip": self._handle_numpy_flip,
65
            "fliplr": self._handle_numpy_fliplr,
66
            "flipud": self._handle_numpy_flipud,
67
            "reshape": self._handle_numpy_reshape,
68
        }
69

70
    # Expose parent properties for convenience
71
    @property
4✔
72
    def tensor_table(self):
4✔
73
        return self._ev.tensor_table
4✔
74

75
    @property
4✔
76
    def builder(self):
4✔
77
        return self._ev.builder
4✔
78

79
    @property
4✔
80
    def container_table(self):
4✔
81
        return self._ev.container_table
4✔
82

83
    @property
4✔
84
    def globals_dict(self):
4✔
85
        return self._ev.globals_dict
×
86

87
    @property
4✔
88
    def shapes_runtime_info(self):
4✔
89
        return self._ev.shapes_runtime_info
4✔
90

91
    @property
4✔
92
    def memory_handler(self):
4✔
93
        """Access the memory handler owned by the parser."""
94
        return self._ev.memory_handler
4✔
95

96
    def _get_unique_id(self):
4✔
97
        return self._ev._get_unique_id()
4✔
98

99
    def _add_read(self, block, expr_str, debug_info=None):
4✔
100
        return self._ev._add_read(block, expr_str, debug_info)
4✔
101

102
    def _is_int(self, operand):
4✔
103
        return self._ev._is_int(operand)
4✔
104

105
    def visit(self, node):
4✔
106
        return self._ev.visit(node)
4✔
107

108
    # ========== Linear Algebra Helper Methods (from LinearAlgebraHandler) ==========
109

110
    def parse_arg(self, node):
4✔
111
        """Parse an array argument, returning (name, start_indices, slice_shape, indices)."""
112
        if isinstance(node, ast.Name):
4✔
113
            if node.id in self.tensor_table:
4✔
114
                return node.id, [], self.tensor_table[node.id].shape, []
4✔
115
        elif isinstance(node, ast.Subscript):
4✔
116
            if isinstance(node.value, ast.Name) and node.value.id in self.tensor_table:
4✔
117
                name = node.value.id
4✔
118
                indices = []
4✔
119
                if isinstance(node.slice, ast.Tuple):
4✔
120
                    indices = node.slice.elts
4✔
121
                else:
122
                    indices = [node.slice]
4✔
123

124
                start_indices = []
4✔
125
                slice_shape = []
4✔
126

127
                for i, idx in enumerate(indices):
4✔
128
                    if isinstance(idx, ast.Slice):
4✔
129
                        start = "0"
4✔
130
                        if idx.lower:
4✔
131
                            start = self._ev.visit(idx.lower)
4✔
132
                        start_indices.append(start)
4✔
133

134
                        shapes = self.tensor_table[name].shape
4✔
135
                        dim_size = (
4✔
136
                            shapes[i] if i < len(shapes) else f"_{name}_shape_{i}"
137
                        )
138
                        stop = dim_size
4✔
139
                        if idx.upper:
4✔
140
                            stop = self._ev.visit(idx.upper)
4✔
141

142
                        size = f"({stop} - {start})"
4✔
143
                        slice_shape.append(size)
4✔
144
                    else:
145
                        if isinstance(idx, ast.Name) and idx.id in self.tensor_table:
4✔
146
                            # This is an array index (gather operation)
147
                            return None, None, None, None
×
148
                        val = self._ev.visit(idx)
4✔
149
                        start_indices.append(val)
4✔
150

151
                return name, start_indices, slice_shape, indices
4✔
152

153
        return None, None, None, None
4✔
154

155
    def flatten_subset(self, name, start_indices):
4✔
156
        """Convert multi-dimensional start indices to a flattened linear offset."""
157
        if not start_indices:
4✔
158
            return []
4✔
159
        info = self.tensor_table[name]
4✔
160
        shapes = info.shape
4✔
161
        ndim = len(info.shape)
4✔
162

163
        if len(start_indices) != ndim:
4✔
164
            return start_indices
4✔
165

166
        strides = []
4✔
167
        current_stride = "1"
4✔
168
        strides.append(current_stride)
4✔
169
        for i in range(ndim - 1, 0, -1):
4✔
170
            dim_size = shapes[i]
4✔
171
            if current_stride == "1":
4✔
172
                current_stride = str(dim_size)
4✔
173
            else:
174
                current_stride = f"({current_stride} * {dim_size})"
4✔
175
            strides.append(current_stride)
4✔
176
        strides = list(reversed(strides))
4✔
177

178
        offset = "0"
4✔
179
        for i in range(ndim):
4✔
180
            idx = start_indices[i]
4✔
181
            stride = strides[i]
4✔
182
            term = f"({idx} * {stride})" if stride != "1" else idx
4✔
183
            if offset == "0":
4✔
184
                offset = term
4✔
185
            else:
186
                offset = f"({offset} + {term})"
4✔
187

188
        return [offset]
4✔
189

190
    def is_gemm(self, node):
4✔
191
        """Check if a node represents a GEMM operation (matrix multiplication)."""
192
        if isinstance(node, ast.BinOp) and isinstance(node.op, ast.MatMult):
4✔
193
            return True
4✔
194
        if isinstance(node, ast.Call):
4✔
195
            if isinstance(node.func, ast.Attribute) and node.func.attr == "dot":
4✔
196
                return True
×
197
            if isinstance(node.func, ast.Name) and node.func.id == "dot":
4✔
198
                return True
×
199
            if isinstance(node.func, ast.Attribute) and node.func.attr == "matmul":
4✔
200
                return True
×
201
            if isinstance(node.func, ast.Name) and node.func.id == "matmul":
4✔
202
                return True
×
203
        if isinstance(node, ast.BinOp) and isinstance(node.op, ast.Add):
4✔
204
            return self.is_gemm(node.left) or self.is_gemm(node.right)
4✔
205
        if isinstance(node, ast.BinOp) and isinstance(node.op, ast.Mult):
4✔
206
            return self.is_gemm(node.left) or self.is_gemm(node.right)
4✔
207
        return False
4✔
208

209
    def _is_stride_1(self, name, indices):
4✔
210
        """Check if the sliced dimension has stride 1 (contiguous access)."""
211
        if name not in self.tensor_table:
4✔
212
            return True
×
213
        info = self.tensor_table[name]
4✔
214
        ndim = len(info.shape)
4✔
215

216
        if not indices:
4✔
217
            return True
4✔
218

219
        sliced_dim = -1
×
220
        for i, idx in enumerate(indices):
×
221
            if isinstance(idx, ast.Slice):
×
222
                sliced_dim = i
×
223
                break
×
224

225
        if sliced_dim == -1:
×
226
            if len(indices) < ndim:
×
227
                sliced_dim = ndim - 1
×
228
            else:
229
                return True
×
230

231
        return sliced_dim == ndim - 1
×
232

233
    def _is_target(self, node, target_name):
4✔
234
        """Check if node refers to the target."""
235
        if isinstance(target_name, ast.AST):
4✔
236
            return self._ev.visit(node) == self._ev.visit(target_name)
4✔
237

238
        if isinstance(node, ast.Name) and node.id == target_name:
4✔
239
            return True
×
240
        if isinstance(node, ast.Subscript):
4✔
241
            if isinstance(node.value, ast.Name) and node.value.id == target_name:
4✔
242
                return True
4✔
243
        return False
4✔
244

245
    def _is_dot_call(self, node):
4✔
246
        """Check if node is a dot product call."""
247
        if isinstance(node, ast.Call):
4✔
248
            if isinstance(node.func, ast.Attribute) and node.func.attr == "dot":
×
249
                return True
×
250
            if isinstance(node.func, ast.Name) and node.func.id == "dot":
×
251
                return True
×
252
        if isinstance(node, ast.BinOp) and isinstance(node.op, ast.MatMult):
4✔
253
            return True
4✔
254
        return False
4✔
255

256
    def handle_gemm(self, target, value_node):
4✔
257
        """Handle GEMM (General Matrix Multiply) operations: C = alpha * A @ B + beta * C."""
258
        target_name = None
4✔
259
        target_subset = []
4✔
260

261
        if isinstance(target, str):
4✔
262
            target_name = target
4✔
263
        elif isinstance(target, ast.Name):
4✔
264
            target_name = target.id
4✔
265
        elif isinstance(target, ast.Subscript):
4✔
266
            if isinstance(target.value, ast.Name):
4✔
267
                res = self.parse_arg(target)
4✔
268
                if res[0]:
4✔
269
                    target_name = res[0]
4✔
270
                    target_subset = self.flatten_subset(target_name, res[1])
4✔
271
                else:
272
                    target_name = target.value.id
×
273

274
        if not target_name or target_name not in self.tensor_table:
4✔
275
            return False
4✔
276

277
        alpha = "1.0"
4✔
278
        beta = "0.0"
4✔
279
        A = None
4✔
280
        B = None
4✔
281

282
        def extract_factor(node):
4✔
283
            if isinstance(node, ast.BinOp) and isinstance(node.op, ast.Mult):
4✔
284
                if self.is_gemm(node.left):
×
285
                    return node.left, self._ev.visit(node.right)
×
286
                if self.is_gemm(node.right):
×
287
                    return node.right, self._ev.visit(node.left)
×
288

289
                res = self.parse_arg(node.left)
×
290
                if res[0]:
×
291
                    return node.left, self._ev.visit(node.right)
×
292
                res = self.parse_arg(node.right)
×
293
                if res[0]:
×
294
                    return node.right, self._ev.visit(node.left)
×
295
            return node, "1.0"
4✔
296

297
        def parse_term(node):
4✔
298
            if isinstance(node, ast.BinOp) and isinstance(node.op, ast.MatMult):
4✔
299
                l, l_f = extract_factor(node.left)
4✔
300
                r, r_f = extract_factor(node.right)
4✔
301
                f = "1.0"
4✔
302
                if l_f != "1.0":
4✔
303
                    f = l_f
×
304
                if r_f != "1.0":
4✔
305
                    if f == "1.0":
×
306
                        f = r_f
×
307
                    else:
308
                        f = f"({f} * {r_f})"
×
309
                return l, r, f
4✔
310

311
            if isinstance(node, ast.Call):
×
312
                is_gemm_call = False
×
313
                if isinstance(node.func, ast.Attribute) and node.func.attr in [
×
314
                    "dot",
315
                    "matmul",
316
                ]:
317
                    is_gemm_call = True
×
318
                if isinstance(node.func, ast.Name) and node.func.id in [
×
319
                    "dot",
320
                    "matmul",
321
                ]:
322
                    is_gemm_call = True
×
323

324
                if is_gemm_call and len(node.args) == 2:
×
325
                    return node.args[0], node.args[1], "1.0"
×
326

327
            if isinstance(node, ast.BinOp) and isinstance(node.op, ast.Mult):
×
328
                l, r, a = parse_term(node.left)
×
329
                if l:
×
330
                    return l, r, self._ev.visit(node.right)
×
331
                l, r, a = parse_term(node.right)
×
332
                if l:
×
333
                    return l, r, self._ev.visit(node.left)
×
334

335
            return None, None, None
×
336

337
        if isinstance(value_node, ast.BinOp) and isinstance(value_node.op, ast.Add):
4✔
338
            l, r, a = parse_term(value_node.left)
×
339
            if l:
×
340
                A = l
×
341
                B = r
×
342
                alpha = a
×
343
                if isinstance(value_node.right, ast.BinOp) and isinstance(
×
344
                    value_node.right.op, ast.Mult
345
                ):
346
                    if self._is_target(value_node.right.left, target_name):
×
347
                        beta = self._ev.visit(value_node.right.right)
×
348
                    elif self._is_target(value_node.right.right, target_name):
×
349
                        beta = self._ev.visit(value_node.right.left)
×
350
                elif self._is_target(value_node.right, target_name):
×
351
                    beta = "1.0"
×
352
            else:
353
                l, r, a = parse_term(value_node.right)
×
354
                if l:
×
355
                    A = l
×
356
                    B = r
×
357
                    alpha = a
×
358
                    if isinstance(value_node.left, ast.BinOp) and isinstance(
×
359
                        value_node.left.op, ast.Mult
360
                    ):
361
                        if self._is_target(value_node.left.left, target_name):
×
362
                            beta = self._ev.visit(value_node.left.right)
×
363
                        elif self._is_target(value_node.left.right, target_name):
×
364
                            beta = self._ev.visit(value_node.left.left)
×
365
                    elif self._is_target(value_node.left, target_name):
×
366
                        beta = "1.0"
×
367
        else:
368
            l, r, a = parse_term(value_node)
4✔
369
            if l:
4✔
370
                A = l
4✔
371
                B = r
4✔
372
                alpha = a
4✔
373

374
        if A is None or B is None:
4✔
375
            return False
×
376

377
        def get_name_and_trans(node):
4✔
378
            if isinstance(node, ast.Attribute) and node.attr == "T":
4✔
379
                return node.value, True
×
380
            return node, False
4✔
381

382
        A_node, trans_a = get_name_and_trans(A)
4✔
383
        B_node, trans_b = get_name_and_trans(B)
4✔
384

385
        if self.is_gemm(A_node):
4✔
386
            tmp_name = self._ev.visit(A_node)
×
387
            A_node = ast.Name(id=tmp_name)
×
388

389
        if self.is_gemm(B_node):
4✔
390
            tmp_name = self._ev.visit(B_node)
×
391
            B_node = ast.Name(id=tmp_name)
×
392

393
        res_a = self.parse_arg(A_node)
4✔
394
        res_b = self.parse_arg(B_node)
4✔
395

396
        if not res_a[0] or not res_b[0]:
4✔
397
            return False
×
398

399
        A_name, subset_a, shape_a, indices_a = res_a
4✔
400
        B_name, subset_b, shape_b, indices_b = res_b
4✔
401

402
        flat_subset_a = self.flatten_subset(A_name, subset_a)
4✔
403
        flat_subset_b = self.flatten_subset(B_name, subset_b)
4✔
404

405
        def get_ndim(name):
4✔
406
            if name not in self.tensor_table:
4✔
407
                return 1
×
408
            return len(self.tensor_table[name].shape)
4✔
409

410
        if len(shape_a) == 2:
4✔
411
            if not trans_a:
4✔
412
                m = shape_a[0]
4✔
413
                k = shape_a[1]
4✔
414
            else:
415
                m = shape_a[1]
×
416
                k = shape_a[0]
×
417
        else:
418
            m = "1"
×
419
            k = shape_a[0]
×
420
            if self._is_stride_1(A_name, indices_a):
×
421
                if get_ndim(A_name) == 1:
×
422
                    trans_a = True
×
423
                else:
424
                    trans_a = False
×
425
            else:
426
                trans_a = True
×
427

428
        if len(shape_b) == 2:
4✔
429
            if not trans_b:
4✔
430
                n = shape_b[1]
4✔
431
            else:
432
                n = shape_b[0]
×
433
        else:
434
            n = "1"
4✔
435
            if self._is_stride_1(B_name, indices_b):
4✔
436
                if get_ndim(B_name) == 1:
4✔
437
                    trans_b = False
4✔
438
                else:
439
                    trans_b = True
×
440
            else:
441
                trans_b = False
×
442

443
        def get_ld(name):
4✔
444
            if name not in self.tensor_table:
4✔
445
                return ""
×
446
            shapes = self.tensor_table[name].shape
4✔
447
            if len(shapes) >= 2:
4✔
448
                return str(shapes[1])
4✔
449
            return "1"
4✔
450

451
        lda = get_ld(A_name)
4✔
452
        ldb = get_ld(B_name)
4✔
453

454
        ldc = ""
4✔
455
        if target_name:
4✔
456
            if get_ndim(target_name) == 1 and m == "1":
4✔
457
                ldc = n
×
458
            else:
459
                ldc = get_ld(target_name)
4✔
460

461
        self.builder.add_gemm(
4✔
462
            A_name,
463
            B_name,
464
            target_name,
465
            alpha,
466
            beta,
467
            m,
468
            n,
469
            k,
470
            trans_a,
471
            trans_b,
472
            flat_subset_a,
473
            flat_subset_b,
474
            target_subset,
475
            lda,
476
            ldb,
477
            ldc,
478
        )
479
        return True
4✔
480

481
    def handle_dot(self, target, value_node):
4✔
482
        """Handle dot product operations for 1D vectors."""
483
        dot_node = None
4✔
484
        is_accumulate = False
4✔
485

486
        if self._is_dot_call(value_node):
4✔
487
            dot_node = value_node
4✔
488
        elif isinstance(value_node, ast.BinOp) and isinstance(value_node.op, ast.Add):
4✔
489
            if self._is_dot_call(value_node.left):
4✔
490
                dot_node = value_node.left
4✔
491
                if self._is_target(value_node.right, target):
4✔
492
                    is_accumulate = True
×
493
            elif self._is_dot_call(value_node.right):
×
494
                dot_node = value_node.right
×
495
                if self._is_target(value_node.left, target):
×
496
                    is_accumulate = True
×
497

498
        if not dot_node:
4✔
499
            return False
×
500

501
        arg0 = None
4✔
502
        arg1 = None
4✔
503

504
        if isinstance(dot_node, ast.Call):
4✔
505
            args = dot_node.args
×
506
            if len(args) != 2:
×
507
                return False
×
508
            arg0 = args[0]
×
509
            arg1 = args[1]
×
510
        elif isinstance(dot_node, ast.BinOp) and isinstance(dot_node.op, ast.MatMult):
4✔
511
            arg0 = dot_node.left
4✔
512
            arg1 = dot_node.right
4✔
513

514
        res_a = self.parse_arg(arg0)
4✔
515
        res_b = self.parse_arg(arg1)
4✔
516

517
        if not res_a[0] or not res_b[0]:
4✔
518
            return False
×
519

520
        name_a, subset_a, shape_a, indices_a = res_a
4✔
521
        name_b, subset_b, shape_b, indices_b = res_b
4✔
522

523
        if len(shape_a) != 1 or len(shape_b) != 1:
4✔
524
            return False
4✔
525

526
        n = shape_a[0]
4✔
527

528
        def get_stride(name, indices):
4✔
529
            if not indices:
4✔
530
                return "1"
4✔
531
            info = self.tensor_table[name]
4✔
532
            shapes = info.shape
4✔
533
            ndim = len(info.shape)
4✔
534

535
            sliced_dim = -1
4✔
536
            for i, idx in enumerate(indices):
4✔
537
                if isinstance(idx, ast.Slice):
4✔
538
                    sliced_dim = i
4✔
539
                    break
4✔
540

541
            if sliced_dim == -1:
4✔
542
                return "1"
×
543

544
            stride = "1"
4✔
545
            for i in range(sliced_dim + 1, ndim):
4✔
546
                dim_size = shapes[i] if i < len(shapes) else f"_{name}_shape_{i}"
×
547
                if stride == "1":
×
548
                    stride = str(dim_size)
×
549
                else:
550
                    stride = f"({stride} * {dim_size})"
×
551
            return stride
4✔
552

553
        incx = get_stride(name_a, indices_a)
4✔
554
        incy = get_stride(name_b, indices_b)
4✔
555

556
        flat_subset_a = self.flatten_subset(name_a, subset_a)
4✔
557
        flat_subset_b = self.flatten_subset(name_b, subset_b)
4✔
558

559
        tmp_res = f"_dot_res_{self._get_unique_id()}"
4✔
560
        self.builder.add_container(tmp_res, Scalar(PrimitiveType.Double), False)
4✔
561
        block = self.builder.add_block()
4✔
562
        constant = self.builder.add_constant(block, "0.0", Scalar(PrimitiveType.Double))
4✔
563
        tasklet = self.builder.add_tasklet(block, TaskletCode.assign, ["_in"], ["_out"])
4✔
564
        self.builder.add_memlet(
4✔
565
            block, constant, "", tasklet, "_in", "", Scalar(PrimitiveType.Double)
566
        )
567
        access = self.builder.add_access(block, tmp_res)
4✔
568
        self.builder.add_memlet(
4✔
569
            block, tasklet, "_out", access, "", "", Scalar(PrimitiveType.Double)
570
        )
571

572
        self.container_table[tmp_res] = Scalar(PrimitiveType.Double)
4✔
573

574
        self.builder.add_dot(
4✔
575
            name_a, name_b, tmp_res, n, incx, incy, flat_subset_a, flat_subset_b
576
        )
577

578
        target_str = target if isinstance(target, str) else self._ev.visit(target)
4✔
579

580
        if not self.builder.exists(target_str):
4✔
581
            self.builder.add_container(target_str, Scalar(PrimitiveType.Double), False)
×
582
            self.container_table[target_str] = Scalar(PrimitiveType.Double)
×
583

584
        if is_accumulate:
4✔
585
            self.builder.add_assignment(target_str, f"{target_str} + {tmp_res}")
×
586
        else:
587
            self.builder.add_assignment(target_str, tmp_res)
4✔
588

589
        return True
4✔
590

591
    def is_outer(self, node):
4✔
592
        """Check if a node represents an outer product operation."""
593
        if isinstance(node, ast.Call):
4✔
594
            if isinstance(node.func, ast.Attribute) and node.func.attr == "outer":
4✔
595
                return True
4✔
596
            if isinstance(node.func, ast.Name) and node.func.id == "outer":
4✔
597
                return True
×
598
        if isinstance(node, ast.BinOp) and isinstance(node.op, ast.Add):
4✔
599
            return self.is_outer(node.left) or self.is_outer(node.right)
4✔
600
        return False
4✔
601

602
    def handle_outer(self, target, value_node):
4✔
603
        """Handle outer product operations."""
604
        target_name = None
4✔
605
        target_subset = []
4✔
606

607
        if isinstance(target, str):
4✔
608
            target_name = target
4✔
609
        elif isinstance(target, ast.Name):
4✔
610
            target_name = target.id
4✔
611
        elif isinstance(target, ast.Subscript):
4✔
612
            res = self.parse_arg(target)
4✔
613
            if res[0]:
4✔
614
                target_name = res[0]
4✔
615
                target_subset = self.flatten_subset(target_name, res[1])
4✔
616
            else:
617
                if isinstance(target.value, ast.Name):
×
618
                    target_name = target.value.id
×
619

620
        if not target_name:
4✔
621
            return False
×
622

623
        outer_calls = []
4✔
624
        target_found = False
4✔
625
        terms = []
4✔
626

627
        def collect_terms(node):
4✔
628
            if isinstance(node, ast.BinOp) and isinstance(node.op, ast.Add):
4✔
629
                collect_terms(node.left)
4✔
630
                collect_terms(node.right)
4✔
631
            else:
632
                terms.append(node)
4✔
633

634
        collect_terms(value_node)
4✔
635

636
        for term in terms:
4✔
637
            if self._is_target(term, target_name):
4✔
638
                target_found = True
4✔
639
            elif isinstance(term, ast.Call) and (
4✔
640
                (isinstance(term.func, ast.Attribute) and term.func.attr == "outer")
641
                or (isinstance(term.func, ast.Name) and term.func.id == "outer")
642
            ):
643
                if len(term.args) != 2:
4✔
644
                    return False
×
645
                outer_calls.append(term)
4✔
646
            else:
647
                return False
×
648

649
        if not outer_calls:
4✔
650
            return False
×
651

652
        parsed_outers = []
4✔
653
        for outer_node in outer_calls:
4✔
654
            arg0 = outer_node.args[0]
4✔
655
            arg1 = outer_node.args[1]
4✔
656

657
            res_a = self.parse_arg(arg0)
4✔
658
            res_b = self.parse_arg(arg1)
4✔
659

660
            if not res_a[0] or not res_b[0]:
4✔
661
                return False
×
662

663
            parsed_outers.append((res_a, res_b))
4✔
664

665
        alpha = "1.0"
4✔
666
        beta = "1.0" if target_found else "0.0"
4✔
667

668
        def get_flattened_size(name, indices, shapes):
4✔
669
            size_expr = "1"
4✔
670
            for s in shapes:
4✔
671
                if size_expr == "1":
4✔
672
                    size_expr = str(s)
4✔
673
                else:
674
                    size_expr = f"({size_expr} * {str(s)})"
×
675
            return size_expr
4✔
676

677
        def get_ld_2d(name):
4✔
678
            if name in self.tensor_table:
4✔
679
                shapes = self.tensor_table[name].shape
4✔
680
                if len(shapes) >= 2:
4✔
681
                    return str(shapes[1])
4✔
682
            return "1"
4✔
683

684
        ldc = get_ld_2d(target_name)
4✔
685

686
        for res_a, res_b in parsed_outers:
4✔
687
            name_a, subset_a, shape_a, indices_a = res_a
4✔
688
            name_b, subset_b, shape_b, indices_b = res_b
4✔
689

690
            m = get_flattened_size(name_a, indices_a, shape_a)
4✔
691
            n = get_flattened_size(name_b, indices_b, shape_b)
4✔
692
            k = "1"
4✔
693

694
            trans_a = False
4✔
695
            trans_b = True
4✔
696

697
            flat_subset_a = self.flatten_subset(name_a, subset_a)
4✔
698
            flat_subset_b = self.flatten_subset(name_b, subset_b)
4✔
699

700
            lda = "1"
4✔
701
            ldb = "1"
4✔
702

703
            self.builder.add_gemm(
4✔
704
                name_a,
705
                name_b,
706
                target_name,
707
                alpha,
708
                beta,
709
                m,
710
                n,
711
                k,
712
                trans_a,
713
                trans_b,
714
                flat_subset_a,
715
                flat_subset_b,
716
                target_subset,
717
                lda,
718
                ldb,
719
                ldc,
720
            )
721
            beta = "1.0"
4✔
722

723
        return True
4✔
724

725
    # ========== Transpose Operations ==========
726

727
    def _parse_perm(self, node):
4✔
728
        """Parse a permutation list or tuple from an AST node."""
729
        if isinstance(node, (ast.List, ast.Tuple)):
4✔
730
            res = []
4✔
731
            for elt in node.elts:
4✔
732
                val = self._ev.visit(elt)
4✔
733
                res.append(int(val))
4✔
734
            return res
4✔
735
        return []
×
736

737
    def is_transpose(self, node):
4✔
738
        """Check if a node represents a transpose operation."""
739
        # Case 1: np.transpose(arr, ...)
740
        if isinstance(node, ast.Call):
4✔
741
            if isinstance(node.func, ast.Attribute) and node.func.attr == "transpose":
4✔
742
                return True
×
743
            if isinstance(node.func, ast.Name) and node.func.id == "transpose":
4✔
744
                return True
×
745

746
        # Case 2: arr.T
747
        if isinstance(node, ast.Attribute) and node.attr == "T":
4✔
748
            return True
4✔
749

750
        return False
4✔
751

752
    def handle_transpose(self, target, value_node):
4✔
753
        """Handle transpose operations including .T and np.transpose()."""
754
        if not self.is_transpose(value_node):
4✔
755
            return False
×
756

757
        input_node = None
4✔
758
        perm = []
4✔
759

760
        if isinstance(value_node, ast.Attribute) and value_node.attr == "T":
4✔
761
            input_node = value_node.value
4✔
762
            perm = []  # Empty means reverse
4✔
763

764
        elif isinstance(value_node, ast.Call):
×
765
            args = value_node.args
×
766
            keywords = value_node.keywords
×
767

768
            is_numpy_func = False
×
769
            if isinstance(value_node.func, ast.Attribute):
×
770
                caller = ""
×
771
                if isinstance(value_node.func.value, ast.Name):
×
772
                    caller = value_node.func.value.id
×
773
                if caller in ["np", "numpy"]:
×
774
                    is_numpy_func = True
×
775
            elif isinstance(value_node.func, ast.Name):
×
776
                is_numpy_func = True
×
777

778
            if is_numpy_func:
×
779
                if len(args) < 1:
×
780
                    return False
×
781
                input_node = args[0]
×
782
                if len(args) > 1:
×
783
                    perm = self._parse_perm(args[1])
×
784
                for kw in keywords:
×
785
                    if kw.arg == "axes":
×
786
                        perm = self._parse_perm(kw.value)
×
787
            else:
788
                if isinstance(value_node.func, ast.Attribute):
×
789
                    input_node = value_node.func.value
×
790
                else:
791
                    return False
×
792
                if len(args) > 0:
×
793
                    perm = self._parse_perm(args[0])
×
794
                for kw in keywords:
×
795
                    if kw.arg == "axes":
×
796
                        perm = self._parse_perm(kw.value)
×
797

798
        input_name = self._ev.visit(input_node)
4✔
799
        if input_name not in self.tensor_table:
4✔
800
            return False
×
801

802
        in_info = self.tensor_table[input_name]
4✔
803
        in_shape = in_info.shape
4✔
804
        in_strings = [str(s) for s in in_shape]
4✔
805

806
        if not perm:
4✔
807
            perm = list(range(len(in_shape)))[::-1]
4✔
808

809
        out_shape = [in_strings[p] for p in perm]
4✔
810

811
        # Get input strides and check if input is contiguous
812
        in_strides = (
4✔
813
            in_info.strides if hasattr(in_info, "strides") and in_info.strides else None
814
        )
815
        if in_strides is None:
4✔
816
            in_strides = self._compute_strides(in_shape, "C")
×
817

818
        if self._is_contiguous(in_shape, in_strides):
4✔
819
            # For contiguous inputs, output strides are permuted input strides
820
            out_strides = [in_strides[p] for p in perm]
4✔
821
        else:
822
            # For non-contiguous inputs, output is C-order for the new shape
823
            out_strides = self._compute_strides(out_shape, "C")
×
824

825
        target_name = ""
4✔
826
        if isinstance(target, ast.Name):
4✔
827
            target_name = target.id
4✔
828
        elif isinstance(target, str):
×
829
            target_name = target
×
830

831
        dtype = Scalar(PrimitiveType.Double)
4✔
832
        if input_name in self.container_table:
4✔
833
            input_type = self.container_table[input_name]
4✔
834
            if isinstance(input_type, Pointer):
4✔
835
                dtype = input_type.pointee_type
4✔
836
            else:
837
                dtype = input_type
×
838

839
        ptr_type = Pointer(dtype)
4✔
840

841
        # Create target container if it doesn't exist
842
        if not self.builder.exists(target_name):
4✔
843
            self.builder.add_container(target_name, ptr_type, False)
4✔
844
            self.container_table[target_name] = ptr_type
4✔
845
        self.tensor_table[target_name] = Tensor(dtype, out_shape, out_strides)
4✔
846

847
        # Create reference memlet to alias the source array (view, not copy)
848
        block = self.builder.add_block()
4✔
849
        t_src = self.builder.add_access(block, input_name)
4✔
850
        t_dst = self.builder.add_access(block, target_name)
4✔
851
        self.builder.add_reference_memlet(block, t_src, t_dst, "0", ptr_type)
4✔
852

853
        return True
4✔
854

855
    def handle_transpose_expr(self, node):
4✔
856
        """Handle .T attribute access in expressions, returning a temp array name."""
857
        if not isinstance(node, ast.Attribute) or node.attr != "T":
4✔
858
            return None
×
859

860
        input_name = self._ev.visit(node.value)
4✔
861
        if input_name not in self.tensor_table:
4✔
862
            return None
×
863

864
        in_info = self.tensor_table[input_name]
4✔
865
        in_shape = in_info.shape
4✔
866
        perm = list(range(len(in_shape)))[::-1]
4✔
867

868
        return self._create_transpose_view(input_name, perm)
4✔
869

870
    def _handle_numpy_transpose(self, node, func_name):
4✔
871
        """Handle np.transpose(arr, axes=...) function call."""
872
        if len(node.args) < 1:
4✔
873
            raise ValueError("np.transpose requires at least one argument")
×
874

875
        input_node = node.args[0]
4✔
876
        input_name = self.visit(input_node)
4✔
877

878
        if input_name not in self.tensor_table:
4✔
879
            raise ValueError(f"Array {input_name} not found in tensor_table")
×
880

881
        in_info = self.tensor_table[input_name]
4✔
882
        in_shape = in_info.shape
4✔
883

884
        perm = []
4✔
885
        if len(node.args) > 1:
4✔
886
            perm = self._parse_perm(node.args[1])
×
887
        for kw in node.keywords:
4✔
888
            if kw.arg == "axes":
4✔
889
                perm = self._parse_perm(kw.value)
4✔
890

891
        if not perm:
4✔
892
            perm = list(range(len(in_shape)))[::-1]
4✔
893

894
        return self._create_transpose_view(input_name, perm)
4✔
895

896
    def _create_transpose_view(self, input_name, perm):
4✔
897
        in_info = self.tensor_table[input_name]
4✔
898
        in_shape = in_info.shape
4✔
899
        in_strings = [str(s) for s in in_shape]
4✔
900

901
        # Compute output shape by permuting
902
        out_shape = [in_strings[p] for p in perm]
4✔
903

904
        # Get input strides and check if input is contiguous
905
        in_strides = (
4✔
906
            in_info.strides if hasattr(in_info, "strides") and in_info.strides else None
907
        )
908
        if in_strides is None:
4✔
909
            in_strides = self._compute_strides(in_shape, "C")
×
910

911
        # Always permute input strides (works for both contiguous and view inputs)
912
        out_strides = [in_strides[p] for p in perm]
4✔
913

914
        # Inherit offset from input tensor (for chained views like flip->transpose)
915
        in_offset = getattr(in_info, "offset", "0") or "0"
4✔
916

917
        # Create new pointer container
918
        tmp_name = f"_tmp_{self._get_unique_id()}"
4✔
919
        ptr_type = Pointer(in_info.element_type)
4✔
920
        self.builder.add_container(tmp_name, ptr_type, False)
4✔
921
        self.container_table[tmp_name] = ptr_type
4✔
922

923
        # Register tensor with permuted shape, strides, and inherited offset
924
        self.tensor_table[tmp_name] = Tensor(
4✔
925
            in_info.element_type, out_shape, out_strides, in_offset
926
        )
927

928
        # Create reference memlet to alias the source array
929
        block = self.builder.add_block()
4✔
930
        t_src = self.builder.add_access(block, input_name)
4✔
931
        t_dst = self.builder.add_access(block, tmp_name)
4✔
932
        self.builder.add_reference_memlet(block, t_src, t_dst, "0", ptr_type)
4✔
933

934
        return tmp_name
4✔
935

936
    def _handle_numpy_flip(self, node, func_name):
4✔
937
        """Handle np.flip(arr, axis=None) - flip array along specified axis.
938

939
        Uses negative strides and offset to create a view without copying.
940
        """
941
        if len(node.args) < 1:
4✔
942
            raise ValueError("np.flip requires at least one argument")
×
943

944
        input_name = self.visit(node.args[0])
4✔
945
        if input_name not in self.tensor_table:
4✔
946
            raise ValueError(f"Array {input_name} not found in tensor_table")
×
947

948
        in_info = self.tensor_table[input_name]
4✔
949
        in_shape = in_info.shape
4✔
950
        ndim = len(in_shape)
4✔
951

952
        # Parse axis argument
953
        axis = None
4✔
954
        if len(node.args) > 1:
4✔
955
            axis_node = node.args[1]
×
956
            if isinstance(axis_node, ast.Constant):
×
957
                axis = axis_node.value
×
958
            elif isinstance(axis_node, ast.UnaryOp) and isinstance(
×
959
                axis_node.op, ast.USub
960
            ):
961
                if isinstance(axis_node.operand, ast.Constant):
×
962
                    axis = -axis_node.operand.value
×
963
        for kw in node.keywords:
4✔
964
            if kw.arg == "axis":
4✔
965
                if isinstance(kw.value, ast.Constant):
4✔
966
                    axis = kw.value.value
4✔
967
                elif isinstance(kw.value, ast.UnaryOp) and isinstance(
4✔
968
                    kw.value.op, ast.USub
969
                ):
970
                    if isinstance(kw.value.operand, ast.Constant):
4✔
971
                        axis = -kw.value.operand.value
4✔
972

973
        # Determine which axes to flip
974
        if axis is None:
4✔
975
            # Flip all axes
976
            axes_to_flip = list(range(ndim))
4✔
977
        else:
978
            if axis < 0:
4✔
979
                axis = ndim + axis
4✔
980
            axes_to_flip = [axis]
4✔
981

982
        return self._create_flip_view(input_name, axes_to_flip)
4✔
983

984
    def _handle_numpy_fliplr(self, node, func_name):
4✔
985
        """Handle np.fliplr(arr) - flip array left-right (axis=1)."""
986
        if len(node.args) < 1:
4✔
987
            raise ValueError("np.fliplr requires one argument")
×
988

989
        input_name = self.visit(node.args[0])
4✔
990
        if input_name not in self.tensor_table:
4✔
991
            raise ValueError(f"Array {input_name} not found in tensor_table")
×
992

993
        in_info = self.tensor_table[input_name]
4✔
994
        if len(in_info.shape) < 2:
4✔
995
            raise ValueError("np.fliplr requires array with ndim >= 2")
×
996

997
        return self._create_flip_view(input_name, [1])
4✔
998

999
    def _handle_numpy_flipud(self, node, func_name):
4✔
1000
        """Handle np.flipud(arr) - flip array up-down (axis=0)."""
1001
        if len(node.args) < 1:
4✔
1002
            raise ValueError("np.flipud requires one argument")
×
1003

1004
        input_name = self.visit(node.args[0])
4✔
1005
        if input_name not in self.tensor_table:
4✔
1006
            raise ValueError(f"Array {input_name} not found in tensor_table")
×
1007

1008
        return self._create_flip_view(input_name, [0])
4✔
1009

1010
    def _create_flip_view(self, input_name, axes_to_flip):
4✔
1011
        """Create a flipped view of an array using Tensor.flip().
1012

1013
        Uses the Tensor type's flip() method which computes the correct
1014
        negative strides and offset adjustment.
1015
        """
1016
        in_tensor = self.tensor_table[input_name]
4✔
1017

1018
        # Apply flip for each axis
1019
        flipped_tensor = in_tensor
4✔
1020
        for axis in axes_to_flip:
4✔
1021
            flipped_tensor = flipped_tensor.flip(axis)
4✔
1022

1023
        # Create new pointer container pointing to same data
1024
        tmp_name = f"_tmp_{self._get_unique_id()}"
4✔
1025
        ptr_type = Pointer(in_tensor.element_type)
4✔
1026
        self.builder.add_container(tmp_name, ptr_type, False)
4✔
1027
        self.container_table[tmp_name] = ptr_type
4✔
1028

1029
        # Store the flipped tensor with its offset in tensor_table
1030
        self.tensor_table[tmp_name] = flipped_tensor
4✔
1031

1032
        # Create reference memlet (offset is handled by tensor's offset property)
1033
        block = self.builder.add_block()
4✔
1034
        t_src = self.builder.add_access(block, input_name)
4✔
1035
        t_dst = self.builder.add_access(block, tmp_name)
4✔
1036
        self.builder.add_reference_memlet(block, t_src, t_dst, "0", ptr_type)
4✔
1037

1038
        return tmp_name
4✔
1039

1040
    def _handle_numpy_reshape(self, node, func_name):
4✔
1041
        """Handle np.reshape(arr, newshape) - reshape array without copying.
1042

1043
        Only works for contiguous arrays; creates a view with new shape/strides.
1044
        """
1045
        if len(node.args) < 2:
4✔
1046
            raise ValueError("np.reshape requires array and new shape")
×
1047

1048
        input_name = self.visit(node.args[0])
4✔
1049
        if input_name not in self.tensor_table:
4✔
1050
            raise ValueError(f"Array {input_name} not found in tensor_table")
×
1051

1052
        in_info = self.tensor_table[input_name]
4✔
1053
        in_shape = in_info.shape
4✔
1054

1055
        # Parse new shape
1056
        new_shape = self._parse_shape(node.args[1])
4✔
1057

1058
        # Get input strides
1059
        in_strides = (
4✔
1060
            in_info.strides if hasattr(in_info, "strides") and in_info.strides else None
1061
        )
1062
        if in_strides is None:
4✔
1063
            in_strides = self._compute_strides(in_shape, "C")
×
1064

1065
        # Check if input is contiguous (C or F order)
1066
        c_contig = self._is_contiguous(in_shape, in_strides)
4✔
1067
        f_contig = self._is_contiguous_f(in_shape, in_strides)
4✔
1068

1069
        if c_contig:
4✔
1070
            out_strides = self._compute_strides(new_shape, "C")
4✔
1071
        elif f_contig:
×
1072
            out_strides = self._compute_strides(new_shape, "F")
×
1073
        else:
1074
            # Non-contiguous array cannot be reshaped without copy
1075
            raise NotImplementedError(
×
1076
                "np.reshape on non-contiguous array not supported (would require copy)"
1077
            )
1078

1079
        # Create new pointer container
1080
        tmp_name = f"_tmp_{self._get_unique_id()}"
4✔
1081
        ptr_type = Pointer(in_info.element_type)
4✔
1082
        self.builder.add_container(tmp_name, ptr_type, False)
4✔
1083
        self.container_table[tmp_name] = ptr_type
4✔
1084

1085
        # Register tensor with new shape and computed strides
1086
        self.tensor_table[tmp_name] = Tensor(
4✔
1087
            in_info.element_type, new_shape, out_strides
1088
        )
1089

1090
        # Create reference memlet to alias the source array (view, no copy)
1091
        block = self.builder.add_block()
4✔
1092
        t_src = self.builder.add_access(block, input_name)
4✔
1093
        t_dst = self.builder.add_access(block, tmp_name)
4✔
1094
        self.builder.add_reference_memlet(block, t_src, t_dst, "0", ptr_type)
4✔
1095

1096
        return tmp_name
4✔
1097

1098
    def _parse_shape(self, shape_node):
4✔
1099
        """Parse a shape argument (tuple, list, or single int)."""
1100
        if isinstance(shape_node, ast.Tuple) or isinstance(shape_node, ast.List):
4✔
1101
            result = []
4✔
1102
            for elt in shape_node.elts:
4✔
1103
                if isinstance(elt, ast.Constant):
4✔
1104
                    result.append(str(elt.value))
4✔
1105
                elif isinstance(elt, ast.Name):
×
1106
                    result.append(elt.id)
×
1107
                elif isinstance(elt, ast.UnaryOp) and isinstance(elt.op, ast.USub):
×
1108
                    if isinstance(elt.operand, ast.Constant):
×
1109
                        result.append(str(-elt.operand.value))
×
1110
                else:
1111
                    result.append(self._shape_to_runtime_expr(elt))
×
1112
            return result
4✔
1113
        elif isinstance(shape_node, ast.Constant):
×
1114
            return [str(shape_node.value)]
×
1115
        elif isinstance(shape_node, ast.Name):
×
1116
            # Could be a variable holding a shape tuple - not supported yet
1117
            raise NotImplementedError("Shape variable not supported, use literal tuple")
×
1118
        else:
1119
            raise ValueError(f"Cannot parse shape: {ast.dump(shape_node)}")
×
1120

1121
    def _is_contiguous_f(self, shape, strides):
4✔
1122
        """Check if array is F-order contiguous."""
1123
        if not shape or not strides:
4✔
1124
            return True
×
1125
        f_strides = self._compute_strides(shape, "F")
4✔
1126
        return [str(s) for s in strides] == [str(s) for s in f_strides]
4✔
1127

1128
    def handle_numpy_call(self, node, func_name):
4✔
1129
        if func_name in self.function_handlers:
4✔
1130
            return self.function_handlers[func_name](node, func_name)
4✔
1131
        raise NotImplementedError(f"NumPy function {func_name} not supported")
×
1132

1133
    def has_handler(self, func_name):
4✔
1134
        return func_name in self.function_handlers
4✔
1135

1136
    def handle_array_unary_op(self, op_type, operand):
4✔
1137
        dtype = self._ev._element_type(operand)
4✔
1138
        if operand in self.tensor_table:
4✔
1139
            tensor = self.tensor_table[operand]
4✔
1140
        else:
1141
            tensor = Tensor(dtype, [])
4✔
1142

1143
        if len(tensor.shape) == 0:
4✔
1144
            tmp_name = self._create_array_temp([], dtype)
4✔
1145

1146
            func_map = {
4✔
1147
                "sqrt": CMathFunction.sqrt,
1148
                "abs": CMathFunction.fabs,
1149
                "absolute": CMathFunction.fabs,
1150
                "exp": CMathFunction.exp,
1151
                "tanh": CMathFunction.tanh,
1152
            }
1153

1154
            block = self.builder.add_block()
4✔
1155
            t_src = self.builder.add_access(block, operand)
4✔
1156
            t_dst = self.builder.add_access(block, tmp_name)
4✔
1157
            t_task = self.builder.add_cmath(
4✔
1158
                block, func_map[op_type], dtype.primitive_type
1159
            )
1160

1161
            self.builder.add_memlet(block, t_src, "void", t_task, "_in1", "", dtype)
4✔
1162
            self.builder.add_memlet(block, t_task, "_out", t_dst, "void", "", dtype)
4✔
1163

1164
            return tmp_name
4✔
1165

1166
        output_strides = self._get_contiguous_output_strides(
4✔
1167
            tensor.shape, tensor.strides
1168
        )
1169
        tmp_name = self._create_array_temp(tensor.shape, dtype, strides=output_strides)
4✔
1170
        tmp_tensor = self.tensor_table[tmp_name]
4✔
1171
        self.builder.add_elementwise_unary_op(
4✔
1172
            op_type, operand, tensor, tmp_name, tmp_tensor
1173
        )
1174

1175
        return tmp_name
4✔
1176

1177
    def handle_array_binary_op(self, op_type, left, right):
4✔
1178
        dtype_left = self._ev._element_type(left)
4✔
1179
        dtype_right = self._ev._element_type(right)
4✔
1180
        dtype = promote_element_types(dtype_left, dtype_right)
4✔
1181

1182
        if left in self.tensor_table:
4✔
1183
            left_tensor = self.tensor_table[left]
4✔
1184
        else:
1185
            left_tensor = Tensor(dtype, [])
4✔
1186

1187
        if right in self.tensor_table:
4✔
1188
            right_tensor = self.tensor_table[right]
4✔
1189
        else:
1190
            right_tensor = Tensor(dtype, [])
4✔
1191

1192
        left_shape = left_tensor.shape
4✔
1193
        right_shape = right_tensor.shape
4✔
1194

1195
        # Compute broadcast output shape
1196
        output_shape = self._compute_broadcast_shape(left_shape, right_shape)
4✔
1197

1198
        # Check if broadcasting is needed
1199
        left_needs_broadcast = (
4✔
1200
            self._needs_broadcast(left_shape, output_shape) if left_shape else False
1201
        )
1202
        right_needs_broadcast = (
4✔
1203
            self._needs_broadcast(right_shape, output_shape) if right_shape else False
1204
        )
1205

1206
        real_left = left
4✔
1207
        real_right = right
4✔
1208
        real_left_tensor = left_tensor
4✔
1209
        real_right_tensor = right_tensor
4✔
1210

1211
        # Broadcast left operand if needed (stride-based, no copy)
1212
        if left_needs_broadcast:
4✔
1213
            left_strides = left_tensor.strides if left_tensor.strides else []
×
1214
            broadcast_strides = self._compute_broadcast_strides(
×
1215
                left_shape, left_strides, output_shape
1216
            )
1217
            # Create a new tensor view with broadcast shape and strides
1218
            # Preserve the offset from the original tensor (important for views like flip)
1219
            left_offset = left_tensor.offset if left_tensor.offset else "0"
×
1220
            real_left_tensor = Tensor(
×
1221
                dtype, output_shape, broadcast_strides, left_offset
1222
            )
1223

1224
        # Broadcast right operand if needed (stride-based, no copy)
1225
        if right_needs_broadcast:
4✔
1226
            right_strides = right_tensor.strides if right_tensor.strides else []
4✔
1227
            broadcast_strides = self._compute_broadcast_strides(
4✔
1228
                right_shape, right_strides, output_shape
1229
            )
1230
            # Create a new tensor view with broadcast shape and strides
1231
            # Preserve the offset from the original tensor (important for views like flip)
1232
            right_offset = right_tensor.offset if right_tensor.offset else "0"
4✔
1233
            real_right_tensor = Tensor(
4✔
1234
                dtype, output_shape, broadcast_strides, right_offset
1235
            )
1236

1237
        # Create output array with broadcast shape
1238
        # Preserve F-order if both inputs are F-order and no broadcasting needed
1239
        if not left_needs_broadcast and not right_needs_broadcast:
4✔
1240
            # Use left tensor strides to determine output order
1241
            output_strides = self._get_contiguous_output_strides(
4✔
1242
                output_shape, left_tensor.strides
1243
            )
1244
        else:
1245
            output_strides = self._compute_strides(output_shape, "C")
4✔
1246
        tmp_name = self._create_array_temp(output_shape, dtype, strides=output_strides)
4✔
1247
        tmp_tensor = self.tensor_table[tmp_name]
4✔
1248

1249
        self.builder.add_elementwise_op(
4✔
1250
            op_type,
1251
            real_left,
1252
            real_left_tensor,
1253
            real_right,
1254
            real_right_tensor,
1255
            tmp_name,
1256
            tmp_tensor,
1257
        )
1258

1259
        return tmp_name
4✔
1260

1261
    def handle_array_negate(self, operand):
4✔
1262
        operand_tensor = self.tensor_table[operand]
4✔
1263
        dtype = self._ev._element_type(operand)
4✔
1264

1265
        output_strides = self._get_contiguous_output_strides(
4✔
1266
            operand_tensor.shape, operand_tensor.strides
1267
        )
1268
        tmp_name = self._create_array_temp(
4✔
1269
            operand_tensor.shape, dtype, strides=output_strides
1270
        )
1271
        tmp_tensor = self.tensor_table[tmp_name]
4✔
1272

1273
        zero_name = f"_tmp_{self._get_unique_id()}"
4✔
1274
        self.builder.add_container(zero_name, dtype, False)
4✔
1275
        self.container_table[zero_name] = dtype
4✔
1276
        self.tensor_table[zero_name] = Tensor(dtype, [])
4✔
1277

1278
        zero_block = self.builder.add_block()
4✔
1279
        t_const = self.builder.add_constant(
4✔
1280
            zero_block,
1281
            "0.0" if dtype.primitive_type == PrimitiveType.Double else "0",
1282
            dtype,
1283
        )
1284
        t_zero = self.builder.add_access(zero_block, zero_name)
4✔
1285
        t_assign = self.builder.add_tasklet(
4✔
1286
            zero_block, TaskletCode.assign, ["_in"], ["_out"]
1287
        )
1288
        self.builder.add_memlet(zero_block, t_const, "void", t_assign, "_in", "")
4✔
1289
        self.builder.add_memlet(zero_block, t_assign, "_out", t_zero, "void", "")
4✔
1290

1291
        zero_tensor = self.tensor_table[zero_name]
4✔
1292
        self.builder.add_elementwise_op(
4✔
1293
            "sub", zero_name, zero_tensor, operand, operand_tensor, tmp_name, tmp_tensor
1294
        )
1295

1296
        return tmp_name
4✔
1297

1298
    def handle_array_compare(self, left, op, right, left_is_array, right_is_array):
4✔
1299
        """Handle elementwise comparison of arrays, returning a boolean array."""
1300
        if left_is_array:
4✔
1301
            shape = self.tensor_table[left].shape
4✔
1302
            arr_name = left
4✔
1303
        else:
1304
            shape = self.tensor_table[right].shape
×
1305
            arr_name = right
×
1306

1307
        use_int_cmp = False
4✔
1308
        arr_dtype = self._ev._element_type(arr_name)
4✔
1309
        if arr_dtype.primitive_type in (PrimitiveType.Int32, PrimitiveType.Int64):
4✔
1310
            use_int_cmp = True
×
1311

1312
        dtype = Scalar(PrimitiveType.Bool)
4✔
1313
        tmp_name = self._create_array_temp(shape, dtype)
4✔
1314

1315
        if use_int_cmp:
4✔
1316
            cmp_ops = {
×
1317
                ">": TaskletCode.int_sgt,
1318
                ">=": TaskletCode.int_sge,
1319
                "<": TaskletCode.int_slt,
1320
                "<=": TaskletCode.int_sle,
1321
                "==": TaskletCode.int_eq,
1322
                "!=": TaskletCode.int_ne,
1323
            }
1324
        else:
1325
            cmp_ops = {
4✔
1326
                ">": TaskletCode.fp_ogt,
1327
                ">=": TaskletCode.fp_oge,
1328
                "<": TaskletCode.fp_olt,
1329
                "<=": TaskletCode.fp_ole,
1330
                "==": TaskletCode.fp_oeq,
1331
                "!=": TaskletCode.fp_one,
1332
            }
1333

1334
        if op not in cmp_ops:
4✔
1335
            raise NotImplementedError(
×
1336
                f"Comparison operator {op} not supported for arrays"
1337
            )
1338

1339
        tasklet_code = cmp_ops[op]
4✔
1340

1341
        scalar_name = None
4✔
1342
        if not left_is_array:
4✔
1343
            scalar_name = left
×
1344
        elif not right_is_array:
4✔
1345
            scalar_name = right
4✔
1346

1347
        if scalar_name is not None and not use_int_cmp:
4✔
1348
            if self._is_int(scalar_name):
4✔
1349
                float_name = f"_tmp_{self._get_unique_id()}"
4✔
1350
                self.builder.add_container(
4✔
1351
                    float_name, Scalar(PrimitiveType.Double), False
1352
                )
1353
                self.container_table[float_name] = Scalar(PrimitiveType.Double)
4✔
1354

1355
                block_conv = self.builder.add_block()
4✔
1356
                t_const = self.builder.add_constant(
4✔
1357
                    block_conv, f"{scalar_name}.0", Scalar(PrimitiveType.Double)
1358
                )
1359
                t_float = self.builder.add_access(block_conv, float_name)
4✔
1360
                t_assign = self.builder.add_tasklet(
4✔
1361
                    block_conv, TaskletCode.assign, ["_in"], ["_out"]
1362
                )
1363
                self.builder.add_memlet(
4✔
1364
                    block_conv, t_const, "void", t_assign, "_in", ""
1365
                )
1366
                self.builder.add_memlet(
4✔
1367
                    block_conv, t_assign, "_out", t_float, "void", ""
1368
                )
1369

1370
                if not left_is_array:
4✔
1371
                    left = float_name
×
1372
                else:
1373
                    right = float_name
4✔
1374

1375
        # Get tensor info for array operands
1376
        left_tensor = self.tensor_table.get(left) if left_is_array else None
4✔
1377
        right_tensor = self.tensor_table.get(right) if right_is_array else None
4✔
1378
        tmp_tensor = self.tensor_table[tmp_name]
4✔
1379

1380
        loop_vars = []
4✔
1381
        for i, dim in enumerate(shape):
4✔
1382
            loop_var = f"_cmp_i{i}_{self._get_unique_id()}"
4✔
1383
            if not self.builder.exists(loop_var):
4✔
1384
                self.builder.add_container(loop_var, Scalar(PrimitiveType.Int64), False)
4✔
1385
                self.container_table[loop_var] = Scalar(PrimitiveType.Int64)
4✔
1386
            loop_vars.append(loop_var)
4✔
1387
            self.builder.begin_for(loop_var, "0", str(dim), "1")
4✔
1388

1389
        # Multi-dimensional subset - TensorToPointerConversion handles strides/offset
1390
        multi_dim_subset = ",".join(loop_vars)
4✔
1391

1392
        block = self.builder.add_block()
4✔
1393

1394
        if left_is_array:
4✔
1395
            t_left = self.builder.add_access(block, left)
4✔
1396
            left_sub = multi_dim_subset
4✔
1397
        else:
1398
            t_left, left_sub = self._add_read(block, left)
×
1399

1400
        if right_is_array:
4✔
1401
            t_right = self.builder.add_access(block, right)
×
1402
            right_sub = multi_dim_subset
×
1403
        else:
1404
            t_right, right_sub = self._add_read(block, right)
4✔
1405

1406
        t_out = self.builder.add_access(block, tmp_name)
4✔
1407

1408
        t_task = self.builder.add_tasklet(
4✔
1409
            block, tasklet_code, ["_in1", "_in2"], ["_out"]
1410
        )
1411

1412
        # Pass tensor type so TensorToPointerConversion uses correct strides/offset
1413
        if left_is_array and left_tensor:
4✔
1414
            self.builder.add_memlet(
4✔
1415
                block, t_left, "void", t_task, "_in1", left_sub, left_tensor
1416
            )
1417
        else:
1418
            self.builder.add_memlet(block, t_left, "void", t_task, "_in1", left_sub)
×
1419

1420
        if right_is_array and right_tensor:
4✔
1421
            self.builder.add_memlet(
×
1422
                block, t_right, "void", t_task, "_in2", right_sub, right_tensor
1423
            )
1424
        else:
1425
            self.builder.add_memlet(block, t_right, "void", t_task, "_in2", right_sub)
4✔
1426

1427
        self.builder.add_memlet(
4✔
1428
            block, t_task, "_out", t_out, "void", multi_dim_subset, tmp_tensor
1429
        )
1430

1431
        for _ in loop_vars:
4✔
1432
            self.builder.end_for()
4✔
1433

1434
        return tmp_name
4✔
1435

1436
    # ========== NumPy Function Handlers ==========
1437

1438
    def _handle_numpy_alloc(self, node, func_name):
4✔
1439
        """Handle np.empty, np.zeros, np.ones, np.ndarray."""
1440
        shape_arg = node.args[0]
4✔
1441
        dims = []
4✔
1442
        dims_runtime = []
4✔
1443
        if isinstance(shape_arg, ast.Tuple):
4✔
1444
            dims = [self.visit(elt) for elt in shape_arg.elts]
4✔
1445
            dims_runtime = [self._shape_to_runtime_expr(elt) for elt in shape_arg.elts]
4✔
1446
        elif isinstance(shape_arg, ast.List):
4✔
1447
            dims = [self.visit(elt) for elt in shape_arg.elts]
×
1448
            dims_runtime = [self._shape_to_runtime_expr(elt) for elt in shape_arg.elts]
×
1449
        else:
1450
            val = self.visit(shape_arg)
4✔
1451
            runtime_val = self._shape_to_runtime_expr(shape_arg)
4✔
1452
            if val.startswith("_shape_proxy_"):
4✔
1453
                array_name = val[len("_shape_proxy_") :]
×
1454
                if array_name in self.tensor_table:
×
1455
                    info = self.tensor_table[array_name]
×
1456
                    dims = info.shape
×
1457
                    dims_runtime = self.shapes_runtime_info.get(array_name, dims)
×
1458
                else:
1459
                    dims = [val]
×
1460
                    dims_runtime = [runtime_val]
×
1461
            else:
1462
                dims = [val]
4✔
1463
                dims_runtime = [runtime_val]
4✔
1464

1465
        dtype_arg = None
4✔
1466
        order = "C"  # Default to C-order (row-major)
4✔
1467
        explicit_strides = None
4✔
1468
        if len(node.args) > 1:
4✔
1469
            dtype_arg = node.args[1]
×
1470

1471
        for kw in node.keywords:
4✔
1472
            if kw.arg == "dtype":
4✔
1473
                dtype_arg = kw.value
4✔
1474
            elif kw.arg == "order":
4✔
1475
                if isinstance(kw.value, ast.Constant):
4✔
1476
                    order = kw.value.value
4✔
1477
            elif kw.arg == "strides":
4✔
1478
                # Parse explicit strides tuple/list
1479
                if isinstance(kw.value, (ast.Tuple, ast.List)):
4✔
1480
                    explicit_strides = [
4✔
1481
                        self._shape_to_runtime_expr(elt) for elt in kw.value.elts
1482
                    ]
1483

1484
        element_type = element_type_from_ast_node(dtype_arg, self.container_table)
4✔
1485

1486
        # Use explicit strides if provided, otherwise compute from order
1487
        if explicit_strides is not None:
4✔
1488
            # Convert byte strides to element strides by dividing by element size
1489
            element_size = self.builder.get_sizeof(element_type)
4✔
1490
            strides = [f"(({s}) / {element_size})" for s in explicit_strides]
4✔
1491
        else:
1492
            strides = self._compute_strides(dims, order)
4✔
1493

1494
        return self._create_array_temp(
4✔
1495
            dims,
1496
            element_type,
1497
            zero_init=(func_name == "zeros"),
1498
            ones_init=(func_name == "ones"),
1499
            shapes_runtime=dims_runtime,
1500
            strides=strides,
1501
        )
1502

1503
    def _handle_numpy_empty_like(self, node, func_name):
4✔
1504
        """Handle np.empty_like."""
1505
        prototype_arg = node.args[0]
4✔
1506
        prototype_name = self.visit(prototype_arg)
4✔
1507

1508
        dims = []
4✔
1509
        if prototype_name in self.tensor_table:
4✔
1510
            dims = self.tensor_table[prototype_name].shape
4✔
1511

1512
        dtype_arg = None
4✔
1513
        order = "C"  # Default to C-order
4✔
1514
        if len(node.args) > 1:
4✔
1515
            dtype_arg = node.args[1]
×
1516

1517
        for kw in node.keywords:
4✔
1518
            if kw.arg == "dtype":
4✔
1519
                dtype_arg = kw.value
4✔
1520
            elif kw.arg == "order":
4✔
1521
                if isinstance(kw.value, ast.Constant):
4✔
1522
                    order = kw.value.value
4✔
1523

1524
        element_type = None
4✔
1525
        if dtype_arg:
4✔
1526
            element_type = element_type_from_ast_node(dtype_arg, self.container_table)
4✔
1527
        else:
1528
            if prototype_name in self.container_table:
4✔
1529
                sym_type = self.container_table[prototype_name]
4✔
1530
                if isinstance(sym_type, Pointer) and sym_type.has_pointee_type():
4✔
1531
                    element_type = sym_type.pointee_type
4✔
1532

1533
        if element_type is None:
4✔
1534
            element_type = Scalar(PrimitiveType.Double)
×
1535

1536
        strides = self._compute_strides(dims, order)
4✔
1537
        return self._create_array_temp(
4✔
1538
            dims, element_type, zero_init=False, ones_init=False, strides=strides
1539
        )
1540

1541
    def _handle_numpy_zeros_like(self, node, func_name):
4✔
1542
        """Handle np.zeros_like."""
1543
        prototype_arg = node.args[0]
4✔
1544
        prototype_name = self.visit(prototype_arg)
4✔
1545

1546
        dims = []
4✔
1547
        if prototype_name in self.tensor_table:
4✔
1548
            dims = self.tensor_table[prototype_name].shape
4✔
1549

1550
        dtype_arg = None
4✔
1551
        order = "C"  # Default to C-order
4✔
1552
        if len(node.args) > 1:
4✔
1553
            dtype_arg = node.args[1]
×
1554

1555
        for kw in node.keywords:
4✔
1556
            if kw.arg == "dtype":
4✔
1557
                dtype_arg = kw.value
4✔
1558
            elif kw.arg == "order":
4✔
1559
                if isinstance(kw.value, ast.Constant):
4✔
1560
                    order = kw.value.value
4✔
1561

1562
        element_type = None
4✔
1563
        if dtype_arg:
4✔
1564
            element_type = element_type_from_ast_node(dtype_arg, self.container_table)
4✔
1565
        else:
1566
            if prototype_name in self.container_table:
4✔
1567
                sym_type = self.container_table[prototype_name]
4✔
1568
                if isinstance(sym_type, Pointer) and sym_type.has_pointee_type():
4✔
1569
                    element_type = sym_type.pointee_type
4✔
1570

1571
        if element_type is None:
4✔
1572
            element_type = Scalar(PrimitiveType.Double)
×
1573

1574
        strides = self._compute_strides(dims, order)
4✔
1575
        return self._create_array_temp(
4✔
1576
            dims, element_type, zero_init=True, ones_init=False, strides=strides
1577
        )
1578

1579
    def _handle_numpy_eye(self, node, func_name):
4✔
1580
        """Handle np.eye."""
1581
        N_arg = node.args[0]
4✔
1582
        N_str = self.visit(N_arg)
4✔
1583
        N_runtime = self._shape_to_runtime_expr(N_arg)
4✔
1584

1585
        M_str = N_str
4✔
1586
        M_arg = N_arg  # Default M = N
4✔
1587
        if len(node.args) > 1:
4✔
1588
            M_arg = node.args[1]
×
1589
            M_str = self.visit(M_arg)
×
1590

1591
        k_str = "0"
4✔
1592
        if len(node.args) > 2:
4✔
1593
            k_str = self.visit(node.args[2])
×
1594

1595
        dtype_arg = None
4✔
1596
        for kw in node.keywords:
4✔
1597
            if kw.arg == "M":
4✔
1598
                M_arg = kw.value
4✔
1599
                M_str = self.visit(M_arg)
4✔
1600
                if M_str == "None":
4✔
1601
                    M_str = N_str
4✔
1602
                    M_arg = N_arg
4✔
1603
            elif kw.arg == "k":
4✔
1604
                k_str = self.visit(kw.value)
4✔
1605
            elif kw.arg == "dtype":
4✔
1606
                dtype_arg = kw.value
4✔
1607

1608
        M_runtime = self._shape_to_runtime_expr(M_arg)
4✔
1609

1610
        element_type = element_type_from_ast_node(dtype_arg, self.container_table)
4✔
1611

1612
        ptr_name = self._create_array_temp(
4✔
1613
            [N_str, M_str],
1614
            element_type,
1615
            zero_init=True,
1616
            shapes_runtime=[N_runtime, M_runtime],
1617
        )
1618

1619
        loop_var = f"_i_{self._get_unique_id()}"
4✔
1620
        if not self.builder.exists(loop_var):
4✔
1621
            self.builder.add_container(loop_var, Scalar(PrimitiveType.Int64), False)
4✔
1622
            self.container_table[loop_var] = Scalar(PrimitiveType.Int64)
4✔
1623

1624
        self.builder.begin_for(loop_var, "0", N_str, "1")
4✔
1625

1626
        cond = f"(({loop_var} + {k_str}) >= 0) & (({loop_var} + {k_str}) < {M_str})"
4✔
1627
        self.builder.begin_if(cond)
4✔
1628

1629
        val = "1.0"
4✔
1630
        if element_type.primitive_type in [
4✔
1631
            PrimitiveType.Int64,
1632
            PrimitiveType.Int32,
1633
            PrimitiveType.Int8,
1634
            PrimitiveType.Int16,
1635
            PrimitiveType.UInt64,
1636
            PrimitiveType.UInt32,
1637
            PrimitiveType.UInt8,
1638
            PrimitiveType.UInt16,
1639
        ]:
1640
            val = "1"
×
1641

1642
        block_assign = self.builder.add_block()
4✔
1643
        t_const = self.builder.add_constant(block_assign, val, element_type)
4✔
1644
        t_arr = self.builder.add_access(block_assign, ptr_name)
4✔
1645
        flat_index = f"(({loop_var}) * ({M_str}) + ({loop_var}) + ({k_str}))"
4✔
1646
        subset = flat_index
4✔
1647

1648
        t_task = self.builder.add_tasklet(
4✔
1649
            block_assign, TaskletCode.assign, ["_in"], ["_out"]
1650
        )
1651
        self.builder.add_memlet(
4✔
1652
            block_assign, t_const, "void", t_task, "_in", "", element_type
1653
        )
1654
        self.builder.add_memlet(block_assign, t_task, "_out", t_arr, "void", subset)
4✔
1655

1656
        self.builder.end_if()
4✔
1657
        self.builder.end_for()
4✔
1658

1659
        return ptr_name
4✔
1660

1661
    def _handle_numpy_binary_op(self, node, func_name):
4✔
1662
        """Handle np.add, np.subtract, np.multiply, np.divide, etc."""
1663
        args = [self.visit(arg) for arg in node.args]
4✔
1664
        if len(args) != 2:
4✔
1665
            raise NotImplementedError(
×
1666
                f"Numpy function {func_name} requires 2 arguments"
1667
            )
1668

1669
        op_map = {
4✔
1670
            "add": "add",
1671
            "subtract": "sub",
1672
            "multiply": "mul",
1673
            "divide": "div",
1674
            "power": "pow",
1675
            "minimum": "min",
1676
            "maximum": "max",
1677
        }
1678
        return self.handle_array_binary_op(op_map[func_name], args[0], args[1])
4✔
1679

1680
    def _handle_numpy_unary_op(self, node, func_name):
4✔
1681
        """Handle np.exp, np.sqrt, np.abs, etc."""
1682
        args = [self.visit(arg) for arg in node.args]
4✔
1683
        if len(args) != 1:
4✔
1684
            raise NotImplementedError(f"Numpy function {func_name} requires 1 argument")
×
1685

1686
        op_name = func_name
4✔
1687
        if op_name == "absolute":
4✔
1688
            op_name = "abs"
×
1689

1690
        return self.handle_array_unary_op(op_name, args[0])
4✔
1691

1692
    def _handle_numpy_where(self, node, func_name):
4✔
1693
        """Handle np.where(condition, x, y) - elementwise ternary selection."""
1694
        if len(node.args) != 3:
4✔
1695
            raise NotImplementedError("np.where requires 3 arguments (condition, x, y)")
×
1696

1697
        cond_name = self.visit(node.args[0])
4✔
1698
        x_name = self.visit(node.args[1])
4✔
1699
        y_name = self.visit(node.args[2])
4✔
1700

1701
        shape = []
4✔
1702
        dtype = Scalar(PrimitiveType.Double)
4✔
1703

1704
        if cond_name in self.tensor_table:
4✔
1705
            shape = self.tensor_table[cond_name].shape
4✔
1706

1707
        if not shape and y_name in self.tensor_table:
4✔
1708
            shape = self.tensor_table[y_name].shape
×
1709

1710
        if not shape and x_name in self.tensor_table:
4✔
1711
            shape = self.tensor_table[x_name].shape
×
1712

1713
        if not shape:
4✔
1714
            raise NotImplementedError("np.where requires at least one array argument")
×
1715

1716
        if y_name in self.container_table:
4✔
1717
            y_type = self.container_table[y_name]
4✔
1718
            if isinstance(y_type, Pointer) and y_type.has_pointee_type():
4✔
1719
                dtype = y_type.pointee_type
4✔
1720
            elif isinstance(y_type, Scalar):
×
1721
                dtype = y_type
×
1722

1723
        tmp_name = self._create_array_temp(shape, dtype)
4✔
1724
        tmp_tensor = self.tensor_table[tmp_name]
4✔
1725

1726
        loop_vars = []
4✔
1727
        for i, dim in enumerate(shape):
4✔
1728
            loop_var = f"_where_i{i}_{self._get_unique_id()}"
4✔
1729
            if not self.builder.exists(loop_var):
4✔
1730
                self.builder.add_container(loop_var, Scalar(PrimitiveType.Int64), False)
4✔
1731
                self.container_table[loop_var] = Scalar(PrimitiveType.Int64)
4✔
1732
            loop_vars.append(loop_var)
4✔
1733
            self.builder.begin_for(loop_var, "0", str(dim), "1")
4✔
1734
        multi_dim_subset = ",".join(loop_vars)
4✔
1735

1736
        cond_tmp = f"_where_cond_{self._get_unique_id()}"
4✔
1737
        self.builder.add_container(cond_tmp, Scalar(PrimitiveType.Bool), False)
4✔
1738
        self.container_table[cond_tmp] = Scalar(PrimitiveType.Bool)
4✔
1739

1740
        block_cond = self.builder.add_block()
4✔
1741
        if cond_name in self.tensor_table:
4✔
1742
            cond_tensor = self.tensor_table[cond_name]
4✔
1743
            t_cond_arr = self.builder.add_access(block_cond, cond_name)
4✔
1744
            t_cond_out = self.builder.add_access(block_cond, cond_tmp)
4✔
1745
            t_cond_task = self.builder.add_tasklet(
4✔
1746
                block_cond, TaskletCode.assign, ["_in"], ["_out"]
1747
            )
1748
            self.builder.add_memlet(
4✔
1749
                block_cond,
1750
                t_cond_arr,
1751
                "void",
1752
                t_cond_task,
1753
                "_in",
1754
                multi_dim_subset,
1755
                cond_tensor,
1756
            )
1757
            self.builder.add_memlet(
4✔
1758
                block_cond, t_cond_task, "_out", t_cond_out, "void", ""
1759
            )
1760
        else:
1761
            t_cond_src, cond_sub = self._add_read(block_cond, cond_name)
×
1762
            t_cond_out = self.builder.add_access(block_cond, cond_tmp)
×
1763
            t_cond_task = self.builder.add_tasklet(
×
1764
                block_cond, TaskletCode.assign, ["_in"], ["_out"]
1765
            )
1766
            self.builder.add_memlet(
×
1767
                block_cond, t_cond_src, "void", t_cond_task, "_in", cond_sub
1768
            )
1769
            self.builder.add_memlet(
×
1770
                block_cond, t_cond_task, "_out", t_cond_out, "void", ""
1771
            )
1772

1773
        self.builder.begin_if(f"{cond_tmp} == true")
4✔
1774

1775
        block_true = self.builder.add_block()
4✔
1776
        t_out_true = self.builder.add_access(block_true, tmp_name)
4✔
1777
        if x_name in self.tensor_table:
4✔
1778
            x_tensor = self.tensor_table[x_name]
4✔
1779
            t_x = self.builder.add_access(block_true, x_name)
4✔
1780
            t_task_true = self.builder.add_tasklet(
4✔
1781
                block_true, TaskletCode.assign, ["_in"], ["_out"]
1782
            )
1783
            self.builder.add_memlet(
4✔
1784
                block_true, t_x, "void", t_task_true, "_in", multi_dim_subset, x_tensor
1785
            )
1786
        else:
1787
            t_x, x_sub = self._add_read(block_true, x_name)
4✔
1788
            t_task_true = self.builder.add_tasklet(
4✔
1789
                block_true, TaskletCode.assign, ["_in"], ["_out"]
1790
            )
1791
            self.builder.add_memlet(block_true, t_x, "void", t_task_true, "_in", x_sub)
4✔
1792
        self.builder.add_memlet(
4✔
1793
            block_true,
1794
            t_task_true,
1795
            "_out",
1796
            t_out_true,
1797
            "void",
1798
            multi_dim_subset,
1799
            tmp_tensor,
1800
        )
1801

1802
        self.builder.begin_else()
4✔
1803

1804
        # False branch: read from y, write to output
1805
        block_false = self.builder.add_block()
4✔
1806
        t_out_false = self.builder.add_access(block_false, tmp_name)
4✔
1807
        if y_name in self.tensor_table:
4✔
1808
            y_tensor = self.tensor_table[y_name]
4✔
1809
            t_y = self.builder.add_access(block_false, y_name)
4✔
1810
            t_task_false = self.builder.add_tasklet(
4✔
1811
                block_false, TaskletCode.assign, ["_in"], ["_out"]
1812
            )
1813
            self.builder.add_memlet(
4✔
1814
                block_false,
1815
                t_y,
1816
                "void",
1817
                t_task_false,
1818
                "_in",
1819
                multi_dim_subset,
1820
                y_tensor,
1821
            )
1822
        else:
1823
            t_y, y_sub = self._add_read(block_false, y_name)
4✔
1824
            t_task_false = self.builder.add_tasklet(
4✔
1825
                block_false, TaskletCode.assign, ["_in"], ["_out"]
1826
            )
1827
            self.builder.add_memlet(
4✔
1828
                block_false, t_y, "void", t_task_false, "_in", y_sub
1829
            )
1830
        self.builder.add_memlet(
4✔
1831
            block_false,
1832
            t_task_false,
1833
            "_out",
1834
            t_out_false,
1835
            "void",
1836
            multi_dim_subset,
1837
            tmp_tensor,
1838
        )
1839

1840
        self.builder.end_if()
4✔
1841

1842
        for _ in loop_vars:
4✔
1843
            self.builder.end_for()
4✔
1844

1845
        return tmp_name
4✔
1846

1847
    def _handle_numpy_clip(self, node, func_name):
4✔
1848
        """Handle np.clip(a, a_min, a_max) - elementwise clipping."""
1849
        if len(node.args) != 3:
4✔
1850
            raise NotImplementedError("np.clip requires 3 arguments (a, a_min, a_max)")
×
1851

1852
        arr_name = self.visit(node.args[0])
4✔
1853
        a_min = self.visit(node.args[1])
4✔
1854
        a_max = self.visit(node.args[2])
4✔
1855

1856
        tmp1 = self.handle_array_binary_op("max", arr_name, a_min)
4✔
1857
        result = self.handle_array_binary_op("min", tmp1, a_max)
4✔
1858

1859
        return result
4✔
1860

1861
    def _handle_numpy_matmul(self, node, func_name):
4✔
1862
        """Handle np.matmul, np.dot."""
1863
        if len(node.args) != 2:
4✔
1864
            raise NotImplementedError("matmul/dot requires 2 arguments")
×
1865
        return self._handle_matmul_helper(node.args[0], node.args[1])
4✔
1866

1867
    def handle_numpy_matmul_op(self, left_node, right_node):
4✔
1868
        """Handle the @ operator for matrix multiplication."""
1869
        return self._handle_matmul_helper(left_node, right_node)
4✔
1870

1871
    def _handle_matmul_helper(self, left_node, right_node):
4✔
1872
        """Helper for matrix multiplication operations."""
1873
        res_a = self.parse_arg(left_node)
4✔
1874
        res_b = self.parse_arg(right_node)
4✔
1875

1876
        if not res_a[0]:
4✔
1877
            left_name = self.visit(left_node)
4✔
1878
            left_node = ast.Name(id=left_name)
4✔
1879
            res_a = self.parse_arg(left_node)
4✔
1880

1881
        if not res_b[0]:
4✔
1882
            right_name = self.visit(right_node)
×
1883
            right_node = ast.Name(id=right_name)
×
1884
            res_b = self.parse_arg(right_node)
×
1885

1886
        name_a, subset_a, shape_a, indices_a = res_a
4✔
1887
        name_b, subset_b, shape_b, indices_b = res_b
4✔
1888

1889
        if not name_a or not name_b:
4✔
1890
            raise NotImplementedError("Could not resolve matmul operands")
×
1891

1892
        real_shape_a = shape_a
4✔
1893
        real_shape_b = shape_b
4✔
1894

1895
        ndim_a = len(real_shape_a)
4✔
1896
        ndim_b = len(real_shape_b)
4✔
1897

1898
        output_shape = []
4✔
1899
        is_scalar = False
4✔
1900

1901
        if ndim_a == 1 and ndim_b == 1:
4✔
1902
            is_scalar = True
4✔
1903
            output_shape = []
4✔
1904
        elif ndim_a == 2 and ndim_b == 2:
4✔
1905
            output_shape = [real_shape_a[0], real_shape_b[1]]
4✔
1906
        elif ndim_a == 2 and ndim_b == 1:
4✔
1907
            output_shape = [real_shape_a[0]]
4✔
1908
        elif ndim_a == 1 and ndim_b == 2:
4✔
1909
            output_shape = [real_shape_b[1]]
×
1910
        elif ndim_a > 2 or ndim_b > 2:
4✔
1911
            if ndim_a == ndim_b:
4✔
1912
                output_shape = list(real_shape_a[:-2]) + [
4✔
1913
                    real_shape_a[-2],
1914
                    real_shape_b[-1],
1915
                ]
1916
            else:
1917
                raise NotImplementedError(
×
1918
                    "Broadcasting with different ranks not fully supported yet"
1919
                )
1920
        else:
1921
            raise NotImplementedError(
×
1922
                f"Matmul with ranks {ndim_a} and {ndim_b} not supported"
1923
            )
1924

1925
        dtype_a = self._ev._element_type(name_a)
4✔
1926
        dtype_b = self._ev._element_type(name_b)
4✔
1927
        dtype = promote_element_types(dtype_a, dtype_b)
4✔
1928

1929
        if is_scalar:
4✔
1930
            tmp_name = f"_tmp_{self._get_unique_id()}"
4✔
1931
            self.builder.add_container(tmp_name, dtype, False)
4✔
1932
            self.container_table[tmp_name] = dtype
4✔
1933
        else:
1934
            tmp_name = self._create_array_temp(output_shape, dtype)
4✔
1935

1936
        if ndim_a > 2 or ndim_b > 2:
4✔
1937
            batch_dims = ndim_a - 2
4✔
1938
            loop_vars = []
4✔
1939

1940
            for i in range(batch_dims):
4✔
1941
                loop_var = f"_i{self._get_unique_id()}"
4✔
1942
                self.builder.add_container(loop_var, Scalar(PrimitiveType.Int64), False)
4✔
1943
                loop_vars.append(loop_var)
4✔
1944
                dim_size = real_shape_a[i]
4✔
1945
                self.builder.begin_for(loop_var, "0", str(dim_size), "1")
4✔
1946

1947
            def make_slice(name, indices):
4✔
1948
                elts = []
4✔
1949
                for idx in indices:
4✔
1950
                    if idx == ":":
4✔
1951
                        elts.append(ast.Slice())
4✔
1952
                    else:
1953
                        elts.append(ast.Name(id=idx))
4✔
1954

1955
                return ast.Subscript(
4✔
1956
                    value=ast.Name(id=name), slice=ast.Tuple(elts=elts), ctx=ast.Load()
1957
                )
1958

1959
            indices = loop_vars + [":", ":"]
4✔
1960
            slice_a = make_slice(name_a, indices)
4✔
1961
            slice_b = make_slice(name_b, indices)
4✔
1962
            slice_c = make_slice(tmp_name, indices)
4✔
1963

1964
            self.handle_gemm(
4✔
1965
                slice_c, ast.BinOp(left=slice_a, op=ast.MatMult(), right=slice_b)
1966
            )
1967

1968
            for _ in range(batch_dims):
4✔
1969
                self.builder.end_for()
4✔
1970
        else:
1971
            if is_scalar:
4✔
1972
                self.handle_dot(
4✔
1973
                    tmp_name,
1974
                    ast.BinOp(left=left_node, op=ast.MatMult(), right=right_node),
1975
                )
1976
            else:
1977
                self.handle_gemm(
4✔
1978
                    tmp_name,
1979
                    ast.BinOp(left=left_node, op=ast.MatMult(), right=right_node),
1980
                )
1981

1982
        return tmp_name
4✔
1983

1984
    def _handle_numpy_outer(self, node, func_name):
4✔
1985
        """Handle np.outer."""
1986
        if len(node.args) != 2:
4✔
1987
            raise NotImplementedError("outer requires 2 arguments")
×
1988

1989
        arg0 = node.args[0]
4✔
1990
        arg1 = node.args[1]
4✔
1991

1992
        res_a = self.parse_arg(arg0)
4✔
1993
        res_b = self.parse_arg(arg1)
4✔
1994

1995
        if not res_a[0]:
4✔
1996
            left_name = self.visit(arg0)
×
1997
            arg0 = ast.Name(id=left_name)
×
1998
            res_a = self.parse_arg(arg0)
×
1999

2000
        if not res_b[0]:
4✔
2001
            right_name = self.visit(arg1)
×
2002
            arg1 = ast.Name(id=right_name)
×
2003
            res_b = self.parse_arg(arg1)
×
2004

2005
        name_a, subset_a, shape_a, indices_a = res_a
4✔
2006
        name_b, subset_b, shape_b, indices_b = res_b
4✔
2007

2008
        if not name_a or not name_b:
4✔
2009
            raise NotImplementedError("Could not resolve outer operands")
×
2010

2011
        def get_flattened_size_expr(name, indices, shapes):
4✔
2012
            size_expr = "1"
4✔
2013
            for s in shapes:
4✔
2014
                if size_expr == "1":
4✔
2015
                    size_expr = str(s)
4✔
2016
                else:
2017
                    size_expr = f"({size_expr} * {str(s)})"
×
2018
            return size_expr
4✔
2019

2020
        m_expr = get_flattened_size_expr(name_a, indices_a, shape_a)
4✔
2021
        n_expr = get_flattened_size_expr(name_b, indices_b, shape_b)
4✔
2022

2023
        dtype_a = self._ev._element_type(name_a)
4✔
2024
        dtype_b = self._ev._element_type(name_b)
4✔
2025
        dtype = promote_element_types(dtype_a, dtype_b)
4✔
2026

2027
        tmp_name = self._create_array_temp([m_expr, n_expr], dtype)
4✔
2028

2029
        new_call_node = ast.Call(
4✔
2030
            func=node.func, args=[arg0, arg1], keywords=node.keywords
2031
        )
2032

2033
        self.handle_outer(tmp_name, new_call_node)
4✔
2034

2035
        return tmp_name
4✔
2036

2037
    def handle_ufunc_outer(self, node, ufunc_name):
4✔
2038
        """Handle np.add.outer, np.subtract.outer, np.multiply.outer, etc."""
2039
        if len(node.args) != 2:
4✔
2040
            raise NotImplementedError(f"{ufunc_name}.outer requires 2 arguments")
×
2041

2042
        if ufunc_name == "multiply":
4✔
2043
            return self._handle_numpy_outer(node, "outer")
4✔
2044

2045
        op_map = {
4✔
2046
            "add": ("add", TaskletCode.fp_add, TaskletCode.int_add),
2047
            "subtract": ("sub", TaskletCode.fp_sub, TaskletCode.int_sub),
2048
            "divide": ("div", TaskletCode.fp_div, TaskletCode.int_sdiv),
2049
            "minimum": ("min", CMathFunction.fmin, TaskletCode.int_smin),
2050
            "maximum": ("max", CMathFunction.fmax, TaskletCode.int_smax),
2051
        }
2052

2053
        if ufunc_name not in op_map:
4✔
2054
            raise NotImplementedError(f"{ufunc_name}.outer not supported")
×
2055

2056
        op_name, fp_opcode, int_opcode = op_map[ufunc_name]
4✔
2057

2058
        arg0 = node.args[0]
4✔
2059
        arg1 = node.args[1]
4✔
2060

2061
        res_a = self.parse_arg(arg0)
4✔
2062
        res_b = self.parse_arg(arg1)
4✔
2063

2064
        if not res_a[0]:
4✔
2065
            left_name = self.visit(arg0)
×
2066
            arg0 = ast.Name(id=left_name)
×
2067
            res_a = self.parse_arg(arg0)
×
2068

2069
        if not res_b[0]:
4✔
2070
            right_name = self.visit(arg1)
×
2071
            arg1 = ast.Name(id=right_name)
×
2072
            res_b = self.parse_arg(arg1)
×
2073

2074
        name_a, subset_a, shape_a, indices_a = res_a
4✔
2075
        name_b, subset_b, shape_b, indices_b = res_b
4✔
2076

2077
        if not name_a or not name_b:
4✔
2078
            raise NotImplementedError("Could not resolve ufunc outer operands")
×
2079

2080
        def get_flattened_size_expr(shapes):
4✔
2081
            if not shapes:
4✔
2082
                return "1"
×
2083
            size_expr = str(shapes[0])
4✔
2084
            for s in shapes[1:]:
4✔
2085
                size_expr = f"({size_expr} * {str(s)})"
×
2086
            return size_expr
4✔
2087

2088
        m_expr = get_flattened_size_expr(shape_a)
4✔
2089
        n_expr = get_flattened_size_expr(shape_b)
4✔
2090

2091
        dtype_left = self._ev._element_type(name_a)
4✔
2092
        dtype_right = self._ev._element_type(name_b)
4✔
2093
        dtype = promote_element_types(dtype_left, dtype_right)
4✔
2094

2095
        is_int = dtype.primitive_type in [
4✔
2096
            PrimitiveType.Int64,
2097
            PrimitiveType.Int32,
2098
            PrimitiveType.Int8,
2099
            PrimitiveType.Int16,
2100
            PrimitiveType.UInt64,
2101
            PrimitiveType.UInt32,
2102
            PrimitiveType.UInt8,
2103
            PrimitiveType.UInt16,
2104
        ]
2105

2106
        tmp_name = self._create_array_temp([m_expr, n_expr], dtype)
4✔
2107

2108
        i_var = self.builder.find_new_name("_outer_i_")
4✔
2109
        j_var = self.builder.find_new_name("_outer_j_")
4✔
2110

2111
        if not self.builder.exists(i_var):
4✔
2112
            self.builder.add_container(i_var, Scalar(PrimitiveType.Int64), False)
4✔
2113
            self.container_table[i_var] = Scalar(PrimitiveType.Int64)
4✔
2114
        if not self.builder.exists(j_var):
4✔
2115
            self.builder.add_container(j_var, Scalar(PrimitiveType.Int64), False)
4✔
2116
            self.container_table[j_var] = Scalar(PrimitiveType.Int64)
4✔
2117

2118
        def compute_linear_index(name, subset, indices, loop_var):
4✔
2119
            if not indices:
4✔
2120
                return loop_var
4✔
2121

2122
            if name in self.tensor_table:
4✔
2123
                info = self.tensor_table[name]
4✔
2124
                shapes = info.shape
4✔
2125
                ndim = len(shapes)
4✔
2126
            else:
2127
                shapes = []
×
2128
                ndim = 0
×
2129

2130
            if ndim == 0:
4✔
2131
                return loop_var
×
2132

2133
            strides = []
4✔
2134
            current_stride = "1"
4✔
2135
            for i in range(ndim - 1, -1, -1):
4✔
2136
                strides.insert(0, current_stride)
4✔
2137
                if i > 0:
4✔
2138
                    dim_size = shapes[i] if i < len(shapes) else f"_{name}_shape_{i}"
4✔
2139
                    if current_stride == "1":
4✔
2140
                        current_stride = str(dim_size)
4✔
2141
                    else:
2142
                        current_stride = f"({current_stride} * {dim_size})"
×
2143

2144
            terms = []
4✔
2145
            loop_var_used = False
4✔
2146

2147
            for i, idx in enumerate(indices):
4✔
2148
                stride = strides[i] if i < len(strides) else "1"
4✔
2149
                start = subset[i] if i < len(subset) else "0"
4✔
2150

2151
                if isinstance(idx, ast.Slice):
4✔
2152
                    if stride == "1":
4✔
2153
                        term = f"({start} + {loop_var})"
4✔
2154
                    else:
2155
                        term = f"(({start} + {loop_var}) * {stride})"
4✔
2156
                    loop_var_used = True
4✔
2157
                else:
2158
                    if stride == "1":
4✔
2159
                        term = start
4✔
2160
                    else:
2161
                        term = f"({start} * {stride})"
4✔
2162

2163
                terms.append(term)
4✔
2164

2165
            if not terms:
4✔
2166
                return loop_var
×
2167

2168
            result = terms[0]
4✔
2169
            for t in terms[1:]:
4✔
2170
                result = f"({result} + {t})"
4✔
2171

2172
            return result
4✔
2173

2174
        self.builder.begin_for(i_var, "0", m_expr, "1")
4✔
2175
        self.builder.begin_for(j_var, "0", n_expr, "1")
4✔
2176

2177
        block = self.builder.add_block()
4✔
2178

2179
        t_a = self.builder.add_access(block, name_a)
4✔
2180
        t_b = self.builder.add_access(block, name_b)
4✔
2181
        t_c = self.builder.add_access(block, tmp_name)
4✔
2182

2183
        if ufunc_name in ["minimum", "maximum"]:
4✔
2184
            if is_int:
4✔
2185
                t_task = self.builder.add_tasklet(
4✔
2186
                    block, int_opcode, ["_in1", "_in2"], ["_out"]
2187
                )
2188
            else:
2189
                t_task = self.builder.add_cmath(block, fp_opcode, dtype.primitive_type)
4✔
2190
        else:
2191
            tasklet_code = int_opcode if is_int else fp_opcode
4✔
2192
            t_task = self.builder.add_tasklet(
4✔
2193
                block, tasklet_code, ["_in1", "_in2"], ["_out"]
2194
            )
2195

2196
        a_index = compute_linear_index(name_a, subset_a, indices_a, i_var)
4✔
2197
        b_index = compute_linear_index(name_b, subset_b, indices_b, j_var)
4✔
2198

2199
        self.builder.add_memlet(block, t_a, "void", t_task, "_in1", a_index)
4✔
2200
        self.builder.add_memlet(block, t_b, "void", t_task, "_in2", b_index)
4✔
2201

2202
        flat_index = f"(({i_var}) * ({n_expr}) + ({j_var}))"
4✔
2203
        self.builder.add_memlet(block, t_task, "_out", t_c, "void", flat_index)
4✔
2204

2205
        self.builder.end_for()
4✔
2206
        self.builder.end_for()
4✔
2207

2208
        return tmp_name
4✔
2209

2210
    def _handle_numpy_reduce(self, node, func_name):
4✔
2211
        """Handle np.sum, np.max, np.min, np.mean, np.std."""
2212
        args = node.args
4✔
2213
        keywords = {kw.arg: kw.value for kw in node.keywords}
4✔
2214

2215
        array_node = args[0]
4✔
2216
        array_name = self.visit(array_node)
4✔
2217

2218
        if array_name not in self.tensor_table:
4✔
2219
            raise ValueError(f"Reduction input must be an array, got {array_name}")
×
2220

2221
        input_tensor = self.tensor_table[array_name]
4✔
2222
        input_shape = input_tensor.shape
4✔
2223
        ndim = len(input_shape)
4✔
2224

2225
        axis = None
4✔
2226
        if len(args) > 1:
4✔
2227
            axis = args[1]
×
2228
        elif "axis" in keywords:
4✔
2229
            axis = keywords["axis"]
4✔
2230

2231
        keepdims = False
4✔
2232
        if "keepdims" in keywords:
4✔
2233
            keepdims_node = keywords["keepdims"]
4✔
2234
            if isinstance(keepdims_node, ast.Constant):
4✔
2235
                keepdims = bool(keepdims_node.value)
4✔
2236

2237
        axes = []
4✔
2238
        if axis is None:
4✔
2239
            axes = list(range(ndim))
4✔
2240
        elif isinstance(axis, ast.Constant):
4✔
2241
            val = axis.value
4✔
2242
            if val < 0:
4✔
2243
                val += ndim
×
2244
            axes = [val]
4✔
2245
        elif isinstance(axis, ast.Tuple):
4✔
2246
            for elt in axis.elts:
×
2247
                if isinstance(elt, ast.Constant):
×
2248
                    val = elt.value
×
2249
                    if val < 0:
×
2250
                        val += ndim
×
2251
                    axes.append(val)
×
2252
        elif (
4✔
2253
            isinstance(axis, ast.UnaryOp)
2254
            and isinstance(axis.op, ast.USub)
2255
            and isinstance(axis.operand, ast.Constant)
2256
        ):
2257
            val = -axis.operand.value
4✔
2258
            if val < 0:
4✔
2259
                val += ndim
4✔
2260
            axes = [val]
4✔
2261
        else:
2262
            try:
×
2263
                val = int(self.visit(axis))
×
2264
                if val < 0:
×
2265
                    val += ndim
×
2266
                axes = [val]
×
2267
            except:
×
2268
                raise NotImplementedError("Dynamic axis not supported")
×
2269

2270
        output_shape = []
4✔
2271
        for i in range(ndim):
4✔
2272
            if i in axes:
4✔
2273
                if keepdims:
4✔
2274
                    output_shape.append("1")
4✔
2275
            else:
2276
                output_shape.append(input_shape[i])
4✔
2277

2278
        dtype = self._ev._element_type(array_name)
4✔
2279

2280
        if not output_shape:
4✔
2281
            tmp_name = f"_tmp_{self._get_unique_id()}"
4✔
2282
            self.builder.add_container(tmp_name, dtype, False)
4✔
2283
            self.container_table[tmp_name] = dtype
4✔
2284
            self.tensor_table[tmp_name] = Tensor(dtype, [])
4✔
2285
        else:
2286
            output_strides = self._compute_strides(output_shape, "C")
4✔
2287
            tmp_name = self._create_array_temp(
4✔
2288
                output_shape, dtype, strides=output_strides
2289
            )
2290

2291
        output_tensor = self.tensor_table[tmp_name]
4✔
2292
        self.builder.add_reduce_op(
4✔
2293
            func_name, array_name, input_tensor, tmp_name, output_tensor, axes, keepdims
2294
        )
2295

2296
        return tmp_name
4✔
2297

2298
    def handle_numpy_astype(self, node, array_name):
4✔
2299
        """Handle numpy array.astype(dtype) method calls."""
2300
        if len(node.args) < 1:
4✔
2301
            raise ValueError("astype requires at least one argument (dtype)")
×
2302

2303
        # Check for copy=False which we don't support (we always copy)
2304
        for kw in node.keywords:
4✔
2305
            if kw.arg == "copy":
4✔
2306
                if isinstance(kw.value, ast.Constant) and kw.value.value is False:
4✔
2307
                    raise NotImplementedError("astype with copy=False is not supported")
4✔
2308

2309
        dtype_arg = node.args[0]
4✔
2310
        target_dtype = element_type_from_ast_node(dtype_arg, self.container_table)
4✔
2311

2312
        if array_name not in self.tensor_table:
4✔
2313
            raise ValueError(f"Array {array_name} not found in tensor_table")
×
2314

2315
        input_tensor = self.tensor_table[array_name]
4✔
2316
        input_shape = input_tensor.shape
4✔
2317
        input_strides = getattr(input_tensor, "strides", None)
4✔
2318

2319
        # Determine output order: preserve F-order if input is F-contiguous
2320
        order = "C"
4✔
2321
        if input_strides and len(input_strides) >= 2 and len(input_shape) >= 2:
4✔
2322
            # F-order: first stride is 1, subsequent strides are products of preceding dims
2323
            f_strides = self._compute_strides(input_shape, "F")
4✔
2324
            if input_strides == f_strides:
4✔
2325
                order = "F"
×
2326

2327
        output_strides = self._compute_strides(input_shape, order)
4✔
2328
        tmp_name = self._create_array_temp(
4✔
2329
            input_shape, target_dtype, strides=output_strides
2330
        )
2331

2332
        output_tensor = self.tensor_table[tmp_name]
4✔
2333
        self.builder.add_cast_op(array_name, input_tensor, tmp_name, output_tensor)
4✔
2334

2335
        return tmp_name
4✔
2336

2337
    def handle_numpy_copy(self, node, array_name):
4✔
2338
        """Handle numpy array.copy() method calls using memcpy."""
2339
        if array_name not in self.tensor_table:
4✔
2340
            raise ValueError(f"Array {array_name} not found in tensor_table")
×
2341

2342
        input_tensor = self.tensor_table[array_name]
4✔
2343
        input_shape = input_tensor.shape
4✔
2344
        input_strides = getattr(input_tensor, "strides", None)
4✔
2345

2346
        element_type = Scalar(PrimitiveType.Double)
4✔
2347
        if array_name in self.container_table:
4✔
2348
            sym_type = self.container_table[array_name]
4✔
2349
            if isinstance(sym_type, Pointer) and sym_type.has_pointee_type():
4✔
2350
                element_type = sym_type.pointee_type
4✔
2351

2352
        # Determine output order: preserve F-order if input is F-contiguous
2353
        order = "C"
4✔
2354
        if input_strides and len(input_strides) >= 2 and len(input_shape) >= 2:
4✔
2355
            f_strides = self._compute_strides(input_shape, "F")
4✔
2356
            if input_strides == f_strides:
4✔
2357
                order = "F"
×
2358

2359
        output_strides = self._compute_strides(input_shape, order)
4✔
2360
        tmp_name = self._create_array_temp(
4✔
2361
            input_shape, element_type, strides=output_strides
2362
        )
2363

2364
        output_tensor = self.tensor_table[tmp_name]
4✔
2365
        # Workaround: "assign-op"
2366
        self.builder.add_cast_op(array_name, input_tensor, tmp_name, output_tensor)
4✔
2367

2368
        return tmp_name
4✔
2369

2370
    def _get_contiguous_output_strides(self, shape, input_strides):
4✔
2371
        """Get contiguous output strides, preserving C or F order if input is contiguous.
2372

2373
        For non-contiguous input strides (e.g., from slices), returns C-order strides.
2374
        This ensures output allocation matches the stride pattern.
2375

2376
        Args:
2377
            shape: Output shape
2378
            input_strides: Strides from input tensor
2379

2380
        Returns:
2381
            List of stride expressions for a contiguous output array
2382
        """
2383
        if not shape or not input_strides:
4✔
2384
            return self._compute_strides(shape, "C")
4✔
2385

2386
        # Preserve order if contiguous, otherwise default to C-order
2387
        c_strides = self._compute_strides(shape, "C")
4✔
2388
        if input_strides == c_strides:
4✔
2389
            return c_strides
4✔
2390
        f_strides = self._compute_strides(shape, "F")
4✔
2391
        if input_strides == f_strides:
4✔
2392
            return f_strides
×
2393
        return c_strides
4✔
2394

2395
    def _compute_strides(self, shape, order="C"):
4✔
2396
        """Compute strides for a given shape and memory order.
2397

2398
        Args:
2399
            shape: List of dimension sizes
2400
            order: "C" for row-major (default), "F" for column-major
2401

2402
        Returns:
2403
            List of stride expressions as strings
2404
        """
2405
        if not shape:
4✔
2406
            return []
×
2407

2408
        ndim = len(shape)
4✔
2409
        strides = []
4✔
2410

2411
        if order == "F":
4✔
2412
            # Column-major (Fortran order): stride[i] = product of shape[:i]
2413
            for dim_idx in range(ndim):
4✔
2414
                if dim_idx == 0:
4✔
2415
                    strides.append("1")
4✔
2416
                else:
2417
                    # Wrap each shape in parens to ensure correct precedence
2418
                    prefix_shapes = [f"({s})" for s in shape[:dim_idx]]
4✔
2419
                    if len(prefix_shapes) == 1:
4✔
2420
                        strides.append(prefix_shapes[0])
4✔
2421
                    else:
2422
                        strides.append("(" + " * ".join(prefix_shapes) + ")")
×
2423
        else:
2424
            # Row-major (C order): stride[i] = product of shape[i+1:]
2425
            for dim_idx in range(ndim):
4✔
2426
                if dim_idx == ndim - 1:
4✔
2427
                    strides.append("1")
4✔
2428
                else:
2429
                    # Wrap each shape in parens to ensure correct precedence
2430
                    suffix_shapes = [f"({s})" for s in shape[dim_idx + 1 :]]
4✔
2431
                    if len(suffix_shapes) == 1:
4✔
2432
                        strides.append(suffix_shapes[0])
4✔
2433
                    else:
2434
                        strides.append("(" + " * ".join(suffix_shapes) + ")")
4✔
2435

2436
        return strides
4✔
2437

2438
    def _is_contiguous(self, shape, strides):
4✔
2439
        """Check if strides represent a contiguous (C or F order) layout."""
2440
        if not shape or not strides:
4✔
2441
            return True
×
2442

2443
        def normalize(s):
4✔
2444
            # Normalize stride expression by removing spaces and outer parens
2445
            s = s.replace(" ", "")
4✔
2446
            while s.startswith("(") and s.endswith(")"):
4✔
2447
                # Only strip if balanced parens
2448
                inner = s[1:-1]
4✔
2449
                depth = 0
4✔
2450
                balanced = True
4✔
2451
                for c in inner:
4✔
2452
                    if c == "(":
4✔
2453
                        depth += 1
×
2454
                    elif c == ")":
4✔
2455
                        depth -= 1
×
2456
                        if depth < 0:
×
2457
                            balanced = False
×
2458
                            break
×
2459
                if balanced and depth == 0:
4✔
2460
                    s = inner
4✔
2461
                else:
2462
                    break
×
2463
            return s
4✔
2464

2465
        c_strides = self._compute_strides(shape, "C")
4✔
2466
        if all(
4✔
2467
            normalize(str(a)) == normalize(str(b)) for a, b in zip(strides, c_strides)
2468
        ):
2469
            return True
4✔
2470
        f_strides = self._compute_strides(shape, "F")
×
2471
        return all(
×
2472
            normalize(str(a)) == normalize(str(b)) for a, b in zip(strides, f_strides)
2473
        )
2474

2475
    def _create_array_temp(
4✔
2476
        self,
2477
        shape,
2478
        dtype,
2479
        zero_init=False,
2480
        ones_init=False,
2481
        shapes_runtime=None,
2482
        strides=None,
2483
    ):
2484
        """Create a temporary array."""
2485
        tmp_name = f"_tmp_{self._get_unique_id()}"
4✔
2486

2487
        # Handle 0-dimensional arrays as scalars
2488
        if not shape or (len(shape) == 0):
4✔
2489
            self.builder.add_container(tmp_name, dtype, False)
4✔
2490
            self.container_table[tmp_name] = dtype
4✔
2491
            self.tensor_table[tmp_name] = Tensor(dtype, [])
4✔
2492

2493
            if zero_init:
4✔
2494
                self.builder.add_assignment(
×
2495
                    tmp_name,
2496
                    "0.0" if dtype.primitive_type == PrimitiveType.Double else "0",
2497
                )
2498
            elif ones_init:
4✔
2499
                self.builder.add_assignment(
×
2500
                    tmp_name,
2501
                    "1.0" if dtype.primitive_type == PrimitiveType.Double else "1",
2502
                )
2503

2504
            return tmp_name
4✔
2505

2506
        # Calculate size - wrap each dimension in parentheses to ensure correct
2507
        # parsing when dimensions are expressions like "-2 + _s0"
2508
        size_str = "1"
4✔
2509
        for dim in shape:
4✔
2510
            size_str = f"({size_str} * ({dim}))"
4✔
2511

2512
        element_size = self.builder.get_sizeof(dtype)
4✔
2513
        total_size = f"({size_str} * {element_size})"
4✔
2514

2515
        # Use provided strides or compute C-order strides
2516
        if strides is None:
4✔
2517
            strides = self._compute_strides(shape, "C")
4✔
2518

2519
        # Create pointer
2520
        ptr_type = Pointer(dtype)
4✔
2521
        self.builder.add_container(tmp_name, ptr_type, False)
4✔
2522
        self.container_table[tmp_name] = ptr_type
4✔
2523
        tensor_entry = Tensor(dtype, shape, strides, "0")
4✔
2524
        if shapes_runtime is not None:
4✔
2525
            self.shapes_runtime_info[tmp_name] = shapes_runtime
4✔
2526
        self.tensor_table[tmp_name] = tensor_entry
4✔
2527

2528
        # Try to hoist allocation to function entry
2529
        init_type = (
4✔
2530
            ManagedMemoryHandler.INIT_ZERO
2531
            if zero_init
2532
            else ManagedMemoryHandler.INIT_NONE
2533
        )
2534
        if not ones_init and self.memory_handler.allocate(
4✔
2535
            tmp_name, ptr_type, total_size, init=init_type
2536
        ):
2537
            pass  # Allocation registered for hoisting
4✔
2538
        else:
2539
            # Emit allocation immediately (size depends on loop variables or needs loop init)
2540
            self._emit_malloc(
4✔
2541
                tmp_name, total_size, ptr_type, zero_init, ones_init, size_str, dtype
2542
            )
2543

2544
        return tmp_name
4✔
2545

2546
    def _emit_malloc(
4✔
2547
        self, tmp_name, total_size, ptr_type, zero_init, ones_init, size_str, dtype
2548
    ):
2549
        """Emit malloc and optional initialization for a temporary array."""
2550
        block1 = self.builder.add_block()
4✔
2551
        t_malloc = self.builder.add_malloc(block1, total_size)
4✔
2552
        t_ptr1 = self.builder.add_access(block1, tmp_name)
4✔
2553
        self.builder.add_memlet(block1, t_malloc, "_ret", t_ptr1, "void", "", ptr_type)
4✔
2554

2555
        if zero_init:
4✔
2556
            block2 = self.builder.add_block()
4✔
2557
            t_memset = self.builder.add_memset(block2, "0", total_size)
4✔
2558
            t_ptr2 = self.builder.add_access(block2, tmp_name)
4✔
2559
            self.builder.add_memlet(
4✔
2560
                block2, t_memset, "_ptr", t_ptr2, "void", "", ptr_type
2561
            )
2562
        elif ones_init:
4✔
2563
            loop_var = f"_i_{self._get_unique_id()}"
4✔
2564
            if not self.builder.exists(loop_var):
4✔
2565
                self.builder.add_container(loop_var, Scalar(PrimitiveType.Int64), False)
4✔
2566
                self.container_table[loop_var] = Scalar(PrimitiveType.Int64)
4✔
2567

2568
            self.builder.begin_for(loop_var, "0", size_str, "1")
4✔
2569

2570
            val = "1.0"
4✔
2571
            if dtype.primitive_type in [
4✔
2572
                PrimitiveType.Int64,
2573
                PrimitiveType.Int32,
2574
                PrimitiveType.Int8,
2575
                PrimitiveType.Int16,
2576
                PrimitiveType.UInt64,
2577
                PrimitiveType.UInt32,
2578
                PrimitiveType.UInt8,
2579
                PrimitiveType.UInt16,
2580
            ]:
2581
                val = "1"
4✔
2582

2583
            block_assign = self.builder.add_block()
4✔
2584
            t_const = self.builder.add_constant(block_assign, val, dtype)
4✔
2585
            t_arr = self.builder.add_access(block_assign, tmp_name)
4✔
2586

2587
            t_task = self.builder.add_tasklet(
4✔
2588
                block_assign, TaskletCode.assign, ["_in"], ["_out"]
2589
            )
2590
            self.builder.add_memlet(
4✔
2591
                block_assign, t_const, "void", t_task, "_in", "", dtype
2592
            )
2593
            self.builder.add_memlet(
4✔
2594
                block_assign, t_task, "_out", t_arr, "void", loop_var
2595
            )
2596

2597
            self.builder.end_for()
4✔
2598

2599
    def _compute_linear_index(self, indices, shapes, array_name, ndim):
4✔
2600
        """Compute linear index from multi-dimensional indices.
2601

2602
        Uses strides from tensor_table if available (supporting F-order arrays),
2603
        otherwise falls back to computing strides assuming C-order.
2604
        """
2605
        if ndim == 0:
×
2606
            return "0"
×
2607

2608
        # Try to get strides from tensor_table
2609
        strides = None
×
2610
        if array_name in self.tensor_table:
×
2611
            tensor_info = self.tensor_table[array_name]
×
2612
            if hasattr(tensor_info, "strides") and tensor_info.strides:
×
2613
                strides = tensor_info.strides
×
2614

2615
        if strides and len(strides) == ndim:
×
2616
            # Use explicit strides from tensor_table
2617
            linear_index = ""
×
2618
            for i in range(ndim):
×
2619
                stride = strides[i]
×
2620
                if stride == "1":
×
2621
                    term = str(indices[i])
×
2622
                else:
2623
                    term = f"(({indices[i]}) * ({stride}))"
×
2624

2625
                if i == 0:
×
2626
                    linear_index = term
×
2627
                else:
2628
                    linear_index = f"({linear_index} + {term})"
×
2629
            return linear_index
×
2630
        else:
2631
            # Fall back to C-order (row-major) stride computation
2632
            linear_index = ""
×
2633
            for i in range(ndim):
×
2634
                term = str(indices[i])
×
2635
                for j in range(i + 1, ndim):
×
2636
                    shape_val = (
×
2637
                        shapes[j] if j < len(shapes) else f"_{array_name}_shape_{j}"
2638
                    )
2639
                    term = f"(({term}) * {shape_val})"
×
2640

2641
                if i == 0:
×
2642
                    linear_index = term
×
2643
                else:
2644
                    linear_index = f"({linear_index} + {term})"
×
2645

2646
            return linear_index
×
2647

2648
    def _compute_broadcast_shape(self, shape_a, shape_b):
4✔
2649
        """Compute the broadcast output shape following NumPy broadcasting rules."""
2650
        if not shape_a:
4✔
2651
            return shape_b
4✔
2652
        if not shape_b:
4✔
2653
            return shape_a
4✔
2654

2655
        max_ndim = max(len(shape_a), len(shape_b))
4✔
2656
        padded_a = ["1"] * (max_ndim - len(shape_a)) + [str(s) for s in shape_a]
4✔
2657
        padded_b = ["1"] * (max_ndim - len(shape_b)) + [str(s) for s in shape_b]
4✔
2658

2659
        result = []
4✔
2660
        for a, b in zip(padded_a, padded_b):
4✔
2661
            if a == "1":
4✔
2662
                result.append(b)
4✔
2663
            elif b == "1":
4✔
2664
                result.append(a)
4✔
2665
            elif a == b:
4✔
2666
                result.append(a)
4✔
2667
            else:
2668
                result.append(a)
4✔
2669

2670
        return result
4✔
2671

2672
    def _needs_broadcast(self, input_shape, output_shape):
4✔
2673
        """Check if input shape needs broadcasting to match output shape."""
2674
        if len(input_shape) != len(output_shape):
4✔
2675
            return True
4✔
2676
        for in_dim, out_dim in zip(input_shape, output_shape):
4✔
2677
            if str(in_dim) != str(out_dim):
4✔
2678
                return True
4✔
2679
        return False
4✔
2680

2681
    def _compute_broadcast_strides(self, input_shape, input_strides, output_shape):
4✔
2682
        """Compute strides for broadcasting input to output shape.
2683

2684
        For broadcast dimensions (size 1), stride is set to 0 so the same
2685
        value is repeated. This enables stride-based broadcasting without copying.
2686
        """
2687
        # Pad input shape and strides on the left to match output ndim
2688
        ndim_diff = len(output_shape) - len(input_shape)
4✔
2689
        padded_shape = ["1"] * ndim_diff + [str(s) for s in input_shape]
4✔
2690
        padded_strides = ["0"] * ndim_diff + [str(s) for s in input_strides]
4✔
2691

2692
        broadcast_strides = []
4✔
2693
        for in_dim, in_stride, out_dim in zip(
4✔
2694
            padded_shape, padded_strides, output_shape
2695
        ):
2696
            # Only use stride 0 when input dimension is exactly "1" (broadcast case).
2697
            # For other cases (including symbolic dimensions that may be equal at runtime),
2698
            # keep the original stride.
2699
            if str(in_dim) == "1" and str(out_dim) != "1":
4✔
2700
                # Broadcast dimension: use stride 0
2701
                broadcast_strides.append("0")
4✔
2702
            else:
2703
                # Non-broadcast dimension or potentially equal symbolic dimensions:
2704
                # keep original stride
2705
                broadcast_strides.append(in_stride)
4✔
2706

2707
        return broadcast_strides
4✔
2708

2709
    def _shape_to_runtime_expr(self, shape_node):
4✔
2710
        """Convert a shape expression AST node to a runtime-evaluable string."""
2711
        if isinstance(shape_node, ast.Constant):
4✔
2712
            return str(shape_node.value)
4✔
2713
        elif isinstance(shape_node, ast.Name):
4✔
2714
            return shape_node.id
4✔
2715
        elif isinstance(shape_node, ast.BinOp):
4✔
2716
            left = self._shape_to_runtime_expr(shape_node.left)
4✔
2717
            right = self._shape_to_runtime_expr(shape_node.right)
4✔
2718
            op = self.visit(shape_node.op)
4✔
2719
            return f"({left} {op} {right})"
4✔
2720
        elif isinstance(shape_node, ast.UnaryOp):
4✔
2721
            operand = self._shape_to_runtime_expr(shape_node.operand)
×
2722
            if isinstance(shape_node.op, ast.USub):
×
2723
                return f"(-{operand})"
×
2724
            elif isinstance(shape_node.op, ast.UAdd):
×
2725
                return operand
×
2726
            else:
2727
                return self.visit(shape_node)
×
2728
        elif isinstance(shape_node, ast.Subscript):
4✔
2729
            val = shape_node.value
4✔
2730
            if isinstance(val, ast.Attribute) and val.attr == "shape":
4✔
2731
                if isinstance(val.value, ast.Name):
4✔
2732
                    arr_name = val.value.id
4✔
2733
                    if isinstance(shape_node.slice, ast.Constant):
4✔
2734
                        idx = shape_node.slice.value
4✔
2735
                        if arr_name in self.tensor_table:
4✔
2736
                            shapes = self.tensor_table[arr_name].shape
4✔
2737
                            if idx < len(shapes):
4✔
2738
                                return shapes[idx]
4✔
2739
                        return f"{arr_name}.shape[{idx}]"
×
2740
            return self.visit(shape_node)
×
2741
        elif isinstance(shape_node, ast.Tuple):
×
2742
            return [self._shape_to_runtime_expr(elt) for elt in shape_node.elts]
×
2743
        elif isinstance(shape_node, ast.List):
×
2744
            return [self._shape_to_runtime_expr(elt) for elt in shape_node.elts]
×
2745
        else:
2746
            return self.visit(shape_node)
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc