• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

daisytuner / sdfglib / 21125207138

19 Jan 2026 12:03AM UTC coverage: 64.513% (-0.01%) from 64.525%
21125207138

push

github

web-flow
Merge pull request #464 from daisytuner/python-frontend-syntax

adds validation checks

9 of 17 new or added lines in 3 files covered. (52.94%)

3 existing lines in 1 file now uncovered.

19797 of 30687 relevant lines covered (64.51%)

390.36 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

77.97
/python/docc/expression_visitor.py
1
import ast
3✔
2
import inspect
3✔
3
import textwrap
3✔
4
from ._sdfg import Scalar, PrimitiveType, Pointer, Type, DebugInfo, Structure
3✔
5

6

7
class ExpressionVisitor(ast.NodeVisitor):
3✔
8
    def __init__(
3✔
9
        self,
10
        array_info=None,
11
        builder=None,
12
        symbol_table=None,
13
        globals_dict=None,
14
        inliner=None,
15
        unique_counter_ref=None,
16
        structure_member_info=None,
17
    ):
18
        self.array_info = array_info if array_info is not None else {}
3✔
19
        self.builder = builder
3✔
20
        self.symbol_table = symbol_table if symbol_table is not None else {}
3✔
21
        self.globals_dict = globals_dict if globals_dict is not None else {}
3✔
22
        self.inliner = inliner
3✔
23
        self._unique_counter_ref = (
3✔
24
            unique_counter_ref if unique_counter_ref is not None else [0]
25
        )
26
        self._access_cache = {}
3✔
27
        self.la_handler = None
3✔
28
        self.structure_member_info = (
3✔
29
            structure_member_info if structure_member_info is not None else {}
30
        )
31
        self._init_numpy_handlers()
3✔
32

33
    def _get_unique_id(self):
3✔
34
        self._unique_counter_ref[0] += 1
3✔
35
        return self._unique_counter_ref[0]
3✔
36

37
    def _get_temp_name(self, prefix="_tmp_"):
3✔
38
        if hasattr(self.builder, "find_new_name"):
3✔
39
            return self.builder.find_new_name(prefix)
×
40
        return f"{prefix}{self._get_unique_id()}"
3✔
41

42
    def _init_numpy_handlers(self):
3✔
43
        self.numpy_handlers = {
3✔
44
            "empty": self._handle_numpy_alloc,
45
            "empty_like": self._handle_numpy_empty_like,
46
            "zeros": self._handle_numpy_alloc,
47
            "zeros_like": self._handle_numpy_zeros_like,
48
            "ones": self._handle_numpy_alloc,
49
            "eye": self._handle_numpy_eye,
50
            "add": self._handle_numpy_binary_op,
51
            "subtract": self._handle_numpy_binary_op,
52
            "multiply": self._handle_numpy_binary_op,
53
            "divide": self._handle_numpy_binary_op,
54
            "power": self._handle_numpy_binary_op,
55
            "exp": self._handle_numpy_unary_op,
56
            "abs": self._handle_numpy_unary_op,
57
            "absolute": self._handle_numpy_unary_op,
58
            "sqrt": self._handle_numpy_unary_op,
59
            "tanh": self._handle_numpy_unary_op,
60
            "sum": self._handle_numpy_reduce,
61
            "max": self._handle_numpy_reduce,
62
            "min": self._handle_numpy_reduce,
63
            "mean": self._handle_numpy_reduce,
64
            "std": self._handle_numpy_reduce,
65
            "matmul": self._handle_numpy_matmul,
66
            "dot": self._handle_numpy_matmul,
67
            "matvec": self._handle_numpy_matmul,
68
            "outer": self._handle_numpy_outer,
69
            "minimum": self._handle_numpy_binary_op,
70
            "maximum": self._handle_numpy_binary_op,
71
        }
72

73
    def generic_visit(self, node):
3✔
74
        return super().generic_visit(node)
×
75

76
    def visit_Constant(self, node):
3✔
77
        if isinstance(node.value, bool):
3✔
78
            return "true" if node.value else "false"
×
79
        return str(node.value)
3✔
80

81
    def visit_Name(self, node):
3✔
82
        return node.id
3✔
83

84
    def _map_numpy_dtype(self, dtype_node):
3✔
85
        # Default to double
86
        if dtype_node is None:
3✔
87
            return Scalar(PrimitiveType.Double)
×
88

89
        if isinstance(dtype_node, ast.Name):
3✔
90
            if dtype_node.id == "float":
3✔
91
                return Scalar(PrimitiveType.Double)
3✔
92
            if dtype_node.id == "int":
3✔
93
                return Scalar(PrimitiveType.Int64)
3✔
94
            if dtype_node.id == "bool":
×
95
                return Scalar(PrimitiveType.Bool)
×
96

97
        if isinstance(dtype_node, ast.Attribute):
3✔
98
            # Handle array.dtype
99
            if (
3✔
100
                isinstance(dtype_node.value, ast.Name)
101
                and dtype_node.value.id in self.symbol_table
102
                and dtype_node.attr == "dtype"
103
            ):
104
                sym_type = self.symbol_table[dtype_node.value.id]
3✔
105
                if isinstance(sym_type, Pointer) and sym_type.has_pointee_type():
3✔
106
                    return sym_type.pointee_type
3✔
107

108
            if isinstance(dtype_node.value, ast.Name) and dtype_node.value.id in [
3✔
109
                "numpy",
110
                "np",
111
            ]:
112
                if dtype_node.attr == "float64":
3✔
113
                    return Scalar(PrimitiveType.Double)
3✔
114
                if dtype_node.attr == "float32":
3✔
115
                    return Scalar(PrimitiveType.Float)
3✔
116
                if dtype_node.attr == "int64":
3✔
117
                    return Scalar(PrimitiveType.Int64)
3✔
118
                if dtype_node.attr == "int32":
3✔
119
                    return Scalar(PrimitiveType.Int32)
3✔
120
                if dtype_node.attr == "bool_":
×
121
                    return Scalar(PrimitiveType.Bool)
×
122

123
        # Fallback
124
        return Scalar(PrimitiveType.Double)
×
125

126
    def _is_int(self, operand):
3✔
127
        try:
3✔
128
            if operand.lstrip("-").isdigit():
3✔
129
                return True
3✔
130
        except ValueError:
×
131
            pass
×
132

133
        name = operand
3✔
134
        if "(" in operand and operand.endswith(")"):
3✔
135
            name = operand.split("(")[0]
×
136

137
        if name in self.symbol_table:
3✔
138
            t = self.symbol_table[name]
3✔
139

140
            def is_int_ptype(pt):
3✔
141
                return pt in [
3✔
142
                    PrimitiveType.Int64,
143
                    PrimitiveType.Int32,
144
                    PrimitiveType.Int8,
145
                    PrimitiveType.Int16,
146
                    PrimitiveType.UInt64,
147
                    PrimitiveType.UInt32,
148
                    PrimitiveType.UInt8,
149
                    PrimitiveType.UInt16,
150
                ]
151

152
            if isinstance(t, Scalar):
3✔
153
                return is_int_ptype(t.primitive_type)
3✔
154

155
            if type(t).__name__ == "Array" and hasattr(t, "element_type"):
×
156
                et = t.element_type
×
157
                if callable(et):
×
158
                    et = et()
×
159
                if isinstance(et, Scalar):
×
160
                    return is_int_ptype(et.primitive_type)
×
161

162
            if type(t).__name__ == "Pointer":
×
163
                if hasattr(t, "pointee_type"):
×
164
                    et = t.pointee_type
×
165
                    if callable(et):
×
166
                        et = et()
×
167
                    if isinstance(et, Scalar):
×
168
                        return is_int_ptype(et.primitive_type)
×
169
                # Fallback: check if it has element_type (maybe alias?)
170
                if hasattr(t, "element_type"):
×
171
                    et = t.element_type
×
172
                    if callable(et):
×
173
                        et = et()
×
174
                    if isinstance(et, Scalar):
×
175
                        return is_int_ptype(et.primitive_type)
×
176

177
        return False
3✔
178

179
    def _add_read(self, block, expr_str, debug_info=None):
3✔
180
        # Try to reuse access node
181
        try:
3✔
182
            if (block, expr_str) in self._access_cache:
3✔
183
                return self._access_cache[(block, expr_str)]
3✔
184
        except TypeError:
×
185
            # block might not be hashable
186
            pass
×
187

188
        if debug_info is None:
3✔
189
            debug_info = DebugInfo()
3✔
190

191
        if "(" in expr_str and expr_str.endswith(")"):
3✔
192
            name = expr_str.split("(")[0]
3✔
193
            subset = expr_str[expr_str.find("(") + 1 : -1]
3✔
194
            access = self.builder.add_access(block, name, debug_info)
3✔
195
            try:
3✔
196
                self._access_cache[(block, expr_str)] = (access, subset)
3✔
197
            except TypeError:
×
198
                pass
×
199
            return access, subset
3✔
200

201
        if self.builder.has_container(expr_str):
3✔
202
            access = self.builder.add_access(block, expr_str, debug_info)
3✔
203
            try:
3✔
204
                self._access_cache[(block, expr_str)] = (access, "")
3✔
205
            except TypeError:
×
206
                pass
×
207
            return access, ""
3✔
208

209
        dtype = Scalar(PrimitiveType.Double)
3✔
210
        if self._is_int(expr_str):
3✔
211
            dtype = Scalar(PrimitiveType.Int64)
3✔
212
        elif expr_str == "true" or expr_str == "false":
3✔
213
            dtype = Scalar(PrimitiveType.Bool)
×
214

215
        const_node = self.builder.add_constant(block, expr_str, dtype, debug_info)
3✔
216
        try:
3✔
217
            self._access_cache[(block, expr_str)] = (const_node, "")
3✔
218
        except TypeError:
×
219
            pass
×
220
        return const_node, ""
3✔
221

222
    def _handle_min_max(self, node, func_name):
3✔
223
        args = [self.visit(arg) for arg in node.args]
3✔
224
        if len(args) != 2:
3✔
225
            raise NotImplementedError(f"{func_name} only supported with 2 arguments")
×
226

227
        # Check types
228
        is_float = False
3✔
229
        arg_types = []
3✔
230

231
        for arg in args:
3✔
232
            name = arg
3✔
233
            if "(" in arg and arg.endswith(")"):
3✔
234
                name = arg.split("(")[0]
×
235

236
            if name in self.symbol_table:
3✔
237
                t = self.symbol_table[name]
3✔
238
                if isinstance(t, Pointer):
3✔
239
                    t = t.base_type
×
240

241
                if t.primitive_type == PrimitiveType.Double:
3✔
242
                    is_float = True
3✔
243
                    arg_types.append(PrimitiveType.Double)
3✔
244
                else:
245
                    arg_types.append(PrimitiveType.Int64)
3✔
246
            elif self._is_int(arg):
×
247
                arg_types.append(PrimitiveType.Int64)
×
248
            else:
249
                # Assume float constant
250
                is_float = True
×
251
                arg_types.append(PrimitiveType.Double)
×
252

253
        dtype = Scalar(PrimitiveType.Double if is_float else PrimitiveType.Int64)
3✔
254

255
        tmp_name = self._get_temp_name("_tmp_")
3✔
256
        self.builder.add_container(tmp_name, dtype, False)
3✔
257
        self.symbol_table[tmp_name] = dtype
3✔
258

259
        if is_float:
3✔
260
            # Cast args if necessary
261
            casted_args = []
3✔
262
            for i, arg in enumerate(args):
3✔
263
                if arg_types[i] != PrimitiveType.Double:
3✔
264
                    # Create temp double
265
                    tmp_cast = self._get_temp_name("_cast_")
3✔
266
                    self.builder.add_container(
3✔
267
                        tmp_cast, Scalar(PrimitiveType.Double), False
268
                    )
269
                    self.symbol_table[tmp_cast] = Scalar(PrimitiveType.Double)
3✔
270

271
                    # Assign int to double (implicit cast)
272
                    self.builder.add_assignment(tmp_cast, arg)
3✔
273
                    casted_args.append(tmp_cast)
3✔
274
                else:
275
                    casted_args.append(arg)
3✔
276

277
            block = self.builder.add_block()
3✔
278
            t_out = self.builder.add_access(block, tmp_name)
3✔
279

280
            intrinsic_name = "fmax" if func_name == "max" else "fmin"
3✔
281
            t_task = self.builder.add_intrinsic(block, intrinsic_name)
3✔
282

283
            for i, arg in enumerate(casted_args):
3✔
284
                t_arg, arg_sub = self._add_read(block, arg)
3✔
285
                self.builder.add_memlet(
3✔
286
                    block, t_arg, "void", t_task, f"_in{i+1}", arg_sub
287
                )
288
        else:
289
            block = self.builder.add_block()
3✔
290
            t_out = self.builder.add_access(block, tmp_name)
3✔
291

292
            # Use int_smax/int_smin tasklet
293
            opcode = "int_smax" if func_name == "max" else "int_smin"
3✔
294
            t_task = self.builder.add_tasklet(block, opcode, ["_in1", "_in2"], ["_out"])
3✔
295

296
            for i, arg in enumerate(args):
3✔
297
                t_arg, arg_sub = self._add_read(block, arg)
3✔
298
                self.builder.add_memlet(
3✔
299
                    block, t_arg, "void", t_task, f"_in{i+1}", arg_sub
300
                )
301

302
        self.builder.add_memlet(block, t_task, "_out", t_out, "void", "")
3✔
303
        return tmp_name
3✔
304

305
    def _handle_python_cast(self, node, func_name):
3✔
306
        """Handle Python type casts: int(), float(), bool()"""
307
        if len(node.args) != 1:
3✔
308
            raise NotImplementedError(f"{func_name}() cast requires exactly 1 argument")
×
309

310
        arg = self.visit(node.args[0])
3✔
311

312
        # Determine target type based on cast function
313
        if func_name == "int":
3✔
314
            target_dtype = Scalar(PrimitiveType.Int64)
3✔
315
        elif func_name == "float":
3✔
316
            target_dtype = Scalar(PrimitiveType.Double)
3✔
317
        elif func_name == "bool":
3✔
318
            target_dtype = Scalar(PrimitiveType.Bool)
3✔
319
        else:
320
            raise NotImplementedError(f"Cast to {func_name} not supported")
×
321

322
        # Determine source type
323
        source_dtype = None
3✔
324
        name = arg
3✔
325
        if "(" in arg and arg.endswith(")"):
3✔
326
            name = arg.split("(")[0]
×
327

328
        if name in self.symbol_table:
3✔
329
            source_dtype = self.symbol_table[name]
3✔
330
            if isinstance(source_dtype, Pointer):
3✔
331
                source_dtype = source_dtype.base_type
×
332
        elif self._is_int(arg):
×
333
            source_dtype = Scalar(PrimitiveType.Int64)
×
334
        elif arg == "true" or arg == "false":
×
335
            source_dtype = Scalar(PrimitiveType.Bool)
×
336
        else:
337
            # Assume float constant
338
            source_dtype = Scalar(PrimitiveType.Double)
×
339

340
        # Create temporary variable for result
341
        tmp_name = self._get_temp_name("_tmp_")
3✔
342
        self.builder.add_container(tmp_name, target_dtype, False)
3✔
343
        self.symbol_table[tmp_name] = target_dtype
3✔
344

345
        # Use tasklet assign opcode for casting (as specified in problem statement)
346
        block = self.builder.add_block()
3✔
347
        t_src, src_sub = self._add_read(block, arg)
3✔
348
        t_dst = self.builder.add_access(block, tmp_name)
3✔
349
        t_task = self.builder.add_tasklet(block, "assign", ["_in"], ["_out"])
3✔
350
        self.builder.add_memlet(block, t_src, "void", t_task, "_in", src_sub)
3✔
351
        self.builder.add_memlet(block, t_task, "_out", t_dst, "void", "")
3✔
352

353
        return tmp_name
3✔
354

355
    def visit_Call(self, node):
3✔
356
        func_name = ""
3✔
357
        module_name = ""
3✔
358
        if isinstance(node.func, ast.Attribute):
3✔
359
            if isinstance(node.func.value, ast.Name):
3✔
360
                if node.func.value.id == "math":
3✔
361
                    module_name = "math"
3✔
362
                    func_name = node.func.attr
3✔
363
                elif node.func.value.id in ["numpy", "np"]:
3✔
364
                    module_name = "numpy"
3✔
365
                    func_name = node.func.attr
3✔
366
                else:
367
                    # Check if it's a method call on an array (e.g., arr.astype(...))
368
                    array_name = node.func.value.id
3✔
369
                    method_name = node.func.attr
3✔
370
                    if array_name in self.array_info and method_name == "astype":
3✔
371
                        return self._handle_numpy_astype(node, array_name)
3✔
372
            elif isinstance(node.func.value, ast.Attribute):
3✔
373
                if (
3✔
374
                    isinstance(node.func.value.value, ast.Name)
375
                    and node.func.value.value.id == "scipy"
376
                    and node.func.value.attr == "special"
377
                ):
378
                    if node.func.attr == "softmax":
3✔
379
                        return self._handle_scipy_softmax(node, "softmax")
3✔
380

381
        elif isinstance(node.func, ast.Name):
3✔
382
            func_name = node.func.id
3✔
383

384
        if module_name == "numpy":
3✔
385
            if func_name in self.numpy_handlers:
3✔
386
                return self.numpy_handlers[func_name](node, func_name)
3✔
387

388
        if func_name in ["max", "min"]:
3✔
389
            return self._handle_min_max(node, func_name)
3✔
390

391
        # Handle Python type casts (int, float, bool)
392
        if func_name in ["int", "float", "bool"]:
3✔
393
            return self._handle_python_cast(node, func_name)
3✔
394

395
        math_funcs = [
3✔
396
            "sin",
397
            "cos",
398
            "tan",
399
            "exp",
400
            "log",
401
            "sqrt",
402
            "pow",
403
            "abs",
404
            "ceil",
405
            "floor",
406
            "asin",
407
            "acos",
408
            "atan",
409
            "sinh",
410
            "cosh",
411
            "tanh",
412
        ]
413

414
        if func_name in math_funcs:
3✔
415
            args = [self.visit(arg) for arg in node.args]
3✔
416

417
            tmp_name = self._get_temp_name("_tmp_")
3✔
418
            dtype = Scalar(PrimitiveType.Double)
3✔
419
            self.builder.add_container(tmp_name, dtype, False)
3✔
420
            self.symbol_table[tmp_name] = dtype
3✔
421

422
            block = self.builder.add_block()
3✔
423
            t_out = self.builder.add_access(block, tmp_name)
3✔
424

425
            t_task = self.builder.add_intrinsic(block, func_name)
3✔
426

427
            for i, arg in enumerate(args):
3✔
428
                t_arg, arg_sub = self._add_read(block, arg)
3✔
429
                self.builder.add_memlet(
3✔
430
                    block, t_arg, "void", t_task, f"_in{i+1}", arg_sub
431
                )
432

433
            self.builder.add_memlet(block, t_task, "_out", t_out, "void", "")
3✔
434
            return tmp_name
3✔
435

436
        if func_name in self.globals_dict:
3✔
437
            obj = self.globals_dict[func_name]
3✔
438
            if inspect.isfunction(obj):
3✔
439
                return self._handle_inline_call(node, obj)
3✔
440

441
        raise NotImplementedError(f"Function call {func_name} not supported")
×
442

443
    def _handle_inline_call(self, node, func_obj):
3✔
444
        # 1. Parse function source
445
        try:
3✔
446
            source_lines, start_line = inspect.getsourcelines(func_obj)
3✔
447
            source = textwrap.dedent("".join(source_lines))
3✔
448
            tree = ast.parse(source)
3✔
449
            func_def = tree.body[0]
3✔
450
        except Exception as e:
×
451
            raise NotImplementedError(
×
452
                f"Could not parse function {func_obj.__name__}: {e}"
453
            )
454

455
        # 2. Evaluate arguments
456
        arg_vars = [self.visit(arg) for arg in node.args]
3✔
457

458
        if len(arg_vars) != len(func_def.args.args):
3✔
459
            raise NotImplementedError(
×
460
                f"Argument count mismatch for {func_obj.__name__}"
461
            )
462

463
        # 3. Generate unique suffix
464
        suffix = f"_{func_obj.__name__}_{self._get_unique_id()}"
3✔
465
        res_name = f"_res{suffix}"
3✔
466

467
        # Assume Int64 for now as match returns 0/1
468
        dtype = Scalar(PrimitiveType.Int64)
3✔
469
        self.builder.add_container(res_name, dtype, False)
3✔
470
        self.symbol_table[res_name] = dtype
3✔
471

472
        # 4. Rename variables
473
        class VariableRenamer(ast.NodeTransformer):
3✔
474
            def __init__(self, suffix, globals_dict):
3✔
475
                self.suffix = suffix
3✔
476
                self.globals_dict = globals_dict
3✔
477

478
            def visit_Name(self, node):
3✔
479
                if node.id in self.globals_dict:
3✔
480
                    return node
3✔
481
                return ast.Name(id=f"{node.id}{self.suffix}", ctx=node.ctx)
3✔
482

483
            def visit_Return(self, node):
3✔
484
                if node.value:
3✔
485
                    val = self.visit(node.value)
3✔
486
                    return ast.Assign(
3✔
487
                        targets=[ast.Name(id=res_name, ctx=ast.Store())],
488
                        value=val,
489
                    )
490
                return node
×
491

492
        renamer = VariableRenamer(suffix, self.globals_dict)
3✔
493
        new_body = [renamer.visit(stmt) for stmt in func_def.body]
3✔
494

495
        # 5. Assign arguments to parameters
496
        param_assignments = []
3✔
497
        for arg_def, arg_val in zip(func_def.args.args, arg_vars):
3✔
498
            param_name = f"{arg_def.arg}{suffix}"
3✔
499

500
            # Infer type and create container
501
            if arg_val in self.symbol_table:
3✔
502
                self.symbol_table[param_name] = self.symbol_table[arg_val]
3✔
503
                self.builder.add_container(
3✔
504
                    param_name, self.symbol_table[arg_val], False
505
                )
506
                val_node = ast.Name(id=arg_val, ctx=ast.Load())
3✔
507
            elif self._is_int(arg_val):
×
508
                self.symbol_table[param_name] = Scalar(PrimitiveType.Int64)
×
509
                self.builder.add_container(
×
510
                    param_name, Scalar(PrimitiveType.Int64), False
511
                )
512
                val_node = ast.Constant(value=int(arg_val))
×
513
            else:
514
                # Assume float constant
515
                try:
×
516
                    val = float(arg_val)
×
517
                    self.symbol_table[param_name] = Scalar(PrimitiveType.Double)
×
518
                    self.builder.add_container(
×
519
                        param_name, Scalar(PrimitiveType.Double), False
520
                    )
521
                    val_node = ast.Constant(value=val)
×
522
                except ValueError:
×
523
                    # Fallback to Name, might fail later if not in symbol table
524
                    val_node = ast.Name(id=arg_val, ctx=ast.Load())
×
525

526
            assign = ast.Assign(
3✔
527
                targets=[ast.Name(id=param_name, ctx=ast.Store())], value=val_node
528
            )
529
            param_assignments.append(assign)
3✔
530

531
        final_body = param_assignments + new_body
3✔
532

533
        # 6. Visit new body using ASTParser
534
        from .ast_parser import ASTParser
3✔
535

536
        parser = ASTParser(
3✔
537
            self.builder,
538
            self.array_info,
539
            self.symbol_table,
540
            globals_dict=self.globals_dict,
541
            unique_counter_ref=self._unique_counter_ref,
542
        )
543

544
        for stmt in final_body:
3✔
545
            parser.visit(stmt)
3✔
546

547
        return res_name
3✔
548

549
    def visit_BinOp(self, node):
3✔
550
        if isinstance(node.op, ast.MatMult):
3✔
551
            return self._handle_numpy_matmul_op(node.left, node.right)
3✔
552

553
        left = self.visit(node.left)
3✔
554
        right = self.visit(node.right)
3✔
555
        op = self.visit(node.op)
3✔
556

557
        # Check if left or right are arrays
558
        left_is_array = left in self.array_info
3✔
559
        right_is_array = right in self.array_info
3✔
560

561
        if left_is_array or right_is_array:
3✔
562
            op_map = {"+": "add", "-": "sub", "*": "mul", "/": "div", "**": "pow"}
3✔
563
            if op in op_map:
3✔
564
                return self._handle_array_binary_op(op_map[op], left, right)
3✔
565
            else:
566
                raise NotImplementedError(f"Array operation {op} not supported")
×
567

568
        tmp_name = f"_tmp_{self._get_unique_id()}"
3✔
569

570
        dtype = Scalar(PrimitiveType.Double)  # Default
3✔
571

572
        left_is_int = self._is_int(left)
3✔
573
        right_is_int = self._is_int(right)
3✔
574

575
        if left_is_int and right_is_int and op not in ["/", "**"]:
3✔
576
            dtype = Scalar(PrimitiveType.Int64)
3✔
577

578
        self.builder.add_container(tmp_name, dtype, False)
3✔
579
        self.symbol_table[tmp_name] = dtype
3✔
580

581
        real_left = left
3✔
582
        real_right = right
3✔
583

584
        if dtype.primitive_type == PrimitiveType.Double:
3✔
585
            if left_is_int:
3✔
586
                left_cast = f"_tmp_{self._get_unique_id()}"
3✔
587
                self.builder.add_container(
3✔
588
                    left_cast, Scalar(PrimitiveType.Double), False
589
                )
590
                self.symbol_table[left_cast] = Scalar(PrimitiveType.Double)
3✔
591

592
                c_block = self.builder.add_block()
3✔
593
                t_src, src_sub = self._add_read(c_block, left)
3✔
594
                t_dst = self.builder.add_access(c_block, left_cast)
3✔
595
                t_task = self.builder.add_tasklet(c_block, "assign", ["_in"], ["_out"])
3✔
596
                self.builder.add_memlet(c_block, t_src, "void", t_task, "_in", src_sub)
3✔
597
                self.builder.add_memlet(c_block, t_task, "_out", t_dst, "void", "")
3✔
598

599
                real_left = left_cast
3✔
600

601
            if right_is_int:
3✔
602
                right_cast = f"_tmp_{self._get_unique_id()}"
3✔
603
                self.builder.add_container(
3✔
604
                    right_cast, Scalar(PrimitiveType.Double), False
605
                )
606
                self.symbol_table[right_cast] = Scalar(PrimitiveType.Double)
3✔
607

608
                c_block = self.builder.add_block()
3✔
609
                t_src, src_sub = self._add_read(c_block, right)
3✔
610
                t_dst = self.builder.add_access(c_block, right_cast)
3✔
611
                t_task = self.builder.add_tasklet(c_block, "assign", ["_in"], ["_out"])
3✔
612
                self.builder.add_memlet(c_block, t_src, "void", t_task, "_in", src_sub)
3✔
613
                self.builder.add_memlet(c_block, t_task, "_out", t_dst, "void", "")
3✔
614

615
                real_right = right_cast
3✔
616

617
        # Special cases
618
        if op == "**":
3✔
619
            block = self.builder.add_block()
3✔
620
            t_left, left_sub = self._add_read(block, real_left)
3✔
621
            t_right, right_sub = self._add_read(block, real_right)
3✔
622
            t_out = self.builder.add_access(block, tmp_name)
3✔
623

624
            t_task = self.builder.add_intrinsic(block, "pow")
3✔
625
            self.builder.add_memlet(block, t_left, "void", t_task, "_in1", left_sub)
3✔
626
            self.builder.add_memlet(block, t_right, "void", t_task, "_in2", right_sub)
3✔
627
            self.builder.add_memlet(block, t_task, "_out", t_out, "void", "")
3✔
628

629
            return tmp_name
3✔
630
        elif op == "%":
3✔
631
            block = self.builder.add_block()
3✔
632
            t_left, left_sub = self._add_read(block, real_left)
3✔
633
            t_right, right_sub = self._add_read(block, real_right)
3✔
634
            t_out = self.builder.add_access(block, tmp_name)
3✔
635

636
            if dtype.primitive_type == PrimitiveType.Int64:
3✔
637
                # Implement ((a % b) + b) % b to match Python's modulo behavior
638

639
                # 1. rem1 = a % b
640
                t_rem1 = self.builder.add_tasklet(
3✔
641
                    block, "int_rem", ["_in1", "_in2"], ["_out"]
642
                )
643
                self.builder.add_memlet(block, t_left, "void", t_rem1, "_in1", left_sub)
3✔
644
                self.builder.add_memlet(
3✔
645
                    block, t_right, "void", t_rem1, "_in2", right_sub
646
                )
647

648
                rem1_name = f"_tmp_{self._get_unique_id()}"
3✔
649
                self.builder.add_container(rem1_name, dtype, False)
3✔
650
                t_rem1_out = self.builder.add_access(block, rem1_name)
3✔
651
                self.builder.add_memlet(block, t_rem1, "_out", t_rem1_out, "void", "")
3✔
652

653
                # 2. add = rem1 + b
654
                t_add = self.builder.add_tasklet(
3✔
655
                    block, "int_add", ["_in1", "_in2"], ["_out"]
656
                )
657
                self.builder.add_memlet(block, t_rem1_out, "void", t_add, "_in1", "")
3✔
658
                self.builder.add_memlet(
3✔
659
                    block, t_right, "void", t_add, "_in2", right_sub
660
                )
661

662
                add_name = f"_tmp_{self._get_unique_id()}"
3✔
663
                self.builder.add_container(add_name, dtype, False)
3✔
664
                t_add_out = self.builder.add_access(block, add_name)
3✔
665
                self.builder.add_memlet(block, t_add, "_out", t_add_out, "void", "")
3✔
666

667
                # 3. res = add % b
668
                t_rem2 = self.builder.add_tasklet(
3✔
669
                    block, "int_rem", ["_in1", "_in2"], ["_out"]
670
                )
671
                self.builder.add_memlet(block, t_add_out, "void", t_rem2, "_in1", "")
3✔
672
                self.builder.add_memlet(
3✔
673
                    block, t_right, "void", t_rem2, "_in2", right_sub
674
                )
675
                self.builder.add_memlet(block, t_rem2, "_out", t_out, "void", "")
3✔
676

677
                return tmp_name
3✔
678
            else:
679
                t_task = self.builder.add_intrinsic(block, "fmod")
3✔
680
                self.builder.add_memlet(block, t_left, "void", t_task, "_in1", left_sub)
3✔
681
                self.builder.add_memlet(
3✔
682
                    block, t_right, "void", t_task, "_in2", right_sub
683
                )
684
                self.builder.add_memlet(block, t_task, "_out", t_out, "void", "")
3✔
685
                return tmp_name
3✔
686

687
        prefix = "int" if dtype.primitive_type == PrimitiveType.Int64 else "fp"
3✔
688
        op_name = ""
3✔
689
        if op == "+":
3✔
690
            op_name = "add"
3✔
691
        elif op == "-":
3✔
692
            op_name = "sub"
3✔
693
        elif op == "*":
3✔
694
            op_name = "mul"
3✔
695
        elif op == "/":
3✔
696
            op_name = "div"
3✔
697
        elif op == "//":
3✔
698
            op_name = "div"
3✔
699
        elif op == "|":
3✔
700
            op_name = "or"
3✔
701
        elif op == "^":
3✔
702
            op_name = "xor"
3✔
703

704
        block = self.builder.add_block()
3✔
705
        t_left, left_sub = self._add_read(block, real_left)
3✔
706
        t_right, right_sub = self._add_read(block, real_right)
3✔
707
        t_out = self.builder.add_access(block, tmp_name)
3✔
708

709
        tasklet_code = f"{prefix}_{op_name}"
3✔
710
        t_task = self.builder.add_tasklet(
3✔
711
            block, tasklet_code, ["_in1", "_in2"], ["_out"]
712
        )
713

714
        self.builder.add_memlet(block, t_left, "void", t_task, "_in1", left_sub)
3✔
715
        self.builder.add_memlet(block, t_right, "void", t_task, "_in2", right_sub)
3✔
716
        self.builder.add_memlet(block, t_task, "_out", t_out, "void", "")
3✔
717

718
        return tmp_name
3✔
719

720
    def _add_assign_constant(self, target_name, value_str, dtype):
3✔
721
        block = self.builder.add_block()
3✔
722
        t_const = self.builder.add_constant(block, value_str, dtype)
3✔
723
        t_dst = self.builder.add_access(block, target_name)
3✔
724
        t_task = self.builder.add_tasklet(block, "assign", ["_in"], ["_out"])
3✔
725
        self.builder.add_memlet(block, t_const, "void", t_task, "_in", "")
3✔
726
        self.builder.add_memlet(block, t_task, "_out", t_dst, "void", "")
3✔
727

728
    def visit_BoolOp(self, node):
3✔
729
        op = self.visit(node.op)
3✔
730
        values = [f"({self.visit(v)} != 0)" for v in node.values]
3✔
731
        expr_str = f"{f' {op} '.join(values)}"
3✔
732

733
        tmp_name = f"_tmp_{self._get_unique_id()}"
3✔
734
        dtype = Scalar(PrimitiveType.Bool)
3✔
735
        self.builder.add_container(tmp_name, dtype, False)
3✔
736

737
        # Use control flow to assign boolean value
738
        self.builder.begin_if(expr_str)
3✔
739
        self._add_assign_constant(tmp_name, "true", dtype)
3✔
740
        self.builder.begin_else()
3✔
741
        self._add_assign_constant(tmp_name, "false", dtype)
3✔
742
        self.builder.end_if()
3✔
743

744
        self.symbol_table[tmp_name] = dtype
3✔
745
        return tmp_name
3✔
746

747
    def visit_Compare(self, node):
3✔
748
        left = self.visit(node.left)
3✔
749
        if len(node.ops) > 1:
3✔
750
            raise NotImplementedError("Chained comparisons not supported yet")
×
751

752
        op = self.visit(node.ops[0])
3✔
753
        right = self.visit(node.comparators[0])
3✔
754
        expr_str = f"{left} {op} {right}"
3✔
755

756
        tmp_name = f"_tmp_{self._get_unique_id()}"
3✔
757
        dtype = Scalar(PrimitiveType.Bool)
3✔
758
        self.builder.add_container(tmp_name, dtype, False)
3✔
759

760
        # Use control flow to assign boolean value
761
        self.builder.begin_if(expr_str)
3✔
762
        self.builder.add_assignment(tmp_name, "true")
3✔
763
        self.builder.begin_else()
3✔
764
        self.builder.add_assignment(tmp_name, "false")
3✔
765
        self.builder.end_if()
3✔
766

767
        self.symbol_table[tmp_name] = dtype
3✔
768
        return tmp_name
3✔
769

770
    def visit_UnaryOp(self, node):
3✔
771
        if (
3✔
772
            isinstance(node.op, ast.USub)
773
            and isinstance(node.operand, ast.Constant)
774
            and isinstance(node.operand.value, (int, float))
775
        ):
776
            return f"-{node.operand.value}"
3✔
777

778
        op = self.visit(node.op)
3✔
779
        operand = self.visit(node.operand)
3✔
780

781
        tmp_name = f"_tmp_{self._get_unique_id()}"
3✔
782
        dtype = Scalar(PrimitiveType.Double)
3✔
783
        if operand in self.symbol_table:
3✔
784
            dtype = self.symbol_table[operand]
3✔
UNCOV
785
        elif self._is_int(operand):
×
UNCOV
786
            dtype = Scalar(PrimitiveType.Int64)
×
UNCOV
787
        elif isinstance(node.op, ast.Not):
×
788
            dtype = Scalar(PrimitiveType.Bool)
×
789

790
        self.builder.add_container(tmp_name, dtype, False)
3✔
791
        self.symbol_table[tmp_name] = dtype
3✔
792

793
        block = self.builder.add_block()
3✔
794
        t_src, src_sub = self._add_read(block, operand)
3✔
795
        t_dst = self.builder.add_access(block, tmp_name)
3✔
796

797
        if isinstance(node.op, ast.Not):
3✔
798
            t_const = self.builder.add_constant(
3✔
799
                block, "true", Scalar(PrimitiveType.Bool)
800
            )
801
            t_task = self.builder.add_tasklet(
3✔
802
                block, "int_xor", ["_in1", "_in2"], ["_out"]
803
            )
804
            self.builder.add_memlet(block, t_src, "void", t_task, "_in1", src_sub)
3✔
805
            self.builder.add_memlet(block, t_const, "void", t_task, "_in2", "")
3✔
806
            self.builder.add_memlet(block, t_task, "_out", t_dst, "void", "")
3✔
807

808
        elif op == "-":
3✔
809
            if dtype.primitive_type == PrimitiveType.Int64:
3✔
810
                t_const = self.builder.add_constant(block, "0", dtype)
3✔
811
                t_task = self.builder.add_tasklet(
3✔
812
                    block, "int_sub", ["_in1", "_in2"], ["_out"]
813
                )
814
                self.builder.add_memlet(block, t_const, "void", t_task, "_in1", "")
3✔
815
                self.builder.add_memlet(block, t_src, "void", t_task, "_in2", src_sub)
3✔
816
                self.builder.add_memlet(block, t_task, "_out", t_dst, "void", "")
3✔
817
            else:
818
                t_task = self.builder.add_tasklet(block, "fp_neg", ["_in"], ["_out"])
3✔
819
                self.builder.add_memlet(block, t_src, "void", t_task, "_in", src_sub)
3✔
820
                self.builder.add_memlet(block, t_task, "_out", t_dst, "void", "")
3✔
821
        else:
822
            t_task = self.builder.add_tasklet(block, "assign", ["_in"], ["_out"])
×
823
            self.builder.add_memlet(block, t_src, "void", t_task, "_in", src_sub)
×
824
            self.builder.add_memlet(block, t_task, "_out", t_dst, "void", "")
×
825

826
        return tmp_name
3✔
827

828
    def _parse_array_arg(self, node, simple_visitor):
3✔
829
        if isinstance(node, ast.Name):
×
830
            if node.id in self.array_info:
×
831
                return node.id, [], self.array_info[node.id]["shapes"]
×
832
        elif isinstance(node, ast.Subscript):
×
833
            if isinstance(node.value, ast.Name) and node.value.id in self.array_info:
×
834
                name = node.value.id
×
835
                ndim = self.array_info[name]["ndim"]
×
836

837
                indices = []
×
838
                if isinstance(node.slice, ast.Tuple):
×
839
                    indices = list(node.slice.elts)
×
840
                else:
841
                    indices = [node.slice]
×
842

843
                while len(indices) < ndim:
×
844
                    indices.append(ast.Slice(lower=None, upper=None, step=None))
×
845

846
                start_indices = []
×
847
                slice_shape = []
×
848

849
                for i, idx in enumerate(indices):
×
850
                    if isinstance(idx, ast.Slice):
×
851
                        start = "0"
×
852
                        if idx.lower:
×
853
                            start = simple_visitor.visit(idx.lower)
×
854
                        start_indices.append(start)
×
855

856
                        shapes = self.array_info[name]["shapes"]
×
857
                        dim_size = (
×
858
                            shapes[i] if i < len(shapes) else f"_{name}_shape_{i}"
859
                        )
860
                        stop = dim_size
×
861
                        if idx.upper:
×
862
                            stop = simple_visitor.visit(idx.upper)
×
863

864
                        size = f"({stop} - {start})"
×
865
                        slice_shape.append(size)
×
866
                    else:
867
                        val = simple_visitor.visit(idx)
×
868
                        start_indices.append(val)
×
869

870
                shapes = self.array_info[name]["shapes"]
×
871
                linear_index = ""
×
872
                for i in range(ndim):
×
873
                    term = start_indices[i]
×
874
                    for j in range(i + 1, ndim):
×
875
                        shape_val = shapes[j] if j < len(shapes) else None
×
876
                        shape_sym = (
×
877
                            shape_val if shape_val is not None else f"_{name}_shape_{j}"
878
                        )
879
                        term = f"({term} * {shape_sym})"
×
880

881
                    if i == 0:
×
882
                        linear_index = term
×
883
                    else:
884
                        linear_index = f"({linear_index} + {term})"
×
885

886
                return name, [linear_index], slice_shape
×
887

888
        return None, None, None
×
889

890
    def visit_Attribute(self, node):
3✔
891
        if node.attr == "shape":
3✔
892
            if isinstance(node.value, ast.Name) and node.value.id in self.array_info:
3✔
893
                return f"_shape_proxy_{node.value.id}"
3✔
894

895
        if isinstance(node.value, ast.Name) and node.value.id == "math":
3✔
896
            val = ""
3✔
897
            if node.attr == "pi":
3✔
898
                val = "M_PI"
3✔
899
            elif node.attr == "e":
3✔
900
                val = "M_E"
3✔
901

902
            if val:
3✔
903
                tmp_name = f"_tmp_{self._get_unique_id()}"
3✔
904
                dtype = Scalar(PrimitiveType.Double)
3✔
905
                self.builder.add_container(tmp_name, dtype, False)
3✔
906
                self.symbol_table[tmp_name] = dtype
3✔
907
                self._add_assign_constant(tmp_name, val, dtype)
3✔
908
                return tmp_name
3✔
909

910
        # Handle class member access (e.g., obj.x, obj.y)
911
        if isinstance(node.value, ast.Name):
3✔
912
            obj_name = node.value.id
3✔
913
            attr_name = node.attr
3✔
914

915
            # Check if the object is a class instance (has a Structure type)
916
            if obj_name in self.symbol_table:
3✔
917
                obj_type = self.symbol_table[obj_name]
3✔
918
                if isinstance(obj_type, Pointer) and obj_type.has_pointee_type():
3✔
919
                    pointee_type = obj_type.pointee_type
3✔
920
                    if isinstance(pointee_type, Structure):
3✔
921
                        struct_name = pointee_type.name
3✔
922

923
                        # Look up member index and type from structure info
924
                        if (
3✔
925
                            struct_name in self.structure_member_info
926
                            and attr_name in self.structure_member_info[struct_name]
927
                        ):
928
                            member_index, member_type = self.structure_member_info[
3✔
929
                                struct_name
930
                            ][attr_name]
931
                        else:
932
                            # This should not happen if structure was registered properly
933
                            raise RuntimeError(
×
934
                                f"Member '{attr_name}' not found in structure '{struct_name}'. "
935
                                f"Available members: {list(self.structure_member_info.get(struct_name, {}).keys())}"
936
                            )
937

938
                        # Generate a tasklet to access the member
939
                        tmp_name = f"_tmp_{self._get_unique_id()}"
3✔
940

941
                        self.builder.add_container(tmp_name, member_type, False)
3✔
942
                        self.symbol_table[tmp_name] = member_type
3✔
943

944
                        # Create a tasklet that reads the member
945
                        block = self.builder.add_block()
3✔
946
                        obj_access = self.builder.add_access(block, obj_name)
3✔
947
                        tmp_access = self.builder.add_access(block, tmp_name)
3✔
948

949
                        # Use tasklet to pass through the value
950
                        # The actual member selection is done via the memlet subset
951
                        tasklet = self.builder.add_tasklet(
3✔
952
                            block, "assign", ["_in"], ["_out"]
953
                        )
954

955
                        # Use member index in the subset to select the correct member
956
                        subset = "0," + str(member_index)
3✔
957
                        self.builder.add_memlet(
3✔
958
                            block, obj_access, "", tasklet, "_in", subset
959
                        )
960
                        self.builder.add_memlet(block, tasklet, "_out", tmp_access, "")
3✔
961

962
                        return tmp_name
3✔
963

964
        raise NotImplementedError(f"Attribute access {node.attr} not supported")
×
965

966
    def visit_Subscript(self, node):
3✔
967
        value_str = self.visit(node.value)
3✔
968

969
        if value_str.startswith("_shape_proxy_"):
3✔
970
            array_name = value_str[len("_shape_proxy_") :]
3✔
971
            if isinstance(node.slice, ast.Constant):
3✔
972
                idx = node.slice.value
3✔
973
            elif isinstance(node.slice, ast.Index):
×
974
                idx = node.slice.value.value
×
975
            else:
976
                try:
×
977
                    idx = int(self.visit(node.slice))
×
978
                except:
×
979
                    raise NotImplementedError(
×
980
                        "Dynamic shape indexing not fully supported yet"
981
                    )
982

983
            if (
3✔
984
                array_name in self.array_info
985
                and "shapes" in self.array_info[array_name]
986
            ):
987
                return self.array_info[array_name]["shapes"][idx]
3✔
988

989
            return f"_{array_name}_shape_{idx}"
×
990

991
        if value_str in self.array_info:
3✔
992
            ndim = self.array_info[value_str]["ndim"]
3✔
993
            shapes = self.array_info[value_str].get("shapes", [])
3✔
994

995
            indices = []
3✔
996
            if isinstance(node.slice, ast.Tuple):
3✔
997
                indices_nodes = node.slice.elts
3✔
998
            else:
999
                indices_nodes = [node.slice]
3✔
1000

1001
            for idx in indices_nodes:
3✔
1002
                if isinstance(idx, ast.Slice):
3✔
1003
                    raise ValueError("Slices not supported in expression indexing")
×
1004

1005
            if isinstance(node.slice, ast.Tuple):
3✔
1006
                indices = [self.visit(elt) for elt in node.slice.elts]
3✔
1007
            else:
1008
                indices = [self.visit(node.slice)]
3✔
1009

1010
            if len(indices) != ndim:
3✔
1011
                raise ValueError(
×
1012
                    f"Array {value_str} has {ndim} dimensions, but accessed with {len(indices)} indices"
1013
                )
1014

1015
            linear_index = ""
3✔
1016
            for i in range(ndim):
3✔
1017
                term = indices[i]
3✔
1018
                for j in range(i + 1, ndim):
3✔
1019
                    shape_val = shapes[j] if j < len(shapes) else None
3✔
1020
                    shape_sym = (
3✔
1021
                        shape_val
1022
                        if shape_val is not None
1023
                        else f"_{value_str}_shape_{j}"
1024
                    )
1025
                    term = f"(({term}) * {shape_sym})"
3✔
1026

1027
                if i == 0:
3✔
1028
                    linear_index = term
3✔
1029
                else:
1030
                    linear_index = f"({linear_index} + {term})"
3✔
1031

1032
            access_str = f"{value_str}({linear_index})"
3✔
1033

1034
            if self.builder and isinstance(node.ctx, ast.Load):
3✔
1035
                dtype = Scalar(PrimitiveType.Double)
3✔
1036
                if value_str in self.symbol_table:
3✔
1037
                    t = self.symbol_table[value_str]
3✔
1038
                    if type(t).__name__ == "Array" and hasattr(t, "element_type"):
3✔
1039
                        et = t.element_type
×
1040
                        if callable(et):
×
1041
                            et = et()
×
1042
                        dtype = et
×
1043
                    elif type(t).__name__ == "Pointer" and hasattr(t, "pointee_type"):
3✔
1044
                        et = t.pointee_type
3✔
1045
                        if callable(et):
3✔
1046
                            et = et()
×
1047
                        dtype = et
3✔
1048

1049
                tmp_name = f"_tmp_{self._get_unique_id()}"
3✔
1050
                self.builder.add_container(tmp_name, dtype, False)
3✔
1051

1052
                block = self.builder.add_block()
3✔
1053
                t_src = self.builder.add_access(block, value_str)
3✔
1054
                t_dst = self.builder.add_access(block, tmp_name)
3✔
1055
                t_task = self.builder.add_tasklet(block, "assign", ["_in"], ["_out"])
3✔
1056

1057
                self.builder.add_memlet(
3✔
1058
                    block, t_src, "void", t_task, "_in", linear_index
1059
                )
1060
                self.builder.add_memlet(block, t_task, "_out", t_dst, "void", "")
3✔
1061

1062
                self.symbol_table[tmp_name] = dtype
3✔
1063
                return tmp_name
3✔
1064

1065
            return access_str
3✔
1066

1067
        slice_val = self.visit(node.slice)
×
1068
        access_str = f"{value_str}({slice_val})"
×
1069

1070
        if (
×
1071
            self.builder
1072
            and isinstance(node.ctx, ast.Load)
1073
            and value_str in self.array_info
1074
        ):
1075
            tmp_name = f"_tmp_{self._get_unique_id()}"
×
1076
            self.builder.add_container(tmp_name, Scalar(PrimitiveType.Double), False)
×
1077
            self.builder.add_assignment(tmp_name, access_str)
×
1078
            self.symbol_table[tmp_name] = Scalar(PrimitiveType.Double)
×
1079
            return tmp_name
×
1080

1081
        return access_str
×
1082

1083
    def visit_Add(self, node):
3✔
1084
        return "+"
3✔
1085

1086
    def visit_Sub(self, node):
3✔
1087
        return "-"
3✔
1088

1089
    def visit_Mult(self, node):
3✔
1090
        return "*"
3✔
1091

1092
    def visit_Div(self, node):
3✔
1093
        return "/"
3✔
1094

1095
    def visit_FloorDiv(self, node):
3✔
1096
        return "//"
3✔
1097

1098
    def visit_Mod(self, node):
3✔
1099
        return "%"
3✔
1100

1101
    def visit_Pow(self, node):
3✔
1102
        return "**"
3✔
1103

1104
    def visit_Eq(self, node):
3✔
1105
        return "=="
×
1106

1107
    def visit_NotEq(self, node):
3✔
1108
        return "!="
×
1109

1110
    def visit_Lt(self, node):
3✔
1111
        return "<"
×
1112

1113
    def visit_LtE(self, node):
3✔
1114
        return "<="
×
1115

1116
    def visit_Gt(self, node):
3✔
1117
        return ">"
3✔
1118

1119
    def visit_GtE(self, node):
3✔
1120
        return ">="
×
1121

1122
    def visit_And(self, node):
3✔
1123
        return "&"
3✔
1124

1125
    def visit_Or(self, node):
3✔
1126
        return "|"
3✔
1127

1128
    def visit_BitAnd(self, node):
3✔
1129
        return "&"
×
1130

1131
    def visit_BitOr(self, node):
3✔
1132
        return "|"
3✔
1133

1134
    def visit_BitXor(self, node):
3✔
1135
        return "^"
3✔
1136

1137
    def visit_Not(self, node):
3✔
1138
        return "!"
3✔
1139

1140
    def visit_USub(self, node):
3✔
1141
        return "-"
3✔
1142

1143
    def visit_UAdd(self, node):
3✔
1144
        return "+"
×
1145

1146
    def visit_Invert(self, node):
3✔
1147
        return "~"
×
1148

1149
    def _get_dtype(self, name):
3✔
1150
        if name in self.symbol_table:
3✔
1151
            t = self.symbol_table[name]
3✔
1152
            if isinstance(t, Scalar):
3✔
1153
                return t
×
1154

1155
            if hasattr(t, "pointee_type"):
3✔
1156
                et = t.pointee_type
3✔
1157
                if callable(et):
3✔
1158
                    et = et()
×
1159
                if isinstance(et, Scalar):
3✔
1160
                    return et
3✔
1161

1162
            if hasattr(t, "element_type"):
×
1163
                et = t.element_type
×
1164
                if callable(et):
×
1165
                    et = et()
×
1166
                if isinstance(et, Scalar):
×
1167
                    return et
×
1168

1169
        if self._is_int(name):
3✔
1170
            return Scalar(PrimitiveType.Int64)
×
1171

1172
        return Scalar(PrimitiveType.Double)
3✔
1173

1174
    def _create_array_temp(self, shape, dtype, zero_init=False, ones_init=False):
3✔
1175
        tmp_name = f"_tmp_{self._get_unique_id()}"
3✔
1176

1177
        # Calculate size
1178
        size_str = "1"
3✔
1179
        for dim in shape:
3✔
1180
            size_str = f"({size_str} * {dim})"
3✔
1181

1182
        element_size = self.builder.get_sizeof(dtype)
3✔
1183
        total_size = f"({size_str} * {element_size})"
3✔
1184

1185
        # Create pointer
1186
        ptr_type = Pointer(dtype)
3✔
1187
        self.builder.add_container(tmp_name, ptr_type, False)
3✔
1188
        self.symbol_table[tmp_name] = ptr_type
3✔
1189
        self.array_info[tmp_name] = {"ndim": len(shape), "shapes": shape}
3✔
1190

1191
        # Malloc
1192
        block1 = self.builder.add_block()
3✔
1193
        t_malloc = self.builder.add_malloc(block1, total_size)
3✔
1194
        t_ptr1 = self.builder.add_access(block1, tmp_name)
3✔
1195
        self.builder.add_memlet(block1, t_malloc, "_ret", t_ptr1, "void", "", ptr_type)
3✔
1196

1197
        if zero_init:
3✔
1198
            block2 = self.builder.add_block()
3✔
1199
            t_memset = self.builder.add_memset(block2, "0", total_size)
3✔
1200
            t_ptr2 = self.builder.add_access(block2, tmp_name)
3✔
1201
            self.builder.add_memlet(
3✔
1202
                block2, t_memset, "_ptr", t_ptr2, "void", "", ptr_type
1203
            )
1204
        elif ones_init:
3✔
1205
            # Initialize array with ones using a loop
1206
            loop_var = f"_i_{self._get_unique_id()}"
3✔
1207
            if not self.builder.has_container(loop_var):
3✔
1208
                self.builder.add_container(loop_var, Scalar(PrimitiveType.Int64), False)
3✔
1209
                self.symbol_table[loop_var] = Scalar(PrimitiveType.Int64)
3✔
1210

1211
            self.builder.begin_for(loop_var, "0", size_str, "1")
3✔
1212

1213
            # Determine the value to set based on dtype
1214
            val = "1.0"
3✔
1215
            if dtype.primitive_type in [
3✔
1216
                PrimitiveType.Int64,
1217
                PrimitiveType.Int32,
1218
                PrimitiveType.Int8,
1219
                PrimitiveType.Int16,
1220
                PrimitiveType.UInt64,
1221
                PrimitiveType.UInt32,
1222
                PrimitiveType.UInt8,
1223
                PrimitiveType.UInt16,
1224
            ]:
1225
                val = "1"
3✔
1226

1227
            block_assign = self.builder.add_block()
3✔
1228
            t_const = self.builder.add_constant(block_assign, val, dtype)
3✔
1229
            t_arr = self.builder.add_access(block_assign, tmp_name)
3✔
1230

1231
            t_task = self.builder.add_tasklet(block_assign, "assign", ["_in"], ["_out"])
3✔
1232
            self.builder.add_memlet(
3✔
1233
                block_assign, t_const, "void", t_task, "_in", "", dtype
1234
            )
1235
            self.builder.add_memlet(
3✔
1236
                block_assign, t_task, "_out", t_arr, "void", loop_var
1237
            )
1238

1239
            self.builder.end_for()
3✔
1240

1241
        return tmp_name
3✔
1242

1243
    def _handle_array_unary_op(self, op_type, operand):
3✔
1244
        # Determine output shape
1245
        shape = []
3✔
1246
        if operand in self.array_info:
3✔
1247
            shape = self.array_info[operand]["shapes"]
3✔
1248

1249
        # Determine dtype
1250
        dtype = self._get_dtype(operand)
3✔
1251

1252
        tmp_name = self._create_array_temp(shape, dtype)
3✔
1253

1254
        # Add operation
1255
        self.builder.add_elementwise_unary_op(op_type, operand, tmp_name, shape)
3✔
1256

1257
        return tmp_name
3✔
1258

1259
    def _handle_array_binary_op(self, op_type, left, right):
3✔
1260
        # Determine output shape
1261
        shape = []
3✔
1262
        if left in self.array_info:
3✔
1263
            shape = self.array_info[left]["shapes"]
3✔
1264
        elif right in self.array_info:
×
1265
            shape = self.array_info[right]["shapes"]
×
1266

1267
        # Determine dtype
1268
        dtype_left = self._get_dtype(left)
3✔
1269
        dtype_right = self._get_dtype(right)
3✔
1270

1271
        assert dtype_left.primitive_type == dtype_right.primitive_type
3✔
1272
        dtype = dtype_left
3✔
1273

1274
        tmp_name = self._create_array_temp(shape, dtype)
3✔
1275

1276
        # Add operation
1277
        self.builder.add_elementwise_op(op_type, left, right, tmp_name, shape)
3✔
1278

1279
        return tmp_name
3✔
1280

1281
    def _handle_numpy_alloc(self, node, func_name):
3✔
1282
        # Parse shape
1283
        shape_arg = node.args[0]
3✔
1284
        dims = []
3✔
1285
        if isinstance(shape_arg, ast.Tuple):
3✔
1286
            dims = [self.visit(elt) for elt in shape_arg.elts]
3✔
1287
        elif isinstance(shape_arg, ast.List):
3✔
1288
            dims = [self.visit(elt) for elt in shape_arg.elts]
×
1289
        else:
1290
            val = self.visit(shape_arg)
3✔
1291
            if val.startswith("_shape_proxy_"):
3✔
1292
                array_name = val[len("_shape_proxy_") :]
×
1293
                if array_name in self.array_info:
×
1294
                    dims = self.array_info[array_name]["shapes"]
×
1295
                else:
1296
                    dims = [val]
×
1297
            else:
1298
                dims = [val]
3✔
1299

1300
        # Parse dtype
1301
        dtype_arg = None
3✔
1302
        if len(node.args) > 1:
3✔
1303
            dtype_arg = node.args[1]
×
1304

1305
        for kw in node.keywords:
3✔
1306
            if kw.arg == "dtype":
3✔
1307
                dtype_arg = kw.value
3✔
1308
                break
3✔
1309

1310
        element_type = self._map_numpy_dtype(dtype_arg)
3✔
1311

1312
        return self._create_array_temp(
3✔
1313
            dims,
1314
            element_type,
1315
            zero_init=(func_name == "zeros"),
1316
            ones_init=(func_name == "ones"),
1317
        )
1318

1319
    def _handle_numpy_empty_like(self, node, func_name):
3✔
1320
        prototype_arg = node.args[0]
3✔
1321
        prototype_name = self.visit(prototype_arg)
3✔
1322

1323
        # Parse shape from prototype
1324
        dims = []
3✔
1325
        if prototype_name in self.array_info:
3✔
1326
            dims = self.array_info[prototype_name]["shapes"]
3✔
1327

1328
        # Parse dtype
1329
        dtype_arg = None
3✔
1330
        if len(node.args) > 1:
3✔
1331
            dtype_arg = node.args[1]
×
1332

1333
        for kw in node.keywords:
3✔
1334
            if kw.arg == "dtype":
3✔
1335
                dtype_arg = kw.value
3✔
1336
                break
3✔
1337

1338
        element_type = None
3✔
1339
        if dtype_arg:
3✔
1340
            element_type = self._map_numpy_dtype(dtype_arg)
3✔
1341
        else:
1342
            if prototype_name in self.symbol_table:
3✔
1343
                sym_type = self.symbol_table[prototype_name]
3✔
1344
                if isinstance(sym_type, Pointer) and sym_type.has_pointee_type():
3✔
1345
                    element_type = sym_type.pointee_type
3✔
1346

1347
        if element_type is None:
3✔
1348
            element_type = Scalar(PrimitiveType.Double)
×
1349

1350
        return self._create_array_temp(
3✔
1351
            dims,
1352
            element_type,
1353
            zero_init=False,
1354
            ones_init=False,
1355
        )
1356

1357
    def _handle_numpy_zeros_like(self, node, func_name):
3✔
1358
        prototype_arg = node.args[0]
3✔
1359
        prototype_name = self.visit(prototype_arg)
3✔
1360

1361
        # Parse shape from prototype
1362
        dims = []
3✔
1363
        if prototype_name in self.array_info:
3✔
1364
            dims = self.array_info[prototype_name]["shapes"]
3✔
1365

1366
        # Parse dtype
1367
        dtype_arg = None
3✔
1368
        if len(node.args) > 1:
3✔
1369
            dtype_arg = node.args[1]
×
1370

1371
        for kw in node.keywords:
3✔
1372
            if kw.arg == "dtype":
3✔
1373
                dtype_arg = kw.value
3✔
1374
                break
3✔
1375

1376
        element_type = None
3✔
1377
        if dtype_arg:
3✔
1378
            element_type = self._map_numpy_dtype(dtype_arg)
3✔
1379
        else:
1380
            if prototype_name in self.symbol_table:
3✔
1381
                sym_type = self.symbol_table[prototype_name]
3✔
1382
                if isinstance(sym_type, Pointer) and sym_type.has_pointee_type():
3✔
1383
                    element_type = sym_type.pointee_type
3✔
1384

1385
        if element_type is None:
3✔
1386
            element_type = Scalar(PrimitiveType.Double)
×
1387

1388
        return self._create_array_temp(
3✔
1389
            dims,
1390
            element_type,
1391
            zero_init=True,
1392
            ones_init=False,
1393
        )
1394

1395
    def _handle_numpy_eye(self, node, func_name):
3✔
1396
        # Parse N
1397
        N_arg = node.args[0]
3✔
1398
        N_str = self.visit(N_arg)
3✔
1399

1400
        # Parse M
1401
        M_str = N_str
3✔
1402
        if len(node.args) > 1:
3✔
1403
            M_str = self.visit(node.args[1])
×
1404

1405
        # Parse k
1406
        k_str = "0"
3✔
1407
        if len(node.args) > 2:
3✔
1408
            k_str = self.visit(node.args[2])
×
1409

1410
        # Check keywords for M, k, dtype
1411
        dtype_arg = None
3✔
1412
        for kw in node.keywords:
3✔
1413
            if kw.arg == "M":
3✔
1414
                M_str = self.visit(kw.value)
3✔
1415
                if M_str == "None":
3✔
1416
                    M_str = N_str
3✔
1417
            elif kw.arg == "k":
3✔
1418
                k_str = self.visit(kw.value)
3✔
1419
            elif kw.arg == "dtype":
3✔
1420
                dtype_arg = kw.value
3✔
1421

1422
        element_type = self._map_numpy_dtype(dtype_arg)
3✔
1423

1424
        ptr_name = self._create_array_temp([N_str, M_str], element_type, zero_init=True)
3✔
1425

1426
        # Loop to set diagonal
1427
        loop_var = f"_i_{self._get_unique_id()}"
3✔
1428
        if not self.builder.has_container(loop_var):
3✔
1429
            self.builder.add_container(loop_var, Scalar(PrimitiveType.Int64), False)
3✔
1430
            self.symbol_table[loop_var] = Scalar(PrimitiveType.Int64)
3✔
1431

1432
        self.builder.begin_for(loop_var, "0", N_str, "1")
3✔
1433

1434
        # Condition: 0 <= i + k < M
1435
        cond = f"(({loop_var} + {k_str}) >= 0) & (({loop_var} + {k_str}) < {M_str})"
3✔
1436
        self.builder.begin_if(cond)
3✔
1437

1438
        # Assignment: A[i, i+k] = 1
1439
        val = "1.0"
3✔
1440
        if element_type.primitive_type in [
3✔
1441
            PrimitiveType.Int64,
1442
            PrimitiveType.Int32,
1443
            PrimitiveType.Int8,
1444
            PrimitiveType.Int16,
1445
            PrimitiveType.UInt64,
1446
            PrimitiveType.UInt32,
1447
            PrimitiveType.UInt8,
1448
            PrimitiveType.UInt16,
1449
        ]:
1450
            val = "1"
×
1451

1452
        block_assign = self.builder.add_block()
3✔
1453
        t_const = self.builder.add_constant(block_assign, val, element_type)
3✔
1454
        t_arr = self.builder.add_access(block_assign, ptr_name)
3✔
1455
        flat_index = f"(({loop_var}) * ({M_str}) + ({loop_var}) + ({k_str}))"
3✔
1456
        subset = flat_index
3✔
1457

1458
        t_task = self.builder.add_tasklet(block_assign, "assign", ["_in"], ["_out"])
3✔
1459
        self.builder.add_memlet(
3✔
1460
            block_assign, t_const, "void", t_task, "_in", "", element_type
1461
        )
1462
        self.builder.add_memlet(block_assign, t_task, "_out", t_arr, "void", subset)
3✔
1463

1464
        self.builder.end_if()
3✔
1465
        self.builder.end_for()
3✔
1466

1467
        return ptr_name
3✔
1468

1469
    def _handle_numpy_binary_op(self, node, func_name):
3✔
1470
        args = [self.visit(arg) for arg in node.args]
3✔
1471
        if len(args) != 2:
3✔
1472
            raise NotImplementedError(
×
1473
                f"Numpy function {func_name} requires 2 arguments"
1474
            )
1475

1476
        op_map = {
3✔
1477
            "add": "add",
1478
            "subtract": "sub",
1479
            "multiply": "mul",
1480
            "divide": "div",
1481
            "power": "pow",
1482
            "minimum": "min",
1483
            "maximum": "max",
1484
        }
1485
        return self._handle_array_binary_op(op_map[func_name], args[0], args[1])
3✔
1486

1487
    def _handle_numpy_matmul_op(self, left_node, right_node):
3✔
1488
        return self._handle_matmul_helper(left_node, right_node)
3✔
1489

1490
    def _handle_numpy_matmul(self, node, func_name):
3✔
1491
        if len(node.args) != 2:
3✔
1492
            raise NotImplementedError("matmul/dot requires 2 arguments")
×
1493
        return self._handle_matmul_helper(node.args[0], node.args[1])
3✔
1494

1495
    def _handle_numpy_outer(self, node, func_name):
3✔
1496
        if len(node.args) != 2:
3✔
1497
            raise NotImplementedError("outer requires 2 arguments")
×
1498

1499
        arg0 = node.args[0]
3✔
1500
        arg1 = node.args[1]
3✔
1501

1502
        if not self.la_handler:
3✔
1503
            raise RuntimeError("LinearAlgebraHandler not initialized")
×
1504

1505
        res_a = self.la_handler.parse_arg(arg0)
3✔
1506
        res_b = self.la_handler.parse_arg(arg1)
3✔
1507

1508
        # Resolve standard names if parse_arg failed (likely complex expression)
1509
        if not res_a[0]:
3✔
1510
            left_name = self.visit(arg0)
×
1511
            arg0 = ast.Name(id=left_name)
×
1512
            res_a = self.la_handler.parse_arg(arg0)
×
1513

1514
        if not res_b[0]:
3✔
1515
            right_name = self.visit(arg1)
×
1516
            arg1 = ast.Name(id=right_name)
×
1517
            res_b = self.la_handler.parse_arg(arg1)
×
1518

1519
        name_a, subset_a, shape_a, indices_a = res_a
3✔
1520
        name_b, subset_b, shape_b, indices_b = res_b
3✔
1521

1522
        if not name_a or not name_b:
3✔
1523
            raise NotImplementedError("Could not resolve outer operands")
×
1524

1525
        def get_flattened_size_expr(name, indices, shapes):
3✔
1526
            # Simplified: if slice, we use parse_arg's returned `shapes` (which are dim sizes of the slice)
1527
            # And multiply them.
1528
            size_expr = "1"
3✔
1529
            for s in shapes:
3✔
1530
                if size_expr == "1":
3✔
1531
                    size_expr = str(s)
3✔
1532
                else:
1533
                    size_expr = f"({size_expr} * {str(s)})"
×
1534
            return size_expr
3✔
1535

1536
        m_expr = get_flattened_size_expr(name_a, indices_a, shape_a)
3✔
1537
        n_expr = get_flattened_size_expr(name_b, indices_b, shape_b)
3✔
1538

1539
        # Create temporary container
1540
        # Since outer usually promotes types or uses standard types, we default to double for now.
1541
        dtype = Scalar(PrimitiveType.Double)
3✔
1542

1543
        # Use helper to create array temp which handles symbol table and array info
1544
        tmp_name = self._create_array_temp([m_expr, n_expr], dtype)
3✔
1545

1546
        new_call_node = ast.Call(
3✔
1547
            func=node.func, args=[arg0, arg1], keywords=node.keywords
1548
        )
1549

1550
        self.la_handler.handle_outer(tmp_name, new_call_node)
3✔
1551

1552
        return tmp_name
3✔
1553

1554
    def _handle_matmul_helper(self, left_node, right_node):
3✔
1555
        if not self.la_handler:
3✔
1556
            raise RuntimeError("LinearAlgebraHandler not initialized")
×
1557

1558
        res_a = self.la_handler.parse_arg(left_node)
3✔
1559
        res_b = self.la_handler.parse_arg(right_node)
3✔
1560

1561
        if not res_a[0]:
3✔
1562
            left_name = self.visit(left_node)
×
1563
            left_node = ast.Name(id=left_name)
×
1564
            res_a = self.la_handler.parse_arg(left_node)
×
1565

1566
        if not res_b[0]:
3✔
1567
            right_name = self.visit(right_node)
×
1568
            right_node = ast.Name(id=right_name)
×
1569
            res_b = self.la_handler.parse_arg(right_node)
×
1570

1571
        name_a, subset_a, shape_a, indices_a = res_a
3✔
1572
        name_b, subset_b, shape_b, indices_b = res_b
3✔
1573

1574
        if not name_a or not name_b:
3✔
1575
            raise NotImplementedError("Could not resolve matmul operands")
×
1576

1577
        real_shape_a = shape_a
3✔
1578
        real_shape_b = shape_b
3✔
1579

1580
        ndim_a = len(real_shape_a)
3✔
1581
        ndim_b = len(real_shape_b)
3✔
1582

1583
        output_shape = []
3✔
1584
        is_scalar = False
3✔
1585

1586
        if ndim_a == 1 and ndim_b == 1:
3✔
1587
            is_scalar = True
3✔
1588
            output_shape = []
3✔
1589
        elif ndim_a == 2 and ndim_b == 2:
3✔
1590
            output_shape = [real_shape_a[0], real_shape_b[1]]
3✔
1591
        elif ndim_a == 2 and ndim_b == 1:
3✔
1592
            output_shape = [real_shape_a[0]]
3✔
1593
        elif ndim_a == 1 and ndim_b == 2:
3✔
1594
            output_shape = [real_shape_b[1]]
×
1595
        elif ndim_a > 2 or ndim_b > 2:
3✔
1596
            if ndim_a == ndim_b:
3✔
1597
                output_shape = list(real_shape_a[:-2]) + [
3✔
1598
                    real_shape_a[-2],
1599
                    real_shape_b[-1],
1600
                ]
1601
            else:
1602
                raise NotImplementedError(
×
1603
                    "Broadcasting with different ranks not fully supported yet"
1604
                )
1605
        else:
1606
            raise NotImplementedError(
×
1607
                f"Matmul with ranks {ndim_a} and {ndim_b} not supported"
1608
            )
1609

1610
        dtype = Scalar(PrimitiveType.Double)
3✔
1611

1612
        if is_scalar:
3✔
1613
            tmp_name = f"_tmp_{self._get_unique_id()}"
3✔
1614
            self.builder.add_container(tmp_name, dtype, False)
3✔
1615
            self.symbol_table[tmp_name] = dtype
3✔
1616
        else:
1617
            tmp_name = self._create_array_temp(output_shape, dtype)
3✔
1618

1619
        if ndim_a > 2 or ndim_b > 2:
3✔
1620
            # Generate loops for broadcasting
1621
            batch_dims = ndim_a - 2
3✔
1622
            loop_vars = []
3✔
1623

1624
            for i in range(batch_dims):
3✔
1625
                loop_var = f"_i{self._get_unique_id()}"
3✔
1626
                self.builder.add_container(loop_var, Scalar(PrimitiveType.Int64), False)
3✔
1627
                loop_vars.append(loop_var)
3✔
1628
                dim_size = real_shape_a[i]
3✔
1629
                self.builder.begin_for(loop_var, "0", str(dim_size), "1")
3✔
1630

1631
            def make_slice(name, indices):
3✔
1632
                elts = []
3✔
1633
                for idx in indices:
3✔
1634
                    if idx == ":":
3✔
1635
                        elts.append(ast.Slice())
3✔
1636
                    else:
1637
                        elts.append(ast.Name(id=idx))
3✔
1638

1639
                return ast.Subscript(
3✔
1640
                    value=ast.Name(id=name), slice=ast.Tuple(elts=elts), ctx=ast.Load()
1641
                )
1642

1643
            indices = loop_vars + [":", ":"]
3✔
1644
            slice_a = make_slice(name_a, indices)
3✔
1645
            slice_b = make_slice(name_b, indices)
3✔
1646
            slice_c = make_slice(tmp_name, indices)
3✔
1647

1648
            self.la_handler.handle_gemm(
3✔
1649
                slice_c, ast.BinOp(left=slice_a, op=ast.MatMult(), right=slice_b)
1650
            )
1651

1652
            for _ in range(batch_dims):
3✔
1653
                self.builder.end_for()
3✔
1654
        else:
1655
            if is_scalar:
3✔
1656
                self.la_handler.handle_dot(
3✔
1657
                    tmp_name,
1658
                    ast.BinOp(left=left_node, op=ast.MatMult(), right=right_node),
1659
                )
1660
            else:
1661
                self.la_handler.handle_gemm(
3✔
1662
                    tmp_name,
1663
                    ast.BinOp(left=left_node, op=ast.MatMult(), right=right_node),
1664
                )
1665

1666
        return tmp_name
3✔
1667

1668
    def _handle_numpy_unary_op(self, node, func_name):
3✔
1669
        args = [self.visit(arg) for arg in node.args]
3✔
1670
        if len(args) != 1:
3✔
1671
            raise NotImplementedError(f"Numpy function {func_name} requires 1 argument")
×
1672

1673
        op_name = func_name
3✔
1674
        if op_name == "absolute":
3✔
1675
            op_name = "abs"
×
1676

1677
        return self._handle_array_unary_op(op_name, args[0])
3✔
1678

1679
    def _handle_numpy_reduce(self, node, func_name):
3✔
1680
        args = node.args
3✔
1681
        keywords = {kw.arg: kw.value for kw in node.keywords}
3✔
1682

1683
        array_node = args[0]
3✔
1684
        array_name = self.visit(array_node)
3✔
1685

1686
        if array_name not in self.array_info:
3✔
1687
            raise ValueError(f"Reduction input must be an array, got {array_name}")
×
1688

1689
        input_shape = self.array_info[array_name]["shapes"]
3✔
1690
        ndim = len(input_shape)
3✔
1691

1692
        axis = None
3✔
1693
        if len(args) > 1:
3✔
1694
            axis = args[1]
×
1695
        elif "axis" in keywords:
3✔
1696
            axis = keywords["axis"]
3✔
1697

1698
        keepdims = False
3✔
1699
        if "keepdims" in keywords:
3✔
1700
            keepdims_node = keywords["keepdims"]
3✔
1701
            if isinstance(keepdims_node, ast.Constant):
3✔
1702
                keepdims = bool(keepdims_node.value)
3✔
1703

1704
        axes = []
3✔
1705
        if axis is None:
3✔
1706
            axes = list(range(ndim))
3✔
1707
        elif isinstance(axis, ast.Constant):  # Single axis
3✔
1708
            val = axis.value
3✔
1709
            if val < 0:
3✔
1710
                val += ndim
×
1711
            axes = [val]
3✔
1712
        elif isinstance(axis, ast.Tuple):  # Multiple axes
×
1713
            for elt in axis.elts:
×
1714
                if isinstance(elt, ast.Constant):
×
1715
                    val = elt.value
×
1716
                    if val < 0:
×
1717
                        val += ndim
×
1718
                    axes.append(val)
×
1719
        elif (
×
1720
            isinstance(axis, ast.UnaryOp)
1721
            and isinstance(axis.op, ast.USub)
1722
            and isinstance(axis.operand, ast.Constant)
1723
        ):
1724
            val = -axis.operand.value
×
1725
            if val < 0:
×
1726
                val += ndim
×
1727
            axes = [val]
×
1728
        else:
1729
            # Try to evaluate simple expression
1730
            try:
×
1731
                val = int(self.visit(axis))
×
1732
                if val < 0:
×
1733
                    val += ndim
×
1734
                axes = [val]
×
1735
            except:
×
1736
                raise NotImplementedError("Dynamic axis not supported")
×
1737

1738
        # Calculate output shape
1739
        output_shape = []
3✔
1740
        for i in range(ndim):
3✔
1741
            if i in axes:
3✔
1742
                if keepdims:
3✔
1743
                    output_shape.append("1")
3✔
1744
            else:
1745
                output_shape.append(input_shape[i])
3✔
1746

1747
        dtype = self._get_dtype(array_name)
3✔
1748

1749
        if not output_shape:
3✔
1750
            tmp_name = f"_tmp_{self._get_unique_id()}"
3✔
1751
            self.builder.add_container(tmp_name, dtype, False)
3✔
1752
            self.symbol_table[tmp_name] = dtype
3✔
1753
            self.array_info[tmp_name] = {"ndim": 0, "shapes": []}
3✔
1754
        else:
1755
            tmp_name = self._create_array_temp(output_shape, dtype)
3✔
1756

1757
        self.builder.add_reduce_op(
3✔
1758
            func_name, array_name, tmp_name, input_shape, axes, keepdims
1759
        )
1760

1761
        return tmp_name
3✔
1762

1763
    def _handle_numpy_astype(self, node, array_name):
3✔
1764
        """Handle numpy array.astype(dtype) method calls."""
1765
        if len(node.args) < 1:
3✔
1766
            raise ValueError("astype requires at least one argument (dtype)")
×
1767

1768
        dtype_arg = node.args[0]
3✔
1769
        target_dtype = self._map_numpy_dtype(dtype_arg)
3✔
1770

1771
        # Get input array shape
1772
        if array_name not in self.array_info:
3✔
1773
            raise ValueError(f"Array {array_name} not found in array_info")
×
1774

1775
        input_shape = self.array_info[array_name]["shapes"]
3✔
1776

1777
        # Create output array with target dtype
1778
        tmp_name = self._create_array_temp(input_shape, target_dtype)
3✔
1779

1780
        # Add cast operation
1781
        self.builder.add_cast_op(
3✔
1782
            array_name, tmp_name, input_shape, target_dtype.primitive_type
1783
        )
1784

1785
        return tmp_name
3✔
1786

1787
    def _handle_scipy_softmax(self, node, func_name):
3✔
1788
        args = node.args
3✔
1789
        keywords = {kw.arg: kw.value for kw in node.keywords}
3✔
1790

1791
        array_node = args[0]
3✔
1792
        array_name = self.visit(array_node)
3✔
1793

1794
        if array_name not in self.array_info:
3✔
1795
            raise ValueError(f"Softmax input must be an array, got {array_name}")
×
1796

1797
        input_shape = self.array_info[array_name]["shapes"]
3✔
1798
        ndim = len(input_shape)
3✔
1799

1800
        axis = None
3✔
1801
        if len(args) > 1:
3✔
1802
            axis = args[1]
×
1803
        elif "axis" in keywords:
3✔
1804
            axis = keywords["axis"]
3✔
1805

1806
        axes = []
3✔
1807
        if axis is None:
3✔
1808
            axes = list(range(ndim))
3✔
1809
        elif isinstance(axis, ast.Constant):  # Single axis
3✔
1810
            val = axis.value
3✔
1811
            if val < 0:
3✔
1812
                val += ndim
×
1813
            axes = [val]
3✔
1814
        elif isinstance(axis, ast.Tuple):  # Multiple axes
×
1815
            for elt in axis.elts:
×
1816
                if isinstance(elt, ast.Constant):
×
1817
                    val = elt.value
×
1818
                    if val < 0:
×
1819
                        val += ndim
×
1820
                    axes.append(val)
×
1821
        elif (
×
1822
            isinstance(axis, ast.UnaryOp)
1823
            and isinstance(axis.op, ast.USub)
1824
            and isinstance(axis.operand, ast.Constant)
1825
        ):
1826
            val = -axis.operand.value
×
1827
            if val < 0:
×
1828
                val += ndim
×
1829
            axes = [val]
×
1830
        else:
1831
            # Try to evaluate simple expression
1832
            try:
×
1833
                val = int(self.visit(axis))
×
1834
                if val < 0:
×
1835
                    val += ndim
×
1836
                axes = [val]
×
1837
            except:
×
1838
                raise NotImplementedError("Dynamic axis not supported")
×
1839

1840
        # Create output array
1841
        # Assume double for now, or infer from input
1842
        dtype = Scalar(PrimitiveType.Double)  # TODO: infer
3✔
1843

1844
        tmp_name = self._create_array_temp(input_shape, dtype)
3✔
1845

1846
        self.builder.add_reduce_op(
3✔
1847
            func_name, array_name, tmp_name, input_shape, axes, False
1848
        )
1849

1850
        return tmp_name
3✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc