• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

daisytuner / docc / 22023884668

14 Feb 2026 08:36PM UTC coverage: 64.903% (-1.4%) from 66.315%
22023884668

Pull #525

github

web-flow
Merge 1d47f8bf2 into 9d01cacd5
Pull Request #525: Step 3 (Native Tensor Support): Refactor Python Frontend

2522 of 3435 new or added lines in 32 files covered. (73.42%)

320 existing lines in 15 files now uncovered.

23204 of 35752 relevant lines covered (64.9%)

370.03 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

78.24
/python/docc/python/functions/numpy.py
1
import ast
4✔
2
from docc.sdfg import (
4✔
3
    Scalar,
4
    PrimitiveType,
5
    Pointer,
6
    DebugInfo,
7
    TaskletCode,
8
    CMathFunction,
9
)
10
from docc.python.types import (
4✔
11
    element_type_from_ast_node,
12
    promote_element_types,
13
)
14
from docc.python.ast_utils import get_debug_info
4✔
15

16

17
class NumPyHandler:
4✔
18
    """
19
    Unified handler for NumPy operations including:
20
    - Array creation (empty, zeros, ones, eye, etc.)
21
    - Elementwise operations (add, subtract, multiply, etc.)
22
    - Linear algebra (matmul, dot, outer, gemm)
23
    - Array manipulation (transpose)
24
    - Reductions (sum, max, min, mean, std)
25
    """
26

27
    def __init__(self, expression_visitor):
4✔
28
        self._ev = expression_visitor
4✔
29
        self._unique_counter = 0
4✔
30
        self.function_handlers = {
4✔
31
            "empty": self._handle_numpy_alloc,
32
            "empty_like": self._handle_numpy_empty_like,
33
            "zeros": self._handle_numpy_alloc,
34
            "zeros_like": self._handle_numpy_zeros_like,
35
            "ones": self._handle_numpy_alloc,
36
            "ndarray": self._handle_numpy_alloc,
37
            "eye": self._handle_numpy_eye,
38
            "add": self._handle_numpy_binary_op,
39
            "subtract": self._handle_numpy_binary_op,
40
            "multiply": self._handle_numpy_binary_op,
41
            "divide": self._handle_numpy_binary_op,
42
            "power": self._handle_numpy_binary_op,
43
            "exp": self._handle_numpy_unary_op,
44
            "abs": self._handle_numpy_unary_op,
45
            "absolute": self._handle_numpy_unary_op,
46
            "sqrt": self._handle_numpy_unary_op,
47
            "tanh": self._handle_numpy_unary_op,
48
            "sum": self._handle_numpy_reduce,
49
            "max": self._handle_numpy_reduce,
50
            "min": self._handle_numpy_reduce,
51
            "mean": self._handle_numpy_reduce,
52
            "std": self._handle_numpy_reduce,
53
            "matmul": self._handle_numpy_matmul,
54
            "dot": self._handle_numpy_matmul,
55
            "matvec": self._handle_numpy_matmul,
56
            "outer": self._handle_numpy_outer,
57
            "minimum": self._handle_numpy_binary_op,
58
            "maximum": self._handle_numpy_binary_op,
59
            "where": self._handle_numpy_where,
60
            "clip": self._handle_numpy_clip,
61
            "transpose": self._handle_numpy_transpose,
62
        }
63

64
    # Expose parent properties for convenience
65
    @property
4✔
66
    def array_info(self):
4✔
67
        return self._ev.array_info
4✔
68

69
    @property
4✔
70
    def builder(self):
4✔
71
        return self._ev.builder
4✔
72

73
    @property
4✔
74
    def symbol_table(self):
4✔
75
        return self._ev.symbol_table
4✔
76

77
    @property
4✔
78
    def globals_dict(self):
4✔
NEW
79
        return self._ev.globals_dict
×
80

81
    def _get_unique_id(self):
4✔
82
        return self._ev._get_unique_id()
4✔
83

84
    def _add_read(self, block, expr_str, debug_info=None):
4✔
85
        return self._ev._add_read(block, expr_str, debug_info)
4✔
86

87
    def _is_int(self, operand):
4✔
88
        return self._ev._is_int(operand)
4✔
89

90
    def visit(self, node):
4✔
91
        return self._ev.visit(node)
4✔
92

93
    # ========== Linear Algebra Helper Methods (from LinearAlgebraHandler) ==========
94

95
    def parse_arg(self, node):
4✔
96
        """Parse an array argument, returning (name, start_indices, slice_shape, indices)."""
97
        if isinstance(node, ast.Name):
4✔
98
            if node.id in self.array_info:
4✔
99
                return node.id, [], self.array_info[node.id]["shapes"], []
4✔
100
        elif isinstance(node, ast.Subscript):
4✔
101
            if isinstance(node.value, ast.Name) and node.value.id in self.array_info:
4✔
102
                name = node.value.id
4✔
103
                indices = []
4✔
104
                if isinstance(node.slice, ast.Tuple):
4✔
105
                    indices = node.slice.elts
4✔
106
                else:
107
                    indices = [node.slice]
4✔
108

109
                start_indices = []
4✔
110
                slice_shape = []
4✔
111

112
                for i, idx in enumerate(indices):
4✔
113
                    if isinstance(idx, ast.Slice):
4✔
114
                        start = "0"
4✔
115
                        if idx.lower:
4✔
116
                            start = self._ev.visit(idx.lower)
4✔
117
                        start_indices.append(start)
4✔
118

119
                        shapes = self.array_info[name]["shapes"]
4✔
120
                        dim_size = (
4✔
121
                            shapes[i] if i < len(shapes) else f"_{name}_shape_{i}"
122
                        )
123
                        stop = dim_size
4✔
124
                        if idx.upper:
4✔
125
                            stop = self._ev.visit(idx.upper)
4✔
126

127
                        size = f"({stop} - {start})"
4✔
128
                        slice_shape.append(size)
4✔
129
                    else:
130
                        if isinstance(idx, ast.Name) and idx.id in self.array_info:
4✔
131
                            # This is an array index (gather operation)
132
                            return None, None, None, None
4✔
133
                        val = self._ev.visit(idx)
4✔
134
                        start_indices.append(val)
4✔
135

136
                return name, start_indices, slice_shape, indices
4✔
137

138
        return None, None, None, None
4✔
139

140
    def flatten_subset(self, name, start_indices):
4✔
141
        """Convert multi-dimensional start indices to a flattened linear offset."""
142
        if not start_indices:
4✔
143
            return []
4✔
144
        info = self.array_info[name]
4✔
145
        shapes = info["shapes"]
4✔
146
        ndim = info["ndim"]
4✔
147

148
        if len(start_indices) != ndim:
4✔
149
            return start_indices
4✔
150

151
        strides = []
4✔
152
        current_stride = "1"
4✔
153
        strides.append(current_stride)
4✔
154
        for i in range(ndim - 1, 0, -1):
4✔
155
            dim_size = shapes[i]
4✔
156
            if current_stride == "1":
4✔
157
                current_stride = str(dim_size)
4✔
158
            else:
159
                current_stride = f"({current_stride} * {dim_size})"
4✔
160
            strides.append(current_stride)
4✔
161
        strides = list(reversed(strides))
4✔
162

163
        offset = "0"
4✔
164
        for i in range(ndim):
4✔
165
            idx = start_indices[i]
4✔
166
            stride = strides[i]
4✔
167
            term = f"({idx} * {stride})" if stride != "1" else idx
4✔
168
            if offset == "0":
4✔
169
                offset = term
4✔
170
            else:
171
                offset = f"({offset} + {term})"
4✔
172

173
        return [offset]
4✔
174

175
    def is_gemm(self, node):
4✔
176
        """Check if a node represents a GEMM operation (matrix multiplication)."""
177
        if isinstance(node, ast.BinOp) and isinstance(node.op, ast.MatMult):
4✔
178
            return True
4✔
179
        if isinstance(node, ast.Call):
4✔
180
            if isinstance(node.func, ast.Attribute) and node.func.attr == "dot":
4✔
NEW
181
                return True
×
182
            if isinstance(node.func, ast.Name) and node.func.id == "dot":
4✔
NEW
183
                return True
×
184
            if isinstance(node.func, ast.Attribute) and node.func.attr == "matmul":
4✔
NEW
185
                return True
×
186
            if isinstance(node.func, ast.Name) and node.func.id == "matmul":
4✔
NEW
187
                return True
×
188
        if isinstance(node, ast.BinOp) and isinstance(node.op, ast.Add):
4✔
189
            return self.is_gemm(node.left) or self.is_gemm(node.right)
4✔
190
        if isinstance(node, ast.BinOp) and isinstance(node.op, ast.Mult):
4✔
191
            return self.is_gemm(node.left) or self.is_gemm(node.right)
4✔
192
        return False
4✔
193

194
    def _is_stride_1(self, name, indices):
4✔
195
        """Check if the sliced dimension has stride 1 (contiguous access)."""
196
        if name not in self.array_info:
4✔
NEW
197
            return True
×
198
        info = self.array_info[name]
4✔
199
        ndim = info["ndim"]
4✔
200

201
        if not indices:
4✔
202
            return True
4✔
203

NEW
204
        sliced_dim = -1
×
NEW
205
        for i, idx in enumerate(indices):
×
NEW
206
            if isinstance(idx, ast.Slice):
×
NEW
207
                sliced_dim = i
×
NEW
208
                break
×
209

NEW
210
        if sliced_dim == -1:
×
NEW
211
            if len(indices) < ndim:
×
NEW
212
                sliced_dim = ndim - 1
×
213
            else:
NEW
214
                return True
×
215

NEW
216
        return sliced_dim == ndim - 1
×
217

218
    def _is_target(self, node, target_name):
4✔
219
        """Check if node refers to the target."""
220
        if isinstance(target_name, ast.AST):
4✔
221
            return self._ev.visit(node) == self._ev.visit(target_name)
4✔
222

223
        if isinstance(node, ast.Name) and node.id == target_name:
4✔
NEW
224
            return True
×
225
        if isinstance(node, ast.Subscript):
4✔
226
            if isinstance(node.value, ast.Name) and node.value.id == target_name:
4✔
227
                return True
4✔
228
        return False
4✔
229

230
    def _is_dot_call(self, node):
4✔
231
        """Check if node is a dot product call."""
232
        if isinstance(node, ast.Call):
4✔
NEW
233
            if isinstance(node.func, ast.Attribute) and node.func.attr == "dot":
×
NEW
234
                return True
×
NEW
235
            if isinstance(node.func, ast.Name) and node.func.id == "dot":
×
NEW
236
                return True
×
237
        if isinstance(node, ast.BinOp) and isinstance(node.op, ast.MatMult):
4✔
238
            return True
4✔
239
        return False
4✔
240

241
    def handle_gemm(self, target, value_node):
4✔
242
        """Handle GEMM (General Matrix Multiply) operations: C = alpha * A @ B + beta * C."""
243
        target_name = None
4✔
244
        target_subset = []
4✔
245

246
        if isinstance(target, str):
4✔
247
            target_name = target
4✔
248
        elif isinstance(target, ast.Name):
4✔
249
            target_name = target.id
4✔
250
        elif isinstance(target, ast.Subscript):
4✔
251
            if isinstance(target.value, ast.Name):
4✔
252
                res = self.parse_arg(target)
4✔
253
                if res[0]:
4✔
254
                    target_name = res[0]
4✔
255
                    target_subset = self.flatten_subset(target_name, res[1])
4✔
256
                else:
NEW
257
                    target_name = target.value.id
×
258

259
        if not target_name or target_name not in self.array_info:
4✔
260
            return False
4✔
261

262
        alpha = "1.0"
4✔
263
        beta = "0.0"
4✔
264
        A = None
4✔
265
        B = None
4✔
266

267
        def extract_factor(node):
4✔
268
            if isinstance(node, ast.BinOp) and isinstance(node.op, ast.Mult):
4✔
NEW
269
                if self.is_gemm(node.left):
×
NEW
270
                    return node.left, self._ev.visit(node.right)
×
NEW
271
                if self.is_gemm(node.right):
×
NEW
272
                    return node.right, self._ev.visit(node.left)
×
273

NEW
274
                res = self.parse_arg(node.left)
×
NEW
275
                if res[0]:
×
NEW
276
                    return node.left, self._ev.visit(node.right)
×
NEW
277
                res = self.parse_arg(node.right)
×
NEW
278
                if res[0]:
×
NEW
279
                    return node.right, self._ev.visit(node.left)
×
280
            return node, "1.0"
4✔
281

282
        def parse_term(node):
4✔
283
            if isinstance(node, ast.BinOp) and isinstance(node.op, ast.MatMult):
4✔
284
                l, l_f = extract_factor(node.left)
4✔
285
                r, r_f = extract_factor(node.right)
4✔
286
                f = "1.0"
4✔
287
                if l_f != "1.0":
4✔
NEW
288
                    f = l_f
×
289
                if r_f != "1.0":
4✔
NEW
290
                    if f == "1.0":
×
NEW
291
                        f = r_f
×
292
                    else:
NEW
293
                        f = f"({f} * {r_f})"
×
294
                return l, r, f
4✔
295

NEW
296
            if isinstance(node, ast.Call):
×
NEW
297
                is_gemm_call = False
×
NEW
298
                if isinstance(node.func, ast.Attribute) and node.func.attr in [
×
299
                    "dot",
300
                    "matmul",
301
                ]:
NEW
302
                    is_gemm_call = True
×
NEW
303
                if isinstance(node.func, ast.Name) and node.func.id in [
×
304
                    "dot",
305
                    "matmul",
306
                ]:
NEW
307
                    is_gemm_call = True
×
308

NEW
309
                if is_gemm_call and len(node.args) == 2:
×
NEW
310
                    return node.args[0], node.args[1], "1.0"
×
311

NEW
312
            if isinstance(node, ast.BinOp) and isinstance(node.op, ast.Mult):
×
NEW
313
                l, r, a = parse_term(node.left)
×
NEW
314
                if l:
×
NEW
315
                    return l, r, self._ev.visit(node.right)
×
NEW
316
                l, r, a = parse_term(node.right)
×
NEW
317
                if l:
×
NEW
318
                    return l, r, self._ev.visit(node.left)
×
319

NEW
320
            return None, None, None
×
321

322
        if isinstance(value_node, ast.BinOp) and isinstance(value_node.op, ast.Add):
4✔
NEW
323
            l, r, a = parse_term(value_node.left)
×
NEW
324
            if l:
×
NEW
325
                A = l
×
NEW
326
                B = r
×
NEW
327
                alpha = a
×
NEW
328
                if isinstance(value_node.right, ast.BinOp) and isinstance(
×
329
                    value_node.right.op, ast.Mult
330
                ):
NEW
331
                    if self._is_target(value_node.right.left, target_name):
×
NEW
332
                        beta = self._ev.visit(value_node.right.right)
×
NEW
333
                    elif self._is_target(value_node.right.right, target_name):
×
NEW
334
                        beta = self._ev.visit(value_node.right.left)
×
NEW
335
                elif self._is_target(value_node.right, target_name):
×
NEW
336
                    beta = "1.0"
×
337
            else:
NEW
338
                l, r, a = parse_term(value_node.right)
×
NEW
339
                if l:
×
NEW
340
                    A = l
×
NEW
341
                    B = r
×
NEW
342
                    alpha = a
×
NEW
343
                    if isinstance(value_node.left, ast.BinOp) and isinstance(
×
344
                        value_node.left.op, ast.Mult
345
                    ):
NEW
346
                        if self._is_target(value_node.left.left, target_name):
×
NEW
347
                            beta = self._ev.visit(value_node.left.right)
×
NEW
348
                        elif self._is_target(value_node.left.right, target_name):
×
NEW
349
                            beta = self._ev.visit(value_node.left.left)
×
NEW
350
                    elif self._is_target(value_node.left, target_name):
×
NEW
351
                        beta = "1.0"
×
352
        else:
353
            l, r, a = parse_term(value_node)
4✔
354
            if l:
4✔
355
                A = l
4✔
356
                B = r
4✔
357
                alpha = a
4✔
358

359
        if A is None or B is None:
4✔
NEW
360
            return False
×
361

362
        def get_name_and_trans(node):
4✔
363
            if isinstance(node, ast.Attribute) and node.attr == "T":
4✔
NEW
364
                return node.value, True
×
365
            return node, False
4✔
366

367
        A_node, trans_a = get_name_and_trans(A)
4✔
368
        B_node, trans_b = get_name_and_trans(B)
4✔
369

370
        if self.is_gemm(A_node):
4✔
NEW
371
            tmp_name = self._ev.visit(A_node)
×
NEW
372
            A_node = ast.Name(id=tmp_name)
×
373

374
        if self.is_gemm(B_node):
4✔
NEW
375
            tmp_name = self._ev.visit(B_node)
×
NEW
376
            B_node = ast.Name(id=tmp_name)
×
377

378
        res_a = self.parse_arg(A_node)
4✔
379
        res_b = self.parse_arg(B_node)
4✔
380

381
        if not res_a[0] or not res_b[0]:
4✔
382
            return False
4✔
383

384
        A_name, subset_a, shape_a, indices_a = res_a
4✔
385
        B_name, subset_b, shape_b, indices_b = res_b
4✔
386

387
        flat_subset_a = self.flatten_subset(A_name, subset_a)
4✔
388
        flat_subset_b = self.flatten_subset(B_name, subset_b)
4✔
389

390
        def get_ndim(name):
4✔
391
            if name not in self.array_info:
4✔
NEW
392
                return 1
×
393
            return self.array_info[name]["ndim"]
4✔
394

395
        if len(shape_a) == 2:
4✔
396
            if not trans_a:
4✔
397
                m = shape_a[0]
4✔
398
                k = shape_a[1]
4✔
399
            else:
NEW
400
                m = shape_a[1]
×
NEW
401
                k = shape_a[0]
×
402
        else:
NEW
403
            m = "1"
×
NEW
404
            k = shape_a[0]
×
NEW
405
            if self._is_stride_1(A_name, indices_a):
×
NEW
406
                if get_ndim(A_name) == 1:
×
NEW
407
                    trans_a = True
×
408
                else:
NEW
409
                    trans_a = False
×
410
            else:
NEW
411
                trans_a = True
×
412

413
        if len(shape_b) == 2:
4✔
414
            if not trans_b:
4✔
415
                n = shape_b[1]
4✔
416
            else:
NEW
417
                n = shape_b[0]
×
418
        else:
419
            n = "1"
4✔
420
            if self._is_stride_1(B_name, indices_b):
4✔
421
                if get_ndim(B_name) == 1:
4✔
422
                    trans_b = False
4✔
423
                else:
NEW
424
                    trans_b = True
×
425
            else:
NEW
426
                trans_b = False
×
427

428
        def get_ld(name):
4✔
429
            if name not in self.array_info:
4✔
NEW
430
                return ""
×
431
            shapes = self.array_info[name]["shapes"]
4✔
432
            if len(shapes) >= 2:
4✔
433
                return str(shapes[1])
4✔
434
            return "1"
4✔
435

436
        lda = get_ld(A_name)
4✔
437
        ldb = get_ld(B_name)
4✔
438

439
        ldc = ""
4✔
440
        if target_name:
4✔
441
            if get_ndim(target_name) == 1 and m == "1":
4✔
NEW
442
                ldc = n
×
443
            else:
444
                ldc = get_ld(target_name)
4✔
445

446
        self.builder.add_gemm(
4✔
447
            A_name,
448
            B_name,
449
            target_name,
450
            alpha,
451
            beta,
452
            m,
453
            n,
454
            k,
455
            trans_a,
456
            trans_b,
457
            flat_subset_a,
458
            flat_subset_b,
459
            target_subset,
460
            lda,
461
            ldb,
462
            ldc,
463
        )
464
        return True
4✔
465

466
    def handle_dot(self, target, value_node):
4✔
467
        """Handle dot product operations for 1D vectors."""
468
        dot_node = None
4✔
469
        is_accumulate = False
4✔
470

471
        if self._is_dot_call(value_node):
4✔
472
            dot_node = value_node
4✔
473
        elif isinstance(value_node, ast.BinOp) and isinstance(value_node.op, ast.Add):
4✔
474
            if self._is_dot_call(value_node.left):
4✔
475
                dot_node = value_node.left
4✔
476
                if self._is_target(value_node.right, target):
4✔
NEW
477
                    is_accumulate = True
×
NEW
478
            elif self._is_dot_call(value_node.right):
×
NEW
479
                dot_node = value_node.right
×
NEW
480
                if self._is_target(value_node.left, target):
×
NEW
481
                    is_accumulate = True
×
482

483
        if not dot_node:
4✔
NEW
484
            return False
×
485

486
        arg0 = None
4✔
487
        arg1 = None
4✔
488

489
        if isinstance(dot_node, ast.Call):
4✔
NEW
490
            args = dot_node.args
×
NEW
491
            if len(args) != 2:
×
NEW
492
                return False
×
NEW
493
            arg0 = args[0]
×
NEW
494
            arg1 = args[1]
×
495
        elif isinstance(dot_node, ast.BinOp) and isinstance(dot_node.op, ast.MatMult):
4✔
496
            arg0 = dot_node.left
4✔
497
            arg1 = dot_node.right
4✔
498

499
        res_a = self.parse_arg(arg0)
4✔
500
        res_b = self.parse_arg(arg1)
4✔
501

502
        if not res_a[0] or not res_b[0]:
4✔
503
            return False
4✔
504

505
        name_a, subset_a, shape_a, indices_a = res_a
4✔
506
        name_b, subset_b, shape_b, indices_b = res_b
4✔
507

508
        if len(shape_a) != 1 or len(shape_b) != 1:
4✔
509
            return False
4✔
510

511
        n = shape_a[0]
4✔
512

513
        def get_stride(name, indices):
4✔
514
            if not indices:
4✔
515
                return "1"
4✔
516
            info = self.array_info[name]
4✔
517
            shapes = info["shapes"]
4✔
518
            ndim = info["ndim"]
4✔
519

520
            sliced_dim = -1
4✔
521
            for i, idx in enumerate(indices):
4✔
522
                if isinstance(idx, ast.Slice):
4✔
523
                    sliced_dim = i
4✔
524
                    break
4✔
525

526
            if sliced_dim == -1:
4✔
NEW
527
                return "1"
×
528

529
            stride = "1"
4✔
530
            for i in range(sliced_dim + 1, ndim):
4✔
NEW
531
                dim_size = shapes[i] if i < len(shapes) else f"_{name}_shape_{i}"
×
NEW
532
                if stride == "1":
×
NEW
533
                    stride = str(dim_size)
×
534
                else:
NEW
535
                    stride = f"({stride} * {dim_size})"
×
536
            return stride
4✔
537

538
        incx = get_stride(name_a, indices_a)
4✔
539
        incy = get_stride(name_b, indices_b)
4✔
540

541
        flat_subset_a = self.flatten_subset(name_a, subset_a)
4✔
542
        flat_subset_b = self.flatten_subset(name_b, subset_b)
4✔
543

544
        tmp_res = f"_dot_res_{self._get_unique_id()}"
4✔
545
        self.builder.add_container(tmp_res, Scalar(PrimitiveType.Double), False)
4✔
546
        block = self.builder.add_block()
4✔
547
        constant = self.builder.add_constant(block, "0.0", Scalar(PrimitiveType.Double))
4✔
548
        tasklet = self.builder.add_tasklet(block, TaskletCode.assign, ["_in"], ["_out"])
4✔
549
        self.builder.add_memlet(
4✔
550
            block, constant, "", tasklet, "_in", "", Scalar(PrimitiveType.Double)
551
        )
552
        access = self.builder.add_access(block, tmp_res)
4✔
553
        self.builder.add_memlet(
4✔
554
            block, tasklet, "_out", access, "", "", Scalar(PrimitiveType.Double)
555
        )
556

557
        self.symbol_table[tmp_res] = Scalar(PrimitiveType.Double)
4✔
558

559
        self.builder.add_dot(
4✔
560
            name_a, name_b, tmp_res, n, incx, incy, flat_subset_a, flat_subset_b
561
        )
562

563
        target_str = target if isinstance(target, str) else self._ev.visit(target)
4✔
564

565
        if not self.builder.exists(target_str):
4✔
NEW
566
            self.builder.add_container(target_str, Scalar(PrimitiveType.Double), False)
×
NEW
567
            self.symbol_table[target_str] = Scalar(PrimitiveType.Double)
×
568

569
        if is_accumulate:
4✔
NEW
570
            self.builder.add_assignment(target_str, f"{target_str} + {tmp_res}")
×
571
        else:
572
            self.builder.add_assignment(target_str, tmp_res)
4✔
573

574
        return True
4✔
575

576
    def is_outer(self, node):
4✔
577
        """Check if a node represents an outer product operation."""
578
        if isinstance(node, ast.Call):
4✔
579
            if isinstance(node.func, ast.Attribute) and node.func.attr == "outer":
4✔
580
                return True
4✔
581
            if isinstance(node.func, ast.Name) and node.func.id == "outer":
4✔
NEW
582
                return True
×
583
        if isinstance(node, ast.BinOp) and isinstance(node.op, ast.Add):
4✔
584
            return self.is_outer(node.left) or self.is_outer(node.right)
4✔
585
        return False
4✔
586

587
    def handle_outer(self, target, value_node):
4✔
588
        """Handle outer product operations."""
589
        target_name = None
4✔
590
        target_subset = []
4✔
591

592
        if isinstance(target, str):
4✔
593
            target_name = target
4✔
594
        elif isinstance(target, ast.Name):
4✔
595
            target_name = target.id
4✔
596
        elif isinstance(target, ast.Subscript):
4✔
597
            res = self.parse_arg(target)
4✔
598
            if res[0]:
4✔
599
                target_name = res[0]
4✔
600
                target_subset = self.flatten_subset(target_name, res[1])
4✔
601
            else:
NEW
602
                if isinstance(target.value, ast.Name):
×
NEW
603
                    target_name = target.value.id
×
604

605
        if not target_name:
4✔
NEW
606
            return False
×
607

608
        outer_calls = []
4✔
609
        target_found = False
4✔
610
        terms = []
4✔
611

612
        def collect_terms(node):
4✔
613
            if isinstance(node, ast.BinOp) and isinstance(node.op, ast.Add):
4✔
614
                collect_terms(node.left)
4✔
615
                collect_terms(node.right)
4✔
616
            else:
617
                terms.append(node)
4✔
618

619
        collect_terms(value_node)
4✔
620

621
        for term in terms:
4✔
622
            if self._is_target(term, target_name):
4✔
623
                target_found = True
4✔
624
            elif isinstance(term, ast.Call) and (
4✔
625
                (isinstance(term.func, ast.Attribute) and term.func.attr == "outer")
626
                or (isinstance(term.func, ast.Name) and term.func.id == "outer")
627
            ):
628
                if len(term.args) != 2:
4✔
NEW
629
                    return False
×
630
                outer_calls.append(term)
4✔
631
            else:
NEW
632
                return False
×
633

634
        if not outer_calls:
4✔
NEW
635
            return False
×
636

637
        parsed_outers = []
4✔
638
        for outer_node in outer_calls:
4✔
639
            arg0 = outer_node.args[0]
4✔
640
            arg1 = outer_node.args[1]
4✔
641

642
            res_a = self.parse_arg(arg0)
4✔
643
            res_b = self.parse_arg(arg1)
4✔
644

645
            if not res_a[0] or not res_b[0]:
4✔
NEW
646
                return False
×
647

648
            parsed_outers.append((res_a, res_b))
4✔
649

650
        alpha = "1.0"
4✔
651
        beta = "1.0" if target_found else "0.0"
4✔
652

653
        def get_flattened_size(name, indices, shapes):
4✔
654
            size_expr = "1"
4✔
655
            for s in shapes:
4✔
656
                if size_expr == "1":
4✔
657
                    size_expr = str(s)
4✔
658
                else:
NEW
659
                    size_expr = f"({size_expr} * {str(s)})"
×
660
            return size_expr
4✔
661

662
        def get_ld_2d(name):
4✔
663
            if name in self.array_info:
4✔
664
                shapes = self.array_info[name]["shapes"]
4✔
665
                if len(shapes) >= 2:
4✔
666
                    return str(shapes[1])
4✔
667
            return "1"
4✔
668

669
        ldc = get_ld_2d(target_name)
4✔
670

671
        for res_a, res_b in parsed_outers:
4✔
672
            name_a, subset_a, shape_a, indices_a = res_a
4✔
673
            name_b, subset_b, shape_b, indices_b = res_b
4✔
674

675
            m = get_flattened_size(name_a, indices_a, shape_a)
4✔
676
            n = get_flattened_size(name_b, indices_b, shape_b)
4✔
677
            k = "1"
4✔
678

679
            trans_a = False
4✔
680
            trans_b = True
4✔
681

682
            flat_subset_a = self.flatten_subset(name_a, subset_a)
4✔
683
            flat_subset_b = self.flatten_subset(name_b, subset_b)
4✔
684

685
            lda = "1"
4✔
686
            ldb = "1"
4✔
687

688
            self.builder.add_gemm(
4✔
689
                name_a,
690
                name_b,
691
                target_name,
692
                alpha,
693
                beta,
694
                m,
695
                n,
696
                k,
697
                trans_a,
698
                trans_b,
699
                flat_subset_a,
700
                flat_subset_b,
701
                target_subset,
702
                lda,
703
                ldb,
704
                ldc,
705
            )
706
            beta = "1.0"
4✔
707

708
        return True
4✔
709

710
    # ========== Transpose Operations ==========
711

712
    def _parse_perm(self, node):
4✔
713
        """Parse a permutation list or tuple from an AST node."""
714
        if isinstance(node, (ast.List, ast.Tuple)):
4✔
715
            res = []
4✔
716
            for elt in node.elts:
4✔
717
                val = self._ev.visit(elt)
4✔
718
                res.append(int(val))
4✔
719
            return res
4✔
NEW
720
        return []
×
721

722
    def is_transpose(self, node):
4✔
723
        """Check if a node represents a transpose operation."""
724
        # Case 1: np.transpose(arr, ...)
725
        if isinstance(node, ast.Call):
4✔
726
            if isinstance(node.func, ast.Attribute) and node.func.attr == "transpose":
4✔
727
                return True
4✔
728
            if isinstance(node.func, ast.Name) and node.func.id == "transpose":
4✔
NEW
729
                return True
×
730

731
        # Case 2: arr.T
732
        if isinstance(node, ast.Attribute) and node.attr == "T":
4✔
733
            return True
4✔
734

735
        return False
4✔
736

737
    def handle_transpose(self, target, value_node):
4✔
738
        """Handle transpose operations including .T and np.transpose()."""
739
        if not self.is_transpose(value_node):
4✔
NEW
740
            return False
×
741

742
        input_node = None
4✔
743
        perm = []
4✔
744

745
        if isinstance(value_node, ast.Attribute) and value_node.attr == "T":
4✔
746
            input_node = value_node.value
4✔
747
            perm = []  # Empty means reverse
4✔
748

749
        elif isinstance(value_node, ast.Call):
4✔
750
            args = value_node.args
4✔
751
            keywords = value_node.keywords
4✔
752

753
            is_numpy_func = False
4✔
754
            if isinstance(value_node.func, ast.Attribute):
4✔
755
                caller = ""
4✔
756
                if isinstance(value_node.func.value, ast.Name):
4✔
757
                    caller = value_node.func.value.id
4✔
758
                if caller in ["np", "numpy"]:
4✔
759
                    is_numpy_func = True
4✔
NEW
760
            elif isinstance(value_node.func, ast.Name):
×
NEW
761
                is_numpy_func = True
×
762

763
            if is_numpy_func:
4✔
764
                if len(args) < 1:
4✔
NEW
765
                    return False
×
766
                input_node = args[0]
4✔
767
                if len(args) > 1:
4✔
NEW
768
                    perm = self._parse_perm(args[1])
×
769
                for kw in keywords:
4✔
770
                    if kw.arg == "axes":
4✔
771
                        perm = self._parse_perm(kw.value)
4✔
772
            else:
NEW
773
                if isinstance(value_node.func, ast.Attribute):
×
NEW
774
                    input_node = value_node.func.value
×
775
                else:
NEW
776
                    return False
×
NEW
777
                if len(args) > 0:
×
NEW
778
                    perm = self._parse_perm(args[0])
×
NEW
779
                for kw in keywords:
×
NEW
780
                    if kw.arg == "axes":
×
NEW
781
                        perm = self._parse_perm(kw.value)
×
782

783
        input_name = self._ev.visit(input_node)
4✔
784
        if input_name not in self.array_info:
4✔
NEW
785
            return False
×
786

787
        in_info = self.array_info[input_name]
4✔
788
        in_shape = in_info["shapes"]
4✔
789
        in_strings = [str(s) for s in in_shape]
4✔
790

791
        if not perm:
4✔
792
            perm = list(range(len(in_shape)))[::-1]
4✔
793

794
        out_shape = [in_strings[p] for p in perm]
4✔
795

796
        target_name = ""
4✔
797
        if isinstance(target, ast.Name):
4✔
798
            target_name = target.id
4✔
NEW
799
        elif isinstance(target, str):
×
NEW
800
            target_name = target
×
801

802
        dtype = Scalar(PrimitiveType.Double)
4✔
803
        if input_name in self.symbol_table:
4✔
804
            input_type = self.symbol_table[input_name]
4✔
805
            if isinstance(input_type, Pointer):
4✔
806
                dtype = input_type.pointee_type
4✔
807
            else:
NEW
808
                dtype = input_type
×
809

810
        ptr_type = Pointer(dtype)
4✔
811

812
        if not self.builder.exists(target_name):
4✔
813
            self.builder.add_container(target_name, ptr_type, False)
4✔
814
            self.symbol_table[target_name] = ptr_type
4✔
815
            self.array_info[target_name] = {"ndim": len(out_shape), "shapes": out_shape}
4✔
816

817
            block_alloc = self.builder.add_block()
4✔
818
            size_expr = "1"
4✔
819
            for dim in out_shape:
4✔
820
                size_expr = f"({size_expr} * {dim})"
4✔
821
            element_size = self.builder.get_sizeof(dtype)
4✔
822
            total_size = f"({size_expr} * {element_size})"
4✔
823

824
            t_malloc = self.builder.add_malloc(block_alloc, total_size)
4✔
825
            t_ptr = self.builder.add_access(block_alloc, target_name)
4✔
826
            self.builder.add_memlet(
4✔
827
                block_alloc, t_malloc, "_ret", t_ptr, "void", "", ptr_type
828
            )
829

830
        debug_info = get_debug_info(
4✔
831
            value_node, getattr(self.builder, "filename", ""), ""
832
        )
833

834
        self.builder.add_transpose(
4✔
835
            input_name, target_name, in_strings, perm, debug_info
836
        )
837
        return True
4✔
838

839
    def handle_transpose_expr(self, node):
4✔
840
        """Handle .T attribute access in expressions, returning a temp array name."""
NEW
841
        if not isinstance(node, ast.Attribute) or node.attr != "T":
×
NEW
842
            return None
×
843

NEW
844
        input_name = self._ev.visit(node.value)
×
NEW
845
        if input_name not in self.array_info:
×
NEW
846
            return None
×
847

NEW
848
        in_info = self.array_info[input_name]
×
NEW
849
        in_shape = in_info["shapes"]
×
NEW
850
        in_strings = [str(s) for s in in_shape]
×
NEW
851
        perm = list(range(len(in_shape)))[::-1]
×
NEW
852
        out_shape = [in_strings[p] for p in perm]
×
853

NEW
854
        dtype = Scalar(PrimitiveType.Double)
×
NEW
855
        if input_name in self.symbol_table:
×
NEW
856
            input_type = self.symbol_table[input_name]
×
NEW
857
            if isinstance(input_type, Pointer):
×
NEW
858
                dtype = input_type.pointee_type
×
859
            else:
NEW
860
                dtype = input_type
×
861

NEW
862
        tmp_name = self._create_array_temp(out_shape, dtype)
×
863

NEW
864
        debug_info = get_debug_info(node, getattr(self.builder, "filename", ""), "")
×
NEW
865
        self.builder.add_transpose(input_name, tmp_name, in_strings, perm, debug_info)
×
866

NEW
867
        return tmp_name
×
868

869
    def _handle_numpy_transpose(self, node, func_name):
4✔
870
        """Handle np.transpose(arr, axes=...) function call."""
NEW
871
        if len(node.args) < 1:
×
NEW
872
            raise ValueError("np.transpose requires at least one argument")
×
873

NEW
874
        input_node = node.args[0]
×
NEW
875
        input_name = self.visit(input_node)
×
876

NEW
877
        if input_name not in self.array_info:
×
NEW
878
            raise ValueError(f"Array {input_name} not found in array_info")
×
879

NEW
880
        in_info = self.array_info[input_name]
×
NEW
881
        in_shape = in_info["shapes"]
×
NEW
882
        in_strings = [str(s) for s in in_shape]
×
883

NEW
884
        perm = []
×
NEW
885
        if len(node.args) > 1:
×
NEW
886
            perm = self._parse_perm(node.args[1])
×
NEW
887
        for kw in node.keywords:
×
NEW
888
            if kw.arg == "axes":
×
NEW
889
                perm = self._parse_perm(kw.value)
×
890

NEW
891
        if not perm:
×
NEW
892
            perm = list(range(len(in_shape)))[::-1]
×
893

NEW
894
        out_shape = [in_strings[p] for p in perm]
×
895

NEW
896
        dtype = Scalar(PrimitiveType.Double)
×
NEW
897
        if input_name in self.symbol_table:
×
NEW
898
            input_type = self.symbol_table[input_name]
×
NEW
899
            if isinstance(input_type, Pointer):
×
NEW
900
                dtype = input_type.pointee_type
×
901
            else:
NEW
902
                dtype = input_type
×
903

NEW
904
        tmp_name = self._create_array_temp(out_shape, dtype)
×
905

NEW
906
        debug_info = get_debug_info(node, getattr(self.builder, "filename", ""), "")
×
NEW
907
        self.builder.add_transpose(input_name, tmp_name, in_strings, perm, debug_info)
×
908

NEW
909
        return tmp_name
×
910

911
    def handle_numpy_call(self, node, func_name):
4✔
912
        if func_name in self.function_handlers:
4✔
913
            return self.function_handlers[func_name](node, func_name)
4✔
NEW
914
        raise NotImplementedError(f"NumPy function {func_name} not supported")
×
915

916
    def has_handler(self, func_name):
4✔
917
        return func_name in self.function_handlers
4✔
918

919
    def handle_array_unary_op(self, op_type, operand):
4✔
920
        shape = []
4✔
921
        if operand in self.array_info:
4✔
922
            shape = self.array_info[operand]["shapes"]
4✔
923

924
        dtype = self._ev._element_type(operand)
4✔
925

926
        if not shape or len(shape) == 0:
4✔
927
            tmp_name = self._create_array_temp(shape, dtype)
4✔
928

929
            func_map = {
4✔
930
                "sqrt": CMathFunction.sqrt,
931
                "abs": CMathFunction.fabs,
932
                "absolute": CMathFunction.fabs,
933
                "exp": CMathFunction.exp,
934
                "tanh": CMathFunction.tanh,
935
            }
936

937
            block = self.builder.add_block()
4✔
938
            t_src = self.builder.add_access(block, operand)
4✔
939
            t_dst = self.builder.add_access(block, tmp_name)
4✔
940
            t_task = self.builder.add_cmath(block, func_map[op_type])
4✔
941

942
            self.builder.add_memlet(block, t_src, "void", t_task, "_in1", "", dtype)
4✔
943
            self.builder.add_memlet(block, t_task, "_out", t_dst, "void", "", dtype)
4✔
944

945
            return tmp_name
4✔
946

947
        tmp_name = self._create_array_temp(shape, dtype)
4✔
948
        self.builder.add_elementwise_unary_op(op_type, operand, tmp_name, shape)
4✔
949

950
        return tmp_name
4✔
951

952
    def handle_array_binary_op(self, op_type, left, right):
4✔
953
        left_shape = []
4✔
954
        right_shape = []
4✔
955
        if left in self.array_info:
4✔
956
            left_shape = self.array_info[left]["shapes"]
4✔
957
        if right in self.array_info:
4✔
958
            right_shape = self.array_info[right]["shapes"]
4✔
959

960
        shape = self._compute_broadcast_shape(left_shape, right_shape)
4✔
961

962
        dtype_left = self._ev._element_type(left)
4✔
963
        dtype_right = self._ev._element_type(right)
4✔
964
        dtype = promote_element_types(dtype_left, dtype_right)
4✔
965

966
        real_left = left
4✔
967
        real_right = right
4✔
968

969
        left_is_scalar = left not in self.array_info
4✔
970
        right_is_scalar = right not in self.array_info
4✔
971

972
        # Cast left operand if needed
973
        if left_is_scalar and dtype_left.primitive_type != dtype.primitive_type:
4✔
974
            left_cast = f"_tmp_{self._get_unique_id()}"
4✔
975
            self.builder.add_container(left_cast, dtype, False)
4✔
976
            self.symbol_table[left_cast] = dtype
4✔
977

978
            c_block = self.builder.add_block()
4✔
979
            t_src, src_sub = self._add_read(c_block, left)
4✔
980
            t_dst = self.builder.add_access(c_block, left_cast)
4✔
981
            t_task = self.builder.add_tasklet(
4✔
982
                c_block, TaskletCode.assign, ["_in"], ["_out"]
983
            )
984
            self.builder.add_memlet(c_block, t_src, "void", t_task, "_in", src_sub)
4✔
985
            self.builder.add_memlet(c_block, t_task, "_out", t_dst, "void", "")
4✔
986

987
            real_left = left_cast
4✔
988

989
        # Cast right operand if needed
990
        if right_is_scalar and dtype_right.primitive_type != dtype.primitive_type:
4✔
991
            right_cast = f"_tmp_{self._get_unique_id()}"
4✔
992
            self.builder.add_container(right_cast, dtype, False)
4✔
993
            self.symbol_table[right_cast] = dtype
4✔
994

995
            c_block = self.builder.add_block()
4✔
996
            t_src, src_sub = self._add_read(c_block, right)
4✔
997
            t_dst = self.builder.add_access(c_block, right_cast)
4✔
998
            t_task = self.builder.add_tasklet(
4✔
999
                c_block, TaskletCode.assign, ["_in"], ["_out"]
1000
            )
1001
            self.builder.add_memlet(c_block, t_src, "void", t_task, "_in", src_sub)
4✔
1002
            self.builder.add_memlet(c_block, t_task, "_out", t_dst, "void", "")
4✔
1003

1004
            real_right = right_cast
4✔
1005

1006
        # Broadcast arrays if needed
1007
        if not left_is_scalar and self._needs_broadcast(left_shape, shape):
4✔
1008
            real_left = self._broadcast_array(real_left, left_shape, shape, dtype)
4✔
1009

1010
        if not right_is_scalar and self._needs_broadcast(right_shape, shape):
4✔
1011
            real_right = self._broadcast_array(real_right, right_shape, shape, dtype)
4✔
1012

1013
        tmp_name = self._create_array_temp(shape, dtype)
4✔
1014
        self.builder.add_elementwise_op(op_type, real_left, real_right, tmp_name, shape)
4✔
1015

1016
        return tmp_name
4✔
1017

1018
    def handle_array_negate(self, operand):
4✔
1019
        shape = self.array_info[operand]["shapes"]
4✔
1020
        dtype = self._ev._element_type(operand)
4✔
1021

1022
        tmp_name = self._create_array_temp(shape, dtype)
4✔
1023

1024
        zero_name = f"_tmp_{self._get_unique_id()}"
4✔
1025
        self.builder.add_container(zero_name, dtype, False)
4✔
1026
        self.symbol_table[zero_name] = dtype
4✔
1027

1028
        zero_block = self.builder.add_block()
4✔
1029
        t_const = self.builder.add_constant(
4✔
1030
            zero_block,
1031
            "0.0" if dtype.primitive_type == PrimitiveType.Double else "0",
1032
            dtype,
1033
        )
1034
        t_zero = self.builder.add_access(zero_block, zero_name)
4✔
1035
        t_assign = self.builder.add_tasklet(
4✔
1036
            zero_block, TaskletCode.assign, ["_in"], ["_out"]
1037
        )
1038
        self.builder.add_memlet(zero_block, t_const, "void", t_assign, "_in", "")
4✔
1039
        self.builder.add_memlet(zero_block, t_assign, "_out", t_zero, "void", "")
4✔
1040

1041
        self.builder.add_elementwise_op("sub", zero_name, operand, tmp_name, shape)
4✔
1042

1043
        return tmp_name
4✔
1044

1045
    def handle_array_compare(self, left, op, right, left_is_array, right_is_array):
4✔
1046
        """Handle elementwise comparison of arrays, returning a boolean array."""
1047
        if left_is_array:
4✔
1048
            shape = self.array_info[left]["shapes"]
4✔
1049
            arr_name = left
4✔
1050
        else:
NEW
1051
            shape = self.array_info[right]["shapes"]
×
NEW
1052
            arr_name = right
×
1053

1054
        use_int_cmp = False
4✔
1055
        arr_dtype = self._ev._element_type(arr_name)
4✔
1056
        if arr_dtype.primitive_type in (PrimitiveType.Int32, PrimitiveType.Int64):
4✔
NEW
1057
            use_int_cmp = True
×
1058

1059
        dtype = Scalar(PrimitiveType.Bool)
4✔
1060
        tmp_name = self._create_array_temp(shape, dtype)
4✔
1061

1062
        if use_int_cmp:
4✔
NEW
1063
            cmp_ops = {
×
1064
                ">": TaskletCode.int_sgt,
1065
                ">=": TaskletCode.int_sge,
1066
                "<": TaskletCode.int_slt,
1067
                "<=": TaskletCode.int_sle,
1068
                "==": TaskletCode.int_eq,
1069
                "!=": TaskletCode.int_ne,
1070
            }
1071
        else:
1072
            cmp_ops = {
4✔
1073
                ">": TaskletCode.fp_ogt,
1074
                ">=": TaskletCode.fp_oge,
1075
                "<": TaskletCode.fp_olt,
1076
                "<=": TaskletCode.fp_ole,
1077
                "==": TaskletCode.fp_oeq,
1078
                "!=": TaskletCode.fp_one,
1079
            }
1080

1081
        if op not in cmp_ops:
4✔
NEW
1082
            raise NotImplementedError(
×
1083
                f"Comparison operator {op} not supported for arrays"
1084
            )
1085

1086
        tasklet_code = cmp_ops[op]
4✔
1087

1088
        scalar_name = None
4✔
1089
        if not left_is_array:
4✔
NEW
1090
            scalar_name = left
×
1091
        elif not right_is_array:
4✔
1092
            scalar_name = right
4✔
1093

1094
        if scalar_name is not None and not use_int_cmp:
4✔
1095
            if self._is_int(scalar_name):
4✔
1096
                float_name = f"_tmp_{self._get_unique_id()}"
4✔
1097
                self.builder.add_container(
4✔
1098
                    float_name, Scalar(PrimitiveType.Double), False
1099
                )
1100
                self.symbol_table[float_name] = Scalar(PrimitiveType.Double)
4✔
1101

1102
                block_conv = self.builder.add_block()
4✔
1103
                t_const = self.builder.add_constant(
4✔
1104
                    block_conv, f"{scalar_name}.0", Scalar(PrimitiveType.Double)
1105
                )
1106
                t_float = self.builder.add_access(block_conv, float_name)
4✔
1107
                t_assign = self.builder.add_tasklet(
4✔
1108
                    block_conv, TaskletCode.assign, ["_in"], ["_out"]
1109
                )
1110
                self.builder.add_memlet(
4✔
1111
                    block_conv, t_const, "void", t_assign, "_in", ""
1112
                )
1113
                self.builder.add_memlet(
4✔
1114
                    block_conv, t_assign, "_out", t_float, "void", ""
1115
                )
1116

1117
                if not left_is_array:
4✔
NEW
1118
                    left = float_name
×
1119
                else:
1120
                    right = float_name
4✔
1121

1122
        loop_vars = []
4✔
1123
        for i, dim in enumerate(shape):
4✔
1124
            loop_var = f"_cmp_i{i}_{self._get_unique_id()}"
4✔
1125
            if not self.builder.exists(loop_var):
4✔
1126
                self.builder.add_container(loop_var, Scalar(PrimitiveType.Int64), False)
4✔
1127
                self.symbol_table[loop_var] = Scalar(PrimitiveType.Int64)
4✔
1128
            loop_vars.append(loop_var)
4✔
1129
            self.builder.begin_for(loop_var, "0", str(dim), "1")
4✔
1130

1131
        linear_idx = self._compute_linear_index(loop_vars, shape, tmp_name, len(shape))
4✔
1132

1133
        block = self.builder.add_block()
4✔
1134

1135
        if left_is_array:
4✔
1136
            t_left = self.builder.add_access(block, left)
4✔
1137
            left_sub = linear_idx
4✔
1138
        else:
NEW
1139
            t_left, left_sub = self._add_read(block, left)
×
1140

1141
        if right_is_array:
4✔
NEW
1142
            t_right = self.builder.add_access(block, right)
×
NEW
1143
            right_sub = linear_idx
×
1144
        else:
1145
            t_right, right_sub = self._add_read(block, right)
4✔
1146

1147
        t_out = self.builder.add_access(block, tmp_name)
4✔
1148

1149
        t_task = self.builder.add_tasklet(
4✔
1150
            block, tasklet_code, ["_in1", "_in2"], ["_out"]
1151
        )
1152

1153
        self.builder.add_memlet(block, t_left, "void", t_task, "_in1", left_sub)
4✔
1154
        self.builder.add_memlet(block, t_right, "void", t_task, "_in2", right_sub)
4✔
1155
        self.builder.add_memlet(block, t_task, "_out", t_out, "void", linear_idx)
4✔
1156

1157
        for _ in loop_vars:
4✔
1158
            self.builder.end_for()
4✔
1159

1160
        return tmp_name
4✔
1161

1162
    # ========== NumPy Function Handlers ==========
1163

1164
    def _handle_numpy_alloc(self, node, func_name):
4✔
1165
        """Handle np.empty, np.zeros, np.ones, np.ndarray."""
1166
        shape_arg = node.args[0]
4✔
1167
        dims = []
4✔
1168
        dims_runtime = []
4✔
1169
        if isinstance(shape_arg, ast.Tuple):
4✔
1170
            dims = [self.visit(elt) for elt in shape_arg.elts]
4✔
1171
            dims_runtime = [self._shape_to_runtime_expr(elt) for elt in shape_arg.elts]
4✔
1172
        elif isinstance(shape_arg, ast.List):
4✔
NEW
1173
            dims = [self.visit(elt) for elt in shape_arg.elts]
×
NEW
1174
            dims_runtime = [self._shape_to_runtime_expr(elt) for elt in shape_arg.elts]
×
1175
        else:
1176
            val = self.visit(shape_arg)
4✔
1177
            runtime_val = self._shape_to_runtime_expr(shape_arg)
4✔
1178
            if val.startswith("_shape_proxy_"):
4✔
NEW
1179
                array_name = val[len("_shape_proxy_") :]
×
NEW
1180
                if array_name in self.array_info:
×
NEW
1181
                    dims = self.array_info[array_name]["shapes"]
×
NEW
1182
                    dims_runtime = self.array_info[array_name].get(
×
1183
                        "shapes_runtime", dims
1184
                    )
1185
                else:
NEW
1186
                    dims = [val]
×
NEW
1187
                    dims_runtime = [runtime_val]
×
1188
            else:
1189
                dims = [val]
4✔
1190
                dims_runtime = [runtime_val]
4✔
1191

1192
        dtype_arg = None
4✔
1193
        if len(node.args) > 1:
4✔
NEW
1194
            dtype_arg = node.args[1]
×
1195

1196
        for kw in node.keywords:
4✔
1197
            if kw.arg == "dtype":
4✔
1198
                dtype_arg = kw.value
4✔
1199
                break
4✔
1200

1201
        element_type = element_type_from_ast_node(dtype_arg, self.symbol_table)
4✔
1202

1203
        return self._create_array_temp(
4✔
1204
            dims,
1205
            element_type,
1206
            zero_init=(func_name == "zeros"),
1207
            ones_init=(func_name == "ones"),
1208
            shapes_runtime=dims_runtime,
1209
        )
1210

1211
    def _handle_numpy_empty_like(self, node, func_name):
4✔
1212
        """Handle np.empty_like."""
1213
        prototype_arg = node.args[0]
4✔
1214
        prototype_name = self.visit(prototype_arg)
4✔
1215

1216
        dims = []
4✔
1217
        if prototype_name in self.array_info:
4✔
1218
            dims = self.array_info[prototype_name]["shapes"]
4✔
1219

1220
        dtype_arg = None
4✔
1221
        if len(node.args) > 1:
4✔
NEW
1222
            dtype_arg = node.args[1]
×
1223

1224
        for kw in node.keywords:
4✔
1225
            if kw.arg == "dtype":
4✔
1226
                dtype_arg = kw.value
4✔
1227
                break
4✔
1228

1229
        element_type = None
4✔
1230
        if dtype_arg:
4✔
1231
            element_type = element_type_from_ast_node(dtype_arg, self.symbol_table)
4✔
1232
        else:
1233
            if prototype_name in self.symbol_table:
4✔
1234
                sym_type = self.symbol_table[prototype_name]
4✔
1235
                if isinstance(sym_type, Pointer) and sym_type.has_pointee_type():
4✔
1236
                    element_type = sym_type.pointee_type
4✔
1237

1238
        if element_type is None:
4✔
NEW
1239
            element_type = Scalar(PrimitiveType.Double)
×
1240

1241
        return self._create_array_temp(
4✔
1242
            dims, element_type, zero_init=False, ones_init=False
1243
        )
1244

1245
    def _handle_numpy_zeros_like(self, node, func_name):
4✔
1246
        """Handle np.zeros_like."""
1247
        prototype_arg = node.args[0]
4✔
1248
        prototype_name = self.visit(prototype_arg)
4✔
1249

1250
        dims = []
4✔
1251
        if prototype_name in self.array_info:
4✔
1252
            dims = self.array_info[prototype_name]["shapes"]
4✔
1253

1254
        dtype_arg = None
4✔
1255
        if len(node.args) > 1:
4✔
NEW
1256
            dtype_arg = node.args[1]
×
1257

1258
        for kw in node.keywords:
4✔
1259
            if kw.arg == "dtype":
4✔
1260
                dtype_arg = kw.value
4✔
1261
                break
4✔
1262

1263
        element_type = None
4✔
1264
        if dtype_arg:
4✔
1265
            element_type = element_type_from_ast_node(dtype_arg, self.symbol_table)
4✔
1266
        else:
1267
            if prototype_name in self.symbol_table:
4✔
1268
                sym_type = self.symbol_table[prototype_name]
4✔
1269
                if isinstance(sym_type, Pointer) and sym_type.has_pointee_type():
4✔
1270
                    element_type = sym_type.pointee_type
4✔
1271

1272
        if element_type is None:
4✔
NEW
1273
            element_type = Scalar(PrimitiveType.Double)
×
1274

1275
        return self._create_array_temp(
4✔
1276
            dims, element_type, zero_init=True, ones_init=False
1277
        )
1278

1279
    def _handle_numpy_eye(self, node, func_name):
4✔
1280
        """Handle np.eye."""
1281
        N_arg = node.args[0]
4✔
1282
        N_str = self.visit(N_arg)
4✔
1283

1284
        M_str = N_str
4✔
1285
        if len(node.args) > 1:
4✔
NEW
1286
            M_str = self.visit(node.args[1])
×
1287

1288
        k_str = "0"
4✔
1289
        if len(node.args) > 2:
4✔
NEW
1290
            k_str = self.visit(node.args[2])
×
1291

1292
        dtype_arg = None
4✔
1293
        for kw in node.keywords:
4✔
1294
            if kw.arg == "M":
4✔
1295
                M_str = self.visit(kw.value)
4✔
1296
                if M_str == "None":
4✔
1297
                    M_str = N_str
4✔
1298
            elif kw.arg == "k":
4✔
1299
                k_str = self.visit(kw.value)
4✔
1300
            elif kw.arg == "dtype":
4✔
1301
                dtype_arg = kw.value
4✔
1302

1303
        element_type = element_type_from_ast_node(dtype_arg, self.symbol_table)
4✔
1304

1305
        ptr_name = self._create_array_temp([N_str, M_str], element_type, zero_init=True)
4✔
1306

1307
        loop_var = f"_i_{self._get_unique_id()}"
4✔
1308
        if not self.builder.exists(loop_var):
4✔
1309
            self.builder.add_container(loop_var, Scalar(PrimitiveType.Int64), False)
4✔
1310
            self.symbol_table[loop_var] = Scalar(PrimitiveType.Int64)
4✔
1311

1312
        self.builder.begin_for(loop_var, "0", N_str, "1")
4✔
1313

1314
        cond = f"(({loop_var} + {k_str}) >= 0) & (({loop_var} + {k_str}) < {M_str})"
4✔
1315
        self.builder.begin_if(cond)
4✔
1316

1317
        val = "1.0"
4✔
1318
        if element_type.primitive_type in [
4✔
1319
            PrimitiveType.Int64,
1320
            PrimitiveType.Int32,
1321
            PrimitiveType.Int8,
1322
            PrimitiveType.Int16,
1323
            PrimitiveType.UInt64,
1324
            PrimitiveType.UInt32,
1325
            PrimitiveType.UInt8,
1326
            PrimitiveType.UInt16,
1327
        ]:
NEW
1328
            val = "1"
×
1329

1330
        block_assign = self.builder.add_block()
4✔
1331
        t_const = self.builder.add_constant(block_assign, val, element_type)
4✔
1332
        t_arr = self.builder.add_access(block_assign, ptr_name)
4✔
1333
        flat_index = f"(({loop_var}) * ({M_str}) + ({loop_var}) + ({k_str}))"
4✔
1334
        subset = flat_index
4✔
1335

1336
        t_task = self.builder.add_tasklet(
4✔
1337
            block_assign, TaskletCode.assign, ["_in"], ["_out"]
1338
        )
1339
        self.builder.add_memlet(
4✔
1340
            block_assign, t_const, "void", t_task, "_in", "", element_type
1341
        )
1342
        self.builder.add_memlet(block_assign, t_task, "_out", t_arr, "void", subset)
4✔
1343

1344
        self.builder.end_if()
4✔
1345
        self.builder.end_for()
4✔
1346

1347
        return ptr_name
4✔
1348

1349
    def _handle_numpy_binary_op(self, node, func_name):
4✔
1350
        """Handle np.add, np.subtract, np.multiply, np.divide, etc."""
1351
        args = [self.visit(arg) for arg in node.args]
4✔
1352
        if len(args) != 2:
4✔
NEW
1353
            raise NotImplementedError(
×
1354
                f"Numpy function {func_name} requires 2 arguments"
1355
            )
1356

1357
        op_map = {
4✔
1358
            "add": "add",
1359
            "subtract": "sub",
1360
            "multiply": "mul",
1361
            "divide": "div",
1362
            "power": "pow",
1363
            "minimum": "min",
1364
            "maximum": "max",
1365
        }
1366
        return self.handle_array_binary_op(op_map[func_name], args[0], args[1])
4✔
1367

1368
    def _handle_numpy_unary_op(self, node, func_name):
4✔
1369
        """Handle np.exp, np.sqrt, np.abs, etc."""
1370
        args = [self.visit(arg) for arg in node.args]
4✔
1371
        if len(args) != 1:
4✔
NEW
1372
            raise NotImplementedError(f"Numpy function {func_name} requires 1 argument")
×
1373

1374
        op_name = func_name
4✔
1375
        if op_name == "absolute":
4✔
NEW
1376
            op_name = "abs"
×
1377

1378
        return self.handle_array_unary_op(op_name, args[0])
4✔
1379

1380
    def _handle_numpy_where(self, node, func_name):
4✔
1381
        """Handle np.where(condition, x, y) - elementwise ternary selection."""
1382
        if len(node.args) != 3:
4✔
NEW
1383
            raise NotImplementedError("np.where requires 3 arguments (condition, x, y)")
×
1384

1385
        cond_name = self.visit(node.args[0])
4✔
1386
        x_name = self.visit(node.args[1])
4✔
1387
        y_name = self.visit(node.args[2])
4✔
1388

1389
        shape = []
4✔
1390
        dtype = Scalar(PrimitiveType.Double)
4✔
1391

1392
        if cond_name in self.array_info:
4✔
1393
            shape = self.array_info[cond_name]["shapes"]
4✔
1394

1395
        if not shape and y_name in self.array_info:
4✔
NEW
1396
            shape = self.array_info[y_name]["shapes"]
×
1397

1398
        if not shape and x_name in self.array_info:
4✔
NEW
1399
            shape = self.array_info[x_name]["shapes"]
×
1400

1401
        if not shape:
4✔
NEW
1402
            raise NotImplementedError("np.where requires at least one array argument")
×
1403

1404
        if y_name in self.symbol_table:
4✔
1405
            y_type = self.symbol_table[y_name]
4✔
1406
            if isinstance(y_type, Pointer) and y_type.has_pointee_type():
4✔
1407
                dtype = y_type.pointee_type
4✔
NEW
1408
            elif isinstance(y_type, Scalar):
×
NEW
1409
                dtype = y_type
×
1410

1411
        tmp_name = self._create_array_temp(shape, dtype)
4✔
1412

1413
        loop_vars = []
4✔
1414
        for i, dim in enumerate(shape):
4✔
1415
            loop_var = f"_where_i{i}_{self._get_unique_id()}"
4✔
1416
            if not self.builder.exists(loop_var):
4✔
1417
                self.builder.add_container(loop_var, Scalar(PrimitiveType.Int64), False)
4✔
1418
                self.symbol_table[loop_var] = Scalar(PrimitiveType.Int64)
4✔
1419
            loop_vars.append(loop_var)
4✔
1420
            self.builder.begin_for(loop_var, "0", str(dim), "1")
4✔
1421

1422
        linear_idx = self._compute_linear_index(loop_vars, shape, tmp_name, len(shape))
4✔
1423

1424
        cond_tmp = f"_where_cond_{self._get_unique_id()}"
4✔
1425
        self.builder.add_container(cond_tmp, Scalar(PrimitiveType.Bool), False)
4✔
1426
        self.symbol_table[cond_tmp] = Scalar(PrimitiveType.Bool)
4✔
1427

1428
        block_cond = self.builder.add_block()
4✔
1429
        if cond_name in self.array_info:
4✔
1430
            t_cond_arr = self.builder.add_access(block_cond, cond_name)
4✔
1431
            t_cond_out = self.builder.add_access(block_cond, cond_tmp)
4✔
1432
            t_cond_task = self.builder.add_tasklet(
4✔
1433
                block_cond, TaskletCode.assign, ["_in"], ["_out"]
1434
            )
1435
            self.builder.add_memlet(
4✔
1436
                block_cond, t_cond_arr, "void", t_cond_task, "_in", linear_idx
1437
            )
1438
            self.builder.add_memlet(
4✔
1439
                block_cond, t_cond_task, "_out", t_cond_out, "void", ""
1440
            )
1441
        else:
NEW
1442
            t_cond_src, cond_sub = self._add_read(block_cond, cond_name)
×
NEW
1443
            t_cond_out = self.builder.add_access(block_cond, cond_tmp)
×
NEW
1444
            t_cond_task = self.builder.add_tasklet(
×
1445
                block_cond, TaskletCode.assign, ["_in"], ["_out"]
1446
            )
NEW
1447
            self.builder.add_memlet(
×
1448
                block_cond, t_cond_src, "void", t_cond_task, "_in", cond_sub
1449
            )
NEW
1450
            self.builder.add_memlet(
×
1451
                block_cond, t_cond_task, "_out", t_cond_out, "void", ""
1452
            )
1453

1454
        self.builder.begin_if(f"{cond_tmp} == true")
4✔
1455

1456
        block_true = self.builder.add_block()
4✔
1457
        t_out_true = self.builder.add_access(block_true, tmp_name)
4✔
1458
        if x_name in self.array_info:
4✔
1459
            t_x = self.builder.add_access(block_true, x_name)
4✔
1460
            t_task_true = self.builder.add_tasklet(
4✔
1461
                block_true, TaskletCode.assign, ["_in"], ["_out"]
1462
            )
1463
            self.builder.add_memlet(
4✔
1464
                block_true, t_x, "void", t_task_true, "_in", linear_idx
1465
            )
1466
        else:
1467
            t_x, x_sub = self._add_read(block_true, x_name)
4✔
1468
            t_task_true = self.builder.add_tasklet(
4✔
1469
                block_true, TaskletCode.assign, ["_in"], ["_out"]
1470
            )
1471
            self.builder.add_memlet(block_true, t_x, "void", t_task_true, "_in", x_sub)
4✔
1472
        self.builder.add_memlet(
4✔
1473
            block_true, t_task_true, "_out", t_out_true, "void", linear_idx
1474
        )
1475

1476
        self.builder.begin_else()
4✔
1477

1478
        block_false = self.builder.add_block()
4✔
1479
        t_out_false = self.builder.add_access(block_false, tmp_name)
4✔
1480
        if y_name in self.array_info:
4✔
1481
            t_y = self.builder.add_access(block_false, y_name)
4✔
1482
            t_task_false = self.builder.add_tasklet(
4✔
1483
                block_false, TaskletCode.assign, ["_in"], ["_out"]
1484
            )
1485
            self.builder.add_memlet(
4✔
1486
                block_false, t_y, "void", t_task_false, "_in", linear_idx
1487
            )
1488
        else:
1489
            t_y, y_sub = self._add_read(block_false, y_name)
4✔
1490
            t_task_false = self.builder.add_tasklet(
4✔
1491
                block_false, TaskletCode.assign, ["_in"], ["_out"]
1492
            )
1493
            self.builder.add_memlet(
4✔
1494
                block_false, t_y, "void", t_task_false, "_in", y_sub
1495
            )
1496
        self.builder.add_memlet(
4✔
1497
            block_false, t_task_false, "_out", t_out_false, "void", linear_idx
1498
        )
1499

1500
        self.builder.end_if()
4✔
1501

1502
        for _ in loop_vars:
4✔
1503
            self.builder.end_for()
4✔
1504

1505
        return tmp_name
4✔
1506

1507
    def _handle_numpy_clip(self, node, func_name):
4✔
1508
        """Handle np.clip(a, a_min, a_max) - elementwise clipping."""
1509
        if len(node.args) != 3:
4✔
NEW
1510
            raise NotImplementedError("np.clip requires 3 arguments (a, a_min, a_max)")
×
1511

1512
        arr_name = self.visit(node.args[0])
4✔
1513
        a_min = self.visit(node.args[1])
4✔
1514
        a_max = self.visit(node.args[2])
4✔
1515

1516
        tmp1 = self.handle_array_binary_op("max", arr_name, a_min)
4✔
1517
        result = self.handle_array_binary_op("min", tmp1, a_max)
4✔
1518

1519
        return result
4✔
1520

1521
    def _handle_numpy_matmul(self, node, func_name):
4✔
1522
        """Handle np.matmul, np.dot."""
1523
        if len(node.args) != 2:
4✔
NEW
1524
            raise NotImplementedError("matmul/dot requires 2 arguments")
×
1525
        return self._handle_matmul_helper(node.args[0], node.args[1])
4✔
1526

1527
    def handle_numpy_matmul_op(self, left_node, right_node):
4✔
1528
        """Handle the @ operator for matrix multiplication."""
1529
        return self._handle_matmul_helper(left_node, right_node)
4✔
1530

1531
    def _handle_matmul_helper(self, left_node, right_node):
4✔
1532
        """Helper for matrix multiplication operations."""
1533
        res_a = self.parse_arg(left_node)
4✔
1534
        res_b = self.parse_arg(right_node)
4✔
1535

1536
        if not res_a[0]:
4✔
1537
            left_name = self.visit(left_node)
4✔
1538
            left_node = ast.Name(id=left_name)
4✔
1539
            res_a = self.parse_arg(left_node)
4✔
1540

1541
        if not res_b[0]:
4✔
1542
            right_name = self.visit(right_node)
4✔
1543
            right_node = ast.Name(id=right_name)
4✔
1544
            res_b = self.parse_arg(right_node)
4✔
1545

1546
        name_a, subset_a, shape_a, indices_a = res_a
4✔
1547
        name_b, subset_b, shape_b, indices_b = res_b
4✔
1548

1549
        if not name_a or not name_b:
4✔
NEW
1550
            raise NotImplementedError("Could not resolve matmul operands")
×
1551

1552
        real_shape_a = shape_a
4✔
1553
        real_shape_b = shape_b
4✔
1554

1555
        ndim_a = len(real_shape_a)
4✔
1556
        ndim_b = len(real_shape_b)
4✔
1557

1558
        output_shape = []
4✔
1559
        is_scalar = False
4✔
1560

1561
        if ndim_a == 1 and ndim_b == 1:
4✔
1562
            is_scalar = True
4✔
1563
            output_shape = []
4✔
1564
        elif ndim_a == 2 and ndim_b == 2:
4✔
1565
            output_shape = [real_shape_a[0], real_shape_b[1]]
4✔
1566
        elif ndim_a == 2 and ndim_b == 1:
4✔
1567
            output_shape = [real_shape_a[0]]
4✔
1568
        elif ndim_a == 1 and ndim_b == 2:
4✔
NEW
1569
            output_shape = [real_shape_b[1]]
×
1570
        elif ndim_a > 2 or ndim_b > 2:
4✔
1571
            if ndim_a == ndim_b:
4✔
1572
                output_shape = list(real_shape_a[:-2]) + [
4✔
1573
                    real_shape_a[-2],
1574
                    real_shape_b[-1],
1575
                ]
1576
            else:
NEW
1577
                raise NotImplementedError(
×
1578
                    "Broadcasting with different ranks not fully supported yet"
1579
                )
1580
        else:
NEW
1581
            raise NotImplementedError(
×
1582
                f"Matmul with ranks {ndim_a} and {ndim_b} not supported"
1583
            )
1584

1585
        dtype_a = self._ev._element_type(name_a)
4✔
1586
        dtype_b = self._ev._element_type(name_b)
4✔
1587
        dtype = promote_element_types(dtype_a, dtype_b)
4✔
1588

1589
        if is_scalar:
4✔
1590
            tmp_name = f"_tmp_{self._get_unique_id()}"
4✔
1591
            self.builder.add_container(tmp_name, dtype, False)
4✔
1592
            self.symbol_table[tmp_name] = dtype
4✔
1593
        else:
1594
            tmp_name = self._create_array_temp(output_shape, dtype)
4✔
1595

1596
        if ndim_a > 2 or ndim_b > 2:
4✔
1597
            batch_dims = ndim_a - 2
4✔
1598
            loop_vars = []
4✔
1599

1600
            for i in range(batch_dims):
4✔
1601
                loop_var = f"_i{self._get_unique_id()}"
4✔
1602
                self.builder.add_container(loop_var, Scalar(PrimitiveType.Int64), False)
4✔
1603
                loop_vars.append(loop_var)
4✔
1604
                dim_size = real_shape_a[i]
4✔
1605
                self.builder.begin_for(loop_var, "0", str(dim_size), "1")
4✔
1606

1607
            def make_slice(name, indices):
4✔
1608
                elts = []
4✔
1609
                for idx in indices:
4✔
1610
                    if idx == ":":
4✔
1611
                        elts.append(ast.Slice())
4✔
1612
                    else:
1613
                        elts.append(ast.Name(id=idx))
4✔
1614

1615
                return ast.Subscript(
4✔
1616
                    value=ast.Name(id=name), slice=ast.Tuple(elts=elts), ctx=ast.Load()
1617
                )
1618

1619
            indices = loop_vars + [":", ":"]
4✔
1620
            slice_a = make_slice(name_a, indices)
4✔
1621
            slice_b = make_slice(name_b, indices)
4✔
1622
            slice_c = make_slice(tmp_name, indices)
4✔
1623

1624
            self.handle_gemm(
4✔
1625
                slice_c, ast.BinOp(left=slice_a, op=ast.MatMult(), right=slice_b)
1626
            )
1627

1628
            for _ in range(batch_dims):
4✔
1629
                self.builder.end_for()
4✔
1630
        else:
1631
            if is_scalar:
4✔
1632
                self.handle_dot(
4✔
1633
                    tmp_name,
1634
                    ast.BinOp(left=left_node, op=ast.MatMult(), right=right_node),
1635
                )
1636
            else:
1637
                self.handle_gemm(
4✔
1638
                    tmp_name,
1639
                    ast.BinOp(left=left_node, op=ast.MatMult(), right=right_node),
1640
                )
1641

1642
        return tmp_name
4✔
1643

1644
    def _handle_numpy_outer(self, node, func_name):
4✔
1645
        """Handle np.outer."""
1646
        if len(node.args) != 2:
4✔
NEW
1647
            raise NotImplementedError("outer requires 2 arguments")
×
1648

1649
        arg0 = node.args[0]
4✔
1650
        arg1 = node.args[1]
4✔
1651

1652
        res_a = self.parse_arg(arg0)
4✔
1653
        res_b = self.parse_arg(arg1)
4✔
1654

1655
        if not res_a[0]:
4✔
NEW
1656
            left_name = self.visit(arg0)
×
NEW
1657
            arg0 = ast.Name(id=left_name)
×
NEW
1658
            res_a = self.parse_arg(arg0)
×
1659

1660
        if not res_b[0]:
4✔
NEW
1661
            right_name = self.visit(arg1)
×
NEW
1662
            arg1 = ast.Name(id=right_name)
×
NEW
1663
            res_b = self.parse_arg(arg1)
×
1664

1665
        name_a, subset_a, shape_a, indices_a = res_a
4✔
1666
        name_b, subset_b, shape_b, indices_b = res_b
4✔
1667

1668
        if not name_a or not name_b:
4✔
NEW
1669
            raise NotImplementedError("Could not resolve outer operands")
×
1670

1671
        def get_flattened_size_expr(name, indices, shapes):
4✔
1672
            size_expr = "1"
4✔
1673
            for s in shapes:
4✔
1674
                if size_expr == "1":
4✔
1675
                    size_expr = str(s)
4✔
1676
                else:
NEW
1677
                    size_expr = f"({size_expr} * {str(s)})"
×
1678
            return size_expr
4✔
1679

1680
        m_expr = get_flattened_size_expr(name_a, indices_a, shape_a)
4✔
1681
        n_expr = get_flattened_size_expr(name_b, indices_b, shape_b)
4✔
1682

1683
        dtype_a = self._ev._element_type(name_a)
4✔
1684
        dtype_b = self._ev._element_type(name_b)
4✔
1685
        dtype = promote_element_types(dtype_a, dtype_b)
4✔
1686

1687
        tmp_name = self._create_array_temp([m_expr, n_expr], dtype)
4✔
1688

1689
        new_call_node = ast.Call(
4✔
1690
            func=node.func, args=[arg0, arg1], keywords=node.keywords
1691
        )
1692

1693
        self.handle_outer(tmp_name, new_call_node)
4✔
1694

1695
        return tmp_name
4✔
1696

1697
    def handle_ufunc_outer(self, node, ufunc_name):
4✔
1698
        """Handle np.add.outer, np.subtract.outer, np.multiply.outer, etc."""
1699
        if len(node.args) != 2:
4✔
NEW
1700
            raise NotImplementedError(f"{ufunc_name}.outer requires 2 arguments")
×
1701

1702
        if ufunc_name == "multiply":
4✔
1703
            return self._handle_numpy_outer(node, "outer")
4✔
1704

1705
        op_map = {
4✔
1706
            "add": ("add", TaskletCode.fp_add, TaskletCode.int_add),
1707
            "subtract": ("sub", TaskletCode.fp_sub, TaskletCode.int_sub),
1708
            "divide": ("div", TaskletCode.fp_div, TaskletCode.int_sdiv),
1709
            "minimum": ("min", CMathFunction.fmin, TaskletCode.int_smin),
1710
            "maximum": ("max", CMathFunction.fmax, TaskletCode.int_smax),
1711
        }
1712

1713
        if ufunc_name not in op_map:
4✔
NEW
1714
            raise NotImplementedError(f"{ufunc_name}.outer not supported")
×
1715

1716
        op_name, fp_opcode, int_opcode = op_map[ufunc_name]
4✔
1717

1718
        arg0 = node.args[0]
4✔
1719
        arg1 = node.args[1]
4✔
1720

1721
        res_a = self.parse_arg(arg0)
4✔
1722
        res_b = self.parse_arg(arg1)
4✔
1723

1724
        if not res_a[0]:
4✔
NEW
1725
            left_name = self.visit(arg0)
×
NEW
1726
            arg0 = ast.Name(id=left_name)
×
NEW
1727
            res_a = self.parse_arg(arg0)
×
1728

1729
        if not res_b[0]:
4✔
NEW
1730
            right_name = self.visit(arg1)
×
NEW
1731
            arg1 = ast.Name(id=right_name)
×
NEW
1732
            res_b = self.parse_arg(arg1)
×
1733

1734
        name_a, subset_a, shape_a, indices_a = res_a
4✔
1735
        name_b, subset_b, shape_b, indices_b = res_b
4✔
1736

1737
        if not name_a or not name_b:
4✔
NEW
1738
            raise NotImplementedError("Could not resolve ufunc outer operands")
×
1739

1740
        def get_flattened_size_expr(shapes):
4✔
1741
            if not shapes:
4✔
NEW
1742
                return "1"
×
1743
            size_expr = str(shapes[0])
4✔
1744
            for s in shapes[1:]:
4✔
NEW
1745
                size_expr = f"({size_expr} * {str(s)})"
×
1746
            return size_expr
4✔
1747

1748
        m_expr = get_flattened_size_expr(shape_a)
4✔
1749
        n_expr = get_flattened_size_expr(shape_b)
4✔
1750

1751
        dtype_left = self._ev._element_type(name_a)
4✔
1752
        dtype_right = self._ev._element_type(name_b)
4✔
1753
        dtype = promote_element_types(dtype_left, dtype_right)
4✔
1754

1755
        is_int = dtype.primitive_type in [
4✔
1756
            PrimitiveType.Int64,
1757
            PrimitiveType.Int32,
1758
            PrimitiveType.Int8,
1759
            PrimitiveType.Int16,
1760
            PrimitiveType.UInt64,
1761
            PrimitiveType.UInt32,
1762
            PrimitiveType.UInt8,
1763
            PrimitiveType.UInt16,
1764
        ]
1765

1766
        tmp_name = self._create_array_temp([m_expr, n_expr], dtype)
4✔
1767

1768
        i_var = self.builder.find_new_name("_outer_i_")
4✔
1769
        j_var = self.builder.find_new_name("_outer_j_")
4✔
1770

1771
        if not self.builder.exists(i_var):
4✔
1772
            self.builder.add_container(i_var, Scalar(PrimitiveType.Int64), False)
4✔
1773
            self.symbol_table[i_var] = Scalar(PrimitiveType.Int64)
4✔
1774
        if not self.builder.exists(j_var):
4✔
1775
            self.builder.add_container(j_var, Scalar(PrimitiveType.Int64), False)
4✔
1776
            self.symbol_table[j_var] = Scalar(PrimitiveType.Int64)
4✔
1777

1778
        def compute_linear_index(name, subset, indices, loop_var):
4✔
1779
            if not indices:
4✔
1780
                return loop_var
4✔
1781

1782
            info = self.array_info.get(name, {})
4✔
1783
            shapes = info.get("shapes", [])
4✔
1784
            ndim = info.get("ndim", len(shapes))
4✔
1785

1786
            if ndim == 0:
4✔
NEW
1787
                return loop_var
×
1788

1789
            strides = []
4✔
1790
            current_stride = "1"
4✔
1791
            for i in range(ndim - 1, -1, -1):
4✔
1792
                strides.insert(0, current_stride)
4✔
1793
                if i > 0:
4✔
1794
                    dim_size = shapes[i] if i < len(shapes) else f"_{name}_shape_{i}"
4✔
1795
                    if current_stride == "1":
4✔
1796
                        current_stride = str(dim_size)
4✔
1797
                    else:
NEW
1798
                        current_stride = f"({current_stride} * {dim_size})"
×
1799

1800
            terms = []
4✔
1801
            loop_var_used = False
4✔
1802

1803
            for i, idx in enumerate(indices):
4✔
1804
                stride = strides[i] if i < len(strides) else "1"
4✔
1805
                start = subset[i] if i < len(subset) else "0"
4✔
1806

1807
                if isinstance(idx, ast.Slice):
4✔
1808
                    if stride == "1":
4✔
1809
                        term = f"({start} + {loop_var})"
4✔
1810
                    else:
1811
                        term = f"(({start} + {loop_var}) * {stride})"
4✔
1812
                    loop_var_used = True
4✔
1813
                else:
1814
                    if stride == "1":
4✔
1815
                        term = start
4✔
1816
                    else:
1817
                        term = f"({start} * {stride})"
4✔
1818

1819
                terms.append(term)
4✔
1820

1821
            if not terms:
4✔
NEW
1822
                return loop_var
×
1823

1824
            result = terms[0]
4✔
1825
            for t in terms[1:]:
4✔
1826
                result = f"({result} + {t})"
4✔
1827

1828
            return result
4✔
1829

1830
        self.builder.begin_for(i_var, "0", m_expr, "1")
4✔
1831
        self.builder.begin_for(j_var, "0", n_expr, "1")
4✔
1832

1833
        block = self.builder.add_block()
4✔
1834

1835
        t_a = self.builder.add_access(block, name_a)
4✔
1836
        t_b = self.builder.add_access(block, name_b)
4✔
1837
        t_c = self.builder.add_access(block, tmp_name)
4✔
1838

1839
        if ufunc_name in ["minimum", "maximum"]:
4✔
1840
            if is_int:
4✔
1841
                t_task = self.builder.add_tasklet(
4✔
1842
                    block, int_opcode, ["_in1", "_in2"], ["_out"]
1843
                )
1844
            else:
1845
                t_task = self.builder.add_cmath(block, fp_opcode)
4✔
1846
        else:
1847
            tasklet_code = int_opcode if is_int else fp_opcode
4✔
1848
            t_task = self.builder.add_tasklet(
4✔
1849
                block, tasklet_code, ["_in1", "_in2"], ["_out"]
1850
            )
1851

1852
        a_index = compute_linear_index(name_a, subset_a, indices_a, i_var)
4✔
1853
        b_index = compute_linear_index(name_b, subset_b, indices_b, j_var)
4✔
1854

1855
        self.builder.add_memlet(block, t_a, "void", t_task, "_in1", a_index)
4✔
1856
        self.builder.add_memlet(block, t_b, "void", t_task, "_in2", b_index)
4✔
1857

1858
        flat_index = f"(({i_var}) * ({n_expr}) + ({j_var}))"
4✔
1859
        self.builder.add_memlet(block, t_task, "_out", t_c, "void", flat_index)
4✔
1860

1861
        self.builder.end_for()
4✔
1862
        self.builder.end_for()
4✔
1863

1864
        return tmp_name
4✔
1865

1866
    def _handle_numpy_reduce(self, node, func_name):
4✔
1867
        """Handle np.sum, np.max, np.min, np.mean, np.std."""
1868
        args = node.args
4✔
1869
        keywords = {kw.arg: kw.value for kw in node.keywords}
4✔
1870

1871
        array_node = args[0]
4✔
1872
        array_name = self.visit(array_node)
4✔
1873

1874
        if array_name not in self.array_info:
4✔
NEW
1875
            raise ValueError(f"Reduction input must be an array, got {array_name}")
×
1876

1877
        input_shape = self.array_info[array_name]["shapes"]
4✔
1878
        ndim = len(input_shape)
4✔
1879

1880
        axis = None
4✔
1881
        if len(args) > 1:
4✔
NEW
1882
            axis = args[1]
×
1883
        elif "axis" in keywords:
4✔
1884
            axis = keywords["axis"]
4✔
1885

1886
        keepdims = False
4✔
1887
        if "keepdims" in keywords:
4✔
1888
            keepdims_node = keywords["keepdims"]
4✔
1889
            if isinstance(keepdims_node, ast.Constant):
4✔
1890
                keepdims = bool(keepdims_node.value)
4✔
1891

1892
        axes = []
4✔
1893
        if axis is None:
4✔
1894
            axes = list(range(ndim))
4✔
1895
        elif isinstance(axis, ast.Constant):
4✔
1896
            val = axis.value
4✔
1897
            if val < 0:
4✔
NEW
1898
                val += ndim
×
1899
            axes = [val]
4✔
NEW
1900
        elif isinstance(axis, ast.Tuple):
×
NEW
1901
            for elt in axis.elts:
×
NEW
1902
                if isinstance(elt, ast.Constant):
×
NEW
1903
                    val = elt.value
×
NEW
1904
                    if val < 0:
×
NEW
1905
                        val += ndim
×
NEW
1906
                    axes.append(val)
×
NEW
1907
        elif (
×
1908
            isinstance(axis, ast.UnaryOp)
1909
            and isinstance(axis.op, ast.USub)
1910
            and isinstance(axis.operand, ast.Constant)
1911
        ):
NEW
1912
            val = -axis.operand.value
×
NEW
1913
            if val < 0:
×
NEW
1914
                val += ndim
×
NEW
1915
            axes = [val]
×
1916
        else:
NEW
1917
            try:
×
NEW
1918
                val = int(self.visit(axis))
×
NEW
1919
                if val < 0:
×
NEW
1920
                    val += ndim
×
NEW
1921
                axes = [val]
×
NEW
1922
            except:
×
NEW
1923
                raise NotImplementedError("Dynamic axis not supported")
×
1924

1925
        output_shape = []
4✔
1926
        for i in range(ndim):
4✔
1927
            if i in axes:
4✔
1928
                if keepdims:
4✔
1929
                    output_shape.append("1")
4✔
1930
            else:
1931
                output_shape.append(input_shape[i])
4✔
1932

1933
        dtype = self._ev._element_type(array_name)
4✔
1934

1935
        if not output_shape:
4✔
1936
            tmp_name = f"_tmp_{self._get_unique_id()}"
4✔
1937
            self.builder.add_container(tmp_name, dtype, False)
4✔
1938
            self.symbol_table[tmp_name] = dtype
4✔
1939
            self.array_info[tmp_name] = {"ndim": 0, "shapes": []}
4✔
1940
        else:
1941
            tmp_name = self._create_array_temp(output_shape, dtype)
4✔
1942

1943
        self.builder.add_reduce_op(
4✔
1944
            func_name, array_name, tmp_name, input_shape, axes, keepdims
1945
        )
1946

1947
        return tmp_name
4✔
1948

1949
    def handle_numpy_astype(self, node, array_name):
4✔
1950
        """Handle numpy array.astype(dtype) method calls."""
1951
        if len(node.args) < 1:
4✔
NEW
1952
            raise ValueError("astype requires at least one argument (dtype)")
×
1953

1954
        dtype_arg = node.args[0]
4✔
1955
        target_dtype = element_type_from_ast_node(dtype_arg, self.symbol_table)
4✔
1956

1957
        if array_name not in self.array_info:
4✔
NEW
1958
            raise ValueError(f"Array {array_name} not found in array_info")
×
1959

1960
        input_shape = self.array_info[array_name]["shapes"]
4✔
1961

1962
        tmp_name = self._create_array_temp(input_shape, target_dtype)
4✔
1963

1964
        self.builder.add_cast_op(
4✔
1965
            array_name, tmp_name, input_shape, target_dtype.primitive_type
1966
        )
1967

1968
        return tmp_name
4✔
1969

1970
    def handle_numpy_copy(self, node, array_name):
4✔
1971
        """Handle numpy array.copy() method calls using memcpy."""
1972
        if array_name not in self.array_info:
4✔
NEW
1973
            raise ValueError(f"Array {array_name} not found in array_info")
×
1974

1975
        input_shape = self.array_info[array_name]["shapes"]
4✔
1976

1977
        element_type = Scalar(PrimitiveType.Double)
4✔
1978
        if array_name in self.symbol_table:
4✔
1979
            sym_type = self.symbol_table[array_name]
4✔
1980
            if isinstance(sym_type, Pointer) and sym_type.has_pointee_type():
4✔
1981
                element_type = sym_type.pointee_type
4✔
1982

1983
        tmp_name = self._create_array_temp(input_shape, element_type)
4✔
1984

1985
        total_elements = " * ".join([f"({s})" for s in input_shape])
4✔
1986
        element_size = self.builder.get_sizeof(element_type)
4✔
1987
        count_expr = f"({total_elements}) * ({element_size})"
4✔
1988

1989
        ptr_type = Pointer(element_type)
4✔
1990

1991
        block = self.builder.add_block()
4✔
1992
        t_src = self.builder.add_access(block, array_name)
4✔
1993
        t_dst = self.builder.add_access(block, tmp_name)
4✔
1994
        t_memcpy = self.builder.add_memcpy(block, count_expr)
4✔
1995

1996
        self.builder.add_memlet(block, t_src, "void", t_memcpy, "_src", "", ptr_type)
4✔
1997
        self.builder.add_memlet(block, t_memcpy, "_dst", t_dst, "void", "", ptr_type)
4✔
1998

1999
        return tmp_name
4✔
2000

2001
    def _create_array_temp(
4✔
2002
        self, shape, dtype, zero_init=False, ones_init=False, shapes_runtime=None
2003
    ):
2004
        """Create a temporary array with the given shape and dtype."""
2005
        tmp_name = f"_tmp_{self._get_unique_id()}"
4✔
2006

2007
        # Handle 0-dimensional arrays as scalars
2008
        if not shape or (len(shape) == 0):
4✔
2009
            self.builder.add_container(tmp_name, dtype, False)
4✔
2010
            self.symbol_table[tmp_name] = dtype
4✔
2011
            self.array_info[tmp_name] = {"ndim": 0, "shapes": []}
4✔
2012

2013
            if zero_init:
4✔
NEW
2014
                self.builder.add_assignment(
×
2015
                    tmp_name,
2016
                    "0.0" if dtype.primitive_type == PrimitiveType.Double else "0",
2017
                )
2018
            elif ones_init:
4✔
NEW
2019
                self.builder.add_assignment(
×
2020
                    tmp_name,
2021
                    "1.0" if dtype.primitive_type == PrimitiveType.Double else "1",
2022
                )
2023

2024
            return tmp_name
4✔
2025

2026
        # Calculate size
2027
        size_str = "1"
4✔
2028
        for dim in shape:
4✔
2029
            size_str = f"({size_str} * {dim})"
4✔
2030

2031
        element_size = self.builder.get_sizeof(dtype)
4✔
2032
        total_size = f"({size_str} * {element_size})"
4✔
2033

2034
        # Create pointer
2035
        ptr_type = Pointer(dtype)
4✔
2036
        self.builder.add_container(tmp_name, ptr_type, False)
4✔
2037
        self.symbol_table[tmp_name] = ptr_type
4✔
2038
        array_info_entry = {"ndim": len(shape), "shapes": shape}
4✔
2039
        if shapes_runtime is not None:
4✔
2040
            array_info_entry["shapes_runtime"] = shapes_runtime
4✔
2041
        self.array_info[tmp_name] = array_info_entry
4✔
2042

2043
        # Malloc
2044
        block1 = self.builder.add_block()
4✔
2045
        t_malloc = self.builder.add_malloc(block1, total_size)
4✔
2046
        t_ptr1 = self.builder.add_access(block1, tmp_name)
4✔
2047
        self.builder.add_memlet(block1, t_malloc, "_ret", t_ptr1, "void", "", ptr_type)
4✔
2048

2049
        if zero_init:
4✔
2050
            block2 = self.builder.add_block()
4✔
2051
            t_memset = self.builder.add_memset(block2, "0", total_size)
4✔
2052
            t_ptr2 = self.builder.add_access(block2, tmp_name)
4✔
2053
            self.builder.add_memlet(
4✔
2054
                block2, t_memset, "_ptr", t_ptr2, "void", "", ptr_type
2055
            )
2056
        elif ones_init:
4✔
2057
            loop_var = f"_i_{self._get_unique_id()}"
4✔
2058
            if not self.builder.exists(loop_var):
4✔
2059
                self.builder.add_container(loop_var, Scalar(PrimitiveType.Int64), False)
4✔
2060
                self.symbol_table[loop_var] = Scalar(PrimitiveType.Int64)
4✔
2061

2062
            self.builder.begin_for(loop_var, "0", size_str, "1")
4✔
2063

2064
            val = "1.0"
4✔
2065
            if dtype.primitive_type in [
4✔
2066
                PrimitiveType.Int64,
2067
                PrimitiveType.Int32,
2068
                PrimitiveType.Int8,
2069
                PrimitiveType.Int16,
2070
                PrimitiveType.UInt64,
2071
                PrimitiveType.UInt32,
2072
                PrimitiveType.UInt8,
2073
                PrimitiveType.UInt16,
2074
            ]:
2075
                val = "1"
4✔
2076

2077
            block_assign = self.builder.add_block()
4✔
2078
            t_const = self.builder.add_constant(block_assign, val, dtype)
4✔
2079
            t_arr = self.builder.add_access(block_assign, tmp_name)
4✔
2080

2081
            t_task = self.builder.add_tasklet(
4✔
2082
                block_assign, TaskletCode.assign, ["_in"], ["_out"]
2083
            )
2084
            self.builder.add_memlet(
4✔
2085
                block_assign, t_const, "void", t_task, "_in", "", dtype
2086
            )
2087
            self.builder.add_memlet(
4✔
2088
                block_assign, t_task, "_out", t_arr, "void", loop_var
2089
            )
2090

2091
            self.builder.end_for()
4✔
2092

2093
        return tmp_name
4✔
2094

2095
    def _compute_linear_index(self, indices, shapes, array_name, ndim):
4✔
2096
        """Compute linear index from multi-dimensional indices."""
2097
        if ndim == 0:
4✔
NEW
2098
            return "0"
×
2099

2100
        linear_index = ""
4✔
2101
        for i in range(ndim):
4✔
2102
            term = str(indices[i])
4✔
2103
            for j in range(i + 1, ndim):
4✔
2104
                shape_val = shapes[j] if j < len(shapes) else f"_{array_name}_shape_{j}"
4✔
2105
                term = f"(({term}) * {shape_val})"
4✔
2106

2107
            if i == 0:
4✔
2108
                linear_index = term
4✔
2109
            else:
2110
                linear_index = f"({linear_index} + {term})"
4✔
2111

2112
        return linear_index
4✔
2113

2114
    def _compute_broadcast_shape(self, shape_a, shape_b):
4✔
2115
        """Compute the broadcast output shape following NumPy broadcasting rules."""
2116
        if not shape_a:
4✔
2117
            return shape_b
4✔
2118
        if not shape_b:
4✔
2119
            return shape_a
4✔
2120

2121
        max_ndim = max(len(shape_a), len(shape_b))
4✔
2122
        padded_a = ["1"] * (max_ndim - len(shape_a)) + [str(s) for s in shape_a]
4✔
2123
        padded_b = ["1"] * (max_ndim - len(shape_b)) + [str(s) for s in shape_b]
4✔
2124

2125
        result = []
4✔
2126
        for a, b in zip(padded_a, padded_b):
4✔
2127
            if a == "1":
4✔
NEW
2128
                result.append(b)
×
2129
            elif b == "1":
4✔
2130
                result.append(a)
4✔
2131
            elif a == b:
4✔
2132
                result.append(a)
4✔
2133
            else:
2134
                result.append(a)
4✔
2135

2136
        return result
4✔
2137

2138
    def _needs_broadcast(self, input_shape, output_shape):
4✔
2139
        """Check if input shape needs broadcasting to match output shape."""
2140
        if len(input_shape) != len(output_shape):
4✔
2141
            return True
4✔
2142
        for in_dim, out_dim in zip(input_shape, output_shape):
4✔
2143
            if str(in_dim) != str(out_dim):
4✔
2144
                return True
4✔
2145
        return False
4✔
2146

2147
    def _broadcast_array(self, arr_name, input_shape, output_shape, dtype):
4✔
2148
        """Broadcast an array from input_shape to output_shape using BroadcastNode."""
2149
        broadcast_tmp = self._create_array_temp(output_shape, dtype)
4✔
2150

2151
        padded_input_shape = ["1"] * (len(output_shape) - len(input_shape)) + [
4✔
2152
            str(s) for s in input_shape
2153
        ]
2154

2155
        input_shape_strs = padded_input_shape
4✔
2156
        output_shape_strs = [str(s) for s in output_shape]
4✔
2157

2158
        self.builder.add_broadcast(
4✔
2159
            arr_name, broadcast_tmp, input_shape_strs, output_shape_strs
2160
        )
2161

2162
        return broadcast_tmp
4✔
2163

2164
    def _shape_to_runtime_expr(self, shape_node):
4✔
2165
        """Convert a shape expression AST node to a runtime-evaluable string."""
2166
        if isinstance(shape_node, ast.Constant):
4✔
2167
            return str(shape_node.value)
4✔
2168
        elif isinstance(shape_node, ast.Name):
4✔
2169
            return shape_node.id
4✔
2170
        elif isinstance(shape_node, ast.BinOp):
4✔
2171
            left = self._shape_to_runtime_expr(shape_node.left)
4✔
2172
            right = self._shape_to_runtime_expr(shape_node.right)
4✔
2173
            op = self.visit(shape_node.op)
4✔
2174
            return f"({left} {op} {right})"
4✔
2175
        elif isinstance(shape_node, ast.UnaryOp):
4✔
NEW
2176
            operand = self._shape_to_runtime_expr(shape_node.operand)
×
NEW
2177
            if isinstance(shape_node.op, ast.USub):
×
NEW
2178
                return f"(-{operand})"
×
NEW
2179
            elif isinstance(shape_node.op, ast.UAdd):
×
NEW
2180
                return operand
×
2181
            else:
NEW
2182
                return self.visit(shape_node)
×
2183
        elif isinstance(shape_node, ast.Subscript):
4✔
2184
            val = shape_node.value
4✔
2185
            if isinstance(val, ast.Attribute) and val.attr == "shape":
4✔
2186
                if isinstance(val.value, ast.Name):
4✔
2187
                    arr_name = val.value.id
4✔
2188
                    if isinstance(shape_node.slice, ast.Constant):
4✔
2189
                        idx = shape_node.slice.value
4✔
2190
                        if arr_name in self.array_info:
4✔
2191
                            shapes = self.array_info[arr_name].get("shapes", [])
4✔
2192
                            if idx < len(shapes):
4✔
2193
                                return shapes[idx]
4✔
NEW
2194
                        return f"{arr_name}.shape[{idx}]"
×
NEW
2195
            return self.visit(shape_node)
×
NEW
2196
        elif isinstance(shape_node, ast.Tuple):
×
NEW
2197
            return [self._shape_to_runtime_expr(elt) for elt in shape_node.elts]
×
NEW
2198
        elif isinstance(shape_node, ast.List):
×
NEW
2199
            return [self._shape_to_runtime_expr(elt) for elt in shape_node.elts]
×
2200
        else:
NEW
2201
            return self.visit(shape_node)
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc