• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

daisytuner / sdfglib / 21113623600

18 Jan 2026 02:50PM UTC coverage: 64.425% (+0.3%) from 64.154%
21113623600

Pull #462

github

web-flow
Merge d503e5691 into 92e9cbdc3
Pull Request #462: adds syntax support for multi-assignments and np.empty_like

221 of 258 new or added lines in 5 files covered. (85.66%)

21 existing lines in 4 files now uncovered.

19678 of 30544 relevant lines covered (64.43%)

385.56 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

78.41
/python/docc/expression_visitor.py
1
import ast
3✔
2
import inspect
3✔
3
import textwrap
3✔
4
from ._sdfg import Scalar, PrimitiveType, Pointer, Type, DebugInfo, Structure
3✔
5

6

7
class ExpressionVisitor(ast.NodeVisitor):
3✔
8
    def __init__(
3✔
9
        self,
10
        array_info=None,
11
        builder=None,
12
        symbol_table=None,
13
        globals_dict=None,
14
        inliner=None,
15
        unique_counter_ref=None,
16
        structure_member_info=None,
17
    ):
18
        self.array_info = array_info if array_info is not None else {}
3✔
19
        self.builder = builder
3✔
20
        self.symbol_table = symbol_table if symbol_table is not None else {}
3✔
21
        self.globals_dict = globals_dict if globals_dict is not None else {}
3✔
22
        self.inliner = inliner
3✔
23
        self._unique_counter_ref = (
3✔
24
            unique_counter_ref if unique_counter_ref is not None else [0]
25
        )
26
        self._access_cache = {}
3✔
27
        self.la_handler = None
3✔
28
        self.structure_member_info = (
3✔
29
            structure_member_info if structure_member_info is not None else {}
30
        )
31
        self._init_numpy_handlers()
3✔
32

33
    def _get_unique_id(self):
3✔
34
        self._unique_counter_ref[0] += 1
3✔
35
        return self._unique_counter_ref[0]
3✔
36

37
    def _get_temp_name(self, prefix="_tmp_"):
3✔
38
        if hasattr(self.builder, "find_new_name"):
3✔
39
            return self.builder.find_new_name(prefix)
×
40
        return f"{prefix}{self._get_unique_id()}"
3✔
41

42
    def _init_numpy_handlers(self):
3✔
43
        self.numpy_handlers = {
3✔
44
            "empty": self._handle_numpy_alloc,
45
            "empty_like": self._handle_numpy_empty_like,
46
            "zeros": self._handle_numpy_alloc,
47
            "zeros_like": self._handle_numpy_zeros_like,
48
            "ones": self._handle_numpy_alloc,
49
            "eye": self._handle_numpy_eye,
50
            "add": self._handle_numpy_binary_op,
51
            "subtract": self._handle_numpy_binary_op,
52
            "multiply": self._handle_numpy_binary_op,
53
            "divide": self._handle_numpy_binary_op,
54
            "power": self._handle_numpy_binary_op,
55
            "exp": self._handle_numpy_unary_op,
56
            "abs": self._handle_numpy_unary_op,
57
            "absolute": self._handle_numpy_unary_op,
58
            "sqrt": self._handle_numpy_unary_op,
59
            "tanh": self._handle_numpy_unary_op,
60
            "sum": self._handle_numpy_reduce,
61
            "max": self._handle_numpy_reduce,
62
            "min": self._handle_numpy_reduce,
63
            "mean": self._handle_numpy_reduce,
64
            "std": self._handle_numpy_reduce,
65
            "matmul": self._handle_numpy_matmul,
66
            "dot": self._handle_numpy_matmul,
67
            "matvec": self._handle_numpy_matmul,
68
            "minimum": self._handle_numpy_binary_op,
69
            "maximum": self._handle_numpy_binary_op,
70
        }
71

72
    def generic_visit(self, node):
3✔
73
        return super().generic_visit(node)
×
74

75
    def visit_Constant(self, node):
3✔
76
        if isinstance(node.value, bool):
3✔
77
            return "true" if node.value else "false"
×
78
        return str(node.value)
3✔
79

80
    def visit_Name(self, node):
3✔
81
        return node.id
3✔
82

83
    def _map_numpy_dtype(self, dtype_node):
3✔
84
        # Default to double
85
        if dtype_node is None:
3✔
86
            return Scalar(PrimitiveType.Double)
×
87

88
        if isinstance(dtype_node, ast.Name):
3✔
89
            if dtype_node.id == "float":
3✔
90
                return Scalar(PrimitiveType.Double)
3✔
91
            if dtype_node.id == "int":
3✔
92
                return Scalar(PrimitiveType.Int64)
3✔
93
            if dtype_node.id == "bool":
×
94
                return Scalar(PrimitiveType.Bool)
×
95

96
        if isinstance(dtype_node, ast.Attribute):
3✔
97
            # Handle array.dtype
98
            if (
3✔
99
                isinstance(dtype_node.value, ast.Name)
100
                and dtype_node.value.id in self.symbol_table
101
                and dtype_node.attr == "dtype"
102
            ):
103
                sym_type = self.symbol_table[dtype_node.value.id]
3✔
104
                if isinstance(sym_type, Pointer) and sym_type.has_pointee_type():
3✔
105
                    return sym_type.pointee_type
3✔
106

107
            if isinstance(dtype_node.value, ast.Name) and dtype_node.value.id in [
3✔
108
                "numpy",
109
                "np",
110
            ]:
111
                if dtype_node.attr == "float64":
3✔
112
                    return Scalar(PrimitiveType.Double)
3✔
113
                if dtype_node.attr == "float32":
3✔
114
                    return Scalar(PrimitiveType.Float)
3✔
115
                if dtype_node.attr == "int64":
3✔
116
                    return Scalar(PrimitiveType.Int64)
3✔
117
                if dtype_node.attr == "int32":
3✔
118
                    return Scalar(PrimitiveType.Int32)
3✔
119
                if dtype_node.attr == "bool_":
×
120
                    return Scalar(PrimitiveType.Bool)
×
121

122
        # Fallback
123
        return Scalar(PrimitiveType.Double)
×
124

125
    def _is_int(self, operand):
3✔
126
        try:
3✔
127
            if operand.lstrip("-").isdigit():
3✔
128
                return True
3✔
129
        except ValueError:
×
130
            pass
×
131

132
        name = operand
3✔
133
        if "(" in operand and operand.endswith(")"):
3✔
134
            name = operand.split("(")[0]
×
135

136
        if name in self.symbol_table:
3✔
137
            t = self.symbol_table[name]
3✔
138

139
            def is_int_ptype(pt):
3✔
140
                return pt in [
3✔
141
                    PrimitiveType.Int64,
142
                    PrimitiveType.Int32,
143
                    PrimitiveType.Int8,
144
                    PrimitiveType.Int16,
145
                    PrimitiveType.UInt64,
146
                    PrimitiveType.UInt32,
147
                    PrimitiveType.UInt8,
148
                    PrimitiveType.UInt16,
149
                ]
150

151
            if isinstance(t, Scalar):
3✔
152
                return is_int_ptype(t.primitive_type)
3✔
153

154
            if type(t).__name__ == "Array" and hasattr(t, "element_type"):
×
155
                et = t.element_type
×
156
                if callable(et):
×
157
                    et = et()
×
158
                if isinstance(et, Scalar):
×
159
                    return is_int_ptype(et.primitive_type)
×
160

161
            if type(t).__name__ == "Pointer":
×
162
                if hasattr(t, "pointee_type"):
×
163
                    et = t.pointee_type
×
164
                    if callable(et):
×
165
                        et = et()
×
166
                    if isinstance(et, Scalar):
×
167
                        return is_int_ptype(et.primitive_type)
×
168
                # Fallback: check if it has element_type (maybe alias?)
169
                if hasattr(t, "element_type"):
×
170
                    et = t.element_type
×
171
                    if callable(et):
×
172
                        et = et()
×
173
                    if isinstance(et, Scalar):
×
174
                        return is_int_ptype(et.primitive_type)
×
175

176
        return False
3✔
177

178
    def _add_read(self, block, expr_str, debug_info=None):
3✔
179
        # Try to reuse access node
180
        try:
3✔
181
            if (block, expr_str) in self._access_cache:
3✔
182
                return self._access_cache[(block, expr_str)]
3✔
183
        except TypeError:
×
184
            # block might not be hashable
185
            pass
×
186

187
        if debug_info is None:
3✔
188
            debug_info = DebugInfo()
3✔
189

190
        if "(" in expr_str and expr_str.endswith(")"):
3✔
191
            name = expr_str.split("(")[0]
3✔
192
            subset = expr_str[expr_str.find("(") + 1 : -1]
3✔
193
            access = self.builder.add_access(block, name, debug_info)
3✔
194
            try:
3✔
195
                self._access_cache[(block, expr_str)] = (access, subset)
3✔
196
            except TypeError:
×
197
                pass
×
198
            return access, subset
3✔
199

200
        if self.builder.has_container(expr_str):
3✔
201
            access = self.builder.add_access(block, expr_str, debug_info)
3✔
202
            try:
3✔
203
                self._access_cache[(block, expr_str)] = (access, "")
3✔
204
            except TypeError:
×
205
                pass
×
206
            return access, ""
3✔
207

208
        dtype = Scalar(PrimitiveType.Double)
3✔
209
        if self._is_int(expr_str):
3✔
210
            dtype = Scalar(PrimitiveType.Int64)
3✔
211
        elif expr_str == "true" or expr_str == "false":
3✔
212
            dtype = Scalar(PrimitiveType.Bool)
×
213

214
        const_node = self.builder.add_constant(block, expr_str, dtype, debug_info)
3✔
215
        try:
3✔
216
            self._access_cache[(block, expr_str)] = (const_node, "")
3✔
217
        except TypeError:
×
218
            pass
×
219
        return const_node, ""
3✔
220

221
    def _handle_min_max(self, node, func_name):
3✔
222
        args = [self.visit(arg) for arg in node.args]
3✔
223
        if len(args) != 2:
3✔
224
            raise NotImplementedError(f"{func_name} only supported with 2 arguments")
×
225

226
        # Check types
227
        is_float = False
3✔
228
        arg_types = []
3✔
229

230
        for arg in args:
3✔
231
            name = arg
3✔
232
            if "(" in arg and arg.endswith(")"):
3✔
233
                name = arg.split("(")[0]
×
234

235
            if name in self.symbol_table:
3✔
236
                t = self.symbol_table[name]
3✔
237
                if isinstance(t, Pointer):
3✔
238
                    t = t.base_type
×
239

240
                if t.primitive_type == PrimitiveType.Double:
3✔
241
                    is_float = True
3✔
242
                    arg_types.append(PrimitiveType.Double)
3✔
243
                else:
244
                    arg_types.append(PrimitiveType.Int64)
3✔
245
            elif self._is_int(arg):
×
246
                arg_types.append(PrimitiveType.Int64)
×
247
            else:
248
                # Assume float constant
249
                is_float = True
×
250
                arg_types.append(PrimitiveType.Double)
×
251

252
        dtype = Scalar(PrimitiveType.Double if is_float else PrimitiveType.Int64)
3✔
253

254
        tmp_name = self._get_temp_name("_tmp_")
3✔
255
        self.builder.add_container(tmp_name, dtype, False)
3✔
256
        self.symbol_table[tmp_name] = dtype
3✔
257

258
        if is_float:
3✔
259
            # Cast args if necessary
260
            casted_args = []
3✔
261
            for i, arg in enumerate(args):
3✔
262
                if arg_types[i] != PrimitiveType.Double:
3✔
263
                    # Create temp double
264
                    tmp_cast = self._get_temp_name("_cast_")
3✔
265
                    self.builder.add_container(
3✔
266
                        tmp_cast, Scalar(PrimitiveType.Double), False
267
                    )
268
                    self.symbol_table[tmp_cast] = Scalar(PrimitiveType.Double)
3✔
269

270
                    # Assign int to double (implicit cast)
271
                    self.builder.add_assignment(tmp_cast, arg)
3✔
272
                    casted_args.append(tmp_cast)
3✔
273
                else:
274
                    casted_args.append(arg)
3✔
275

276
            block = self.builder.add_block()
3✔
277
            t_out = self.builder.add_access(block, tmp_name)
3✔
278

279
            intrinsic_name = "fmax" if func_name == "max" else "fmin"
3✔
280
            t_task = self.builder.add_intrinsic(block, intrinsic_name)
3✔
281

282
            for i, arg in enumerate(casted_args):
3✔
283
                t_arg, arg_sub = self._add_read(block, arg)
3✔
284
                self.builder.add_memlet(
3✔
285
                    block, t_arg, "void", t_task, f"_in{i+1}", arg_sub
286
                )
287
        else:
288
            block = self.builder.add_block()
3✔
289
            t_out = self.builder.add_access(block, tmp_name)
3✔
290

291
            # Use int_smax/int_smin tasklet
292
            opcode = "int_smax" if func_name == "max" else "int_smin"
3✔
293
            t_task = self.builder.add_tasklet(block, opcode, ["_in1", "_in2"], ["_out"])
3✔
294

295
            for i, arg in enumerate(args):
3✔
296
                t_arg, arg_sub = self._add_read(block, arg)
3✔
297
                self.builder.add_memlet(
3✔
298
                    block, t_arg, "void", t_task, f"_in{i+1}", arg_sub
299
                )
300

301
        self.builder.add_memlet(block, t_task, "_out", t_out, "void", "")
3✔
302
        return tmp_name
3✔
303

304
    def _handle_python_cast(self, node, func_name):
3✔
305
        """Handle Python type casts: int(), float(), bool()"""
306
        if len(node.args) != 1:
3✔
307
            raise NotImplementedError(f"{func_name}() cast requires exactly 1 argument")
×
308

309
        arg = self.visit(node.args[0])
3✔
310

311
        # Determine target type based on cast function
312
        if func_name == "int":
3✔
313
            target_dtype = Scalar(PrimitiveType.Int64)
3✔
314
        elif func_name == "float":
3✔
315
            target_dtype = Scalar(PrimitiveType.Double)
3✔
316
        elif func_name == "bool":
3✔
317
            target_dtype = Scalar(PrimitiveType.Bool)
3✔
318
        else:
319
            raise NotImplementedError(f"Cast to {func_name} not supported")
×
320

321
        # Determine source type
322
        source_dtype = None
3✔
323
        name = arg
3✔
324
        if "(" in arg and arg.endswith(")"):
3✔
325
            name = arg.split("(")[0]
×
326

327
        if name in self.symbol_table:
3✔
328
            source_dtype = self.symbol_table[name]
3✔
329
            if isinstance(source_dtype, Pointer):
3✔
330
                source_dtype = source_dtype.base_type
×
331
        elif self._is_int(arg):
×
332
            source_dtype = Scalar(PrimitiveType.Int64)
×
333
        elif arg == "true" or arg == "false":
×
334
            source_dtype = Scalar(PrimitiveType.Bool)
×
335
        else:
336
            # Assume float constant
337
            source_dtype = Scalar(PrimitiveType.Double)
×
338

339
        # Create temporary variable for result
340
        tmp_name = self._get_temp_name("_tmp_")
3✔
341
        self.builder.add_container(tmp_name, target_dtype, False)
3✔
342
        self.symbol_table[tmp_name] = target_dtype
3✔
343

344
        # Use tasklet assign opcode for casting (as specified in problem statement)
345
        block = self.builder.add_block()
3✔
346
        t_src, src_sub = self._add_read(block, arg)
3✔
347
        t_dst = self.builder.add_access(block, tmp_name)
3✔
348
        t_task = self.builder.add_tasklet(block, "assign", ["_in"], ["_out"])
3✔
349
        self.builder.add_memlet(block, t_src, "void", t_task, "_in", src_sub)
3✔
350
        self.builder.add_memlet(block, t_task, "_out", t_dst, "void", "")
3✔
351

352
        return tmp_name
3✔
353

354
    def visit_Call(self, node):
3✔
355
        func_name = ""
3✔
356
        module_name = ""
3✔
357
        if isinstance(node.func, ast.Attribute):
3✔
358
            if isinstance(node.func.value, ast.Name):
3✔
359
                if node.func.value.id == "math":
3✔
360
                    module_name = "math"
3✔
361
                    func_name = node.func.attr
3✔
362
                elif node.func.value.id in ["numpy", "np"]:
3✔
363
                    module_name = "numpy"
3✔
364
                    func_name = node.func.attr
3✔
365
                else:
366
                    # Check if it's a method call on an array (e.g., arr.astype(...))
367
                    array_name = node.func.value.id
3✔
368
                    method_name = node.func.attr
3✔
369
                    if array_name in self.array_info and method_name == "astype":
3✔
370
                        return self._handle_numpy_astype(node, array_name)
3✔
371
            elif isinstance(node.func.value, ast.Attribute):
3✔
372
                if (
3✔
373
                    isinstance(node.func.value.value, ast.Name)
374
                    and node.func.value.value.id == "scipy"
375
                    and node.func.value.attr == "special"
376
                ):
377
                    if node.func.attr == "softmax":
3✔
378
                        return self._handle_scipy_softmax(node, "softmax")
3✔
379

380
        elif isinstance(node.func, ast.Name):
3✔
381
            func_name = node.func.id
3✔
382

383
        if module_name == "numpy":
3✔
384
            if func_name in self.numpy_handlers:
3✔
385
                return self.numpy_handlers[func_name](node, func_name)
3✔
386

387
        if func_name in ["max", "min"]:
3✔
388
            return self._handle_min_max(node, func_name)
3✔
389

390
        # Handle Python type casts (int, float, bool)
391
        if func_name in ["int", "float", "bool"]:
3✔
392
            return self._handle_python_cast(node, func_name)
3✔
393

394
        math_funcs = [
3✔
395
            "sin",
396
            "cos",
397
            "tan",
398
            "exp",
399
            "log",
400
            "sqrt",
401
            "pow",
402
            "abs",
403
            "ceil",
404
            "floor",
405
            "asin",
406
            "acos",
407
            "atan",
408
            "sinh",
409
            "cosh",
410
            "tanh",
411
        ]
412

413
        if func_name in math_funcs:
3✔
414
            args = [self.visit(arg) for arg in node.args]
3✔
415

416
            tmp_name = self._get_temp_name("_tmp_")
3✔
417
            dtype = Scalar(PrimitiveType.Double)
3✔
418
            self.builder.add_container(tmp_name, dtype, False)
3✔
419
            self.symbol_table[tmp_name] = dtype
3✔
420

421
            block = self.builder.add_block()
3✔
422
            t_out = self.builder.add_access(block, tmp_name)
3✔
423

424
            t_task = self.builder.add_intrinsic(block, func_name)
3✔
425

426
            for i, arg in enumerate(args):
3✔
427
                t_arg, arg_sub = self._add_read(block, arg)
3✔
428
                self.builder.add_memlet(
3✔
429
                    block, t_arg, "void", t_task, f"_in{i+1}", arg_sub
430
                )
431

432
            self.builder.add_memlet(block, t_task, "_out", t_out, "void", "")
3✔
433
            return tmp_name
3✔
434

435
        if func_name in self.globals_dict:
3✔
436
            obj = self.globals_dict[func_name]
3✔
437
            if inspect.isfunction(obj):
3✔
438
                return self._handle_inline_call(node, obj)
3✔
439

440
        raise NotImplementedError(f"Function call {func_name} not supported")
×
441

442
    def _handle_inline_call(self, node, func_obj):
3✔
443
        # 1. Parse function source
444
        try:
3✔
445
            source_lines, start_line = inspect.getsourcelines(func_obj)
3✔
446
            source = textwrap.dedent("".join(source_lines))
3✔
447
            tree = ast.parse(source)
3✔
448
            func_def = tree.body[0]
3✔
449
        except Exception as e:
×
450
            raise NotImplementedError(
×
451
                f"Could not parse function {func_obj.__name__}: {e}"
452
            )
453

454
        # 2. Evaluate arguments
455
        arg_vars = [self.visit(arg) for arg in node.args]
3✔
456

457
        if len(arg_vars) != len(func_def.args.args):
3✔
458
            raise NotImplementedError(
×
459
                f"Argument count mismatch for {func_obj.__name__}"
460
            )
461

462
        # 3. Generate unique suffix
463
        suffix = f"_{func_obj.__name__}_{self._get_unique_id()}"
3✔
464
        res_name = f"_res{suffix}"
3✔
465

466
        # Assume Int64 for now as match returns 0/1
467
        dtype = Scalar(PrimitiveType.Int64)
3✔
468
        self.builder.add_container(res_name, dtype, False)
3✔
469
        self.symbol_table[res_name] = dtype
3✔
470

471
        # 4. Rename variables
472
        class VariableRenamer(ast.NodeTransformer):
3✔
473
            def __init__(self, suffix, globals_dict):
3✔
474
                self.suffix = suffix
3✔
475
                self.globals_dict = globals_dict
3✔
476

477
            def visit_Name(self, node):
3✔
478
                if node.id in self.globals_dict:
3✔
479
                    return node
3✔
480
                return ast.Name(id=f"{node.id}{self.suffix}", ctx=node.ctx)
3✔
481

482
            def visit_Return(self, node):
3✔
483
                if node.value:
3✔
484
                    val = self.visit(node.value)
3✔
485
                    return ast.Assign(
3✔
486
                        targets=[ast.Name(id=res_name, ctx=ast.Store())],
487
                        value=val,
488
                    )
489
                return node
×
490

491
        renamer = VariableRenamer(suffix, self.globals_dict)
3✔
492
        new_body = [renamer.visit(stmt) for stmt in func_def.body]
3✔
493

494
        # 5. Assign arguments to parameters
495
        param_assignments = []
3✔
496
        for arg_def, arg_val in zip(func_def.args.args, arg_vars):
3✔
497
            param_name = f"{arg_def.arg}{suffix}"
3✔
498

499
            # Infer type and create container
500
            if arg_val in self.symbol_table:
3✔
501
                self.symbol_table[param_name] = self.symbol_table[arg_val]
3✔
502
                self.builder.add_container(
3✔
503
                    param_name, self.symbol_table[arg_val], False
504
                )
505
                val_node = ast.Name(id=arg_val, ctx=ast.Load())
3✔
506
            elif self._is_int(arg_val):
×
507
                self.symbol_table[param_name] = Scalar(PrimitiveType.Int64)
×
508
                self.builder.add_container(
×
509
                    param_name, Scalar(PrimitiveType.Int64), False
510
                )
511
                val_node = ast.Constant(value=int(arg_val))
×
512
            else:
513
                # Assume float constant
514
                try:
×
515
                    val = float(arg_val)
×
516
                    self.symbol_table[param_name] = Scalar(PrimitiveType.Double)
×
517
                    self.builder.add_container(
×
518
                        param_name, Scalar(PrimitiveType.Double), False
519
                    )
520
                    val_node = ast.Constant(value=val)
×
521
                except ValueError:
×
522
                    # Fallback to Name, might fail later if not in symbol table
523
                    val_node = ast.Name(id=arg_val, ctx=ast.Load())
×
524

525
            assign = ast.Assign(
3✔
526
                targets=[ast.Name(id=param_name, ctx=ast.Store())], value=val_node
527
            )
528
            param_assignments.append(assign)
3✔
529

530
        final_body = param_assignments + new_body
3✔
531

532
        # 6. Visit new body using ASTParser
533
        from .ast_parser import ASTParser
3✔
534

535
        parser = ASTParser(
3✔
536
            self.builder,
537
            self.array_info,
538
            self.symbol_table,
539
            globals_dict=self.globals_dict,
540
            unique_counter_ref=self._unique_counter_ref,
541
        )
542

543
        for stmt in final_body:
3✔
544
            parser.visit(stmt)
3✔
545

546
        return res_name
3✔
547

548
    def visit_BinOp(self, node):
3✔
549
        if isinstance(node.op, ast.MatMult):
3✔
550
            return self._handle_numpy_matmul_op(node.left, node.right)
3✔
551

552
        left = self.visit(node.left)
3✔
553
        right = self.visit(node.right)
3✔
554
        op = self.visit(node.op)
3✔
555

556
        # Check if left or right are arrays
557
        left_is_array = left in self.array_info
3✔
558
        right_is_array = right in self.array_info
3✔
559

560
        if left_is_array or right_is_array:
3✔
561
            op_map = {"+": "add", "-": "sub", "*": "mul", "/": "div", "**": "pow"}
3✔
562
            if op in op_map:
3✔
563
                return self._handle_array_binary_op(op_map[op], left, right)
3✔
564
            else:
565
                raise NotImplementedError(f"Array operation {op} not supported")
×
566

567
        tmp_name = f"_tmp_{self._get_unique_id()}"
3✔
568

569
        dtype = Scalar(PrimitiveType.Double)  # Default
3✔
570

571
        left_is_int = self._is_int(left)
3✔
572
        right_is_int = self._is_int(right)
3✔
573

574
        if left_is_int and right_is_int and op not in ["/", "**"]:
3✔
575
            dtype = Scalar(PrimitiveType.Int64)
3✔
576

577
        self.builder.add_container(tmp_name, dtype, False)
3✔
578
        self.symbol_table[tmp_name] = dtype
3✔
579

580
        real_left = left
3✔
581
        real_right = right
3✔
582

583
        if dtype.primitive_type == PrimitiveType.Double:
3✔
584
            if left_is_int:
3✔
585
                left_cast = f"_tmp_{self._get_unique_id()}"
3✔
586
                self.builder.add_container(
3✔
587
                    left_cast, Scalar(PrimitiveType.Double), False
588
                )
589
                self.symbol_table[left_cast] = Scalar(PrimitiveType.Double)
3✔
590

591
                c_block = self.builder.add_block()
3✔
592
                t_src, src_sub = self._add_read(c_block, left)
3✔
593
                t_dst = self.builder.add_access(c_block, left_cast)
3✔
594
                t_task = self.builder.add_tasklet(c_block, "assign", ["_in"], ["_out"])
3✔
595
                self.builder.add_memlet(c_block, t_src, "void", t_task, "_in", src_sub)
3✔
596
                self.builder.add_memlet(c_block, t_task, "_out", t_dst, "void", "")
3✔
597

598
                real_left = left_cast
3✔
599

600
            if right_is_int:
3✔
601
                right_cast = f"_tmp_{self._get_unique_id()}"
3✔
602
                self.builder.add_container(
3✔
603
                    right_cast, Scalar(PrimitiveType.Double), False
604
                )
605
                self.symbol_table[right_cast] = Scalar(PrimitiveType.Double)
3✔
606

607
                c_block = self.builder.add_block()
3✔
608
                t_src, src_sub = self._add_read(c_block, right)
3✔
609
                t_dst = self.builder.add_access(c_block, right_cast)
3✔
610
                t_task = self.builder.add_tasklet(c_block, "assign", ["_in"], ["_out"])
3✔
611
                self.builder.add_memlet(c_block, t_src, "void", t_task, "_in", src_sub)
3✔
612
                self.builder.add_memlet(c_block, t_task, "_out", t_dst, "void", "")
3✔
613

614
                real_right = right_cast
3✔
615

616
        # Special cases
617
        if op == "**":
3✔
618
            block = self.builder.add_block()
3✔
619
            t_left, left_sub = self._add_read(block, real_left)
3✔
620
            t_right, right_sub = self._add_read(block, real_right)
3✔
621
            t_out = self.builder.add_access(block, tmp_name)
3✔
622

623
            t_task = self.builder.add_intrinsic(block, "pow")
3✔
624
            self.builder.add_memlet(block, t_left, "void", t_task, "_in1", left_sub)
3✔
625
            self.builder.add_memlet(block, t_right, "void", t_task, "_in2", right_sub)
3✔
626
            self.builder.add_memlet(block, t_task, "_out", t_out, "void", "")
3✔
627

628
            return tmp_name
3✔
629
        elif op == "%":
3✔
630
            block = self.builder.add_block()
3✔
631
            t_left, left_sub = self._add_read(block, real_left)
3✔
632
            t_right, right_sub = self._add_read(block, real_right)
3✔
633
            t_out = self.builder.add_access(block, tmp_name)
3✔
634

635
            if dtype.primitive_type == PrimitiveType.Int64:
3✔
636
                # Implement ((a % b) + b) % b to match Python's modulo behavior
637

638
                # 1. rem1 = a % b
639
                t_rem1 = self.builder.add_tasklet(
3✔
640
                    block, "int_rem", ["_in1", "_in2"], ["_out"]
641
                )
642
                self.builder.add_memlet(block, t_left, "void", t_rem1, "_in1", left_sub)
3✔
643
                self.builder.add_memlet(
3✔
644
                    block, t_right, "void", t_rem1, "_in2", right_sub
645
                )
646

647
                rem1_name = f"_tmp_{self._get_unique_id()}"
3✔
648
                self.builder.add_container(rem1_name, dtype, False)
3✔
649
                t_rem1_out = self.builder.add_access(block, rem1_name)
3✔
650
                self.builder.add_memlet(block, t_rem1, "_out", t_rem1_out, "void", "")
3✔
651

652
                # 2. add = rem1 + b
653
                t_add = self.builder.add_tasklet(
3✔
654
                    block, "int_add", ["_in1", "_in2"], ["_out"]
655
                )
656
                self.builder.add_memlet(block, t_rem1_out, "void", t_add, "_in1", "")
3✔
657
                self.builder.add_memlet(
3✔
658
                    block, t_right, "void", t_add, "_in2", right_sub
659
                )
660

661
                add_name = f"_tmp_{self._get_unique_id()}"
3✔
662
                self.builder.add_container(add_name, dtype, False)
3✔
663
                t_add_out = self.builder.add_access(block, add_name)
3✔
664
                self.builder.add_memlet(block, t_add, "_out", t_add_out, "void", "")
3✔
665

666
                # 3. res = add % b
667
                t_rem2 = self.builder.add_tasklet(
3✔
668
                    block, "int_rem", ["_in1", "_in2"], ["_out"]
669
                )
670
                self.builder.add_memlet(block, t_add_out, "void", t_rem2, "_in1", "")
3✔
671
                self.builder.add_memlet(
3✔
672
                    block, t_right, "void", t_rem2, "_in2", right_sub
673
                )
674
                self.builder.add_memlet(block, t_rem2, "_out", t_out, "void", "")
3✔
675

676
                return tmp_name
3✔
677
            else:
678
                t_task = self.builder.add_intrinsic(block, "fmod")
3✔
679
                self.builder.add_memlet(block, t_left, "void", t_task, "_in1", left_sub)
3✔
680
                self.builder.add_memlet(
3✔
681
                    block, t_right, "void", t_task, "_in2", right_sub
682
                )
683
                self.builder.add_memlet(block, t_task, "_out", t_out, "void", "")
3✔
684
                return tmp_name
3✔
685

686
        prefix = "int" if dtype.primitive_type == PrimitiveType.Int64 else "fp"
3✔
687
        op_name = ""
3✔
688
        if op == "+":
3✔
689
            op_name = "add"
3✔
690
        elif op == "-":
3✔
691
            op_name = "sub"
3✔
692
        elif op == "*":
3✔
693
            op_name = "mul"
3✔
694
        elif op == "/":
3✔
695
            op_name = "div"
3✔
696
        elif op == "//":
3✔
697
            op_name = "div"
3✔
698
        elif op == "|":
3✔
699
            op_name = "or"
3✔
700
        elif op == "^":
3✔
701
            op_name = "xor"
3✔
702

703
        block = self.builder.add_block()
3✔
704
        t_left, left_sub = self._add_read(block, real_left)
3✔
705
        t_right, right_sub = self._add_read(block, real_right)
3✔
706
        t_out = self.builder.add_access(block, tmp_name)
3✔
707

708
        tasklet_code = f"{prefix}_{op_name}"
3✔
709
        t_task = self.builder.add_tasklet(
3✔
710
            block, tasklet_code, ["_in1", "_in2"], ["_out"]
711
        )
712

713
        self.builder.add_memlet(block, t_left, "void", t_task, "_in1", left_sub)
3✔
714
        self.builder.add_memlet(block, t_right, "void", t_task, "_in2", right_sub)
3✔
715
        self.builder.add_memlet(block, t_task, "_out", t_out, "void", "")
3✔
716

717
        return tmp_name
3✔
718

719
    def _add_assign_constant(self, target_name, value_str, dtype):
3✔
720
        block = self.builder.add_block()
3✔
721
        t_const = self.builder.add_constant(block, value_str, dtype)
3✔
722
        t_dst = self.builder.add_access(block, target_name)
3✔
723
        t_task = self.builder.add_tasklet(block, "assign", ["_in"], ["_out"])
3✔
724
        self.builder.add_memlet(block, t_const, "void", t_task, "_in", "")
3✔
725
        self.builder.add_memlet(block, t_task, "_out", t_dst, "void", "")
3✔
726

727
    def visit_BoolOp(self, node):
3✔
728
        op = self.visit(node.op)
3✔
729
        values = [f"({self.visit(v)} != 0)" for v in node.values]
3✔
730
        expr_str = f"{f' {op} '.join(values)}"
3✔
731

732
        tmp_name = f"_tmp_{self._get_unique_id()}"
3✔
733
        dtype = Scalar(PrimitiveType.Bool)
3✔
734
        self.builder.add_container(tmp_name, dtype, False)
3✔
735

736
        # Use control flow to assign boolean value
737
        self.builder.begin_if(expr_str)
3✔
738
        self._add_assign_constant(tmp_name, "true", dtype)
3✔
739
        self.builder.begin_else()
3✔
740
        self._add_assign_constant(tmp_name, "false", dtype)
3✔
741
        self.builder.end_if()
3✔
742

743
        self.symbol_table[tmp_name] = dtype
3✔
744
        return tmp_name
3✔
745

746
    def visit_Compare(self, node):
3✔
747
        left = self.visit(node.left)
3✔
748
        if len(node.ops) > 1:
3✔
749
            raise NotImplementedError("Chained comparisons not supported yet")
×
750

751
        op = self.visit(node.ops[0])
3✔
752
        right = self.visit(node.comparators[0])
3✔
753
        expr_str = f"{left} {op} {right}"
3✔
754

755
        tmp_name = f"_tmp_{self._get_unique_id()}"
3✔
756
        dtype = Scalar(PrimitiveType.Bool)
3✔
757
        self.builder.add_container(tmp_name, dtype, False)
3✔
758

759
        # Use control flow to assign boolean value
760
        self.builder.begin_if(expr_str)
3✔
761
        self.builder.add_assignment(tmp_name, "true")
3✔
762
        self.builder.begin_else()
3✔
763
        self.builder.add_assignment(tmp_name, "false")
3✔
764
        self.builder.end_if()
3✔
765

766
        self.symbol_table[tmp_name] = dtype
3✔
767
        return tmp_name
3✔
768

769
    def visit_UnaryOp(self, node):
3✔
770
        op = self.visit(node.op)
3✔
771
        operand = self.visit(node.operand)
3✔
772

773
        tmp_name = f"_tmp_{self._get_unique_id()}"
3✔
774
        dtype = Scalar(PrimitiveType.Double)
3✔
775
        if operand in self.symbol_table:
3✔
776
            dtype = self.symbol_table[operand]
3✔
777
        elif self._is_int(operand):
3✔
778
            dtype = Scalar(PrimitiveType.Int64)
3✔
779
        elif isinstance(node.op, ast.Not):
3✔
780
            dtype = Scalar(PrimitiveType.Bool)
×
781

782
        self.builder.add_container(tmp_name, dtype, False)
3✔
783
        self.symbol_table[tmp_name] = dtype
3✔
784

785
        block = self.builder.add_block()
3✔
786
        t_src, src_sub = self._add_read(block, operand)
3✔
787
        t_dst = self.builder.add_access(block, tmp_name)
3✔
788

789
        if isinstance(node.op, ast.Not):
3✔
790
            t_const = self.builder.add_constant(
3✔
791
                block, "true", Scalar(PrimitiveType.Bool)
792
            )
793
            t_task = self.builder.add_tasklet(
3✔
794
                block, "int_xor", ["_in1", "_in2"], ["_out"]
795
            )
796
            self.builder.add_memlet(block, t_src, "void", t_task, "_in1", src_sub)
3✔
797
            self.builder.add_memlet(block, t_const, "void", t_task, "_in2", "")
3✔
798
            self.builder.add_memlet(block, t_task, "_out", t_dst, "void", "")
3✔
799

800
        elif op == "-":
3✔
801
            if dtype.primitive_type == PrimitiveType.Int64:
3✔
802
                t_const = self.builder.add_constant(block, "0", dtype)
3✔
803
                t_task = self.builder.add_tasklet(
3✔
804
                    block, "int_sub", ["_in1", "_in2"], ["_out"]
805
                )
806
                self.builder.add_memlet(block, t_const, "void", t_task, "_in1", "")
3✔
807
                self.builder.add_memlet(block, t_src, "void", t_task, "_in2", src_sub)
3✔
808
                self.builder.add_memlet(block, t_task, "_out", t_dst, "void", "")
3✔
809
            else:
810
                t_task = self.builder.add_tasklet(block, "fp_neg", ["_in"], ["_out"])
3✔
811
                self.builder.add_memlet(block, t_src, "void", t_task, "_in", src_sub)
3✔
812
                self.builder.add_memlet(block, t_task, "_out", t_dst, "void", "")
3✔
813
        else:
814
            t_task = self.builder.add_tasklet(block, "assign", ["_in"], ["_out"])
×
815
            self.builder.add_memlet(block, t_src, "void", t_task, "_in", src_sub)
×
816
            self.builder.add_memlet(block, t_task, "_out", t_dst, "void", "")
×
817

818
        return tmp_name
3✔
819

820
    def _parse_array_arg(self, node, simple_visitor):
3✔
821
        if isinstance(node, ast.Name):
×
822
            if node.id in self.array_info:
×
823
                return node.id, [], self.array_info[node.id]["shapes"]
×
824
        elif isinstance(node, ast.Subscript):
×
825
            if isinstance(node.value, ast.Name) and node.value.id in self.array_info:
×
826
                name = node.value.id
×
827
                ndim = self.array_info[name]["ndim"]
×
828

829
                indices = []
×
830
                if isinstance(node.slice, ast.Tuple):
×
831
                    indices = list(node.slice.elts)
×
832
                else:
833
                    indices = [node.slice]
×
834

835
                while len(indices) < ndim:
×
836
                    indices.append(ast.Slice(lower=None, upper=None, step=None))
×
837

838
                start_indices = []
×
839
                slice_shape = []
×
840

841
                for i, idx in enumerate(indices):
×
842
                    if isinstance(idx, ast.Slice):
×
843
                        start = "0"
×
844
                        if idx.lower:
×
845
                            start = simple_visitor.visit(idx.lower)
×
846
                        start_indices.append(start)
×
847

848
                        shapes = self.array_info[name]["shapes"]
×
849
                        dim_size = (
×
850
                            shapes[i] if i < len(shapes) else f"_{name}_shape_{i}"
851
                        )
852
                        stop = dim_size
×
853
                        if idx.upper:
×
854
                            stop = simple_visitor.visit(idx.upper)
×
855

856
                        size = f"({stop} - {start})"
×
857
                        slice_shape.append(size)
×
858
                    else:
859
                        val = simple_visitor.visit(idx)
×
860
                        start_indices.append(val)
×
861

862
                shapes = self.array_info[name]["shapes"]
×
863
                linear_index = ""
×
864
                for i in range(ndim):
×
865
                    term = start_indices[i]
×
866
                    for j in range(i + 1, ndim):
×
867
                        shape_val = shapes[j] if j < len(shapes) else None
×
868
                        shape_sym = (
×
869
                            shape_val if shape_val is not None else f"_{name}_shape_{j}"
870
                        )
871
                        term = f"({term} * {shape_sym})"
×
872

873
                    if i == 0:
×
874
                        linear_index = term
×
875
                    else:
876
                        linear_index = f"({linear_index} + {term})"
×
877

878
                return name, [linear_index], slice_shape
×
879

880
        return None, None, None
×
881

882
    def visit_Attribute(self, node):
3✔
883
        if node.attr == "shape":
3✔
884
            if isinstance(node.value, ast.Name) and node.value.id in self.array_info:
3✔
885
                return f"_shape_proxy_{node.value.id}"
3✔
886

887
        if isinstance(node.value, ast.Name) and node.value.id == "math":
3✔
888
            val = ""
3✔
889
            if node.attr == "pi":
3✔
890
                val = "M_PI"
3✔
891
            elif node.attr == "e":
3✔
892
                val = "M_E"
3✔
893

894
            if val:
3✔
895
                tmp_name = f"_tmp_{self._get_unique_id()}"
3✔
896
                dtype = Scalar(PrimitiveType.Double)
3✔
897
                self.builder.add_container(tmp_name, dtype, False)
3✔
898
                self.symbol_table[tmp_name] = dtype
3✔
899
                self._add_assign_constant(tmp_name, val, dtype)
3✔
900
                return tmp_name
3✔
901

902
        # Handle class member access (e.g., obj.x, obj.y)
903
        if isinstance(node.value, ast.Name):
3✔
904
            obj_name = node.value.id
3✔
905
            attr_name = node.attr
3✔
906

907
            # Check if the object is a class instance (has a Structure type)
908
            if obj_name in self.symbol_table:
3✔
909
                obj_type = self.symbol_table[obj_name]
3✔
910
                if isinstance(obj_type, Pointer) and obj_type.has_pointee_type():
3✔
911
                    pointee_type = obj_type.pointee_type
3✔
912
                    if isinstance(pointee_type, Structure):
3✔
913
                        struct_name = pointee_type.name
3✔
914

915
                        # Look up member index and type from structure info
916
                        if (
3✔
917
                            struct_name in self.structure_member_info
918
                            and attr_name in self.structure_member_info[struct_name]
919
                        ):
920
                            member_index, member_type = self.structure_member_info[
3✔
921
                                struct_name
922
                            ][attr_name]
923
                        else:
924
                            # This should not happen if structure was registered properly
925
                            raise RuntimeError(
×
926
                                f"Member '{attr_name}' not found in structure '{struct_name}'. "
927
                                f"Available members: {list(self.structure_member_info.get(struct_name, {}).keys())}"
928
                            )
929

930
                        # Generate a tasklet to access the member
931
                        tmp_name = f"_tmp_{self._get_unique_id()}"
3✔
932

933
                        self.builder.add_container(tmp_name, member_type, False)
3✔
934
                        self.symbol_table[tmp_name] = member_type
3✔
935

936
                        # Create a tasklet that reads the member
937
                        block = self.builder.add_block()
3✔
938
                        obj_access = self.builder.add_access(block, obj_name)
3✔
939
                        tmp_access = self.builder.add_access(block, tmp_name)
3✔
940

941
                        # Use tasklet to pass through the value
942
                        # The actual member selection is done via the memlet subset
943
                        tasklet = self.builder.add_tasklet(
3✔
944
                            block, "assign", ["_in"], ["_out"]
945
                        )
946

947
                        # Use member index in the subset to select the correct member
948
                        subset = "0," + str(member_index)
3✔
949
                        self.builder.add_memlet(
3✔
950
                            block, obj_access, "", tasklet, "_in", subset
951
                        )
952
                        self.builder.add_memlet(block, tasklet, "_out", tmp_access, "")
3✔
953

954
                        return tmp_name
3✔
955

956
        raise NotImplementedError(f"Attribute access {node.attr} not supported")
×
957

958
    def visit_Subscript(self, node):
3✔
959
        value_str = self.visit(node.value)
3✔
960

961
        if value_str.startswith("_shape_proxy_"):
3✔
962
            array_name = value_str[len("_shape_proxy_") :]
3✔
963
            if isinstance(node.slice, ast.Constant):
3✔
964
                idx = node.slice.value
3✔
965
            elif isinstance(node.slice, ast.Index):
×
966
                idx = node.slice.value.value
×
967
            else:
968
                try:
×
969
                    idx = int(self.visit(node.slice))
×
970
                except:
×
971
                    raise NotImplementedError(
×
972
                        "Dynamic shape indexing not fully supported yet"
973
                    )
974

975
            if (
3✔
976
                array_name in self.array_info
977
                and "shapes" in self.array_info[array_name]
978
            ):
979
                return self.array_info[array_name]["shapes"][idx]
3✔
980

981
            return f"_{array_name}_shape_{idx}"
×
982

983
        if value_str in self.array_info:
3✔
984
            ndim = self.array_info[value_str]["ndim"]
3✔
985
            shapes = self.array_info[value_str].get("shapes", [])
3✔
986

987
            indices = []
3✔
988
            if isinstance(node.slice, ast.Tuple):
3✔
989
                indices_nodes = node.slice.elts
3✔
990
            else:
991
                indices_nodes = [node.slice]
3✔
992

993
            for idx in indices_nodes:
3✔
994
                if isinstance(idx, ast.Slice):
3✔
995
                    raise ValueError("Slices not supported in expression indexing")
×
996

997
            if isinstance(node.slice, ast.Tuple):
3✔
998
                indices = [self.visit(elt) for elt in node.slice.elts]
3✔
999
            else:
1000
                indices = [self.visit(node.slice)]
3✔
1001

1002
            if len(indices) != ndim:
3✔
1003
                raise ValueError(
×
1004
                    f"Array {value_str} has {ndim} dimensions, but accessed with {len(indices)} indices"
1005
                )
1006

1007
            linear_index = ""
3✔
1008
            for i in range(ndim):
3✔
1009
                term = indices[i]
3✔
1010
                for j in range(i + 1, ndim):
3✔
1011
                    shape_val = shapes[j] if j < len(shapes) else None
3✔
1012
                    shape_sym = (
3✔
1013
                        shape_val
1014
                        if shape_val is not None
1015
                        else f"_{value_str}_shape_{j}"
1016
                    )
1017
                    term = f"(({term}) * {shape_sym})"
3✔
1018

1019
                if i == 0:
3✔
1020
                    linear_index = term
3✔
1021
                else:
1022
                    linear_index = f"({linear_index} + {term})"
3✔
1023

1024
            access_str = f"{value_str}({linear_index})"
3✔
1025

1026
            if self.builder and isinstance(node.ctx, ast.Load):
3✔
1027
                dtype = Scalar(PrimitiveType.Double)
3✔
1028
                if value_str in self.symbol_table:
3✔
1029
                    t = self.symbol_table[value_str]
3✔
1030
                    if type(t).__name__ == "Array" and hasattr(t, "element_type"):
3✔
1031
                        et = t.element_type
×
1032
                        if callable(et):
×
1033
                            et = et()
×
1034
                        dtype = et
×
1035
                    elif type(t).__name__ == "Pointer" and hasattr(t, "pointee_type"):
3✔
1036
                        et = t.pointee_type
3✔
1037
                        if callable(et):
3✔
1038
                            et = et()
×
1039
                        dtype = et
3✔
1040

1041
                tmp_name = f"_tmp_{self._get_unique_id()}"
3✔
1042
                self.builder.add_container(tmp_name, dtype, False)
3✔
1043

1044
                block = self.builder.add_block()
3✔
1045
                t_src = self.builder.add_access(block, value_str)
3✔
1046
                t_dst = self.builder.add_access(block, tmp_name)
3✔
1047
                t_task = self.builder.add_tasklet(block, "assign", ["_in"], ["_out"])
3✔
1048

1049
                self.builder.add_memlet(
3✔
1050
                    block, t_src, "void", t_task, "_in", linear_index
1051
                )
1052
                self.builder.add_memlet(block, t_task, "_out", t_dst, "void", "")
3✔
1053

1054
                self.symbol_table[tmp_name] = dtype
3✔
1055
                return tmp_name
3✔
1056

1057
            return access_str
3✔
1058

1059
        slice_val = self.visit(node.slice)
×
1060
        access_str = f"{value_str}({slice_val})"
×
1061

1062
        if (
×
1063
            self.builder
1064
            and isinstance(node.ctx, ast.Load)
1065
            and value_str in self.array_info
1066
        ):
1067
            tmp_name = f"_tmp_{self._get_unique_id()}"
×
1068
            self.builder.add_container(tmp_name, Scalar(PrimitiveType.Double), False)
×
1069
            self.builder.add_assignment(tmp_name, access_str)
×
1070
            self.symbol_table[tmp_name] = Scalar(PrimitiveType.Double)
×
1071
            return tmp_name
×
1072

1073
        return access_str
×
1074

1075
    def visit_Add(self, node):
3✔
1076
        return "+"
3✔
1077

1078
    def visit_Sub(self, node):
3✔
1079
        return "-"
3✔
1080

1081
    def visit_Mult(self, node):
3✔
1082
        return "*"
3✔
1083

1084
    def visit_Div(self, node):
3✔
1085
        return "/"
3✔
1086

1087
    def visit_FloorDiv(self, node):
3✔
1088
        return "//"
3✔
1089

1090
    def visit_Mod(self, node):
3✔
1091
        return "%"
3✔
1092

1093
    def visit_Pow(self, node):
3✔
1094
        return "**"
3✔
1095

1096
    def visit_Eq(self, node):
3✔
1097
        return "=="
×
1098

1099
    def visit_NotEq(self, node):
3✔
1100
        return "!="
×
1101

1102
    def visit_Lt(self, node):
3✔
1103
        return "<"
×
1104

1105
    def visit_LtE(self, node):
3✔
1106
        return "<="
×
1107

1108
    def visit_Gt(self, node):
3✔
1109
        return ">"
3✔
1110

1111
    def visit_GtE(self, node):
3✔
1112
        return ">="
×
1113

1114
    def visit_And(self, node):
3✔
1115
        return "&"
3✔
1116

1117
    def visit_Or(self, node):
3✔
1118
        return "|"
3✔
1119

1120
    def visit_BitAnd(self, node):
3✔
1121
        return "&"
×
1122

1123
    def visit_BitOr(self, node):
3✔
1124
        return "|"
3✔
1125

1126
    def visit_BitXor(self, node):
3✔
1127
        return "^"
3✔
1128

1129
    def visit_Not(self, node):
3✔
1130
        return "!"
3✔
1131

1132
    def visit_USub(self, node):
3✔
1133
        return "-"
3✔
1134

1135
    def visit_UAdd(self, node):
3✔
1136
        return "+"
×
1137

1138
    def visit_Invert(self, node):
3✔
1139
        return "~"
×
1140

1141
    def _get_dtype(self, name):
3✔
1142
        if name in self.symbol_table:
3✔
1143
            t = self.symbol_table[name]
3✔
1144
            if isinstance(t, Scalar):
3✔
1145
                return t
×
1146

1147
            if hasattr(t, "pointee_type"):
3✔
1148
                et = t.pointee_type
3✔
1149
                if callable(et):
3✔
1150
                    et = et()
×
1151
                if isinstance(et, Scalar):
3✔
1152
                    return et
3✔
1153

1154
            if hasattr(t, "element_type"):
×
1155
                et = t.element_type
×
1156
                if callable(et):
×
1157
                    et = et()
×
1158
                if isinstance(et, Scalar):
×
1159
                    return et
×
1160

1161
        if self._is_int(name):
3✔
1162
            return Scalar(PrimitiveType.Int64)
×
1163

1164
        return Scalar(PrimitiveType.Double)
3✔
1165

1166
    def _create_array_temp(self, shape, dtype, zero_init=False, ones_init=False):
3✔
1167
        tmp_name = f"_tmp_{self._get_unique_id()}"
3✔
1168

1169
        # Calculate size
1170
        size_str = "1"
3✔
1171
        for dim in shape:
3✔
1172
            size_str = f"({size_str} * {dim})"
3✔
1173

1174
        element_size = self.builder.get_sizeof(dtype)
3✔
1175
        total_size = f"({size_str} * {element_size})"
3✔
1176

1177
        # Create pointer
1178
        ptr_type = Pointer(dtype)
3✔
1179
        self.builder.add_container(tmp_name, ptr_type, False)
3✔
1180
        self.symbol_table[tmp_name] = ptr_type
3✔
1181
        self.array_info[tmp_name] = {"ndim": len(shape), "shapes": shape}
3✔
1182

1183
        # Malloc
1184
        block1 = self.builder.add_block()
3✔
1185
        t_malloc = self.builder.add_malloc(block1, total_size)
3✔
1186
        t_ptr1 = self.builder.add_access(block1, tmp_name)
3✔
1187
        self.builder.add_memlet(block1, t_malloc, "_ret", t_ptr1, "void", "", ptr_type)
3✔
1188

1189
        if zero_init:
3✔
1190
            block2 = self.builder.add_block()
3✔
1191
            t_memset = self.builder.add_memset(block2, "0", total_size)
3✔
1192
            t_ptr2 = self.builder.add_access(block2, tmp_name)
3✔
1193
            self.builder.add_memlet(
3✔
1194
                block2, t_memset, "_ptr", t_ptr2, "void", "", ptr_type
1195
            )
1196
        elif ones_init:
3✔
1197
            # Initialize array with ones using a loop
1198
            loop_var = f"_i_{self._get_unique_id()}"
3✔
1199
            if not self.builder.has_container(loop_var):
3✔
1200
                self.builder.add_container(loop_var, Scalar(PrimitiveType.Int64), False)
3✔
1201
                self.symbol_table[loop_var] = Scalar(PrimitiveType.Int64)
3✔
1202

1203
            self.builder.begin_for(loop_var, "0", size_str, "1")
3✔
1204

1205
            # Determine the value to set based on dtype
1206
            val = "1.0"
3✔
1207
            if dtype.primitive_type in [
3✔
1208
                PrimitiveType.Int64,
1209
                PrimitiveType.Int32,
1210
                PrimitiveType.Int8,
1211
                PrimitiveType.Int16,
1212
                PrimitiveType.UInt64,
1213
                PrimitiveType.UInt32,
1214
                PrimitiveType.UInt8,
1215
                PrimitiveType.UInt16,
1216
            ]:
1217
                val = "1"
3✔
1218

1219
            block_assign = self.builder.add_block()
3✔
1220
            t_const = self.builder.add_constant(block_assign, val, dtype)
3✔
1221
            t_arr = self.builder.add_access(block_assign, tmp_name)
3✔
1222

1223
            t_task = self.builder.add_tasklet(block_assign, "assign", ["_in"], ["_out"])
3✔
1224
            self.builder.add_memlet(
3✔
1225
                block_assign, t_const, "void", t_task, "_in", "", dtype
1226
            )
1227
            self.builder.add_memlet(
3✔
1228
                block_assign, t_task, "_out", t_arr, "void", loop_var
1229
            )
1230

1231
            self.builder.end_for()
3✔
1232

1233
        return tmp_name
3✔
1234

1235
    def _handle_array_unary_op(self, op_type, operand):
3✔
1236
        # Determine output shape
1237
        shape = []
3✔
1238
        if operand in self.array_info:
3✔
1239
            shape = self.array_info[operand]["shapes"]
3✔
1240

1241
        # Determine dtype
1242
        dtype = self._get_dtype(operand)
3✔
1243

1244
        tmp_name = self._create_array_temp(shape, dtype)
3✔
1245

1246
        # Add operation
1247
        self.builder.add_elementwise_unary_op(op_type, operand, tmp_name, shape)
3✔
1248

1249
        return tmp_name
3✔
1250

1251
    def _handle_array_binary_op(self, op_type, left, right):
3✔
1252
        # Determine output shape
1253
        shape = []
3✔
1254
        if left in self.array_info:
3✔
1255
            shape = self.array_info[left]["shapes"]
3✔
1256
        elif right in self.array_info:
×
1257
            shape = self.array_info[right]["shapes"]
×
1258

1259
        # Determine dtype
1260
        dtype_left = self._get_dtype(left)
3✔
1261
        dtype_right = self._get_dtype(right)
3✔
1262

1263
        assert dtype_left.primitive_type == dtype_right.primitive_type
3✔
1264
        dtype = dtype_left
3✔
1265

1266
        tmp_name = self._create_array_temp(shape, dtype)
3✔
1267

1268
        # Add operation
1269
        self.builder.add_elementwise_op(op_type, left, right, tmp_name, shape)
3✔
1270

1271
        return tmp_name
3✔
1272

1273
    def _handle_numpy_alloc(self, node, func_name):
3✔
1274
        # Parse shape
1275
        shape_arg = node.args[0]
3✔
1276
        dims = []
3✔
1277
        if isinstance(shape_arg, ast.Tuple):
3✔
1278
            dims = [self.visit(elt) for elt in shape_arg.elts]
3✔
1279
        elif isinstance(shape_arg, ast.List):
3✔
1280
            dims = [self.visit(elt) for elt in shape_arg.elts]
×
1281
        else:
1282
            val = self.visit(shape_arg)
3✔
1283
            if val.startswith("_shape_proxy_"):
3✔
1284
                array_name = val[len("_shape_proxy_") :]
×
1285
                if array_name in self.array_info:
×
1286
                    dims = self.array_info[array_name]["shapes"]
×
1287
                else:
1288
                    dims = [val]
×
1289
            else:
1290
                dims = [val]
3✔
1291

1292
        # Parse dtype
1293
        dtype_arg = None
3✔
1294
        if len(node.args) > 1:
3✔
1295
            dtype_arg = node.args[1]
×
1296

1297
        for kw in node.keywords:
3✔
1298
            if kw.arg == "dtype":
3✔
1299
                dtype_arg = kw.value
3✔
1300
                break
3✔
1301

1302
        element_type = self._map_numpy_dtype(dtype_arg)
3✔
1303

1304
        return self._create_array_temp(
3✔
1305
            dims,
1306
            element_type,
1307
            zero_init=(func_name == "zeros"),
1308
            ones_init=(func_name == "ones"),
1309
        )
1310

1311
    def _handle_numpy_empty_like(self, node, func_name):
3✔
1312
        prototype_arg = node.args[0]
3✔
1313
        prototype_name = self.visit(prototype_arg)
3✔
1314

1315
        # Parse shape from prototype
1316
        dims = []
3✔
1317
        if prototype_name in self.array_info:
3✔
1318
            dims = self.array_info[prototype_name]["shapes"]
3✔
1319

1320
        # Parse dtype
1321
        dtype_arg = None
3✔
1322
        if len(node.args) > 1:
3✔
NEW
1323
            dtype_arg = node.args[1]
×
1324

1325
        for kw in node.keywords:
3✔
1326
            if kw.arg == "dtype":
3✔
1327
                dtype_arg = kw.value
3✔
1328
                break
3✔
1329

1330
        element_type = None
3✔
1331
        if dtype_arg:
3✔
1332
            element_type = self._map_numpy_dtype(dtype_arg)
3✔
1333
        else:
1334
            if prototype_name in self.symbol_table:
3✔
1335
                sym_type = self.symbol_table[prototype_name]
3✔
1336
                if isinstance(sym_type, Pointer) and sym_type.has_pointee_type():
3✔
1337
                    element_type = sym_type.pointee_type
3✔
1338

1339
        if element_type is None:
3✔
NEW
1340
            element_type = Scalar(PrimitiveType.Double)
×
1341

1342
        return self._create_array_temp(
3✔
1343
            dims,
1344
            element_type,
1345
            zero_init=False,
1346
            ones_init=False,
1347
        )
1348

1349
    def _handle_numpy_zeros_like(self, node, func_name):
3✔
1350
        prototype_arg = node.args[0]
3✔
1351
        prototype_name = self.visit(prototype_arg)
3✔
1352

1353
        # Parse shape from prototype
1354
        dims = []
3✔
1355
        if prototype_name in self.array_info:
3✔
1356
            dims = self.array_info[prototype_name]["shapes"]
3✔
1357

1358
        # Parse dtype
1359
        dtype_arg = None
3✔
1360
        if len(node.args) > 1:
3✔
NEW
1361
            dtype_arg = node.args[1]
×
1362

1363
        for kw in node.keywords:
3✔
1364
            if kw.arg == "dtype":
3✔
1365
                dtype_arg = kw.value
3✔
1366
                break
3✔
1367

1368
        element_type = None
3✔
1369
        if dtype_arg:
3✔
1370
            element_type = self._map_numpy_dtype(dtype_arg)
3✔
1371
        else:
1372
            if prototype_name in self.symbol_table:
3✔
1373
                sym_type = self.symbol_table[prototype_name]
3✔
1374
                if isinstance(sym_type, Pointer) and sym_type.has_pointee_type():
3✔
1375
                    element_type = sym_type.pointee_type
3✔
1376

1377
        if element_type is None:
3✔
NEW
1378
            element_type = Scalar(PrimitiveType.Double)
×
1379

1380
        return self._create_array_temp(
3✔
1381
            dims,
1382
            element_type,
1383
            zero_init=True,
1384
            ones_init=False,
1385
        )
1386

1387
    def _handle_numpy_eye(self, node, func_name):
3✔
1388
        # Parse N
1389
        N_arg = node.args[0]
3✔
1390
        N_str = self.visit(N_arg)
3✔
1391

1392
        # Parse M
1393
        M_str = N_str
3✔
1394
        if len(node.args) > 1:
3✔
1395
            M_str = self.visit(node.args[1])
×
1396

1397
        # Parse k
1398
        k_str = "0"
3✔
1399
        if len(node.args) > 2:
3✔
1400
            k_str = self.visit(node.args[2])
×
1401

1402
        # Check keywords for M, k, dtype
1403
        dtype_arg = None
3✔
1404
        for kw in node.keywords:
3✔
1405
            if kw.arg == "M":
3✔
1406
                M_str = self.visit(kw.value)
3✔
1407
                if M_str == "None":
3✔
1408
                    M_str = N_str
3✔
1409
            elif kw.arg == "k":
3✔
1410
                k_str = self.visit(kw.value)
3✔
1411
            elif kw.arg == "dtype":
3✔
1412
                dtype_arg = kw.value
3✔
1413

1414
        element_type = self._map_numpy_dtype(dtype_arg)
3✔
1415

1416
        ptr_name = self._create_array_temp([N_str, M_str], element_type, zero_init=True)
3✔
1417

1418
        # Loop to set diagonal
1419
        loop_var = f"_i_{self._get_unique_id()}"
3✔
1420
        if not self.builder.has_container(loop_var):
3✔
1421
            self.builder.add_container(loop_var, Scalar(PrimitiveType.Int64), False)
3✔
1422
            self.symbol_table[loop_var] = Scalar(PrimitiveType.Int64)
3✔
1423

1424
        self.builder.begin_for(loop_var, "0", N_str, "1")
3✔
1425

1426
        # Condition: 0 <= i + k < M
1427
        cond = f"(({loop_var} + {k_str}) >= 0) & (({loop_var} + {k_str}) < {M_str})"
3✔
1428
        self.builder.begin_if(cond)
3✔
1429

1430
        # Assignment: A[i, i+k] = 1
1431
        val = "1.0"
3✔
1432
        if element_type.primitive_type in [
3✔
1433
            PrimitiveType.Int64,
1434
            PrimitiveType.Int32,
1435
            PrimitiveType.Int8,
1436
            PrimitiveType.Int16,
1437
            PrimitiveType.UInt64,
1438
            PrimitiveType.UInt32,
1439
            PrimitiveType.UInt8,
1440
            PrimitiveType.UInt16,
1441
        ]:
1442
            val = "1"
×
1443

1444
        block_assign = self.builder.add_block()
3✔
1445
        t_const = self.builder.add_constant(block_assign, val, element_type)
3✔
1446
        t_arr = self.builder.add_access(block_assign, ptr_name)
3✔
1447
        flat_index = f"(({loop_var}) * ({M_str}) + ({loop_var}) + ({k_str}))"
3✔
1448
        subset = flat_index
3✔
1449

1450
        t_task = self.builder.add_tasklet(block_assign, "assign", ["_in"], ["_out"])
3✔
1451
        self.builder.add_memlet(
3✔
1452
            block_assign, t_const, "void", t_task, "_in", "", element_type
1453
        )
1454
        self.builder.add_memlet(block_assign, t_task, "_out", t_arr, "void", subset)
3✔
1455

1456
        self.builder.end_if()
3✔
1457
        self.builder.end_for()
3✔
1458

1459
        return ptr_name
3✔
1460

1461
    def _handle_numpy_binary_op(self, node, func_name):
3✔
1462
        args = [self.visit(arg) for arg in node.args]
3✔
1463
        if len(args) != 2:
3✔
1464
            raise NotImplementedError(
×
1465
                f"Numpy function {func_name} requires 2 arguments"
1466
            )
1467

1468
        op_map = {
3✔
1469
            "add": "add",
1470
            "subtract": "sub",
1471
            "multiply": "mul",
1472
            "divide": "div",
1473
            "power": "pow",
1474
            "minimum": "min",
1475
            "maximum": "max",
1476
        }
1477
        return self._handle_array_binary_op(op_map[func_name], args[0], args[1])
3✔
1478

1479
    def _handle_numpy_matmul_op(self, left_node, right_node):
3✔
1480
        return self._handle_matmul_helper(left_node, right_node)
3✔
1481

1482
    def _handle_numpy_matmul(self, node, func_name):
3✔
1483
        if len(node.args) != 2:
3✔
1484
            raise NotImplementedError("matmul/dot requires 2 arguments")
×
1485
        return self._handle_matmul_helper(node.args[0], node.args[1])
3✔
1486

1487
    def _handle_matmul_helper(self, left_node, right_node):
3✔
1488
        if not self.la_handler:
3✔
1489
            raise RuntimeError("LinearAlgebraHandler not initialized")
×
1490

1491
        res_a = self.la_handler.parse_arg(left_node)
3✔
1492
        res_b = self.la_handler.parse_arg(right_node)
3✔
1493

1494
        if not res_a[0]:
3✔
1495
            left_name = self.visit(left_node)
×
1496
            left_node = ast.Name(id=left_name)
×
1497
            res_a = self.la_handler.parse_arg(left_node)
×
1498

1499
        if not res_b[0]:
3✔
1500
            right_name = self.visit(right_node)
×
1501
            right_node = ast.Name(id=right_name)
×
1502
            res_b = self.la_handler.parse_arg(right_node)
×
1503

1504
        name_a, subset_a, shape_a, indices_a = res_a
3✔
1505
        name_b, subset_b, shape_b, indices_b = res_b
3✔
1506

1507
        if not name_a or not name_b:
3✔
1508
            raise NotImplementedError("Could not resolve matmul operands")
×
1509

1510
        real_shape_a = shape_a
3✔
1511
        real_shape_b = shape_b
3✔
1512

1513
        ndim_a = len(real_shape_a)
3✔
1514
        ndim_b = len(real_shape_b)
3✔
1515

1516
        output_shape = []
3✔
1517
        is_scalar = False
3✔
1518

1519
        if ndim_a == 1 and ndim_b == 1:
3✔
1520
            is_scalar = True
3✔
1521
            output_shape = []
3✔
1522
        elif ndim_a == 2 and ndim_b == 2:
3✔
1523
            output_shape = [real_shape_a[0], real_shape_b[1]]
3✔
1524
        elif ndim_a == 2 and ndim_b == 1:
3✔
1525
            output_shape = [real_shape_a[0]]
3✔
1526
        elif ndim_a == 1 and ndim_b == 2:
3✔
1527
            output_shape = [real_shape_b[1]]
×
1528
        elif ndim_a > 2 or ndim_b > 2:
3✔
1529
            if ndim_a == ndim_b:
3✔
1530
                output_shape = list(real_shape_a[:-2]) + [
3✔
1531
                    real_shape_a[-2],
1532
                    real_shape_b[-1],
1533
                ]
1534
            else:
1535
                raise NotImplementedError(
×
1536
                    "Broadcasting with different ranks not fully supported yet"
1537
                )
1538
        else:
1539
            raise NotImplementedError(
×
1540
                f"Matmul with ranks {ndim_a} and {ndim_b} not supported"
1541
            )
1542

1543
        dtype = Scalar(PrimitiveType.Double)
3✔
1544

1545
        if is_scalar:
3✔
1546
            tmp_name = f"_tmp_{self._get_unique_id()}"
3✔
1547
            self.builder.add_container(tmp_name, dtype, False)
3✔
1548
            self.symbol_table[tmp_name] = dtype
3✔
1549
        else:
1550
            tmp_name = self._create_array_temp(output_shape, dtype)
3✔
1551

1552
        if ndim_a > 2 or ndim_b > 2:
3✔
1553
            # Generate loops for broadcasting
1554
            batch_dims = ndim_a - 2
3✔
1555
            loop_vars = []
3✔
1556

1557
            for i in range(batch_dims):
3✔
1558
                loop_var = f"_i{self._get_unique_id()}"
3✔
1559
                self.builder.add_container(loop_var, Scalar(PrimitiveType.Int64), False)
3✔
1560
                loop_vars.append(loop_var)
3✔
1561
                dim_size = real_shape_a[i]
3✔
1562
                self.builder.begin_for(loop_var, "0", str(dim_size), "1")
3✔
1563

1564
            def make_slice(name, indices):
3✔
1565
                elts = []
3✔
1566
                for idx in indices:
3✔
1567
                    if idx == ":":
3✔
1568
                        elts.append(ast.Slice())
3✔
1569
                    else:
1570
                        elts.append(ast.Name(id=idx))
3✔
1571

1572
                return ast.Subscript(
3✔
1573
                    value=ast.Name(id=name), slice=ast.Tuple(elts=elts), ctx=ast.Load()
1574
                )
1575

1576
            indices = loop_vars + [":", ":"]
3✔
1577
            slice_a = make_slice(name_a, indices)
3✔
1578
            slice_b = make_slice(name_b, indices)
3✔
1579
            slice_c = make_slice(tmp_name, indices)
3✔
1580

1581
            self.la_handler.handle_gemm(
3✔
1582
                slice_c, ast.BinOp(left=slice_a, op=ast.MatMult(), right=slice_b)
1583
            )
1584

1585
            for _ in range(batch_dims):
3✔
1586
                self.builder.end_for()
3✔
1587
        else:
1588
            if is_scalar:
3✔
1589
                self.la_handler.handle_dot(
3✔
1590
                    tmp_name,
1591
                    ast.BinOp(left=left_node, op=ast.MatMult(), right=right_node),
1592
                )
1593
            else:
1594
                self.la_handler.handle_gemm(
3✔
1595
                    tmp_name,
1596
                    ast.BinOp(left=left_node, op=ast.MatMult(), right=right_node),
1597
                )
1598

1599
        return tmp_name
3✔
1600

1601
    def _handle_numpy_unary_op(self, node, func_name):
3✔
1602
        args = [self.visit(arg) for arg in node.args]
3✔
1603
        if len(args) != 1:
3✔
1604
            raise NotImplementedError(f"Numpy function {func_name} requires 1 argument")
×
1605

1606
        op_name = func_name
3✔
1607
        if op_name == "absolute":
3✔
1608
            op_name = "abs"
×
1609

1610
        return self._handle_array_unary_op(op_name, args[0])
3✔
1611

1612
    def _handle_numpy_reduce(self, node, func_name):
3✔
1613
        args = node.args
3✔
1614
        keywords = {kw.arg: kw.value for kw in node.keywords}
3✔
1615

1616
        array_node = args[0]
3✔
1617
        array_name = self.visit(array_node)
3✔
1618

1619
        if array_name not in self.array_info:
3✔
1620
            raise ValueError(f"Reduction input must be an array, got {array_name}")
×
1621

1622
        input_shape = self.array_info[array_name]["shapes"]
3✔
1623
        ndim = len(input_shape)
3✔
1624

1625
        axis = None
3✔
1626
        if len(args) > 1:
3✔
1627
            axis = args[1]
×
1628
        elif "axis" in keywords:
3✔
1629
            axis = keywords["axis"]
3✔
1630

1631
        keepdims = False
3✔
1632
        if "keepdims" in keywords:
3✔
1633
            keepdims_node = keywords["keepdims"]
3✔
1634
            if isinstance(keepdims_node, ast.Constant):
3✔
1635
                keepdims = bool(keepdims_node.value)
3✔
1636

1637
        axes = []
3✔
1638
        if axis is None:
3✔
1639
            axes = list(range(ndim))
3✔
1640
        elif isinstance(axis, ast.Constant):  # Single axis
3✔
1641
            val = axis.value
3✔
1642
            if val < 0:
3✔
1643
                val += ndim
×
1644
            axes = [val]
3✔
1645
        elif isinstance(axis, ast.Tuple):  # Multiple axes
×
1646
            for elt in axis.elts:
×
1647
                if isinstance(elt, ast.Constant):
×
1648
                    val = elt.value
×
1649
                    if val < 0:
×
1650
                        val += ndim
×
1651
                    axes.append(val)
×
1652
        elif (
×
1653
            isinstance(axis, ast.UnaryOp)
1654
            and isinstance(axis.op, ast.USub)
1655
            and isinstance(axis.operand, ast.Constant)
1656
        ):
1657
            val = -axis.operand.value
×
1658
            if val < 0:
×
1659
                val += ndim
×
1660
            axes = [val]
×
1661
        else:
1662
            # Try to evaluate simple expression
1663
            try:
×
1664
                val = int(self.visit(axis))
×
1665
                if val < 0:
×
1666
                    val += ndim
×
1667
                axes = [val]
×
1668
            except:
×
1669
                raise NotImplementedError("Dynamic axis not supported")
×
1670

1671
        # Calculate output shape
1672
        output_shape = []
3✔
1673
        for i in range(ndim):
3✔
1674
            if i in axes:
3✔
1675
                if keepdims:
3✔
1676
                    output_shape.append("1")
3✔
1677
            else:
1678
                output_shape.append(input_shape[i])
3✔
1679

1680
        dtype = self._get_dtype(array_name)
3✔
1681

1682
        if not output_shape:
3✔
1683
            tmp_name = f"_tmp_{self._get_unique_id()}"
3✔
1684
            self.builder.add_container(tmp_name, dtype, False)
3✔
1685
            self.symbol_table[tmp_name] = dtype
3✔
1686
            self.array_info[tmp_name] = {"ndim": 0, "shapes": []}
3✔
1687
        else:
1688
            tmp_name = self._create_array_temp(output_shape, dtype)
3✔
1689

1690
        self.builder.add_reduce_op(
3✔
1691
            func_name, array_name, tmp_name, input_shape, axes, keepdims
1692
        )
1693

1694
        return tmp_name
3✔
1695

1696
    def _handle_numpy_astype(self, node, array_name):
3✔
1697
        """Handle numpy array.astype(dtype) method calls."""
1698
        if len(node.args) < 1:
3✔
1699
            raise ValueError("astype requires at least one argument (dtype)")
×
1700

1701
        dtype_arg = node.args[0]
3✔
1702
        target_dtype = self._map_numpy_dtype(dtype_arg)
3✔
1703

1704
        # Get input array shape
1705
        if array_name not in self.array_info:
3✔
1706
            raise ValueError(f"Array {array_name} not found in array_info")
×
1707

1708
        input_shape = self.array_info[array_name]["shapes"]
3✔
1709

1710
        # Create output array with target dtype
1711
        tmp_name = self._create_array_temp(input_shape, target_dtype)
3✔
1712

1713
        # Add cast operation
1714
        self.builder.add_cast_op(
3✔
1715
            array_name, tmp_name, input_shape, target_dtype.primitive_type
1716
        )
1717

1718
        return tmp_name
3✔
1719

1720
    def _handle_scipy_softmax(self, node, func_name):
3✔
1721
        args = node.args
3✔
1722
        keywords = {kw.arg: kw.value for kw in node.keywords}
3✔
1723

1724
        array_node = args[0]
3✔
1725
        array_name = self.visit(array_node)
3✔
1726

1727
        if array_name not in self.array_info:
3✔
1728
            raise ValueError(f"Softmax input must be an array, got {array_name}")
×
1729

1730
        input_shape = self.array_info[array_name]["shapes"]
3✔
1731
        ndim = len(input_shape)
3✔
1732

1733
        axis = None
3✔
1734
        if len(args) > 1:
3✔
1735
            axis = args[1]
×
1736
        elif "axis" in keywords:
3✔
1737
            axis = keywords["axis"]
3✔
1738

1739
        axes = []
3✔
1740
        if axis is None:
3✔
1741
            axes = list(range(ndim))
3✔
1742
        elif isinstance(axis, ast.Constant):  # Single axis
3✔
1743
            val = axis.value
3✔
1744
            if val < 0:
3✔
1745
                val += ndim
×
1746
            axes = [val]
3✔
1747
        elif isinstance(axis, ast.Tuple):  # Multiple axes
×
1748
            for elt in axis.elts:
×
1749
                if isinstance(elt, ast.Constant):
×
1750
                    val = elt.value
×
1751
                    if val < 0:
×
1752
                        val += ndim
×
1753
                    axes.append(val)
×
1754
        elif (
×
1755
            isinstance(axis, ast.UnaryOp)
1756
            and isinstance(axis.op, ast.USub)
1757
            and isinstance(axis.operand, ast.Constant)
1758
        ):
1759
            val = -axis.operand.value
×
1760
            if val < 0:
×
1761
                val += ndim
×
1762
            axes = [val]
×
1763
        else:
1764
            # Try to evaluate simple expression
1765
            try:
×
1766
                val = int(self.visit(axis))
×
1767
                if val < 0:
×
1768
                    val += ndim
×
1769
                axes = [val]
×
1770
            except:
×
1771
                raise NotImplementedError("Dynamic axis not supported")
×
1772

1773
        # Create output array
1774
        # Assume double for now, or infer from input
1775
        dtype = Scalar(PrimitiveType.Double)  # TODO: infer
3✔
1776

1777
        tmp_name = self._create_array_temp(input_shape, dtype)
3✔
1778

1779
        self.builder.add_reduce_op(
3✔
1780
            func_name, array_name, tmp_name, input_shape, axes, False
1781
        )
1782

1783
        return tmp_name
3✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc