• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

daisytuner / sdfglib / 21113623600

18 Jan 2026 02:50PM UTC coverage: 64.425% (+0.3%) from 64.154%
21113623600

Pull #462

github

web-flow
Merge d503e5691 into 92e9cbdc3
Pull Request #462: adds syntax support for multi-assignments and np.empty_like

221 of 258 new or added lines in 5 files covered. (85.66%)

21 existing lines in 4 files now uncovered.

19678 of 30544 relevant lines covered (64.43%)

385.56 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

80.54
/python/docc/compiled_sdfg.py
1
import ctypes
3✔
2
from ._sdfg import Scalar, Array, Pointer, Structure, PrimitiveType
3✔
3

4
try:
3✔
5
    import numpy as np
3✔
6
except ImportError:
×
7
    np = None
×
8

9
_CTYPES_MAP = {
3✔
10
    PrimitiveType.Bool: ctypes.c_bool,
11
    PrimitiveType.Int8: ctypes.c_int8,
12
    PrimitiveType.Int16: ctypes.c_int16,
13
    PrimitiveType.Int32: ctypes.c_int32,
14
    PrimitiveType.Int64: ctypes.c_int64,
15
    PrimitiveType.UInt8: ctypes.c_uint8,
16
    PrimitiveType.UInt16: ctypes.c_uint16,
17
    PrimitiveType.UInt32: ctypes.c_uint32,
18
    PrimitiveType.UInt64: ctypes.c_uint64,
19
    PrimitiveType.Float: ctypes.c_float,
20
    PrimitiveType.Double: ctypes.c_double,
21
}
22

23

24
class CompiledSDFG:
3✔
25
    def __init__(
3✔
26
        self,
27
        lib_path,
28
        sdfg,
29
        shape_sources=None,
30
        structure_member_info=None,
31
        output_args=None,
32
        output_shapes=None,
33
    ):
34
        self.lib_path = lib_path
3✔
35
        self.sdfg = sdfg
3✔
36
        self.shape_sources = shape_sources or []
3✔
37
        self.structure_member_info = structure_member_info or {}
3✔
38
        self.lib = ctypes.CDLL(lib_path)
3✔
39
        self.func = getattr(self.lib, sdfg.name)
3✔
40

41
        # Check for output args
42
        self.output_args = output_args or []
3✔
43
        if not self.output_args and hasattr(sdfg, "metadata"):
3✔
44
            out_args_str = sdfg.metadata("output_args")
3✔
45
            if out_args_str:
3✔
NEW
46
                self.output_args = out_args_str.split(",")
×
47

48
        self.output_shapes = output_shapes or {}
3✔
49

50
        # Cache for ctypes structure definitions
51
        self._ctypes_structures = {}
3✔
52

53
        # Set up argument types
54
        self.arg_names = sdfg.arguments
3✔
55
        self.arg_types = []
3✔
56
        self.arg_sdfg_types = []  # Keep track of original sdfg types
3✔
57
        for arg_name in sdfg.arguments:
3✔
58
            arg_type = sdfg.type(arg_name)
3✔
59
            self.arg_sdfg_types.append(arg_type)
3✔
60
            ct_type = self._get_ctypes_type(arg_type)
3✔
61
            self.arg_types.append(ct_type)
3✔
62

63
        self.func.argtypes = self.arg_types
3✔
64

65
        # Set up return type
66
        self.func.restype = self._get_ctypes_type(sdfg.return_type)
3✔
67

68
    def _create_ctypes_structure(self, struct_name):
3✔
69
        """Create a ctypes Structure class for the given structure name."""
70
        if struct_name in self._ctypes_structures:
3✔
71
            return self._ctypes_structures[struct_name]
×
72

73
        if struct_name not in self.structure_member_info:
3✔
74
            raise ValueError(f"Structure '{struct_name}' not found in member info")
×
75

76
        # Get member info: {member_name: (index, type)}
77
        members = self.structure_member_info[struct_name]
3✔
78
        # Sort by index to get correct order
79
        sorted_members = sorted(members.items(), key=lambda x: x[1][0])
3✔
80

81
        # Build _fields_ for ctypes.Structure
82
        fields = []
3✔
83
        for member_name, (index, member_type) in sorted_members:
3✔
84
            ct_type = self._get_ctypes_type(member_type)
3✔
85
            fields.append((member_name, ct_type))
3✔
86

87
        # Create the ctypes Structure class dynamically
88
        class CStructure(ctypes.Structure):
3✔
89
            _fields_ = fields
3✔
90

91
        self._ctypes_structures[struct_name] = CStructure
3✔
92
        return CStructure
3✔
93

94
    def _get_ctypes_type(self, sdfg_type):
3✔
95
        if isinstance(sdfg_type, Scalar):
3✔
96
            return _CTYPES_MAP.get(sdfg_type.primitive_type, ctypes.c_void_p)
3✔
97
        elif isinstance(sdfg_type, Array):
3✔
98
            # Arrays are passed as pointers
99
            elem_type = _CTYPES_MAP.get(sdfg_type.primitive_type, ctypes.c_void_p)
×
100
            return ctypes.POINTER(elem_type)
×
101
        elif isinstance(sdfg_type, Pointer):
3✔
102
            # Check if pointee is a Structure
103
            # Note: has_pointee_type() is guaranteed to exist on Pointer instances from C++ bindings
104
            if sdfg_type.has_pointee_type():
3✔
105
                pointee = sdfg_type.pointee_type
3✔
106
                if isinstance(pointee, Structure):
3✔
107
                    # Create ctypes structure and return pointer to it
108
                    struct_class = self._create_ctypes_structure(pointee.name)
3✔
109
                    return ctypes.POINTER(struct_class)
3✔
110
                elif isinstance(pointee, Scalar):
3✔
111
                    elem_type = _CTYPES_MAP.get(pointee.primitive_type, ctypes.c_void_p)
3✔
112
                    return ctypes.POINTER(elem_type)
3✔
113
            return ctypes.c_void_p
×
114
        return ctypes.c_void_p
×
115

116
    def __call__(self, *args):
3✔
117
        # Identify user arguments vs implicit arguments (shapes, return values)
118

119
        # 1. Compute shape symbol values from user args input
120
        shape_symbol_values = {}
3✔
121
        for u_idx, dim_idx in self.shape_sources:
3✔
122
            if u_idx < len(args):
3✔
123
                val = args[u_idx].shape[dim_idx]
3✔
124
                s_idx = self.shape_sources.index((u_idx, dim_idx))
3✔
125
                shape_symbol_values[f"_s{s_idx}"] = val
3✔
126

127
        param_arg_idx = 0
3✔
128
        for name in self.arg_names:
3✔
129
            if name in self.output_args:
3✔
130
                continue
3✔
131
            if name.startswith("_s") and name[2:].isdigit():
3✔
132
                continue
3✔
133

134
            # Must be a user parameter
135
            if param_arg_idx < len(args):
3✔
136
                val = args[param_arg_idx]
3✔
137
                if isinstance(val, (int, float, np.integer, np.floating)):
3✔
138
                    shape_symbol_values[name] = val
3✔
139
                param_arg_idx += 1
3✔
140

141
        converted_args = []
3✔
142
        structure_refs = []
3✔
143
        return_buffers = {}
3✔
144

145
        next_user_arg_idx = 0
3✔
146

147
        for i, arg_name in enumerate(self.arg_names):
3✔
148
            target_type = self.arg_types[i]
3✔
149

150
            if arg_name in self.output_args:
3✔
151
                base_type = target_type._type_
3✔
152

153
                # If array (pointer type) and we have shape info, we need to allocate array.
154
                # If not in output_shapes, assume scalar return (pointer to single value).
155
                if arg_name in self.output_shapes:
3✔
156
                    size = 1
3✔
157
                    dims = self.output_shapes[arg_name]
3✔
158
                    # Evaluate
159
                    for dim_str in dims:
3✔
160
                        try:
3✔
161
                            val = eval(str(dim_str), {}, shape_symbol_values)
3✔
162
                            size *= int(val)
3✔
NEW
163
                        except Exception as e:
×
NEW
164
                            raise RuntimeError(
×
165
                                f"Could not evaluate shape {dim_str} for {arg_name}: {e}"
166
                            )
167

168
                    buf_type = base_type * size
3✔
169
                    buf = buf_type()
3✔
170
                    return_buffers[arg_name] = (buf, size, dims)
3✔
171
                    converted_args.append(
3✔
172
                        ctypes.cast(ctypes.addressof(buf), target_type)
173
                    )
174
                    continue
3✔
175

176
                # Scalar Return (Pointer(Scalar))
177
                buf = base_type()
3✔
178
                return_buffers[arg_name] = (buf, 1, None)
3✔
179
                converted_args.append(ctypes.byref(buf))
3✔
180
                continue
3✔
181

182
            if arg_name.startswith("_s") and arg_name[2:].isdigit():
3✔
183
                s_idx = int(arg_name[2:])
3✔
184
                if f"_s{s_idx}" in shape_symbol_values:
3✔
185
                    val = shape_symbol_values[f"_s{s_idx}"]
3✔
186
                    converted_args.append(ctypes.c_int64(val))
3✔
187
                else:
NEW
188
                    converted_args.append(ctypes.c_int64(0))
×
189
                continue
3✔
190

191
            # User Argument
192
            if next_user_arg_idx >= len(args):
3✔
NEW
193
                raise ValueError("Not enough arguments provided")
×
194

195
            arg = args[next_user_arg_idx]
3✔
196
            next_user_arg_idx += 1
3✔
197

198
            # ... Conversion logic (numpy to ctypes) ...
199
            sdfg_type = self.arg_sdfg_types[i]
3✔
200

201
            if np is not None and isinstance(arg, np.ndarray):
3✔
202
                if hasattr(target_type, "contents"):
3✔
203
                    converted_args.append(arg.ctypes.data_as(target_type))
3✔
204
                else:
205
                    converted_args.append(arg)
×
206
            elif (
3✔
207
                sdfg_type
208
                and isinstance(sdfg_type, Pointer)
209
                and sdfg_type.has_pointee_type()
210
                and isinstance(sdfg_type.pointee_type, Structure)
211
            ):
212
                # Struct logic
213
                struct_name = sdfg_type.pointee_type.name
3✔
214
                struct_class = self._ctypes_structures.get(struct_name)
3✔
215
                members = self.structure_member_info[struct_name]
3✔
216
                sorted_members = sorted(members.items(), key=lambda x: x[1][0])
3✔
217
                struct_values = {}
3✔
218
                for member_name, (index, member_type) in sorted_members:
3✔
219
                    if hasattr(arg, member_name):
3✔
220
                        struct_values[member_name] = getattr(arg, member_name)
3✔
221
                c_struct = struct_class(**struct_values)
3✔
222
                structure_refs.append(c_struct)
3✔
223
                converted_args.append(ctypes.pointer(c_struct))
3✔
224
            else:
225
                converted_args.append(
3✔
226
                    target_type(arg)
227
                )  # Explicit cast to ensure int stays int
228

229
        self.func(*converted_args)
3✔
230

231
        # Process returns
232
        results = []
3✔
233
        sorted_ret_names = sorted(
3✔
234
            return_buffers.keys(), key=lambda x: int(x.split("_")[-1])
235
        )
236

237
        for name in sorted_ret_names:
3✔
238
            buf, size, dims = return_buffers[name]
3✔
239
            if size == 1 and dims is None:
3✔
240
                # Scalar
241
                # buf is c_double / c_int instance
242
                results.append(buf.value)
3✔
243
            else:
244
                # Array
245
                # buf is (c_double * size) instance.
246
                # Convert to numpy
247
                if np is not None:
3✔
248
                    # Create numpy array from buffer
249
                    arr = np.ctypeslib.as_array(buf)  # 1D
3✔
250
                    if dims:
3✔
251
                        # Reshape
252
                        try:
3✔
253
                            shape = []
3✔
254
                            for dim_str in dims:
3✔
255
                                val = eval(str(dim_str), {}, shape_symbol_values)
3✔
256
                                shape.append(int(val))
3✔
257
                            arr = arr.reshape(shape)
3✔
NEW
258
                        except:
×
NEW
259
                            pass
×
260
                    results.append(arr)
3✔
261
                else:
262
                    # fallback list
NEW
263
                    results.append(list(buf))
×
264

265
        if len(results) == 1:
3✔
266
            return results[0]
3✔
267
        elif len(results) > 1:
3✔
268
            return tuple(results)
3✔
269

270
        return None
3✔
271

272
    def get_return_shape(self, *args):
3✔
273
        shape_str = self.sdfg.metadata("return_shape")
3✔
274
        if not shape_str:
3✔
275
            return None
3✔
276

UNCOV
277
        shape_exprs = shape_str.split(",")
×
278

279
        # Reconstruct shape values
UNCOV
280
        shape_values = {}
×
UNCOV
281
        for i, (arg_idx, dim_idx) in enumerate(self.shape_sources):
×
UNCOV
282
            arg = args[arg_idx]
×
UNCOV
283
            if np is not None and isinstance(arg, np.ndarray):
×
UNCOV
284
                val = arg.shape[dim_idx]
×
UNCOV
285
                shape_values[f"_s{i}"] = val
×
286

287
        # Add scalar arguments to shape_values
288
        # We assume the first len(args) arguments in sdfg.arguments correspond to the user arguments
UNCOV
289
        if hasattr(self.sdfg, "arguments"):
×
UNCOV
290
            for arg_name, arg_val in zip(self.sdfg.arguments, args):
×
UNCOV
291
                if isinstance(arg_val, (int, np.integer)):
×
UNCOV
292
                    shape_values[arg_name] = int(arg_val)
×
293

UNCOV
294
        evaluated_shape = []
×
UNCOV
295
        for expr in shape_exprs:
×
296
            # Simple evaluation using eval with shape_values
297
            # Warning: eval is unsafe, but here expressions come from our compiler
UNCOV
298
            try:
×
UNCOV
299
                val = eval(expr, {}, shape_values)
×
UNCOV
300
                evaluated_shape.append(int(val))
×
301
            except Exception:
×
302
                return None
×
303

UNCOV
304
        return tuple(evaluated_shape)
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc