• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

daisytuner / sdfglib / 21112672771

18 Jan 2026 01:40PM UTC coverage: 64.355% (+0.2%) from 64.154%
21112672771

Pull #462

github

web-flow
Merge 6c6ce34f9 into 92e9cbdc3
Pull Request #462: adds syntax support for multi-assignments and np.empty_like

45 of 52 new or added lines in 3 files covered. (86.54%)

1 existing line in 1 file now uncovered.

19555 of 30386 relevant lines covered (64.36%)

387.55 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

78.14
/python/docc/expression_visitor.py
1
import ast
3✔
2
import inspect
3✔
3
import textwrap
3✔
4
from ._sdfg import Scalar, PrimitiveType, Pointer, Type, DebugInfo, Structure
3✔
5

6

7
class ExpressionVisitor(ast.NodeVisitor):
3✔
8
    def __init__(
3✔
9
        self,
10
        array_info=None,
11
        builder=None,
12
        symbol_table=None,
13
        globals_dict=None,
14
        inliner=None,
15
        unique_counter_ref=None,
16
        structure_member_info=None,
17
    ):
18
        self.array_info = array_info if array_info is not None else {}
3✔
19
        self.builder = builder
3✔
20
        self.symbol_table = symbol_table if symbol_table is not None else {}
3✔
21
        self.globals_dict = globals_dict if globals_dict is not None else {}
3✔
22
        self.inliner = inliner
3✔
23
        self._unique_counter_ref = (
3✔
24
            unique_counter_ref if unique_counter_ref is not None else [0]
25
        )
26
        self._access_cache = {}
3✔
27
        self.la_handler = None
3✔
28
        self.structure_member_info = (
3✔
29
            structure_member_info if structure_member_info is not None else {}
30
        )
31
        self._init_numpy_handlers()
3✔
32

33
    def _get_unique_id(self):
3✔
34
        self._unique_counter_ref[0] += 1
3✔
35
        return self._unique_counter_ref[0]
3✔
36

37
    def _get_temp_name(self, prefix="_tmp_"):
3✔
38
        if hasattr(self.builder, "find_new_name"):
3✔
39
            return self.builder.find_new_name(prefix)
×
40
        return f"{prefix}{self._get_unique_id()}"
3✔
41

42
    def _init_numpy_handlers(self):
3✔
43
        self.numpy_handlers = {
3✔
44
            "empty": self._handle_numpy_alloc,
45
            "empty_like": self._handle_numpy_empty_like,
46
            "zeros": self._handle_numpy_alloc,
47
            "ones": self._handle_numpy_alloc,
48
            "eye": self._handle_numpy_eye,
49
            "add": self._handle_numpy_binary_op,
50
            "subtract": self._handle_numpy_binary_op,
51
            "multiply": self._handle_numpy_binary_op,
52
            "divide": self._handle_numpy_binary_op,
53
            "power": self._handle_numpy_binary_op,
54
            "exp": self._handle_numpy_unary_op,
55
            "abs": self._handle_numpy_unary_op,
56
            "absolute": self._handle_numpy_unary_op,
57
            "sqrt": self._handle_numpy_unary_op,
58
            "tanh": self._handle_numpy_unary_op,
59
            "sum": self._handle_numpy_reduce,
60
            "max": self._handle_numpy_reduce,
61
            "min": self._handle_numpy_reduce,
62
            "mean": self._handle_numpy_reduce,
63
            "std": self._handle_numpy_reduce,
64
            "matmul": self._handle_numpy_matmul,
65
            "dot": self._handle_numpy_matmul,
66
            "matvec": self._handle_numpy_matmul,
67
            "minimum": self._handle_numpy_binary_op,
68
            "maximum": self._handle_numpy_binary_op,
69
        }
70

71
    def generic_visit(self, node):
3✔
72
        return super().generic_visit(node)
×
73

74
    def visit_Constant(self, node):
3✔
75
        if isinstance(node.value, bool):
3✔
76
            return "true" if node.value else "false"
×
77
        return str(node.value)
3✔
78

79
    def visit_Name(self, node):
3✔
80
        return node.id
3✔
81

82
    def _map_numpy_dtype(self, dtype_node):
3✔
83
        # Default to double
84
        if dtype_node is None:
3✔
85
            return Scalar(PrimitiveType.Double)
×
86

87
        if isinstance(dtype_node, ast.Name):
3✔
88
            if dtype_node.id == "float":
3✔
89
                return Scalar(PrimitiveType.Double)
3✔
90
            if dtype_node.id == "int":
3✔
91
                return Scalar(PrimitiveType.Int64)
3✔
92
            if dtype_node.id == "bool":
×
93
                return Scalar(PrimitiveType.Bool)
×
94

95
        if isinstance(dtype_node, ast.Attribute):
3✔
96
            # Handle array.dtype
97
            if (
3✔
98
                isinstance(dtype_node.value, ast.Name)
99
                and dtype_node.value.id in self.symbol_table
100
                and dtype_node.attr == "dtype"
101
            ):
102
                sym_type = self.symbol_table[dtype_node.value.id]
3✔
103
                if isinstance(sym_type, Pointer) and sym_type.has_pointee_type():
3✔
104
                    return sym_type.pointee_type
3✔
105

106
            if isinstance(dtype_node.value, ast.Name) and dtype_node.value.id in [
3✔
107
                "numpy",
108
                "np",
109
            ]:
110
                if dtype_node.attr == "float64":
3✔
111
                    return Scalar(PrimitiveType.Double)
3✔
112
                if dtype_node.attr == "float32":
3✔
113
                    return Scalar(PrimitiveType.Float)
3✔
114
                if dtype_node.attr == "int64":
3✔
115
                    return Scalar(PrimitiveType.Int64)
3✔
116
                if dtype_node.attr == "int32":
3✔
117
                    return Scalar(PrimitiveType.Int32)
3✔
118
                if dtype_node.attr == "bool_":
×
119
                    return Scalar(PrimitiveType.Bool)
×
120

121
        # Fallback
122
        return Scalar(PrimitiveType.Double)
×
123

124
    def _is_int(self, operand):
3✔
125
        try:
3✔
126
            if operand.lstrip("-").isdigit():
3✔
127
                return True
3✔
128
        except ValueError:
×
129
            pass
×
130

131
        name = operand
3✔
132
        if "(" in operand and operand.endswith(")"):
3✔
133
            name = operand.split("(")[0]
×
134

135
        if name in self.symbol_table:
3✔
136
            t = self.symbol_table[name]
3✔
137

138
            def is_int_ptype(pt):
3✔
139
                return pt in [
3✔
140
                    PrimitiveType.Int64,
141
                    PrimitiveType.Int32,
142
                    PrimitiveType.Int8,
143
                    PrimitiveType.Int16,
144
                    PrimitiveType.UInt64,
145
                    PrimitiveType.UInt32,
146
                    PrimitiveType.UInt8,
147
                    PrimitiveType.UInt16,
148
                ]
149

150
            if isinstance(t, Scalar):
3✔
151
                return is_int_ptype(t.primitive_type)
3✔
152

153
            if type(t).__name__ == "Array" and hasattr(t, "element_type"):
×
154
                et = t.element_type
×
155
                if callable(et):
×
156
                    et = et()
×
157
                if isinstance(et, Scalar):
×
158
                    return is_int_ptype(et.primitive_type)
×
159

160
            if type(t).__name__ == "Pointer":
×
161
                if hasattr(t, "pointee_type"):
×
162
                    et = t.pointee_type
×
163
                    if callable(et):
×
164
                        et = et()
×
165
                    if isinstance(et, Scalar):
×
166
                        return is_int_ptype(et.primitive_type)
×
167
                # Fallback: check if it has element_type (maybe alias?)
168
                if hasattr(t, "element_type"):
×
169
                    et = t.element_type
×
170
                    if callable(et):
×
171
                        et = et()
×
172
                    if isinstance(et, Scalar):
×
173
                        return is_int_ptype(et.primitive_type)
×
174

175
        return False
3✔
176

177
    def _add_read(self, block, expr_str, debug_info=None):
3✔
178
        # Try to reuse access node
179
        try:
3✔
180
            if (block, expr_str) in self._access_cache:
3✔
181
                return self._access_cache[(block, expr_str)]
3✔
182
        except TypeError:
×
183
            # block might not be hashable
184
            pass
×
185

186
        if debug_info is None:
3✔
187
            debug_info = DebugInfo()
3✔
188

189
        if "(" in expr_str and expr_str.endswith(")"):
3✔
190
            name = expr_str.split("(")[0]
3✔
191
            subset = expr_str[expr_str.find("(") + 1 : -1]
3✔
192
            access = self.builder.add_access(block, name, debug_info)
3✔
193
            try:
3✔
194
                self._access_cache[(block, expr_str)] = (access, subset)
3✔
195
            except TypeError:
×
196
                pass
×
197
            return access, subset
3✔
198

199
        if self.builder.has_container(expr_str):
3✔
200
            access = self.builder.add_access(block, expr_str, debug_info)
3✔
201
            try:
3✔
202
                self._access_cache[(block, expr_str)] = (access, "")
3✔
203
            except TypeError:
×
204
                pass
×
205
            return access, ""
3✔
206

207
        dtype = Scalar(PrimitiveType.Double)
3✔
208
        if self._is_int(expr_str):
3✔
209
            dtype = Scalar(PrimitiveType.Int64)
3✔
210
        elif expr_str == "true" or expr_str == "false":
3✔
211
            dtype = Scalar(PrimitiveType.Bool)
×
212

213
        const_node = self.builder.add_constant(block, expr_str, dtype, debug_info)
3✔
214
        try:
3✔
215
            self._access_cache[(block, expr_str)] = (const_node, "")
3✔
216
        except TypeError:
×
217
            pass
×
218
        return const_node, ""
3✔
219

220
    def _handle_min_max(self, node, func_name):
3✔
221
        args = [self.visit(arg) for arg in node.args]
3✔
222
        if len(args) != 2:
3✔
223
            raise NotImplementedError(f"{func_name} only supported with 2 arguments")
×
224

225
        # Check types
226
        is_float = False
3✔
227
        arg_types = []
3✔
228

229
        for arg in args:
3✔
230
            name = arg
3✔
231
            if "(" in arg and arg.endswith(")"):
3✔
232
                name = arg.split("(")[0]
×
233

234
            if name in self.symbol_table:
3✔
235
                t = self.symbol_table[name]
3✔
236
                if isinstance(t, Pointer):
3✔
237
                    t = t.base_type
×
238

239
                if t.primitive_type == PrimitiveType.Double:
3✔
240
                    is_float = True
3✔
241
                    arg_types.append(PrimitiveType.Double)
3✔
242
                else:
243
                    arg_types.append(PrimitiveType.Int64)
3✔
244
            elif self._is_int(arg):
×
245
                arg_types.append(PrimitiveType.Int64)
×
246
            else:
247
                # Assume float constant
248
                is_float = True
×
249
                arg_types.append(PrimitiveType.Double)
×
250

251
        dtype = Scalar(PrimitiveType.Double if is_float else PrimitiveType.Int64)
3✔
252

253
        tmp_name = self._get_temp_name("_tmp_")
3✔
254
        self.builder.add_container(tmp_name, dtype, False)
3✔
255
        self.symbol_table[tmp_name] = dtype
3✔
256

257
        if is_float:
3✔
258
            # Cast args if necessary
259
            casted_args = []
3✔
260
            for i, arg in enumerate(args):
3✔
261
                if arg_types[i] != PrimitiveType.Double:
3✔
262
                    # Create temp double
263
                    tmp_cast = self._get_temp_name("_cast_")
3✔
264
                    self.builder.add_container(
3✔
265
                        tmp_cast, Scalar(PrimitiveType.Double), False
266
                    )
267
                    self.symbol_table[tmp_cast] = Scalar(PrimitiveType.Double)
3✔
268

269
                    # Assign int to double (implicit cast)
270
                    self.builder.add_assignment(tmp_cast, arg)
3✔
271
                    casted_args.append(tmp_cast)
3✔
272
                else:
273
                    casted_args.append(arg)
3✔
274

275
            block = self.builder.add_block()
3✔
276
            t_out = self.builder.add_access(block, tmp_name)
3✔
277

278
            intrinsic_name = "fmax" if func_name == "max" else "fmin"
3✔
279
            t_task = self.builder.add_intrinsic(block, intrinsic_name)
3✔
280

281
            for i, arg in enumerate(casted_args):
3✔
282
                t_arg, arg_sub = self._add_read(block, arg)
3✔
283
                self.builder.add_memlet(
3✔
284
                    block, t_arg, "void", t_task, f"_in{i+1}", arg_sub
285
                )
286
        else:
287
            block = self.builder.add_block()
3✔
288
            t_out = self.builder.add_access(block, tmp_name)
3✔
289

290
            # Use int_smax/int_smin tasklet
291
            opcode = "int_smax" if func_name == "max" else "int_smin"
3✔
292
            t_task = self.builder.add_tasklet(block, opcode, ["_in1", "_in2"], ["_out"])
3✔
293

294
            for i, arg in enumerate(args):
3✔
295
                t_arg, arg_sub = self._add_read(block, arg)
3✔
296
                self.builder.add_memlet(
3✔
297
                    block, t_arg, "void", t_task, f"_in{i+1}", arg_sub
298
                )
299

300
        self.builder.add_memlet(block, t_task, "_out", t_out, "void", "")
3✔
301
        return tmp_name
3✔
302

303
    def _handle_python_cast(self, node, func_name):
3✔
304
        """Handle Python type casts: int(), float(), bool()"""
305
        if len(node.args) != 1:
3✔
306
            raise NotImplementedError(f"{func_name}() cast requires exactly 1 argument")
×
307

308
        arg = self.visit(node.args[0])
3✔
309

310
        # Determine target type based on cast function
311
        if func_name == "int":
3✔
312
            target_dtype = Scalar(PrimitiveType.Int64)
3✔
313
        elif func_name == "float":
3✔
314
            target_dtype = Scalar(PrimitiveType.Double)
3✔
315
        elif func_name == "bool":
3✔
316
            target_dtype = Scalar(PrimitiveType.Bool)
3✔
317
        else:
318
            raise NotImplementedError(f"Cast to {func_name} not supported")
×
319

320
        # Determine source type
321
        source_dtype = None
3✔
322
        name = arg
3✔
323
        if "(" in arg and arg.endswith(")"):
3✔
324
            name = arg.split("(")[0]
×
325

326
        if name in self.symbol_table:
3✔
327
            source_dtype = self.symbol_table[name]
3✔
328
            if isinstance(source_dtype, Pointer):
3✔
329
                source_dtype = source_dtype.base_type
×
330
        elif self._is_int(arg):
×
331
            source_dtype = Scalar(PrimitiveType.Int64)
×
332
        elif arg == "true" or arg == "false":
×
333
            source_dtype = Scalar(PrimitiveType.Bool)
×
334
        else:
335
            # Assume float constant
336
            source_dtype = Scalar(PrimitiveType.Double)
×
337

338
        # Create temporary variable for result
339
        tmp_name = self._get_temp_name("_tmp_")
3✔
340
        self.builder.add_container(tmp_name, target_dtype, False)
3✔
341
        self.symbol_table[tmp_name] = target_dtype
3✔
342

343
        # Use tasklet assign opcode for casting (as specified in problem statement)
344
        block = self.builder.add_block()
3✔
345
        t_src, src_sub = self._add_read(block, arg)
3✔
346
        t_dst = self.builder.add_access(block, tmp_name)
3✔
347
        t_task = self.builder.add_tasklet(block, "assign", ["_in"], ["_out"])
3✔
348
        self.builder.add_memlet(block, t_src, "void", t_task, "_in", src_sub)
3✔
349
        self.builder.add_memlet(block, t_task, "_out", t_dst, "void", "")
3✔
350

351
        return tmp_name
3✔
352

353
    def visit_Call(self, node):
3✔
354
        func_name = ""
3✔
355
        module_name = ""
3✔
356
        if isinstance(node.func, ast.Attribute):
3✔
357
            if isinstance(node.func.value, ast.Name):
3✔
358
                if node.func.value.id == "math":
3✔
359
                    module_name = "math"
3✔
360
                    func_name = node.func.attr
3✔
361
                elif node.func.value.id in ["numpy", "np"]:
3✔
362
                    module_name = "numpy"
3✔
363
                    func_name = node.func.attr
3✔
364
                else:
365
                    # Check if it's a method call on an array (e.g., arr.astype(...))
366
                    array_name = node.func.value.id
3✔
367
                    method_name = node.func.attr
3✔
368
                    if array_name in self.array_info and method_name == "astype":
3✔
369
                        return self._handle_numpy_astype(node, array_name)
3✔
370
            elif isinstance(node.func.value, ast.Attribute):
3✔
371
                if (
3✔
372
                    isinstance(node.func.value.value, ast.Name)
373
                    and node.func.value.value.id == "scipy"
374
                    and node.func.value.attr == "special"
375
                ):
376
                    if node.func.attr == "softmax":
3✔
377
                        return self._handle_scipy_softmax(node, "softmax")
3✔
378

379
        elif isinstance(node.func, ast.Name):
3✔
380
            func_name = node.func.id
3✔
381

382
        if module_name == "numpy":
3✔
383
            if func_name in self.numpy_handlers:
3✔
384
                return self.numpy_handlers[func_name](node, func_name)
3✔
385

386
        if func_name in ["max", "min"]:
3✔
387
            return self._handle_min_max(node, func_name)
3✔
388

389
        # Handle Python type casts (int, float, bool)
390
        if func_name in ["int", "float", "bool"]:
3✔
391
            return self._handle_python_cast(node, func_name)
3✔
392

393
        math_funcs = [
3✔
394
            "sin",
395
            "cos",
396
            "tan",
397
            "exp",
398
            "log",
399
            "sqrt",
400
            "pow",
401
            "abs",
402
            "ceil",
403
            "floor",
404
            "asin",
405
            "acos",
406
            "atan",
407
            "sinh",
408
            "cosh",
409
            "tanh",
410
        ]
411

412
        if func_name in math_funcs:
3✔
413
            args = [self.visit(arg) for arg in node.args]
3✔
414

415
            tmp_name = self._get_temp_name("_tmp_")
3✔
416
            dtype = Scalar(PrimitiveType.Double)
3✔
417
            self.builder.add_container(tmp_name, dtype, False)
3✔
418
            self.symbol_table[tmp_name] = dtype
3✔
419

420
            block = self.builder.add_block()
3✔
421
            t_out = self.builder.add_access(block, tmp_name)
3✔
422

423
            t_task = self.builder.add_intrinsic(block, func_name)
3✔
424

425
            for i, arg in enumerate(args):
3✔
426
                t_arg, arg_sub = self._add_read(block, arg)
3✔
427
                self.builder.add_memlet(
3✔
428
                    block, t_arg, "void", t_task, f"_in{i+1}", arg_sub
429
                )
430

431
            self.builder.add_memlet(block, t_task, "_out", t_out, "void", "")
3✔
432
            return tmp_name
3✔
433

434
        if func_name in self.globals_dict:
3✔
435
            obj = self.globals_dict[func_name]
3✔
436
            if inspect.isfunction(obj):
3✔
437
                return self._handle_inline_call(node, obj)
3✔
438

439
        raise NotImplementedError(f"Function call {func_name} not supported")
×
440

441
    def _handle_inline_call(self, node, func_obj):
3✔
442
        # 1. Parse function source
443
        try:
3✔
444
            source_lines, start_line = inspect.getsourcelines(func_obj)
3✔
445
            source = textwrap.dedent("".join(source_lines))
3✔
446
            tree = ast.parse(source)
3✔
447
            func_def = tree.body[0]
3✔
448
        except Exception as e:
×
449
            raise NotImplementedError(
×
450
                f"Could not parse function {func_obj.__name__}: {e}"
451
            )
452

453
        # 2. Evaluate arguments
454
        arg_vars = [self.visit(arg) for arg in node.args]
3✔
455

456
        if len(arg_vars) != len(func_def.args.args):
3✔
457
            raise NotImplementedError(
×
458
                f"Argument count mismatch for {func_obj.__name__}"
459
            )
460

461
        # 3. Generate unique suffix
462
        suffix = f"_{func_obj.__name__}_{self._get_unique_id()}"
3✔
463
        res_name = f"_res{suffix}"
3✔
464

465
        # Assume Int64 for now as match returns 0/1
466
        dtype = Scalar(PrimitiveType.Int64)
3✔
467
        self.builder.add_container(res_name, dtype, False)
3✔
468
        self.symbol_table[res_name] = dtype
3✔
469

470
        # 4. Rename variables
471
        class VariableRenamer(ast.NodeTransformer):
3✔
472
            def __init__(self, suffix, globals_dict):
3✔
473
                self.suffix = suffix
3✔
474
                self.globals_dict = globals_dict
3✔
475

476
            def visit_Name(self, node):
3✔
477
                if node.id in self.globals_dict:
3✔
478
                    return node
3✔
479
                return ast.Name(id=f"{node.id}{self.suffix}", ctx=node.ctx)
3✔
480

481
            def visit_Return(self, node):
3✔
482
                if node.value:
3✔
483
                    val = self.visit(node.value)
3✔
484
                    return ast.Assign(
3✔
485
                        targets=[ast.Name(id=res_name, ctx=ast.Store())],
486
                        value=val,
487
                    )
488
                return node
×
489

490
        renamer = VariableRenamer(suffix, self.globals_dict)
3✔
491
        new_body = [renamer.visit(stmt) for stmt in func_def.body]
3✔
492

493
        # 5. Assign arguments to parameters
494
        param_assignments = []
3✔
495
        for arg_def, arg_val in zip(func_def.args.args, arg_vars):
3✔
496
            param_name = f"{arg_def.arg}{suffix}"
3✔
497

498
            # Infer type and create container
499
            if arg_val in self.symbol_table:
3✔
500
                self.symbol_table[param_name] = self.symbol_table[arg_val]
3✔
501
                self.builder.add_container(
3✔
502
                    param_name, self.symbol_table[arg_val], False
503
                )
504
                val_node = ast.Name(id=arg_val, ctx=ast.Load())
3✔
505
            elif self._is_int(arg_val):
×
506
                self.symbol_table[param_name] = Scalar(PrimitiveType.Int64)
×
507
                self.builder.add_container(
×
508
                    param_name, Scalar(PrimitiveType.Int64), False
509
                )
510
                val_node = ast.Constant(value=int(arg_val))
×
511
            else:
512
                # Assume float constant
513
                try:
×
514
                    val = float(arg_val)
×
515
                    self.symbol_table[param_name] = Scalar(PrimitiveType.Double)
×
516
                    self.builder.add_container(
×
517
                        param_name, Scalar(PrimitiveType.Double), False
518
                    )
519
                    val_node = ast.Constant(value=val)
×
520
                except ValueError:
×
521
                    # Fallback to Name, might fail later if not in symbol table
522
                    val_node = ast.Name(id=arg_val, ctx=ast.Load())
×
523

524
            assign = ast.Assign(
3✔
525
                targets=[ast.Name(id=param_name, ctx=ast.Store())], value=val_node
526
            )
527
            param_assignments.append(assign)
3✔
528

529
        final_body = param_assignments + new_body
3✔
530

531
        # 6. Visit new body using ASTParser
532
        from .ast_parser import ASTParser
3✔
533

534
        parser = ASTParser(
3✔
535
            self.builder,
536
            self.array_info,
537
            self.symbol_table,
538
            globals_dict=self.globals_dict,
539
            unique_counter_ref=self._unique_counter_ref,
540
        )
541

542
        for stmt in final_body:
3✔
543
            parser.visit(stmt)
3✔
544

545
        return res_name
3✔
546

547
    def visit_BinOp(self, node):
3✔
548
        if isinstance(node.op, ast.MatMult):
3✔
549
            return self._handle_numpy_matmul_op(node.left, node.right)
3✔
550

551
        left = self.visit(node.left)
3✔
552
        right = self.visit(node.right)
3✔
553
        op = self.visit(node.op)
3✔
554

555
        # Check if left or right are arrays
556
        left_is_array = left in self.array_info
3✔
557
        right_is_array = right in self.array_info
3✔
558

559
        if left_is_array or right_is_array:
3✔
560
            op_map = {"+": "add", "-": "sub", "*": "mul", "/": "div", "**": "pow"}
3✔
561
            if op in op_map:
3✔
562
                return self._handle_array_binary_op(op_map[op], left, right)
3✔
563
            else:
564
                raise NotImplementedError(f"Array operation {op} not supported")
×
565

566
        tmp_name = f"_tmp_{self._get_unique_id()}"
3✔
567

568
        dtype = Scalar(PrimitiveType.Double)  # Default
3✔
569

570
        left_is_int = self._is_int(left)
3✔
571
        right_is_int = self._is_int(right)
3✔
572

573
        if left_is_int and right_is_int and op not in ["/", "**"]:
3✔
574
            dtype = Scalar(PrimitiveType.Int64)
3✔
575

576
        self.builder.add_container(tmp_name, dtype, False)
3✔
577
        self.symbol_table[tmp_name] = dtype
3✔
578

579
        real_left = left
3✔
580
        real_right = right
3✔
581

582
        if dtype.primitive_type == PrimitiveType.Double:
3✔
583
            if left_is_int:
3✔
584
                left_cast = f"_tmp_{self._get_unique_id()}"
3✔
585
                self.builder.add_container(
3✔
586
                    left_cast, Scalar(PrimitiveType.Double), False
587
                )
588
                self.symbol_table[left_cast] = Scalar(PrimitiveType.Double)
3✔
589

590
                c_block = self.builder.add_block()
3✔
591
                t_src, src_sub = self._add_read(c_block, left)
3✔
592
                t_dst = self.builder.add_access(c_block, left_cast)
3✔
593
                t_task = self.builder.add_tasklet(c_block, "assign", ["_in"], ["_out"])
3✔
594
                self.builder.add_memlet(c_block, t_src, "void", t_task, "_in", src_sub)
3✔
595
                self.builder.add_memlet(c_block, t_task, "_out", t_dst, "void", "")
3✔
596

597
                real_left = left_cast
3✔
598

599
            if right_is_int:
3✔
600
                right_cast = f"_tmp_{self._get_unique_id()}"
3✔
601
                self.builder.add_container(
3✔
602
                    right_cast, Scalar(PrimitiveType.Double), False
603
                )
604
                self.symbol_table[right_cast] = Scalar(PrimitiveType.Double)
3✔
605

606
                c_block = self.builder.add_block()
3✔
607
                t_src, src_sub = self._add_read(c_block, right)
3✔
608
                t_dst = self.builder.add_access(c_block, right_cast)
3✔
609
                t_task = self.builder.add_tasklet(c_block, "assign", ["_in"], ["_out"])
3✔
610
                self.builder.add_memlet(c_block, t_src, "void", t_task, "_in", src_sub)
3✔
611
                self.builder.add_memlet(c_block, t_task, "_out", t_dst, "void", "")
3✔
612

613
                real_right = right_cast
3✔
614

615
        # Special cases
616
        if op == "**":
3✔
617
            block = self.builder.add_block()
3✔
618
            t_left, left_sub = self._add_read(block, real_left)
3✔
619
            t_right, right_sub = self._add_read(block, real_right)
3✔
620
            t_out = self.builder.add_access(block, tmp_name)
3✔
621

622
            t_task = self.builder.add_intrinsic(block, "pow")
3✔
623
            self.builder.add_memlet(block, t_left, "void", t_task, "_in1", left_sub)
3✔
624
            self.builder.add_memlet(block, t_right, "void", t_task, "_in2", right_sub)
3✔
625
            self.builder.add_memlet(block, t_task, "_out", t_out, "void", "")
3✔
626

627
            return tmp_name
3✔
628
        elif op == "%":
3✔
629
            block = self.builder.add_block()
3✔
630
            t_left, left_sub = self._add_read(block, real_left)
3✔
631
            t_right, right_sub = self._add_read(block, real_right)
3✔
632
            t_out = self.builder.add_access(block, tmp_name)
3✔
633

634
            if dtype.primitive_type == PrimitiveType.Int64:
3✔
635
                # Implement ((a % b) + b) % b to match Python's modulo behavior
636

637
                # 1. rem1 = a % b
638
                t_rem1 = self.builder.add_tasklet(
3✔
639
                    block, "int_rem", ["_in1", "_in2"], ["_out"]
640
                )
641
                self.builder.add_memlet(block, t_left, "void", t_rem1, "_in1", left_sub)
3✔
642
                self.builder.add_memlet(
3✔
643
                    block, t_right, "void", t_rem1, "_in2", right_sub
644
                )
645

646
                rem1_name = f"_tmp_{self._get_unique_id()}"
3✔
647
                self.builder.add_container(rem1_name, dtype, False)
3✔
648
                t_rem1_out = self.builder.add_access(block, rem1_name)
3✔
649
                self.builder.add_memlet(block, t_rem1, "_out", t_rem1_out, "void", "")
3✔
650

651
                # 2. add = rem1 + b
652
                t_add = self.builder.add_tasklet(
3✔
653
                    block, "int_add", ["_in1", "_in2"], ["_out"]
654
                )
655
                self.builder.add_memlet(block, t_rem1_out, "void", t_add, "_in1", "")
3✔
656
                self.builder.add_memlet(
3✔
657
                    block, t_right, "void", t_add, "_in2", right_sub
658
                )
659

660
                add_name = f"_tmp_{self._get_unique_id()}"
3✔
661
                self.builder.add_container(add_name, dtype, False)
3✔
662
                t_add_out = self.builder.add_access(block, add_name)
3✔
663
                self.builder.add_memlet(block, t_add, "_out", t_add_out, "void", "")
3✔
664

665
                # 3. res = add % b
666
                t_rem2 = self.builder.add_tasklet(
3✔
667
                    block, "int_rem", ["_in1", "_in2"], ["_out"]
668
                )
669
                self.builder.add_memlet(block, t_add_out, "void", t_rem2, "_in1", "")
3✔
670
                self.builder.add_memlet(
3✔
671
                    block, t_right, "void", t_rem2, "_in2", right_sub
672
                )
673
                self.builder.add_memlet(block, t_rem2, "_out", t_out, "void", "")
3✔
674

675
                return tmp_name
3✔
676
            else:
677
                t_task = self.builder.add_intrinsic(block, "fmod")
3✔
678
                self.builder.add_memlet(block, t_left, "void", t_task, "_in1", left_sub)
3✔
679
                self.builder.add_memlet(
3✔
680
                    block, t_right, "void", t_task, "_in2", right_sub
681
                )
682
                self.builder.add_memlet(block, t_task, "_out", t_out, "void", "")
3✔
683
                return tmp_name
3✔
684

685
        prefix = "int" if dtype.primitive_type == PrimitiveType.Int64 else "fp"
3✔
686
        op_name = ""
3✔
687
        if op == "+":
3✔
688
            op_name = "add"
3✔
689
        elif op == "-":
3✔
690
            op_name = "sub"
3✔
691
        elif op == "*":
3✔
692
            op_name = "mul"
3✔
693
        elif op == "/":
3✔
694
            op_name = "div"
3✔
695
        elif op == "//":
3✔
696
            op_name = "div"
3✔
697
        elif op == "|":
3✔
698
            op_name = "or"
3✔
699
        elif op == "^":
3✔
700
            op_name = "xor"
3✔
701

702
        block = self.builder.add_block()
3✔
703
        t_left, left_sub = self._add_read(block, real_left)
3✔
704
        t_right, right_sub = self._add_read(block, real_right)
3✔
705
        t_out = self.builder.add_access(block, tmp_name)
3✔
706

707
        tasklet_code = f"{prefix}_{op_name}"
3✔
708
        t_task = self.builder.add_tasklet(
3✔
709
            block, tasklet_code, ["_in1", "_in2"], ["_out"]
710
        )
711

712
        self.builder.add_memlet(block, t_left, "void", t_task, "_in1", left_sub)
3✔
713
        self.builder.add_memlet(block, t_right, "void", t_task, "_in2", right_sub)
3✔
714
        self.builder.add_memlet(block, t_task, "_out", t_out, "void", "")
3✔
715

716
        return tmp_name
3✔
717

718
    def _add_assign_constant(self, target_name, value_str, dtype):
3✔
719
        block = self.builder.add_block()
3✔
720
        t_const = self.builder.add_constant(block, value_str, dtype)
3✔
721
        t_dst = self.builder.add_access(block, target_name)
3✔
722
        t_task = self.builder.add_tasklet(block, "assign", ["_in"], ["_out"])
3✔
723
        self.builder.add_memlet(block, t_const, "void", t_task, "_in", "")
3✔
724
        self.builder.add_memlet(block, t_task, "_out", t_dst, "void", "")
3✔
725

726
    def visit_BoolOp(self, node):
3✔
727
        op = self.visit(node.op)
3✔
728
        values = [f"({self.visit(v)} != 0)" for v in node.values]
3✔
729
        expr_str = f"{f' {op} '.join(values)}"
3✔
730

731
        tmp_name = f"_tmp_{self._get_unique_id()}"
3✔
732
        dtype = Scalar(PrimitiveType.Bool)
3✔
733
        self.builder.add_container(tmp_name, dtype, False)
3✔
734

735
        # Use control flow to assign boolean value
736
        self.builder.begin_if(expr_str)
3✔
737
        self._add_assign_constant(tmp_name, "true", dtype)
3✔
738
        self.builder.begin_else()
3✔
739
        self._add_assign_constant(tmp_name, "false", dtype)
3✔
740
        self.builder.end_if()
3✔
741

742
        self.symbol_table[tmp_name] = dtype
3✔
743
        return tmp_name
3✔
744

745
    def visit_Compare(self, node):
3✔
746
        left = self.visit(node.left)
3✔
747
        if len(node.ops) > 1:
3✔
748
            raise NotImplementedError("Chained comparisons not supported yet")
×
749

750
        op = self.visit(node.ops[0])
3✔
751
        right = self.visit(node.comparators[0])
3✔
752
        expr_str = f"{left} {op} {right}"
3✔
753

754
        tmp_name = f"_tmp_{self._get_unique_id()}"
3✔
755
        dtype = Scalar(PrimitiveType.Bool)
3✔
756
        self.builder.add_container(tmp_name, dtype, False)
3✔
757

758
        # Use control flow to assign boolean value
759
        self.builder.begin_if(expr_str)
3✔
760
        self.builder.add_assignment(tmp_name, "true")
3✔
761
        self.builder.begin_else()
3✔
762
        self.builder.add_assignment(tmp_name, "false")
3✔
763
        self.builder.end_if()
3✔
764

765
        self.symbol_table[tmp_name] = dtype
3✔
766
        return tmp_name
3✔
767

768
    def visit_UnaryOp(self, node):
3✔
769
        op = self.visit(node.op)
3✔
770
        operand = self.visit(node.operand)
3✔
771

772
        tmp_name = f"_tmp_{self._get_unique_id()}"
3✔
773
        dtype = Scalar(PrimitiveType.Double)
3✔
774
        if operand in self.symbol_table:
3✔
775
            dtype = self.symbol_table[operand]
3✔
776
        elif self._is_int(operand):
3✔
777
            dtype = Scalar(PrimitiveType.Int64)
3✔
778
        elif isinstance(node.op, ast.Not):
3✔
779
            dtype = Scalar(PrimitiveType.Bool)
×
780

781
        self.builder.add_container(tmp_name, dtype, False)
3✔
782
        self.symbol_table[tmp_name] = dtype
3✔
783

784
        block = self.builder.add_block()
3✔
785
        t_src, src_sub = self._add_read(block, operand)
3✔
786
        t_dst = self.builder.add_access(block, tmp_name)
3✔
787

788
        if isinstance(node.op, ast.Not):
3✔
789
            t_const = self.builder.add_constant(
3✔
790
                block, "true", Scalar(PrimitiveType.Bool)
791
            )
792
            t_task = self.builder.add_tasklet(
3✔
793
                block, "int_xor", ["_in1", "_in2"], ["_out"]
794
            )
795
            self.builder.add_memlet(block, t_src, "void", t_task, "_in1", src_sub)
3✔
796
            self.builder.add_memlet(block, t_const, "void", t_task, "_in2", "")
3✔
797
            self.builder.add_memlet(block, t_task, "_out", t_dst, "void", "")
3✔
798

799
        elif op == "-":
3✔
800
            if dtype.primitive_type == PrimitiveType.Int64:
3✔
801
                t_const = self.builder.add_constant(block, "0", dtype)
3✔
802
                t_task = self.builder.add_tasklet(
3✔
803
                    block, "int_sub", ["_in1", "_in2"], ["_out"]
804
                )
805
                self.builder.add_memlet(block, t_const, "void", t_task, "_in1", "")
3✔
806
                self.builder.add_memlet(block, t_src, "void", t_task, "_in2", src_sub)
3✔
807
                self.builder.add_memlet(block, t_task, "_out", t_dst, "void", "")
3✔
808
            else:
809
                t_task = self.builder.add_tasklet(block, "fp_neg", ["_in"], ["_out"])
3✔
810
                self.builder.add_memlet(block, t_src, "void", t_task, "_in", src_sub)
3✔
811
                self.builder.add_memlet(block, t_task, "_out", t_dst, "void", "")
3✔
812
        else:
813
            t_task = self.builder.add_tasklet(block, "assign", ["_in"], ["_out"])
×
814
            self.builder.add_memlet(block, t_src, "void", t_task, "_in", src_sub)
×
815
            self.builder.add_memlet(block, t_task, "_out", t_dst, "void", "")
×
816

817
        return tmp_name
3✔
818

819
    def _parse_array_arg(self, node, simple_visitor):
3✔
820
        if isinstance(node, ast.Name):
×
821
            if node.id in self.array_info:
×
822
                return node.id, [], self.array_info[node.id]["shapes"]
×
823
        elif isinstance(node, ast.Subscript):
×
824
            if isinstance(node.value, ast.Name) and node.value.id in self.array_info:
×
825
                name = node.value.id
×
826
                ndim = self.array_info[name]["ndim"]
×
827

828
                indices = []
×
829
                if isinstance(node.slice, ast.Tuple):
×
830
                    indices = list(node.slice.elts)
×
831
                else:
832
                    indices = [node.slice]
×
833

834
                while len(indices) < ndim:
×
835
                    indices.append(ast.Slice(lower=None, upper=None, step=None))
×
836

837
                start_indices = []
×
838
                slice_shape = []
×
839

840
                for i, idx in enumerate(indices):
×
841
                    if isinstance(idx, ast.Slice):
×
842
                        start = "0"
×
843
                        if idx.lower:
×
844
                            start = simple_visitor.visit(idx.lower)
×
845
                        start_indices.append(start)
×
846

847
                        shapes = self.array_info[name]["shapes"]
×
848
                        dim_size = (
×
849
                            shapes[i] if i < len(shapes) else f"_{name}_shape_{i}"
850
                        )
851
                        stop = dim_size
×
852
                        if idx.upper:
×
853
                            stop = simple_visitor.visit(idx.upper)
×
854

855
                        size = f"({stop} - {start})"
×
856
                        slice_shape.append(size)
×
857
                    else:
858
                        val = simple_visitor.visit(idx)
×
859
                        start_indices.append(val)
×
860

861
                shapes = self.array_info[name]["shapes"]
×
862
                linear_index = ""
×
863
                for i in range(ndim):
×
864
                    term = start_indices[i]
×
865
                    for j in range(i + 1, ndim):
×
866
                        shape_val = shapes[j] if j < len(shapes) else None
×
867
                        shape_sym = (
×
868
                            shape_val if shape_val is not None else f"_{name}_shape_{j}"
869
                        )
870
                        term = f"({term} * {shape_sym})"
×
871

872
                    if i == 0:
×
873
                        linear_index = term
×
874
                    else:
875
                        linear_index = f"({linear_index} + {term})"
×
876

877
                return name, [linear_index], slice_shape
×
878

879
        return None, None, None
×
880

881
    def visit_Attribute(self, node):
3✔
882
        if node.attr == "shape":
3✔
883
            if isinstance(node.value, ast.Name) and node.value.id in self.array_info:
3✔
884
                return f"_shape_proxy_{node.value.id}"
3✔
885

886
        if isinstance(node.value, ast.Name) and node.value.id == "math":
3✔
887
            val = ""
3✔
888
            if node.attr == "pi":
3✔
889
                val = "M_PI"
3✔
890
            elif node.attr == "e":
3✔
891
                val = "M_E"
3✔
892

893
            if val:
3✔
894
                tmp_name = f"_tmp_{self._get_unique_id()}"
3✔
895
                dtype = Scalar(PrimitiveType.Double)
3✔
896
                self.builder.add_container(tmp_name, dtype, False)
3✔
897
                self.symbol_table[tmp_name] = dtype
3✔
898
                self._add_assign_constant(tmp_name, val, dtype)
3✔
899
                return tmp_name
3✔
900

901
        # Handle class member access (e.g., obj.x, obj.y)
902
        if isinstance(node.value, ast.Name):
3✔
903
            obj_name = node.value.id
3✔
904
            attr_name = node.attr
3✔
905

906
            # Check if the object is a class instance (has a Structure type)
907
            if obj_name in self.symbol_table:
3✔
908
                obj_type = self.symbol_table[obj_name]
3✔
909
                if isinstance(obj_type, Pointer) and obj_type.has_pointee_type():
3✔
910
                    pointee_type = obj_type.pointee_type
3✔
911
                    if isinstance(pointee_type, Structure):
3✔
912
                        struct_name = pointee_type.name
3✔
913

914
                        # Look up member index and type from structure info
915
                        if (
3✔
916
                            struct_name in self.structure_member_info
917
                            and attr_name in self.structure_member_info[struct_name]
918
                        ):
919
                            member_index, member_type = self.structure_member_info[
3✔
920
                                struct_name
921
                            ][attr_name]
922
                        else:
923
                            # This should not happen if structure was registered properly
924
                            raise RuntimeError(
×
925
                                f"Member '{attr_name}' not found in structure '{struct_name}'. "
926
                                f"Available members: {list(self.structure_member_info.get(struct_name, {}).keys())}"
927
                            )
928

929
                        # Generate a tasklet to access the member
930
                        tmp_name = f"_tmp_{self._get_unique_id()}"
3✔
931

932
                        self.builder.add_container(tmp_name, member_type, False)
3✔
933
                        self.symbol_table[tmp_name] = member_type
3✔
934

935
                        # Create a tasklet that reads the member
936
                        block = self.builder.add_block()
3✔
937
                        obj_access = self.builder.add_access(block, obj_name)
3✔
938
                        tmp_access = self.builder.add_access(block, tmp_name)
3✔
939

940
                        # Use tasklet to pass through the value
941
                        # The actual member selection is done via the memlet subset
942
                        tasklet = self.builder.add_tasklet(
3✔
943
                            block, "assign", ["_in"], ["_out"]
944
                        )
945

946
                        # Use member index in the subset to select the correct member
947
                        subset = "0," + str(member_index)
3✔
948
                        self.builder.add_memlet(
3✔
949
                            block, obj_access, "", tasklet, "_in", subset
950
                        )
951
                        self.builder.add_memlet(block, tasklet, "_out", tmp_access, "")
3✔
952

953
                        return tmp_name
3✔
954

955
        raise NotImplementedError(f"Attribute access {node.attr} not supported")
×
956

957
    def visit_Subscript(self, node):
3✔
958
        value_str = self.visit(node.value)
3✔
959

960
        if value_str.startswith("_shape_proxy_"):
3✔
961
            array_name = value_str[len("_shape_proxy_") :]
3✔
962
            if isinstance(node.slice, ast.Constant):
3✔
963
                idx = node.slice.value
3✔
964
            elif isinstance(node.slice, ast.Index):
×
965
                idx = node.slice.value.value
×
966
            else:
967
                try:
×
968
                    idx = int(self.visit(node.slice))
×
969
                except:
×
970
                    raise NotImplementedError(
×
971
                        "Dynamic shape indexing not fully supported yet"
972
                    )
973

974
            if (
3✔
975
                array_name in self.array_info
976
                and "shapes" in self.array_info[array_name]
977
            ):
978
                return self.array_info[array_name]["shapes"][idx]
3✔
979

980
            return f"_{array_name}_shape_{idx}"
×
981

982
        if value_str in self.array_info:
3✔
983
            ndim = self.array_info[value_str]["ndim"]
3✔
984
            shapes = self.array_info[value_str].get("shapes", [])
3✔
985

986
            indices = []
3✔
987
            if isinstance(node.slice, ast.Tuple):
3✔
988
                indices_nodes = node.slice.elts
3✔
989
            else:
990
                indices_nodes = [node.slice]
3✔
991

992
            for idx in indices_nodes:
3✔
993
                if isinstance(idx, ast.Slice):
3✔
994
                    raise ValueError("Slices not supported in expression indexing")
×
995

996
            if isinstance(node.slice, ast.Tuple):
3✔
997
                indices = [self.visit(elt) for elt in node.slice.elts]
3✔
998
            else:
999
                indices = [self.visit(node.slice)]
3✔
1000

1001
            if len(indices) != ndim:
3✔
1002
                raise ValueError(
×
1003
                    f"Array {value_str} has {ndim} dimensions, but accessed with {len(indices)} indices"
1004
                )
1005

1006
            linear_index = ""
3✔
1007
            for i in range(ndim):
3✔
1008
                term = indices[i]
3✔
1009
                for j in range(i + 1, ndim):
3✔
1010
                    shape_val = shapes[j] if j < len(shapes) else None
3✔
1011
                    shape_sym = (
3✔
1012
                        shape_val
1013
                        if shape_val is not None
1014
                        else f"_{value_str}_shape_{j}"
1015
                    )
1016
                    term = f"(({term}) * {shape_sym})"
3✔
1017

1018
                if i == 0:
3✔
1019
                    linear_index = term
3✔
1020
                else:
1021
                    linear_index = f"({linear_index} + {term})"
3✔
1022

1023
            access_str = f"{value_str}({linear_index})"
3✔
1024

1025
            if self.builder and isinstance(node.ctx, ast.Load):
3✔
1026
                dtype = Scalar(PrimitiveType.Double)
3✔
1027
                if value_str in self.symbol_table:
3✔
1028
                    t = self.symbol_table[value_str]
3✔
1029
                    if type(t).__name__ == "Array" and hasattr(t, "element_type"):
3✔
1030
                        et = t.element_type
×
1031
                        if callable(et):
×
1032
                            et = et()
×
1033
                        dtype = et
×
1034
                    elif type(t).__name__ == "Pointer" and hasattr(t, "pointee_type"):
3✔
1035
                        et = t.pointee_type
3✔
1036
                        if callable(et):
3✔
1037
                            et = et()
×
1038
                        dtype = et
3✔
1039

1040
                tmp_name = f"_tmp_{self._get_unique_id()}"
3✔
1041
                self.builder.add_container(tmp_name, dtype, False)
3✔
1042

1043
                block = self.builder.add_block()
3✔
1044
                t_src = self.builder.add_access(block, value_str)
3✔
1045
                t_dst = self.builder.add_access(block, tmp_name)
3✔
1046
                t_task = self.builder.add_tasklet(block, "assign", ["_in"], ["_out"])
3✔
1047

1048
                self.builder.add_memlet(
3✔
1049
                    block, t_src, "void", t_task, "_in", linear_index
1050
                )
1051
                self.builder.add_memlet(block, t_task, "_out", t_dst, "void", "")
3✔
1052

1053
                self.symbol_table[tmp_name] = dtype
3✔
1054
                return tmp_name
3✔
1055

1056
            return access_str
3✔
1057

1058
        slice_val = self.visit(node.slice)
×
1059
        access_str = f"{value_str}({slice_val})"
×
1060

1061
        if (
×
1062
            self.builder
1063
            and isinstance(node.ctx, ast.Load)
1064
            and value_str in self.array_info
1065
        ):
1066
            tmp_name = f"_tmp_{self._get_unique_id()}"
×
1067
            self.builder.add_container(tmp_name, Scalar(PrimitiveType.Double), False)
×
1068
            self.builder.add_assignment(tmp_name, access_str)
×
1069
            self.symbol_table[tmp_name] = Scalar(PrimitiveType.Double)
×
1070
            return tmp_name
×
1071

1072
        return access_str
×
1073

1074
    def visit_Add(self, node):
3✔
1075
        return "+"
3✔
1076

1077
    def visit_Sub(self, node):
3✔
1078
        return "-"
3✔
1079

1080
    def visit_Mult(self, node):
3✔
1081
        return "*"
3✔
1082

1083
    def visit_Div(self, node):
3✔
1084
        return "/"
3✔
1085

1086
    def visit_FloorDiv(self, node):
3✔
1087
        return "//"
3✔
1088

1089
    def visit_Mod(self, node):
3✔
1090
        return "%"
3✔
1091

1092
    def visit_Pow(self, node):
3✔
1093
        return "**"
3✔
1094

1095
    def visit_Eq(self, node):
3✔
1096
        return "=="
×
1097

1098
    def visit_NotEq(self, node):
3✔
1099
        return "!="
×
1100

1101
    def visit_Lt(self, node):
3✔
1102
        return "<"
×
1103

1104
    def visit_LtE(self, node):
3✔
1105
        return "<="
×
1106

1107
    def visit_Gt(self, node):
3✔
1108
        return ">"
3✔
1109

1110
    def visit_GtE(self, node):
3✔
1111
        return ">="
×
1112

1113
    def visit_And(self, node):
3✔
1114
        return "&"
3✔
1115

1116
    def visit_Or(self, node):
3✔
1117
        return "|"
3✔
1118

1119
    def visit_BitAnd(self, node):
3✔
1120
        return "&"
×
1121

1122
    def visit_BitOr(self, node):
3✔
1123
        return "|"
3✔
1124

1125
    def visit_BitXor(self, node):
3✔
1126
        return "^"
3✔
1127

1128
    def visit_Not(self, node):
3✔
1129
        return "!"
3✔
1130

1131
    def visit_USub(self, node):
3✔
1132
        return "-"
3✔
1133

1134
    def visit_UAdd(self, node):
3✔
1135
        return "+"
×
1136

1137
    def visit_Invert(self, node):
3✔
1138
        return "~"
×
1139

1140
    def _get_dtype(self, name):
3✔
1141
        if name in self.symbol_table:
3✔
1142
            t = self.symbol_table[name]
3✔
1143
            if isinstance(t, Scalar):
3✔
1144
                return t
×
1145

1146
            if hasattr(t, "pointee_type"):
3✔
1147
                et = t.pointee_type
3✔
1148
                if callable(et):
3✔
1149
                    et = et()
×
1150
                if isinstance(et, Scalar):
3✔
1151
                    return et
3✔
1152

1153
            if hasattr(t, "element_type"):
×
1154
                et = t.element_type
×
1155
                if callable(et):
×
1156
                    et = et()
×
1157
                if isinstance(et, Scalar):
×
1158
                    return et
×
1159

1160
        if self._is_int(name):
3✔
1161
            return Scalar(PrimitiveType.Int64)
×
1162

1163
        return Scalar(PrimitiveType.Double)
3✔
1164

1165
    def _create_array_temp(self, shape, dtype, zero_init=False, ones_init=False):
3✔
1166
        tmp_name = f"_tmp_{self._get_unique_id()}"
3✔
1167

1168
        # Calculate size
1169
        size_str = "1"
3✔
1170
        for dim in shape:
3✔
1171
            size_str = f"({size_str} * {dim})"
3✔
1172

1173
        element_size = self.builder.get_sizeof(dtype)
3✔
1174
        total_size = f"({size_str} * {element_size})"
3✔
1175

1176
        # Create pointer
1177
        ptr_type = Pointer(dtype)
3✔
1178
        self.builder.add_container(tmp_name, ptr_type, False)
3✔
1179
        self.symbol_table[tmp_name] = ptr_type
3✔
1180
        self.array_info[tmp_name] = {"ndim": len(shape), "shapes": shape}
3✔
1181

1182
        # Malloc
1183
        block1 = self.builder.add_block()
3✔
1184
        t_malloc = self.builder.add_malloc(block1, total_size)
3✔
1185
        t_ptr1 = self.builder.add_access(block1, tmp_name)
3✔
1186
        self.builder.add_memlet(block1, t_malloc, "_ret", t_ptr1, "void", "", ptr_type)
3✔
1187

1188
        if zero_init:
3✔
1189
            block2 = self.builder.add_block()
3✔
1190
            t_memset = self.builder.add_memset(block2, "0", total_size)
3✔
1191
            t_ptr2 = self.builder.add_access(block2, tmp_name)
3✔
1192
            self.builder.add_memlet(
3✔
1193
                block2, t_memset, "_ptr", t_ptr2, "void", "", ptr_type
1194
            )
1195
        elif ones_init:
3✔
1196
            # Initialize array with ones using a loop
1197
            loop_var = f"_i_{self._get_unique_id()}"
3✔
1198
            if not self.builder.has_container(loop_var):
3✔
1199
                self.builder.add_container(loop_var, Scalar(PrimitiveType.Int64), False)
3✔
1200
                self.symbol_table[loop_var] = Scalar(PrimitiveType.Int64)
3✔
1201

1202
            self.builder.begin_for(loop_var, "0", size_str, "1")
3✔
1203

1204
            # Determine the value to set based on dtype
1205
            val = "1.0"
3✔
1206
            if dtype.primitive_type in [
3✔
1207
                PrimitiveType.Int64,
1208
                PrimitiveType.Int32,
1209
                PrimitiveType.Int8,
1210
                PrimitiveType.Int16,
1211
                PrimitiveType.UInt64,
1212
                PrimitiveType.UInt32,
1213
                PrimitiveType.UInt8,
1214
                PrimitiveType.UInt16,
1215
            ]:
1216
                val = "1"
3✔
1217

1218
            block_assign = self.builder.add_block()
3✔
1219
            t_const = self.builder.add_constant(block_assign, val, dtype)
3✔
1220
            t_arr = self.builder.add_access(block_assign, tmp_name)
3✔
1221

1222
            t_task = self.builder.add_tasklet(block_assign, "assign", ["_in"], ["_out"])
3✔
1223
            self.builder.add_memlet(
3✔
1224
                block_assign, t_const, "void", t_task, "_in", "", dtype
1225
            )
1226
            self.builder.add_memlet(
3✔
1227
                block_assign, t_task, "_out", t_arr, "void", loop_var
1228
            )
1229

1230
            self.builder.end_for()
3✔
1231

1232
        return tmp_name
3✔
1233

1234
    def _handle_array_unary_op(self, op_type, operand):
3✔
1235
        # Determine output shape
1236
        shape = []
3✔
1237
        if operand in self.array_info:
3✔
1238
            shape = self.array_info[operand]["shapes"]
3✔
1239

1240
        # Determine dtype
1241
        dtype = self._get_dtype(operand)
3✔
1242

1243
        tmp_name = self._create_array_temp(shape, dtype)
3✔
1244

1245
        # Add operation
1246
        self.builder.add_elementwise_unary_op(op_type, operand, tmp_name, shape)
3✔
1247

1248
        return tmp_name
3✔
1249

1250
    def _handle_array_binary_op(self, op_type, left, right):
3✔
1251
        # Determine output shape
1252
        shape = []
3✔
1253
        if left in self.array_info:
3✔
1254
            shape = self.array_info[left]["shapes"]
3✔
1255
        elif right in self.array_info:
×
1256
            shape = self.array_info[right]["shapes"]
×
1257

1258
        # Determine dtype
1259
        dtype_left = self._get_dtype(left)
3✔
1260
        dtype_right = self._get_dtype(right)
3✔
1261

1262
        assert dtype_left.primitive_type == dtype_right.primitive_type
3✔
1263
        dtype = dtype_left
3✔
1264

1265
        tmp_name = self._create_array_temp(shape, dtype)
3✔
1266

1267
        # Add operation
1268
        self.builder.add_elementwise_op(op_type, left, right, tmp_name, shape)
3✔
1269

1270
        return tmp_name
3✔
1271

1272
    def _handle_numpy_alloc(self, node, func_name):
3✔
1273
        # Parse shape
1274
        shape_arg = node.args[0]
3✔
1275
        dims = []
3✔
1276
        if isinstance(shape_arg, ast.Tuple):
3✔
1277
            dims = [self.visit(elt) for elt in shape_arg.elts]
3✔
1278
        elif isinstance(shape_arg, ast.List):
3✔
1279
            dims = [self.visit(elt) for elt in shape_arg.elts]
×
1280
        else:
1281
            val = self.visit(shape_arg)
3✔
1282
            if val.startswith("_shape_proxy_"):
3✔
1283
                array_name = val[len("_shape_proxy_") :]
×
1284
                if array_name in self.array_info:
×
1285
                    dims = self.array_info[array_name]["shapes"]
×
1286
                else:
1287
                    dims = [val]
×
1288
            else:
1289
                dims = [val]
3✔
1290

1291
        # Parse dtype
1292
        dtype_arg = None
3✔
1293
        if len(node.args) > 1:
3✔
1294
            dtype_arg = node.args[1]
×
1295

1296
        for kw in node.keywords:
3✔
1297
            if kw.arg == "dtype":
3✔
1298
                dtype_arg = kw.value
3✔
1299
                break
3✔
1300

1301
        element_type = self._map_numpy_dtype(dtype_arg)
3✔
1302

1303
        return self._create_array_temp(
3✔
1304
            dims,
1305
            element_type,
1306
            zero_init=(func_name == "zeros"),
1307
            ones_init=(func_name == "ones"),
1308
        )
1309

1310
    def _handle_numpy_empty_like(self, node, func_name):
3✔
1311
        prototype_arg = node.args[0]
3✔
1312
        prototype_name = self.visit(prototype_arg)
3✔
1313

1314
        # Parse shape from prototype
1315
        dims = []
3✔
1316
        if prototype_name in self.array_info:
3✔
1317
            dims = self.array_info[prototype_name]["shapes"]
3✔
1318

1319
        # Parse dtype
1320
        dtype_arg = None
3✔
1321
        if len(node.args) > 1:
3✔
NEW
1322
            dtype_arg = node.args[1]
×
1323

1324
        for kw in node.keywords:
3✔
1325
            if kw.arg == "dtype":
3✔
1326
                dtype_arg = kw.value
3✔
1327
                break
3✔
1328

1329
        element_type = None
3✔
1330
        if dtype_arg:
3✔
1331
            element_type = self._map_numpy_dtype(dtype_arg)
3✔
1332
        else:
1333
            if prototype_name in self.symbol_table:
3✔
1334
                sym_type = self.symbol_table[prototype_name]
3✔
1335
                if isinstance(sym_type, Pointer) and sym_type.has_pointee_type():
3✔
1336
                    element_type = sym_type.pointee_type
3✔
1337

1338
        if element_type is None:
3✔
NEW
1339
            element_type = Scalar(PrimitiveType.Double)
×
1340

1341
        return self._create_array_temp(
3✔
1342
            dims,
1343
            element_type,
1344
            zero_init=False,
1345
            ones_init=False,
1346
        )
1347

1348
    def _handle_numpy_eye(self, node, func_name):
3✔
1349
        # Parse N
1350
        N_arg = node.args[0]
3✔
1351
        N_str = self.visit(N_arg)
3✔
1352

1353
        # Parse M
1354
        M_str = N_str
3✔
1355
        if len(node.args) > 1:
3✔
1356
            M_str = self.visit(node.args[1])
×
1357

1358
        # Parse k
1359
        k_str = "0"
3✔
1360
        if len(node.args) > 2:
3✔
1361
            k_str = self.visit(node.args[2])
×
1362

1363
        # Check keywords for M, k, dtype
1364
        dtype_arg = None
3✔
1365
        for kw in node.keywords:
3✔
1366
            if kw.arg == "M":
3✔
1367
                M_str = self.visit(kw.value)
3✔
1368
                if M_str == "None":
3✔
1369
                    M_str = N_str
3✔
1370
            elif kw.arg == "k":
3✔
1371
                k_str = self.visit(kw.value)
3✔
1372
            elif kw.arg == "dtype":
3✔
1373
                dtype_arg = kw.value
3✔
1374

1375
        element_type = self._map_numpy_dtype(dtype_arg)
3✔
1376

1377
        ptr_name = self._create_array_temp([N_str, M_str], element_type, zero_init=True)
3✔
1378

1379
        # Loop to set diagonal
1380
        loop_var = f"_i_{self._get_unique_id()}"
3✔
1381
        if not self.builder.has_container(loop_var):
3✔
1382
            self.builder.add_container(loop_var, Scalar(PrimitiveType.Int64), False)
3✔
1383
            self.symbol_table[loop_var] = Scalar(PrimitiveType.Int64)
3✔
1384

1385
        self.builder.begin_for(loop_var, "0", N_str, "1")
3✔
1386

1387
        # Condition: 0 <= i + k < M
1388
        cond = f"(({loop_var} + {k_str}) >= 0) & (({loop_var} + {k_str}) < {M_str})"
3✔
1389
        self.builder.begin_if(cond)
3✔
1390

1391
        # Assignment: A[i, i+k] = 1
1392
        val = "1.0"
3✔
1393
        if element_type.primitive_type in [
3✔
1394
            PrimitiveType.Int64,
1395
            PrimitiveType.Int32,
1396
            PrimitiveType.Int8,
1397
            PrimitiveType.Int16,
1398
            PrimitiveType.UInt64,
1399
            PrimitiveType.UInt32,
1400
            PrimitiveType.UInt8,
1401
            PrimitiveType.UInt16,
1402
        ]:
1403
            val = "1"
×
1404

1405
        block_assign = self.builder.add_block()
3✔
1406
        t_const = self.builder.add_constant(block_assign, val, element_type)
3✔
1407
        t_arr = self.builder.add_access(block_assign, ptr_name)
3✔
1408
        flat_index = f"(({loop_var}) * ({M_str}) + ({loop_var}) + ({k_str}))"
3✔
1409
        subset = flat_index
3✔
1410

1411
        t_task = self.builder.add_tasklet(block_assign, "assign", ["_in"], ["_out"])
3✔
1412
        self.builder.add_memlet(
3✔
1413
            block_assign, t_const, "void", t_task, "_in", "", element_type
1414
        )
1415
        self.builder.add_memlet(block_assign, t_task, "_out", t_arr, "void", subset)
3✔
1416

1417
        self.builder.end_if()
3✔
1418
        self.builder.end_for()
3✔
1419

1420
        return ptr_name
3✔
1421

1422
    def _handle_numpy_binary_op(self, node, func_name):
3✔
1423
        args = [self.visit(arg) for arg in node.args]
3✔
1424
        if len(args) != 2:
3✔
1425
            raise NotImplementedError(
×
1426
                f"Numpy function {func_name} requires 2 arguments"
1427
            )
1428

1429
        op_map = {
3✔
1430
            "add": "add",
1431
            "subtract": "sub",
1432
            "multiply": "mul",
1433
            "divide": "div",
1434
            "power": "pow",
1435
            "minimum": "min",
1436
            "maximum": "max",
1437
        }
1438
        return self._handle_array_binary_op(op_map[func_name], args[0], args[1])
3✔
1439

1440
    def _handle_numpy_matmul_op(self, left_node, right_node):
3✔
1441
        return self._handle_matmul_helper(left_node, right_node)
3✔
1442

1443
    def _handle_numpy_matmul(self, node, func_name):
3✔
1444
        if len(node.args) != 2:
3✔
1445
            raise NotImplementedError("matmul/dot requires 2 arguments")
×
1446
        return self._handle_matmul_helper(node.args[0], node.args[1])
3✔
1447

1448
    def _handle_matmul_helper(self, left_node, right_node):
3✔
1449
        if not self.la_handler:
3✔
1450
            raise RuntimeError("LinearAlgebraHandler not initialized")
×
1451

1452
        res_a = self.la_handler.parse_arg(left_node)
3✔
1453
        res_b = self.la_handler.parse_arg(right_node)
3✔
1454

1455
        if not res_a[0]:
3✔
1456
            left_name = self.visit(left_node)
×
1457
            left_node = ast.Name(id=left_name)
×
1458
            res_a = self.la_handler.parse_arg(left_node)
×
1459

1460
        if not res_b[0]:
3✔
1461
            right_name = self.visit(right_node)
×
1462
            right_node = ast.Name(id=right_name)
×
1463
            res_b = self.la_handler.parse_arg(right_node)
×
1464

1465
        name_a, subset_a, shape_a, indices_a = res_a
3✔
1466
        name_b, subset_b, shape_b, indices_b = res_b
3✔
1467

1468
        if not name_a or not name_b:
3✔
1469
            raise NotImplementedError("Could not resolve matmul operands")
×
1470

1471
        real_shape_a = shape_a
3✔
1472
        real_shape_b = shape_b
3✔
1473

1474
        ndim_a = len(real_shape_a)
3✔
1475
        ndim_b = len(real_shape_b)
3✔
1476

1477
        output_shape = []
3✔
1478
        is_scalar = False
3✔
1479

1480
        if ndim_a == 1 and ndim_b == 1:
3✔
1481
            is_scalar = True
3✔
1482
            output_shape = []
3✔
1483
        elif ndim_a == 2 and ndim_b == 2:
3✔
1484
            output_shape = [real_shape_a[0], real_shape_b[1]]
3✔
1485
        elif ndim_a == 2 and ndim_b == 1:
3✔
1486
            output_shape = [real_shape_a[0]]
3✔
1487
        elif ndim_a == 1 and ndim_b == 2:
3✔
1488
            output_shape = [real_shape_b[1]]
×
1489
        elif ndim_a > 2 or ndim_b > 2:
3✔
1490
            if ndim_a == ndim_b:
3✔
1491
                output_shape = list(real_shape_a[:-2]) + [
3✔
1492
                    real_shape_a[-2],
1493
                    real_shape_b[-1],
1494
                ]
1495
            else:
1496
                raise NotImplementedError(
×
1497
                    "Broadcasting with different ranks not fully supported yet"
1498
                )
1499
        else:
1500
            raise NotImplementedError(
×
1501
                f"Matmul with ranks {ndim_a} and {ndim_b} not supported"
1502
            )
1503

1504
        dtype = Scalar(PrimitiveType.Double)
3✔
1505

1506
        if is_scalar:
3✔
1507
            tmp_name = f"_tmp_{self._get_unique_id()}"
3✔
1508
            self.builder.add_container(tmp_name, dtype, False)
3✔
1509
            self.symbol_table[tmp_name] = dtype
3✔
1510
        else:
1511
            tmp_name = self._create_array_temp(output_shape, dtype)
3✔
1512

1513
        if ndim_a > 2 or ndim_b > 2:
3✔
1514
            # Generate loops for broadcasting
1515
            batch_dims = ndim_a - 2
3✔
1516
            loop_vars = []
3✔
1517

1518
            for i in range(batch_dims):
3✔
1519
                loop_var = f"_i{self._get_unique_id()}"
3✔
1520
                self.builder.add_container(loop_var, Scalar(PrimitiveType.Int64), False)
3✔
1521
                loop_vars.append(loop_var)
3✔
1522
                dim_size = real_shape_a[i]
3✔
1523
                self.builder.begin_for(loop_var, "0", str(dim_size), "1")
3✔
1524

1525
            def make_slice(name, indices):
3✔
1526
                elts = []
3✔
1527
                for idx in indices:
3✔
1528
                    if idx == ":":
3✔
1529
                        elts.append(ast.Slice())
3✔
1530
                    else:
1531
                        elts.append(ast.Name(id=idx))
3✔
1532

1533
                return ast.Subscript(
3✔
1534
                    value=ast.Name(id=name), slice=ast.Tuple(elts=elts), ctx=ast.Load()
1535
                )
1536

1537
            indices = loop_vars + [":", ":"]
3✔
1538
            slice_a = make_slice(name_a, indices)
3✔
1539
            slice_b = make_slice(name_b, indices)
3✔
1540
            slice_c = make_slice(tmp_name, indices)
3✔
1541

1542
            self.la_handler.handle_gemm(
3✔
1543
                slice_c, ast.BinOp(left=slice_a, op=ast.MatMult(), right=slice_b)
1544
            )
1545

1546
            for _ in range(batch_dims):
3✔
1547
                self.builder.end_for()
3✔
1548
        else:
1549
            if is_scalar:
3✔
1550
                self.la_handler.handle_dot(
3✔
1551
                    tmp_name,
1552
                    ast.BinOp(left=left_node, op=ast.MatMult(), right=right_node),
1553
                )
1554
            else:
1555
                self.la_handler.handle_gemm(
3✔
1556
                    tmp_name,
1557
                    ast.BinOp(left=left_node, op=ast.MatMult(), right=right_node),
1558
                )
1559

1560
        return tmp_name
3✔
1561

1562
    def _handle_numpy_unary_op(self, node, func_name):
3✔
1563
        args = [self.visit(arg) for arg in node.args]
3✔
1564
        if len(args) != 1:
3✔
1565
            raise NotImplementedError(f"Numpy function {func_name} requires 1 argument")
×
1566

1567
        op_name = func_name
3✔
1568
        if op_name == "absolute":
3✔
1569
            op_name = "abs"
×
1570

1571
        return self._handle_array_unary_op(op_name, args[0])
3✔
1572

1573
    def _handle_numpy_reduce(self, node, func_name):
3✔
1574
        args = node.args
3✔
1575
        keywords = {kw.arg: kw.value for kw in node.keywords}
3✔
1576

1577
        array_node = args[0]
3✔
1578
        array_name = self.visit(array_node)
3✔
1579

1580
        if array_name not in self.array_info:
3✔
1581
            raise ValueError(f"Reduction input must be an array, got {array_name}")
×
1582

1583
        input_shape = self.array_info[array_name]["shapes"]
3✔
1584
        ndim = len(input_shape)
3✔
1585

1586
        axis = None
3✔
1587
        if len(args) > 1:
3✔
1588
            axis = args[1]
×
1589
        elif "axis" in keywords:
3✔
1590
            axis = keywords["axis"]
3✔
1591

1592
        keepdims = False
3✔
1593
        if "keepdims" in keywords:
3✔
1594
            keepdims_node = keywords["keepdims"]
3✔
1595
            if isinstance(keepdims_node, ast.Constant):
3✔
1596
                keepdims = bool(keepdims_node.value)
3✔
1597

1598
        axes = []
3✔
1599
        if axis is None:
3✔
1600
            axes = list(range(ndim))
3✔
1601
        elif isinstance(axis, ast.Constant):  # Single axis
3✔
1602
            val = axis.value
3✔
1603
            if val < 0:
3✔
1604
                val += ndim
×
1605
            axes = [val]
3✔
1606
        elif isinstance(axis, ast.Tuple):  # Multiple axes
×
1607
            for elt in axis.elts:
×
1608
                if isinstance(elt, ast.Constant):
×
1609
                    val = elt.value
×
1610
                    if val < 0:
×
1611
                        val += ndim
×
1612
                    axes.append(val)
×
1613
        elif (
×
1614
            isinstance(axis, ast.UnaryOp)
1615
            and isinstance(axis.op, ast.USub)
1616
            and isinstance(axis.operand, ast.Constant)
1617
        ):
1618
            val = -axis.operand.value
×
1619
            if val < 0:
×
1620
                val += ndim
×
1621
            axes = [val]
×
1622
        else:
1623
            # Try to evaluate simple expression
1624
            try:
×
1625
                val = int(self.visit(axis))
×
1626
                if val < 0:
×
1627
                    val += ndim
×
1628
                axes = [val]
×
1629
            except:
×
1630
                raise NotImplementedError("Dynamic axis not supported")
×
1631

1632
        # Calculate output shape
1633
        output_shape = []
3✔
1634
        for i in range(ndim):
3✔
1635
            if i in axes:
3✔
1636
                if keepdims:
3✔
1637
                    output_shape.append("1")
3✔
1638
            else:
1639
                output_shape.append(input_shape[i])
3✔
1640

1641
        dtype = self._get_dtype(array_name)
3✔
1642

1643
        if not output_shape:
3✔
1644
            tmp_name = f"_tmp_{self._get_unique_id()}"
3✔
1645
            self.builder.add_container(tmp_name, dtype, False)
3✔
1646
            self.symbol_table[tmp_name] = dtype
3✔
1647
            self.array_info[tmp_name] = {"ndim": 0, "shapes": []}
3✔
1648
        else:
1649
            tmp_name = self._create_array_temp(output_shape, dtype)
3✔
1650

1651
        self.builder.add_reduce_op(
3✔
1652
            func_name, array_name, tmp_name, input_shape, axes, keepdims
1653
        )
1654

1655
        return tmp_name
3✔
1656

1657
    def _handle_numpy_astype(self, node, array_name):
3✔
1658
        """Handle numpy array.astype(dtype) method calls."""
1659
        if len(node.args) < 1:
3✔
1660
            raise ValueError("astype requires at least one argument (dtype)")
×
1661

1662
        dtype_arg = node.args[0]
3✔
1663
        target_dtype = self._map_numpy_dtype(dtype_arg)
3✔
1664

1665
        # Get input array shape
1666
        if array_name not in self.array_info:
3✔
1667
            raise ValueError(f"Array {array_name} not found in array_info")
×
1668

1669
        input_shape = self.array_info[array_name]["shapes"]
3✔
1670

1671
        # Create output array with target dtype
1672
        tmp_name = self._create_array_temp(input_shape, target_dtype)
3✔
1673

1674
        # Add cast operation
1675
        self.builder.add_cast_op(
3✔
1676
            array_name, tmp_name, input_shape, target_dtype.primitive_type
1677
        )
1678

1679
        return tmp_name
3✔
1680

1681
    def _handle_scipy_softmax(self, node, func_name):
3✔
1682
        args = node.args
3✔
1683
        keywords = {kw.arg: kw.value for kw in node.keywords}
3✔
1684

1685
        array_node = args[0]
3✔
1686
        array_name = self.visit(array_node)
3✔
1687

1688
        if array_name not in self.array_info:
3✔
1689
            raise ValueError(f"Softmax input must be an array, got {array_name}")
×
1690

1691
        input_shape = self.array_info[array_name]["shapes"]
3✔
1692
        ndim = len(input_shape)
3✔
1693

1694
        axis = None
3✔
1695
        if len(args) > 1:
3✔
1696
            axis = args[1]
×
1697
        elif "axis" in keywords:
3✔
1698
            axis = keywords["axis"]
3✔
1699

1700
        axes = []
3✔
1701
        if axis is None:
3✔
1702
            axes = list(range(ndim))
3✔
1703
        elif isinstance(axis, ast.Constant):  # Single axis
3✔
1704
            val = axis.value
3✔
1705
            if val < 0:
3✔
1706
                val += ndim
×
1707
            axes = [val]
3✔
1708
        elif isinstance(axis, ast.Tuple):  # Multiple axes
×
1709
            for elt in axis.elts:
×
1710
                if isinstance(elt, ast.Constant):
×
1711
                    val = elt.value
×
1712
                    if val < 0:
×
1713
                        val += ndim
×
1714
                    axes.append(val)
×
1715
        elif (
×
1716
            isinstance(axis, ast.UnaryOp)
1717
            and isinstance(axis.op, ast.USub)
1718
            and isinstance(axis.operand, ast.Constant)
1719
        ):
1720
            val = -axis.operand.value
×
1721
            if val < 0:
×
1722
                val += ndim
×
1723
            axes = [val]
×
1724
        else:
1725
            # Try to evaluate simple expression
1726
            try:
×
1727
                val = int(self.visit(axis))
×
1728
                if val < 0:
×
1729
                    val += ndim
×
1730
                axes = [val]
×
1731
            except:
×
1732
                raise NotImplementedError("Dynamic axis not supported")
×
1733

1734
        # Create output array
1735
        # Assume double for now, or infer from input
1736
        dtype = Scalar(PrimitiveType.Double)  # TODO: infer
3✔
1737

1738
        tmp_name = self._create_array_temp(input_shape, dtype)
3✔
1739

1740
        self.builder.add_reduce_op(
3✔
1741
            func_name, array_name, tmp_name, input_shape, axes, False
1742
        )
1743

1744
        return tmp_name
3✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc