• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

daisytuner / docc / 22023884668

14 Feb 2026 08:36PM UTC coverage: 64.903% (-1.4%) from 66.315%
22023884668

Pull #525

github

web-flow
Merge 1d47f8bf2 into 9d01cacd5
Pull Request #525: Step 3 (Native Tensor Support): Refactor Python Frontend

2522 of 3435 new or added lines in 32 files covered. (73.42%)

320 existing lines in 15 files now uncovered.

23204 of 35752 relevant lines covered (64.9%)

370.03 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

33.72
/python/docc/python/functions/scipy.py
1
import ast
4✔
2
from docc.sdfg import Scalar, PrimitiveType, Pointer
4✔
3
from docc.python.ast_utils import get_debug_info
4✔
4

5

6
class SciPyHandler:
4✔
7
    """Handler for SciPy functions (scipy.special, scipy.signal, etc.)."""
8

9
    def __init__(self, expression_visitor):
4✔
10
        self._ev = expression_visitor
4✔
11
        # Nested structure: submodule -> {func_name -> handler}
12
        self.function_handlers = {
4✔
13
            "special": {
14
                "softmax": self._handle_softmax,
15
            },
16
            "signal": {
17
                "correlate2d": self._handle_correlate2d_expr,
18
            },
19
        }
20

21
    def has_handler(self, submodule, func_name):
4✔
22
        """Check if this handler can handle the given submodule.func_name."""
23
        return (
4✔
24
            submodule in self.function_handlers
25
            and func_name in self.function_handlers[submodule]
26
        )
27

28
    def handle_scipy_call(self, node, submodule, func_name):
4✔
29
        """Handle a call to a SciPy function."""
30
        if self.has_handler(submodule, func_name):
4✔
31
            return self.function_handlers[submodule][func_name](node, func_name)
4✔
NEW
32
        raise NotImplementedError(
×
33
            f"SciPy function scipy.{submodule}.{func_name} not supported"
34
        )
35

36
    # Expose parent properties for convenience
37
    @property
4✔
38
    def array_info(self):
4✔
39
        return self._ev.array_info
4✔
40

41
    @property
4✔
42
    def builder(self):
4✔
43
        return self._ev.builder
4✔
44

45
    @property
4✔
46
    def symbol_table(self):
4✔
NEW
47
        return self._ev.symbol_table
×
48

49
    def _get_unique_id(self):
4✔
NEW
50
        return self._ev._get_unique_id()
×
51

52
    def visit(self, node):
4✔
53
        return self._ev.visit(node)
4✔
54

55
    def _create_array_temp(self, shape, dtype):
4✔
56
        """Create a temporary array with the given shape and dtype."""
57
        return self._ev.numpy_visitor._create_array_temp(shape, dtype)
4✔
58

59
    # ========== scipy.special Functions ==========
60

61
    def _handle_softmax(self, node, func_name):
4✔
62
        """Handle scipy.special.softmax."""
63
        args = node.args
4✔
64
        keywords = {kw.arg: kw.value for kw in node.keywords}
4✔
65

66
        array_node = args[0]
4✔
67
        array_name = self.visit(array_node)
4✔
68

69
        if array_name not in self.array_info:
4✔
NEW
70
            raise ValueError(f"Softmax input must be an array, got {array_name}")
×
71

72
        input_shape = self.array_info[array_name]["shapes"]
4✔
73
        ndim = len(input_shape)
4✔
74

75
        axis = None
4✔
76
        if len(args) > 1:
4✔
NEW
77
            axis = args[1]
×
78
        elif "axis" in keywords:
4✔
79
            axis = keywords["axis"]
4✔
80

81
        axes = []
4✔
82
        if axis is None:
4✔
83
            axes = list(range(ndim))
4✔
84
        elif isinstance(axis, ast.Constant):
4✔
85
            val = axis.value
4✔
86
            if val < 0:
4✔
NEW
87
                val += ndim
×
88
            axes = [val]
4✔
NEW
89
        elif isinstance(axis, ast.Tuple):
×
NEW
90
            for elt in axis.elts:
×
NEW
91
                if isinstance(elt, ast.Constant):
×
NEW
92
                    val = elt.value
×
NEW
93
                    if val < 0:
×
NEW
94
                        val += ndim
×
NEW
95
                    axes.append(val)
×
NEW
96
        elif (
×
97
            isinstance(axis, ast.UnaryOp)
98
            and isinstance(axis.op, ast.USub)
99
            and isinstance(axis.operand, ast.Constant)
100
        ):
NEW
101
            val = -axis.operand.value
×
NEW
102
            if val < 0:
×
NEW
103
                val += ndim
×
NEW
104
            axes = [val]
×
105
        else:
NEW
106
            try:
×
NEW
107
                val = int(self.visit(axis))
×
NEW
108
                if val < 0:
×
NEW
109
                    val += ndim
×
NEW
110
                axes = [val]
×
NEW
111
            except:
×
NEW
112
                raise NotImplementedError("Dynamic axis not supported")
×
113

114
        dtype = Scalar(PrimitiveType.Double)
4✔
115

116
        tmp_name = self._create_array_temp(input_shape, dtype)
4✔
117

118
        self.builder.add_reduce_op(
4✔
119
            func_name, array_name, tmp_name, input_shape, axes, False
120
        )
121

122
        return tmp_name
4✔
123

124
    # ========== scipy.signal Functions ==========
125

126
    def is_correlate2d(self, node):
4✔
127
        """Check if a node represents a scipy.signal.correlate2d call."""
128
        if not isinstance(node, ast.Call):
4✔
129
            return False
4✔
130

131
        if isinstance(node.func, ast.Attribute):
4✔
132
            if node.func.attr == "correlate2d":
4✔
133
                return True
×
134
        elif isinstance(node.func, ast.Name):
4✔
135
            if node.func.id == "correlate2d":
4✔
136
                return True
×
137

138
        return False
4✔
139

140
    def handle_correlate2d(self, target, value_node):
4✔
141
        """Handle scipy.signal.correlate2d (2D correlation/convolution).
142

143
        Args:
144
            target: The assignment target (ast.Name or string)
145
            value_node: The correlate2d call node
146

147
        Returns:
148
            True if handled successfully, False otherwise
149
        """
NEW
150
        if not self.is_correlate2d(value_node):
×
151
            return False
×
152

153
        args = value_node.args
×
154
        if len(args) < 2:
×
155
            return False
×
156

157
        in1_node = args[0]
×
158
        in2_node = args[1]
×
159

NEW
160
        in1_name = self.visit(in1_node)
×
NEW
161
        in2_name = self.visit(in2_node)
×
162

163
        if in1_name not in self.array_info:
×
164
            return False
×
165
        if in2_name not in self.array_info:
×
166
            return False
×
167

168
        in1_info = self.array_info[in1_name]
×
169
        in2_info = self.array_info[in2_name]
×
170

171
        # Check dimensions
172
        if in1_info["ndim"] != 2 or in2_info["ndim"] != 2:
×
173
            raise NotImplementedError(
×
174
                "Only 2D convolution is currently supported via scipy.signal mapping"
175
            )
176

177
        in1_shape = in1_info["shapes"]
×
178
        in2_shape = in2_info["shapes"]
×
179

180
        # Scipy Correlate2d / Convolve2d
181
        # Default mode is 'full', boundary 'fill', fillvalue 0
182

183
        mode = "full"
×
184
        # Parse kwargs
185
        for keyword in value_node.keywords:
×
186
            if keyword.arg == "mode" and isinstance(keyword.value, ast.Constant):
×
187
                mode = keyword.value.value
×
188

189
        # Also check positional args for mode
190
        if len(args) > 2 and isinstance(args[2], ast.Constant):
×
191
            mode = args[2].value
×
192

193
        if mode != "valid" and mode != "full" and mode != "same":
×
194
            raise NotImplementedError(f"Unsupported convolution mode: {mode}")
×
195

196
        # Map to ConvNode
197
        # Treat as N=1, C_in=1, C_out=1
198

199
        shape_strs = ["1", "1"] + [str(s) for s in in1_shape]
×
200
        kernel_shape_strs = [str(s) for s in in2_shape]
×
201

202
        # Default strides 1
203
        strides = ["1", "1"]
×
204
        dilations = ["1", "1"]
×
205
        group = "1"
×
206
        output_channels = "1"
×
207

208
        pads = ["0", "0", "0", "0"]
×
209

210
        if mode == "valid":
×
211
            pads = ["0", "0", "0", "0"]
×
212
        elif mode == "full":
×
213
            # Padding is kernel_size - 1 on both sides
214
            h_k = kernel_shape_strs[0]
×
215
            w_k = kernel_shape_strs[1]
×
216
            pad_h = f"({h_k} - 1)"
×
217
            pad_w = f"({w_k} - 1)"
×
218
            pads = [pad_h, pad_w, pad_h, pad_w]
×
219
        elif mode == "same":
×
220
            # Padding is kernel_size // 2
221
            h_k = kernel_shape_strs[0]
×
222
            w_k = kernel_shape_strs[1]
×
223
            pad_h = f"idiv({h_k}, 2)"
×
224
            pad_w = f"idiv({w_k}, 2)"
×
225
            pads = [pad_h, pad_w, pad_h, pad_w]
×
226

227
        target_name = ""
×
228
        if isinstance(target, ast.Name):
×
229
            target_name = target.id
×
230
        elif isinstance(target, str):
×
231
            target_name = target
×
232

233
        if not target_name:
×
234
            return False
×
235

236
        if self.builder.exists(target_name):
×
237
            # Ensure shape is inferred
238
            pass
×
239
        else:
240
            # Infer shape
241
            out_shape = []
×
242
            H1 = str(in1_shape[0])
×
243
            W1 = str(in1_shape[1])
×
244
            H2 = str(in2_shape[0])
×
245
            W2 = str(in2_shape[1])
×
246

247
            if mode == "valid":
×
248
                out_shape = [f"({H1} - {H2} + 1)", f"({W1} - {W2} + 1)"]
×
249
            elif mode == "same":
×
250
                out_shape = [H1, W1]
×
251
            elif mode == "full":
×
252
                out_shape = [f"({H1} + {H2} - 1)", f"({W1} + {W2} - 1)"]
×
253

254
            # Use Double type (float)
255
            dtype = Scalar(PrimitiveType.Double)
×
256
            ptr_type = Pointer(dtype)
×
257

258
            self.builder.add_container(target_name, ptr_type, False)
×
259

260
            # Update parser state
261
            self.symbol_table[target_name] = ptr_type
×
262
            self.array_info[target_name] = {"ndim": 2, "shapes": out_shape}
×
263

264
            # Allocate memory for the result
265
            block_alloc = self.builder.add_block()
×
266

267
            # Calculate size: shape[0] * shape[1] * sizeof(double)
268
            # Assuming double (8 bytes)
269
            size_expr = f"(({out_shape[0]}) * ({out_shape[1]}))"
×
270
            total_size_expr = f"({size_expr} * 8)"
×
271

272
            t_malloc = self.builder.add_malloc(block_alloc, total_size_expr)
×
273
            t_ptr = self.builder.add_access(block_alloc, target_name)
×
274
            self.builder.add_memlet(
×
275
                block_alloc, t_malloc, "_ret", t_ptr, "void", "", ptr_type
276
            )
277

278
        debug_info = get_debug_info(
×
279
            value_node, getattr(self.builder, "filename", ""), ""
280
        )
281

282
        self.builder.add_conv(
×
283
            in1_name,
284
            in2_name,
285
            target_name,
286
            shape_strs,
287
            kernel_shape_strs,
288
            strides,
289
            pads,
290
            dilations,
291
            output_channels,
292
            group,
293
            debug_info,
294
        )
295
        return True
×
296

297
    def _handle_correlate2d_expr(self, node, func_name):
4✔
298
        """Handle scipy.signal.correlate2d as an expression (creates temp array).
299

300
        This wrapper is used when correlate2d appears in an expression context
301
        rather than a direct assignment.
302
        """
303
        # Create a temporary name for the result
NEW
304
        tmp_name = self.builder.find_new_name("_corr2d_")
×
305
        # Delegate to the main handler
NEW
306
        success = self.handle_correlate2d(tmp_name, node)
×
NEW
307
        if not success:
×
NEW
308
            raise NotImplementedError("Failed to handle correlate2d expression")
×
NEW
309
        return tmp_name
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc