• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

daisytuner / docc / 22023884668

14 Feb 2026 08:36PM UTC coverage: 64.903% (-1.4%) from 66.315%
22023884668

Pull #525

github

web-flow
Merge 1d47f8bf2 into 9d01cacd5
Pull Request #525: Step 3 (Native Tensor Support): Refactor Python Frontend

2522 of 3435 new or added lines in 32 files covered. (73.42%)

320 existing lines in 15 files now uncovered.

23204 of 35752 relevant lines covered (64.9%)

370.03 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

83.93
/python/docc/python/functions/python.py
1
from docc.sdfg import (
4✔
2
    Scalar,
3
    PrimitiveType,
4
    Pointer,
5
    TaskletCode,
6
    CMathFunction,
7
)
8

9

10
class PythonHandler:
4✔
11
    """Handler for Python built-in functions (min, max, type casts, etc.)."""
12

13
    def __init__(self, expression_visitor):
4✔
14
        self._ev = expression_visitor
4✔
15
        self.function_handlers = {
4✔
16
            "min": self._handle_min_max,
17
            "max": self._handle_min_max,
18
            "int": self._handle_python_cast,
19
            "float": self._handle_python_cast,
20
            "bool": self._handle_python_cast,
21
        }
22

23
    def has_handler(self, func_name):
4✔
24
        """Check if this handler can handle the given function name."""
25
        return func_name in self.function_handlers
4✔
26

27
    def handle_python_call(self, node, func_name):
4✔
28
        """Handle a call to a Python built-in function."""
29
        if func_name in self.function_handlers:
4✔
30
            return self.function_handlers[func_name](node, func_name)
4✔
NEW
31
        raise NotImplementedError(f"Python function {func_name} not supported")
×
32

33
    # Expose parent properties for convenience
34
    @property
4✔
35
    def builder(self):
4✔
36
        return self._ev.builder
4✔
37

38
    @property
4✔
39
    def symbol_table(self):
4✔
40
        return self._ev.symbol_table
4✔
41

42
    def _add_read(self, block, expr_str, debug_info=None):
4✔
43
        return self._ev._add_read(block, expr_str, debug_info)
4✔
44

45
    def _is_int(self, operand):
4✔
NEW
46
        return self._ev._is_int(operand)
×
47

48
    def visit(self, node):
4✔
49
        return self._ev.visit(node)
4✔
50

51
    def _handle_min_max(self, node, func_name):
4✔
52
        """Handle Python's built-in min() and max() functions."""
53
        args = [self.visit(arg) for arg in node.args]
4✔
54
        if len(args) != 2:
4✔
NEW
55
            raise NotImplementedError(f"{func_name} only supported with 2 arguments")
×
56

57
        # Check types
58
        is_float = False
4✔
59
        arg_types = []
4✔
60

61
        for arg in args:
4✔
62
            name = arg
4✔
63
            if "(" in arg and arg.endswith(")"):
4✔
NEW
64
                name = arg.split("(")[0]
×
65

66
            if name in self.symbol_table:
4✔
67
                t = self.symbol_table[name]
4✔
68
                if isinstance(t, Pointer):
4✔
NEW
69
                    t = t.base_type
×
70

71
                if t.primitive_type == PrimitiveType.Double:
4✔
72
                    is_float = True
4✔
73
                    arg_types.append(PrimitiveType.Double)
4✔
74
                else:
75
                    arg_types.append(PrimitiveType.Int64)
4✔
NEW
76
            elif self._is_int(arg):
×
NEW
77
                arg_types.append(PrimitiveType.Int64)
×
78
            else:
79
                # Assume float constant
NEW
80
                is_float = True
×
NEW
81
                arg_types.append(PrimitiveType.Double)
×
82

83
        dtype = Scalar(PrimitiveType.Double if is_float else PrimitiveType.Int64)
4✔
84

85
        tmp_name = self.builder.find_new_name("_tmp_")
4✔
86
        self.builder.add_container(tmp_name, dtype, False)
4✔
87
        self.symbol_table[tmp_name] = dtype
4✔
88

89
        if is_float:
4✔
90
            # Cast args if necessary
91
            casted_args = []
4✔
92
            for i, arg in enumerate(args):
4✔
93
                if arg_types[i] != PrimitiveType.Double:
4✔
94
                    # Create temp double
95
                    tmp_cast = self.builder.find_new_name("_cast_")
4✔
96
                    self.builder.add_container(
4✔
97
                        tmp_cast, Scalar(PrimitiveType.Double), False
98
                    )
99
                    self.symbol_table[tmp_cast] = Scalar(PrimitiveType.Double)
4✔
100

101
                    # Assign int to double (implicit cast)
102
                    self.builder.add_assignment(tmp_cast, arg)
4✔
103
                    casted_args.append(tmp_cast)
4✔
104
                else:
105
                    casted_args.append(arg)
4✔
106

107
            block = self.builder.add_block()
4✔
108
            t_out = self.builder.add_access(block, tmp_name)
4✔
109

110
            intrinsic_name = (
4✔
111
                CMathFunction.fmax if func_name == "max" else CMathFunction.fmin
112
            )
113
            t_task = self.builder.add_cmath(block, intrinsic_name)
4✔
114

115
            for i, arg in enumerate(casted_args):
4✔
116
                t_arg, arg_sub = self._add_read(block, arg)
4✔
117
                self.builder.add_memlet(
4✔
118
                    block, t_arg, "void", t_task, f"_in{i+1}", arg_sub
119
                )
120
        else:
121
            block = self.builder.add_block()
4✔
122
            t_out = self.builder.add_access(block, tmp_name)
4✔
123

124
            # Use int_smax/int_smin tasklet
125
            opcode = None
4✔
126
            if func_name == "max":
4✔
127
                opcode = TaskletCode.int_smax
4✔
128
            else:
129
                opcode = TaskletCode.int_smin
4✔
130
            t_task = self.builder.add_tasklet(block, opcode, ["_in1", "_in2"], ["_out"])
4✔
131

132
            for i, arg in enumerate(args):
4✔
133
                t_arg, arg_sub = self._add_read(block, arg)
4✔
134
                self.builder.add_memlet(
4✔
135
                    block, t_arg, "void", t_task, f"_in{i+1}", arg_sub
136
                )
137

138
        self.builder.add_memlet(block, t_task, "_out", t_out, "void", "")
4✔
139
        return tmp_name
4✔
140

141
    def _handle_python_cast(self, node, func_name):
4✔
142
        """Handle Python type casts: int(), float(), bool()"""
143
        if len(node.args) != 1:
4✔
NEW
144
            raise NotImplementedError(f"{func_name}() cast requires exactly 1 argument")
×
145

146
        arg = self.visit(node.args[0])
4✔
147

148
        # Determine target type based on cast function
149
        if func_name == "int":
4✔
150
            target_dtype = Scalar(PrimitiveType.Int64)
4✔
151
        elif func_name == "float":
4✔
152
            target_dtype = Scalar(PrimitiveType.Double)
4✔
153
        elif func_name == "bool":
4✔
154
            target_dtype = Scalar(PrimitiveType.Bool)
4✔
155
        else:
NEW
156
            raise NotImplementedError(f"Cast to {func_name} not supported")
×
157

158
        # Determine source type
159
        source_dtype = None
4✔
160
        name = arg
4✔
161
        if "(" in arg and arg.endswith(")"):
4✔
NEW
162
            name = arg.split("(")[0]
×
163

164
        if name in self.symbol_table:
4✔
165
            source_dtype = self.symbol_table[name]
4✔
166
            if isinstance(source_dtype, Pointer):
4✔
NEW
167
                source_dtype = source_dtype.base_type
×
NEW
168
        elif self._is_int(arg):
×
NEW
169
            source_dtype = Scalar(PrimitiveType.Int64)
×
NEW
170
        elif arg == "true" or arg == "false":
×
NEW
171
            source_dtype = Scalar(PrimitiveType.Bool)
×
172
        else:
173
            # Assume float constant
NEW
174
            source_dtype = Scalar(PrimitiveType.Double)
×
175

176
        # Create temporary variable for result
177
        tmp_name = self.builder.find_new_name("_tmp_")
4✔
178
        self.builder.add_container(tmp_name, target_dtype, False)
4✔
179
        self.symbol_table[tmp_name] = target_dtype
4✔
180

181
        # Use tasklet assign opcode for casting (as specified in problem statement)
182
        block = self.builder.add_block()
4✔
183
        t_src, src_sub = self._add_read(block, arg)
4✔
184
        t_dst = self.builder.add_access(block, tmp_name)
4✔
185
        t_task = self.builder.add_tasklet(block, TaskletCode.assign, ["_in"], ["_out"])
4✔
186
        self.builder.add_memlet(block, t_src, "void", t_task, "_in", src_sub)
4✔
187
        self.builder.add_memlet(block, t_task, "_out", t_dst, "void", "")
4✔
188

189
        return tmp_name
4✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc