• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

daisytuner / docc / 27902492495

21 Jun 2026 11:13AM UTC coverage: 61.856% (+0.008%) from 61.848%
27902492495

Pull #781

github

web-flow
Merge 0b9f1add0 into 9ce8041e7
Pull Request #781: Extend Segformer benchmarks setup

99 of 120 new or added lines in 8 files covered. (82.5%)

126 existing lines in 6 files now uncovered.

37062 of 59917 relevant lines covered (61.86%)

1017.43 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

77.6
/python/docc/compiler/compiled_sdfg.py
1
import ctypes
4✔
2
import re
4✔
3
import time
4✔
4
from docc.sdfg import Scalar, Array, Pointer, Structure, PrimitiveType
4✔
5

6
import numpy as np
4✔
7
import ml_dtypes
4✔
8

9

10
def idiv(a, b):
4✔
11
    """Integer division (floor division for positive numbers)."""
12
    return int(a) // int(b)
4✔
13

14

15
# Evaluation context for shape expressions
16
_EVAL_GLOBALS = {"idiv": idiv}
4✔
17

18
# Pre-compiled regex for _convert_to_python_syntax
19
_FUNC_CALL_PATTERN = re.compile(r"([a-zA-Z_][a-zA-Z0-9_]*)\(([^()]+)\)")
4✔
20
_PLACEHOLDER_PATTERN = re.compile(
4✔
21
    r"@@@FUNC@@@([a-zA-Z_][a-zA-Z0-9_]*)@@@(.+?)@@@END@@@"
22
)
23
_KNOWN_FUNCTIONS = frozenset(
4✔
24
    {"int", "float", "abs", "min", "max", "sum", "len", "idiv"}
25
)
26

27
# Argument type constants for fast dispatch
28
_ARG_TYPE_OUTPUT_ARRAY = 0
4✔
29
_ARG_TYPE_OUTPUT_SCALAR = 1
4✔
30
_ARG_TYPE_SHAPE = 2
4✔
31
_ARG_TYPE_USER_ARRAY = 3
4✔
32
_ARG_TYPE_USER_STRUCT = 4
4✔
33
_ARG_TYPE_USER_SCALAR = 5
4✔
34

35
# Pre-cache ctypes.c_int64 for speed
36
_c_int64 = ctypes.c_int64
4✔
37
_ctypes_cast = ctypes.cast
4✔
38
_ctypes_addressof = ctypes.addressof
4✔
39
_ctypes_byref = ctypes.byref
4✔
40
_ctypes_pointer = ctypes.pointer
4✔
41
_ctypes_c_void_p = ctypes.c_void_p
4✔
42

43
# Map primitive types to numpy dtypes for fast buffer allocation
44
_PRIMITIVE_TO_NP_DTYPE = {
4✔
45
    PrimitiveType.Float: np.float32,
46
    PrimitiveType.Double: np.float64,
47
    PrimitiveType.Int8: np.int8,
48
    PrimitiveType.Int16: np.int16,
49
    PrimitiveType.Int32: np.int32,
50
    PrimitiveType.Int64: np.int64,
51
    PrimitiveType.UInt8: np.uint8,
52
    PrimitiveType.UInt16: np.uint16,
53
    PrimitiveType.UInt32: np.uint32,
54
    PrimitiveType.UInt64: np.uint64,
55
    PrimitiveType.Bool: np.bool_,
56
    PrimitiveType.Half: np.float16,
57
    PrimitiveType.BFloat: ml_dtypes.bfloat16,
58
}
59

60
# Pre-computed dtype map for numpy conversion
61
_NUMPY_DTYPE_MAP = {
4✔
62
    PrimitiveType.Float: np.float32,
63
    PrimitiveType.Double: np.float64,
64
    PrimitiveType.Int8: np.int8,
65
    PrimitiveType.Int16: np.int16,
66
    PrimitiveType.Int32: np.int32,
67
    PrimitiveType.Int64: np.int64,
68
    PrimitiveType.UInt8: np.uint8,
69
    PrimitiveType.UInt16: np.uint16,
70
    PrimitiveType.UInt32: np.uint32,
71
    PrimitiveType.UInt64: np.uint64,
72
    PrimitiveType.Bool: np.bool_,
73
    PrimitiveType.Half: np.float16,
74
    PrimitiveType.BFloat: ml_dtypes.bfloat16,
75
}
76

77
_CTYPES_MAP = {
4✔
78
    PrimitiveType.Bool: ctypes.c_bool,
79
    PrimitiveType.Int8: ctypes.c_int8,
80
    PrimitiveType.Int16: ctypes.c_int16,
81
    PrimitiveType.Int32: ctypes.c_int32,
82
    PrimitiveType.Int64: ctypes.c_int64,
83
    PrimitiveType.UInt8: ctypes.c_uint8,
84
    PrimitiveType.UInt16: ctypes.c_uint16,
85
    PrimitiveType.UInt32: ctypes.c_uint32,
86
    PrimitiveType.UInt64: ctypes.c_uint64,
87
    PrimitiveType.Float: ctypes.c_float,
88
    PrimitiveType.Double: ctypes.c_double,
89
    # Half and BFloat are 2 bytes, use c_uint16 for raw storage
90
    PrimitiveType.Half: ctypes.c_uint16,
91
    PrimitiveType.BFloat: ctypes.c_uint16,
92
}
93

94

95
class CompiledSDFG:
4✔
96
    def __init__(
4✔
97
        self,
98
        lib_path,
99
        sdfg,
100
        shape_sources=None,
101
        structure_member_info=None,
102
        output_args=None,
103
        output_shapes=None,
104
        output_strides=None,
105
    ):
106
        self.lib_path = lib_path
4✔
107
        self.sdfg = sdfg
4✔
108
        self.shape_sources = shape_sources or []
4✔
109
        self.structure_member_info = structure_member_info or {}
4✔
110
        self.lib = ctypes.CDLL(lib_path)
4✔
111
        self.func = getattr(self.lib, sdfg.name)
4✔
112

113
        # Check for output args
114
        self.output_args = output_args or []
4✔
115
        if not self.output_args and hasattr(sdfg, "metadata"):
4✔
116
            out_args_str = sdfg.metadata("output_args")
4✔
117
            if out_args_str:
4✔
NEW
118
                self.output_args = out_args_str.split(",")
×
119

120
        self.output_shapes = output_shapes or {}
4✔
121
        self.output_strides = output_strides or {}
4✔
122

123
        # Cache for ctypes structure definitions
124
        self._ctypes_structures = {}
4✔
125

126
        # Set up argument types
127
        self.arg_names = sdfg.arguments
4✔
128
        self.arg_types = []
4✔
129
        self.arg_sdfg_types = []  # Keep track of original sdfg types
4✔
130
        for arg_name in sdfg.arguments:
4✔
131
            arg_type = sdfg.type(arg_name)
4✔
132
            self.arg_sdfg_types.append(arg_type)
4✔
133
            ct_type = self._get_ctypes_type(arg_type)
4✔
134
            self.arg_types.append(ct_type)
4✔
135

136
        self.func.argtypes = self.arg_types
4✔
137

138
        # Set up return type
139
        self.func.restype = self._get_ctypes_type(sdfg.return_type)
4✔
140

141
        # Pre-compute argument classification for fast __call__
142
        self._precompute_arg_metadata()
4✔
143

144
    def _precompute_arg_metadata(self):
4✔
145
        """Pre-compute argument metadata for fast __call__ dispatch."""
146
        output_args_set = set(self.output_args)
4✔
147

148
        # Build shape source lookup: s_idx -> (u_idx, dim_idx)
149
        # Also pre-compute the shape keys
150
        self._shape_sources_list = []  # [(s_idx, u_idx, dim_idx, key_str), ...]
4✔
151
        for i, (u_idx, dim_idx) in enumerate(self.shape_sources):
4✔
152
            self._shape_sources_list.append((i, u_idx, dim_idx, f"_s{i}"))
4✔
153

154
        # Classify each argument using tuple-based info for faster access
155
        # Each entry is (arg_type, *type_specific_data)
156
        self._arg_info = []
4✔
157
        user_arg_counter = 0
4✔
158

159
        # For output ordering (avoid sorting at runtime)
160
        output_order = []
4✔
161

162
        for i, arg_name in enumerate(self.arg_names):
4✔
163
            if arg_name in output_args_set:
4✔
164
                # Output argument
165
                target_type = self.arg_types[i]
4✔
166
                base_type = target_type._type_
4✔
167
                sdfg_type = self.arg_sdfg_types[i]
4✔
168

169
                # Pre-compute primitive type for return processing
170
                primitive_type = None
4✔
171
                if isinstance(sdfg_type, Pointer) and sdfg_type.has_pointee_type():
4✔
172
                    pointee = sdfg_type.pointee_type
4✔
173
                    if isinstance(pointee, Scalar):
4✔
174
                        primitive_type = pointee.primitive_type
4✔
175

176
                if arg_name in self.output_shapes:
4✔
177
                    dims = self.output_shapes[arg_name]
4✔
178
                    # Always compile shape expressions - they may depend on runtime values
179
                    compiled_dims = []
4✔
180
                    for d in dims:
4✔
181
                        d_str = str(d)
4✔
182
                        expr = self._convert_to_python_syntax(d_str)
4✔
183
                        compiled_dims.append(compile(expr, "<shape>", "eval"))
4✔
184

185
                    # Pre-compile stride expressions if available
186
                    compiled_strides = None
4✔
187
                    if arg_name in self.output_strides:
4✔
188
                        compiled_strides = []
4✔
189
                        for s in self.output_strides[arg_name]:
4✔
190
                            expr = self._convert_to_python_syntax(str(s))
4✔
191
                            compiled_strides.append(compile(expr, "<stride>", "eval"))
4✔
192

193
                    # Get numpy dtype for fast allocation
194
                    np_dtype = (
4✔
195
                        _PRIMITIVE_TO_NP_DTYPE.get(primitive_type, np.float64)
196
                        if primitive_type
197
                        else np.float64
198
                    )
199

200
                    # Tuple: (arg_type, name, base_type, target_type, compiled_dims, compiled_strides, primitive_type, np_dtype)
201
                    info_idx = len(self._arg_info)
4✔
202
                    self._arg_info.append(
4✔
203
                        (
204
                            _ARG_TYPE_OUTPUT_ARRAY,
205
                            arg_name,
206
                            base_type,
207
                            target_type,
208
                            compiled_dims,
209
                            compiled_strides,
210
                            primitive_type,
211
                            np_dtype,
212
                        )
213
                    )
214
                    output_order.append((int(arg_name.split("_")[-1]), info_idx))
4✔
215
                else:
216
                    # Scalar return
217
                    info_idx = len(self._arg_info)
4✔
218
                    self._arg_info.append(
4✔
219
                        (_ARG_TYPE_OUTPUT_SCALAR, arg_name, base_type, primitive_type)
220
                    )
221
                    output_order.append((int(arg_name.split("_")[-1]), info_idx))
4✔
222

223
            elif arg_name.startswith("_s") and arg_name[2:].isdigit():
4✔
224
                # Shape symbol argument - tuple: (arg_type, s_idx, key_str)
225
                s_idx = int(arg_name[2:])
4✔
226
                self._arg_info.append((_ARG_TYPE_SHAPE, s_idx, f"_s{s_idx}"))
4✔
227
            else:
228
                # User argument
229
                sdfg_type = self.arg_sdfg_types[i]
4✔
230
                target_type = self.arg_types[i]
4✔
231
                is_struct_ptr = (
4✔
232
                    sdfg_type
233
                    and isinstance(sdfg_type, Pointer)
234
                    and sdfg_type.has_pointee_type()
235
                    and isinstance(sdfg_type.pointee_type, Structure)
236
                )
237

238
                if is_struct_ptr:
4✔
239
                    struct_name = sdfg_type.pointee_type.name
4✔
240
                    struct_class = self._create_ctypes_structure(struct_name)
4✔
241
                    members = self.structure_member_info[struct_name]
4✔
242
                    sorted_members = tuple(
4✔
243
                        sorted(members.items(), key=lambda x: x[1][0])
244
                    )
245
                    # Tuple: (arg_type, user_idx, name, struct_class, sorted_members)
246
                    self._arg_info.append(
4✔
247
                        (
248
                            _ARG_TYPE_USER_STRUCT,
249
                            user_arg_counter,
250
                            arg_name,
251
                            struct_class,
252
                            sorted_members,
253
                        )
254
                    )
255
                elif hasattr(target_type, "contents"):
4✔
256
                    # Array user arg - tuple: (arg_type, user_idx, name, target_type)
257
                    self._arg_info.append(
4✔
258
                        (_ARG_TYPE_USER_ARRAY, user_arg_counter, arg_name, target_type)
259
                    )
260
                else:
261
                    # Scalar user arg - tuple: (arg_type, user_idx, name, target_type)
262
                    self._arg_info.append(
4✔
263
                        (_ARG_TYPE_USER_SCALAR, user_arg_counter, arg_name, target_type)
264
                    )
265
                user_arg_counter += 1
4✔
266

267
        self._num_user_args = user_arg_counter
4✔
268

269
        # Pre-sort output order and build position map
270
        output_order.sort(key=lambda x: x[0])
4✔
271
        self._output_order = tuple(idx for _, idx in output_order)
4✔
272
        # Map from _arg_info index to result position (for O(1) lookup)
273
        self._output_pos_map = {idx: pos for pos, idx in enumerate(self._output_order)}
4✔
274

275
    def _convert_to_python_syntax(self, expr_str):
4✔
276
        result = expr_str
4✔
277

278
        while True:
4✔
279
            match = _FUNC_CALL_PATTERN.search(result)
4✔
280
            if not match:
4✔
281
                break
4✔
282

283
            name = match.group(1)
4✔
284
            index = match.group(2)
4✔
285

286
            if name.lower() in _KNOWN_FUNCTIONS:
4✔
287
                # Use unique delimiters that won't appear in expressions
288
                placeholder = f"@@@FUNC@@@{name}@@@{index}@@@END@@@"
4✔
289
                result = result[: match.start()] + placeholder + result[match.end() :]
4✔
290
            else:
UNCOV
291
                result = (
×
292
                    result[: match.start()] + f"{name}[{index}]" + result[match.end() :]
293
                )
294

295
        result = _PLACEHOLDER_PATTERN.sub(r"\1(\2)", result)
4✔
296

297
        return result
4✔
298

299
    def _create_ctypes_structure(self, struct_name):
4✔
300
        """Create a ctypes Structure class for the given structure name."""
301
        if struct_name in self._ctypes_structures:
4✔
302
            return self._ctypes_structures[struct_name]
4✔
303

304
        if struct_name not in self.structure_member_info:
4✔
UNCOV
305
            raise ValueError(f"Structure '{struct_name}' not found in member info")
×
306

307
        # Get member info: {member_name: (index, type)}
308
        members = self.structure_member_info[struct_name]
4✔
309
        # Sort by index to get correct order
310
        sorted_members = sorted(members.items(), key=lambda x: x[1][0])
4✔
311

312
        # Build _fields_ for ctypes.Structure
313
        fields = []
4✔
314
        for member_name, (index, member_type) in sorted_members:
4✔
315
            ct_type = self._get_ctypes_type(member_type)
4✔
316
            fields.append((member_name, ct_type))
4✔
317

318
        # Create the ctypes Structure class dynamically
319
        class CStructure(ctypes.Structure):
4✔
320
            _fields_ = fields
4✔
321

322
        self._ctypes_structures[struct_name] = CStructure
4✔
323
        return CStructure
4✔
324

325
    def _get_ctypes_type(self, sdfg_type):
4✔
326
        if isinstance(sdfg_type, Scalar):
4✔
327
            return _CTYPES_MAP.get(sdfg_type.primitive_type, ctypes.c_void_p)
4✔
328
        elif isinstance(sdfg_type, Array):
4✔
329
            # Arrays are passed as pointers
UNCOV
330
            elem_type = _CTYPES_MAP.get(sdfg_type.primitive_type, ctypes.c_void_p)
×
UNCOV
331
            return ctypes.POINTER(elem_type)
×
332
        elif isinstance(sdfg_type, Pointer):
4✔
333
            # Check if pointee is a Structure
334
            # Note: has_pointee_type() is guaranteed to exist on Pointer instances from C++ bindings
335
            if sdfg_type.has_pointee_type():
4✔
336
                pointee = sdfg_type.pointee_type
4✔
337
                if isinstance(pointee, Structure):
4✔
338
                    # Create ctypes structure and return pointer to it
339
                    struct_class = self._create_ctypes_structure(pointee.name)
4✔
340
                    return ctypes.POINTER(struct_class)
4✔
341
                elif isinstance(pointee, Scalar):
4✔
342
                    elem_type = _CTYPES_MAP.get(pointee.primitive_type, ctypes.c_void_p)
4✔
343
                    return ctypes.POINTER(elem_type)
4✔
UNCOV
344
            return ctypes.c_void_p
×
UNCOV
345
        return ctypes.c_void_p
×
346

347
    def _convert_return_value(self, func_result, shape_symbol_values):
4✔
348
        return_type = self.sdfg.return_type
4✔
349

350
        if isinstance(return_type, Pointer):
4✔
UNCOV
351
            if return_type.has_pointee_type():
×
UNCOV
352
                pointee = return_type.pointee_type
×
UNCOV
353
                if isinstance(pointee, Scalar):
×
354
                    # Pointer to scalar element type - need to determine array size
355
                    # Get return shape from metadata if available
UNCOV
356
                    return_shape_str = self.sdfg.metadata("return_shape")
×
UNCOV
357
                    if return_shape_str:
×
358
                        # Strip brackets (metadata may be "[10,10]" format)
UNCOV
359
                        return_shape_str = return_shape_str.strip("[]")
×
UNCOV
360
                        shape = []
×
UNCOV
361
                        for dim_str in return_shape_str.split(","):
×
UNCOV
362
                            try:
×
UNCOV
363
                                eval_str = self._convert_to_python_syntax(str(dim_str))
×
UNCOV
364
                                val = eval(eval_str, _EVAL_GLOBALS, shape_symbol_values)
×
UNCOV
365
                                shape.append(int(val))
×
UNCOV
366
                            except Exception:
×
367
                                # Can't evaluate shape, return raw pointer
UNCOV
368
                                return func_result
×
369

370
                        # Determine numpy dtype from primitive type
UNCOV
371
                        dtype = _NUMPY_DTYPE_MAP.get(pointee.primitive_type, np.float64)
×
372

373
                        # Calculate total size
UNCOV
374
                        total_size = 1
×
UNCOV
375
                        for dim in shape:
×
UNCOV
376
                            total_size *= dim
×
377

378
                        # Create numpy array from pointer
UNCOV
379
                        ct_type = _CTYPES_MAP.get(
×
380
                            pointee.primitive_type, ctypes.c_double
381
                        )
UNCOV
382
                        arr_type = ct_type * total_size
×
383
                        # For Half/BFloat, use np.frombuffer since np.ctypeslib.as_array
384
                        # doesn't support these types (PEP 3118 buffer format limitation)
UNCOV
385
                        if pointee.primitive_type in (
×
386
                            PrimitiveType.Half,
387
                            PrimitiveType.BFloat,
388
                        ):
UNCOV
389
                            byte_size = total_size * 2  # Half and BFloat are 2 bytes
×
UNCOV
390
                            arr = np.frombuffer(
×
391
                                (ctypes.c_char * byte_size).from_address(
392
                                    ctypes.cast(func_result, ctypes.c_void_p).value
393
                                ),
394
                                dtype=dtype,
395
                            ).copy()
396
                        else:
UNCOV
397
                            arr = np.ctypeslib.as_array(
×
398
                                ctypes.cast(
399
                                    func_result, ctypes.POINTER(arr_type)
400
                                ).contents
401
                            )
UNCOV
402
                        return arr.reshape(shape)
×
403
                    else:
404
                        # No shape info - try to infer from input shapes
405
                        # For identity-like operations, the output shape matches input
UNCOV
406
                        if len(self.shape_sources) > 0 and len(shape_symbol_values) > 0:
×
407
                            # Use first input's shape as a fallback
UNCOV
408
                            shape = []
×
UNCOV
409
                            for i in range(len(self.shape_sources)):
×
UNCOV
410
                                if f"_s{i}" in shape_symbol_values:
×
UNCOV
411
                                    shape.append(shape_symbol_values[f"_s{i}"])
×
412

UNCOV
413
                            if shape:
×
UNCOV
414
                                dtype = _NUMPY_DTYPE_MAP.get(
×
415
                                    pointee.primitive_type, np.float64
416
                                )
417

UNCOV
418
                                total_size = 1
×
UNCOV
419
                                for dim in shape:
×
UNCOV
420
                                    total_size *= dim
×
421

UNCOV
422
                                ct_type = _CTYPES_MAP.get(
×
423
                                    pointee.primitive_type, ctypes.c_double
424
                                )
UNCOV
425
                                arr_type = ct_type * total_size
×
426
                                # For Half/BFloat, use np.frombuffer since np.ctypeslib.as_array
427
                                # doesn't support these types (PEP 3118 buffer format limitation)
UNCOV
428
                                if pointee.primitive_type in (
×
429
                                    PrimitiveType.Half,
430
                                    PrimitiveType.BFloat,
431
                                ):
UNCOV
432
                                    byte_size = (
×
433
                                        total_size * 2
434
                                    )  # Half and BFloat are 2 bytes
UNCOV
435
                                    arr = np.frombuffer(
×
436
                                        (ctypes.c_char * byte_size).from_address(
437
                                            ctypes.cast(
438
                                                func_result, ctypes.c_void_p
439
                                            ).value
440
                                        ),
441
                                        dtype=dtype,
442
                                    ).copy()
443
                                else:
444
                                    arr = np.ctypeslib.as_array(
×
445
                                        ctypes.cast(
446
                                            func_result, ctypes.POINTER(arr_type)
447
                                        ).contents
448
                                    )
UNCOV
449
                                return arr.reshape(shape)
×
450

451
                        # Can't determine shape, return raw pointer
UNCOV
452
                        return func_result
×
453
        elif isinstance(return_type, Scalar):
4✔
454
            return func_result
4✔
455

UNCOV
456
        return func_result
×
457

458
    def __call__(self, *args):
4✔
459
        # Ultra-fast path using pre-computed tuple-based argument info
460
        # Local variable caching for speed
461
        _eval = eval
4✔
462
        _GLOBALS = _EVAL_GLOBALS
4✔
463
        _np_empty = np.empty
4✔
464

465
        # 1. Build shape_symbol_values from shape sources (pre-computed list)
466
        shape_symbol_values = {}
4✔
467
        for s_idx, u_idx, dim_idx, key in self._shape_sources_list:
4✔
468
            if u_idx < len(args):
4✔
469
                shape_symbol_values[key] = args[u_idx].shape[dim_idx]
4✔
470

471
        # 2. Process arguments using tuple-based dispatch
472
        converted_args = []
4✔
473
        structure_refs = (
4✔
474
            []
475
        )  # Keep refs alive (includes numpy arrays for output buffers)
476
        return_buffers = (
4✔
477
            []
478
        )  # List of (np_arr, size, dims, compiled_strides, primitive_type)
479

480
        for info in self._arg_info:
4✔
481
            arg_type = info[0]
4✔
482

483
            if arg_type == _ARG_TYPE_OUTPUT_ARRAY:
4✔
484
                # info = (type, name, base_type, target_type, compiled_dims, compiled_strides, primitive_type, np_dtype)
485
                target_type = info[3]
4✔
486
                compiled_dims = info[4]
4✔
487
                compiled_strides = info[5]
4✔
488
                np_dtype = info[7]
4✔
489

490
                # Evaluate size from compiled code objects
491
                size = 1
4✔
492
                dims = []
4✔
493
                for code in compiled_dims:
4✔
494
                    d = int(_eval(code, _GLOBALS, shape_symbol_values))
4✔
495
                    dims.append(d)
4✔
496
                    size *= d
4✔
497

498
                # Use numpy for fast allocation (much faster than ctypes)
499
                buf_arr = _np_empty(size, dtype=np_dtype)
4✔
500
                structure_refs.append(buf_arr)  # Keep alive
4✔
501
                return_buffers.append((buf_arr, size, dims, compiled_strides, info[6]))
4✔
502
                # Get pointer directly from numpy array interface
503
                ptr = buf_arr.ctypes.data
4✔
504
                converted_args.append(_ctypes_cast(ptr, target_type))
4✔
505

506
            elif arg_type == _ARG_TYPE_OUTPUT_SCALAR:
4✔
507
                # info = (type, name, base_type, primitive_type)
508
                base_type = info[2]
4✔
509
                primitive_type = info[3]
4✔
510
                buf = base_type()
4✔
511
                structure_refs.append(buf)
4✔
512
                return_buffers.append((buf, 1, None, None, primitive_type))
4✔
513
                converted_args.append(_ctypes_byref(buf))
4✔
514

515
            elif arg_type == _ARG_TYPE_SHAPE:
4✔
516
                # info = (type, s_idx, key_str)
517
                converted_args.append(_c_int64(shape_symbol_values.get(info[2], 0)))
4✔
518

519
            elif arg_type == _ARG_TYPE_USER_ARRAY:
4✔
520
                # info = (type, user_idx, name, target_type)
521
                user_idx = info[1]
4✔
522
                arg = args[user_idx]
4✔
523
                shape_symbol_values[info[2]] = arg  # For indirect access
4✔
524
                # Direct pointer access - faster than data_as()
525
                converted_args.append(_ctypes_cast(arg.ctypes.data, info[3]))
4✔
526

527
            elif arg_type == _ARG_TYPE_USER_STRUCT:
4✔
528
                # info = (type, user_idx, name, struct_class, sorted_members)
529
                arg = args[info[1]]
4✔
530
                shape_symbol_values[info[2]] = arg
4✔
531
                struct_class = info[3]
4✔
532
                struct_values = {
4✔
533
                    m[0]: getattr(arg, m[0]) for m in info[4] if hasattr(arg, m[0])
534
                }
535
                c_struct = struct_class(**struct_values)
4✔
536
                structure_refs.append(c_struct)
4✔
537
                converted_args.append(_ctypes_pointer(c_struct))
4✔
538

539
            else:  # _ARG_TYPE_USER_SCALAR
540
                # info = (type, user_idx, name, target_type)
541
                arg = args[info[1]]
4✔
542
                shape_symbol_values[info[2]] = arg
4✔
543
                converted_args.append(info[3](arg))
4✔
544

545
        # 3. Call the function
546
        func_result = self.func(*converted_args)
4✔
547

548
        # 4. Process returns using pre-sorted order
549
        if not return_buffers:
4✔
550
            if func_result is not None:
4✔
551
                return self._convert_return_value(func_result, shape_symbol_values)
4✔
552
            return None
4✔
553

554
        # return_buffers: [(np_arr_or_ctypes_scalar, size, dims, compiled_strides, primitive_type), ...]
555
        num_outputs = len(return_buffers)
4✔
556
        results = [None] * num_outputs
4✔
557

558
        buf_idx = 0
4✔
559
        for i, info in enumerate(self._arg_info):
4✔
560
            arg_type = info[0]
4✔
561
            if arg_type not in (_ARG_TYPE_OUTPUT_ARRAY, _ARG_TYPE_OUTPUT_SCALAR):
4✔
562
                continue
4✔
563

564
            result_pos = self._output_pos_map[i]
4✔
565
            buf, size, dims, compiled_strides, primitive_type = return_buffers[buf_idx]
4✔
566
            buf_idx += 1
4✔
567

568
            if arg_type == _ARG_TYPE_OUTPUT_SCALAR:
4✔
569
                # Scalar - buf is a ctypes scalar
570
                results[result_pos] = buf.value
4✔
571
            else:
572
                # Array - buf is already a numpy array
573
                arr = buf
4✔
574
                if dims and len(dims) > 1:
4✔
575
                    # Need to reshape
576
                    if compiled_strides:
4✔
577
                        try:
4✔
578
                            itemsize = arr.itemsize
4✔
579
                            byte_strides = tuple(
4✔
580
                                int(_eval(s, {}, shape_symbol_values)) * itemsize
581
                                for s in compiled_strides
582
                            )
583
                            arr = np.lib.stride_tricks.as_strided(
4✔
584
                                arr, shape=dims, strides=byte_strides
585
                            )
586
                        except:
4✔
587
                            arr = arr.reshape(dims)
4✔
588
                    else:
UNCOV
589
                        arr = arr.reshape(dims)
×
590
                elif dims and len(dims) == 1:
4✔
591
                    pass  # Already 1D with correct size
4✔
592
                results[result_pos] = arr
4✔
593

594
        if len(results) == 1:
4✔
595
            return results[0]
4✔
596
        return tuple(results) if results else None
4✔
597

598
    def get_return_shape(self, *args):
4✔
599
        shape_str = self.sdfg.metadata("return_shape")
4✔
600
        if not shape_str:
4✔
601
            return None
4✔
602

NEW
603
        shape_exprs = shape_str.split(",")
×
604

605
        # Reconstruct shape values
NEW
606
        shape_values = {}
×
NEW
607
        for i, (arg_idx, dim_idx) in enumerate(self.shape_sources):
×
NEW
608
            arg = args[arg_idx]
×
NEW
609
            if np is not None and isinstance(arg, np.ndarray):
×
NEW
610
                val = arg.shape[dim_idx]
×
NEW
611
                shape_values[f"_s{i}"] = val
×
612

613
        # Add scalar arguments to shape_values
614
        # We assume the first len(args) arguments in sdfg.arguments correspond to the user arguments
NEW
615
        if hasattr(self.sdfg, "arguments"):
×
NEW
616
            for arg_name, arg_val in zip(self.sdfg.arguments, args):
×
NEW
617
                if isinstance(arg_val, (int, np.integer)):
×
NEW
618
                    shape_values[arg_name] = int(arg_val)
×
619

NEW
620
        evaluated_shape = []
×
NEW
621
        for expr in shape_exprs:
×
622
            # Simple evaluation using eval with shape_values
623
            # Warning: eval is unsafe, but here expressions come from our compiler
NEW
624
            try:
×
NEW
625
                val = eval(expr, _EVAL_GLOBALS, shape_values)
×
NEW
626
                evaluated_shape.append(int(val))
×
NEW
627
            except Exception:
×
NEW
628
                return None
×
629

NEW
630
        return tuple(evaluated_shape)
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc