• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

daisytuner / docc / 27856839861

20 Jun 2026 02:00AM UTC coverage: 61.736%. First build
27856839861

Pull #789

github

web-flow
Merge 22229aa8f into b7103c21a
Pull Request #789:

53 of 174 new or added lines in 7 files covered. (30.46%)

36995 of 59925 relevant lines covered (61.74%)

1016.05 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

59.45
/python/docc/compiler/compiled_sdfg.py
1
import ctypes
4✔
2
import re
4✔
3
import time
4✔
4
import warnings
4✔
5
from docc.sdfg import Scalar, Array, Pointer, Structure, PrimitiveType
4✔
6

7
import numpy as np
4✔
8
import ml_dtypes
4✔
9

10

11
class DoccPerformanceWarning(UserWarning):
4✔
12
    """Warning emitted when a slower-than-necessary execution path is taken."""
13

14

15
def warn_cpu_fallback(backend):
4✔
16
    """Emit a one-time performance warning about a forced device->host copy.
17

18
    Args:
19
        backend: The GPU backend name (e.g. "cuda", "rocm") for context.
20
    """
NEW
21
    warnings.warn(
×
22
        f"GPU inputs are being copied to host memory before execution because "
23
        f"device residency could not be enabled for this {backend} program "
24
        f"(some argument is touched by host code). This adds host<->device "
25
        f"transfer overhead. Run on fully-offloadable code to keep data on the "
26
        f"device.",
27
        DoccPerformanceWarning,
28
        stacklevel=3,
29
    )
30

31

32
def warn_host_to_device(backend):
4✔
33
    """Emit a one-time performance warning about a forced host->device copy.
34

35
    This is the symmetric counterpart of :func:`warn_cpu_fallback`: it fires when
36
    a device-resident program receives host inputs (numpy arrays or CPU tensors)
37
    that must be copied to the device before the call.
38

39
    Args:
40
        backend: The GPU backend name (e.g. "cuda", "rocm") for context.
41
    """
NEW
42
    warnings.warn(
×
43
        f"Host inputs are being copied to device memory before execution because "
44
        f"this device-resident {backend} program expects device pointers. This "
45
        f"adds host->device transfer overhead. Pass GPU arrays (cupy arrays or "
46
        f"CUDA tensors) to keep the data on the device.",
47
        DoccPerformanceWarning,
48
        stacklevel=3,
49
    )
50

51

52
def idiv(a, b):
4✔
53
    """Integer division (floor division for positive numbers)."""
54
    return int(a) // int(b)
4✔
55

56

57
def _is_device_array(arg):
4✔
58
    """Return True if ``arg`` already lives in device memory.
59

60
    cupy arrays and CUDA torch tensors expose ``__cuda_array_interface__``; a
61
    torch tensor reports its location via ``is_cuda``. Host arrays (numpy, CPU
62
    torch tensors) return False.
63
    """
NEW
64
    if getattr(arg, "__cuda_array_interface__", None) is not None:
×
NEW
65
        return True
×
NEW
66
    is_cuda = getattr(arg, "is_cuda", None)
×
NEW
67
    if is_cuda is not None:
×
NEW
68
        return bool(is_cuda)
×
NEW
69
    return False
×
70

71

72
def _device_array_ptr(arg):
4✔
73
    """Extract the raw device pointer from a GPU array (cupy or torch.cuda).
74

75
    Both cupy arrays and CUDA torch tensors expose ``__cuda_array_interface__``;
76
    torch tensors additionally expose ``data_ptr()``. Returns an integer address.
77
    """
NEW
78
    cai = getattr(arg, "__cuda_array_interface__", None)
×
NEW
79
    if cai is not None:
×
NEW
80
        return cai["data"][0]
×
NEW
81
    data_ptr = getattr(arg, "data_ptr", None)
×
NEW
82
    if callable(data_ptr):
×
NEW
83
        return data_ptr()
×
NEW
84
    raise TypeError(
×
85
        f"Device-resident execution requires a GPU array exposing "
86
        f"__cuda_array_interface__ or data_ptr(), got {type(arg).__name__}"
87
    )
88

89

90
# Evaluation context for shape expressions
91
_EVAL_GLOBALS = {"idiv": idiv}
4✔
92

93
# Pre-compiled regex for _convert_to_python_syntax
94
_FUNC_CALL_PATTERN = re.compile(r"([a-zA-Z_][a-zA-Z0-9_]*)\(([^()]+)\)")
4✔
95
_PLACEHOLDER_PATTERN = re.compile(
4✔
96
    r"@@@FUNC@@@([a-zA-Z_][a-zA-Z0-9_]*)@@@(.+?)@@@END@@@"
97
)
98
_KNOWN_FUNCTIONS = frozenset(
4✔
99
    {"int", "float", "abs", "min", "max", "sum", "len", "idiv"}
100
)
101

102
# Argument type constants for fast dispatch
103
_ARG_TYPE_OUTPUT_ARRAY = 0
4✔
104
_ARG_TYPE_OUTPUT_SCALAR = 1
4✔
105
_ARG_TYPE_SHAPE = 2
4✔
106
_ARG_TYPE_USER_ARRAY = 3
4✔
107
_ARG_TYPE_USER_STRUCT = 4
4✔
108
_ARG_TYPE_USER_SCALAR = 5
4✔
109

110
# Pre-cache ctypes.c_int64 for speed
111
_c_int64 = ctypes.c_int64
4✔
112
_ctypes_cast = ctypes.cast
4✔
113
_ctypes_addressof = ctypes.addressof
4✔
114
_ctypes_byref = ctypes.byref
4✔
115
_ctypes_pointer = ctypes.pointer
4✔
116
_ctypes_c_void_p = ctypes.c_void_p
4✔
117

118
# Map primitive types to numpy dtypes for fast buffer allocation
119
_PRIMITIVE_TO_NP_DTYPE = {
4✔
120
    PrimitiveType.Float: np.float32,
121
    PrimitiveType.Double: np.float64,
122
    PrimitiveType.Int8: np.int8,
123
    PrimitiveType.Int16: np.int16,
124
    PrimitiveType.Int32: np.int32,
125
    PrimitiveType.Int64: np.int64,
126
    PrimitiveType.UInt8: np.uint8,
127
    PrimitiveType.UInt16: np.uint16,
128
    PrimitiveType.UInt32: np.uint32,
129
    PrimitiveType.UInt64: np.uint64,
130
    PrimitiveType.Bool: np.bool_,
131
    PrimitiveType.Half: np.float16,
132
    PrimitiveType.BFloat: ml_dtypes.bfloat16,
133
}
134

135
# Pre-computed dtype map for numpy conversion
136
_NUMPY_DTYPE_MAP = {
4✔
137
    PrimitiveType.Float: np.float32,
138
    PrimitiveType.Double: np.float64,
139
    PrimitiveType.Int8: np.int8,
140
    PrimitiveType.Int16: np.int16,
141
    PrimitiveType.Int32: np.int32,
142
    PrimitiveType.Int64: np.int64,
143
    PrimitiveType.UInt8: np.uint8,
144
    PrimitiveType.UInt16: np.uint16,
145
    PrimitiveType.UInt32: np.uint32,
146
    PrimitiveType.UInt64: np.uint64,
147
    PrimitiveType.Bool: np.bool_,
148
    PrimitiveType.Half: np.float16,
149
    PrimitiveType.BFloat: ml_dtypes.bfloat16,
150
}
151

152
_CTYPES_MAP = {
4✔
153
    PrimitiveType.Bool: ctypes.c_bool,
154
    PrimitiveType.Int8: ctypes.c_int8,
155
    PrimitiveType.Int16: ctypes.c_int16,
156
    PrimitiveType.Int32: ctypes.c_int32,
157
    PrimitiveType.Int64: ctypes.c_int64,
158
    PrimitiveType.UInt8: ctypes.c_uint8,
159
    PrimitiveType.UInt16: ctypes.c_uint16,
160
    PrimitiveType.UInt32: ctypes.c_uint32,
161
    PrimitiveType.UInt64: ctypes.c_uint64,
162
    PrimitiveType.Float: ctypes.c_float,
163
    PrimitiveType.Double: ctypes.c_double,
164
    # Half and BFloat are 2 bytes, use c_uint16 for raw storage
165
    PrimitiveType.Half: ctypes.c_uint16,
166
    PrimitiveType.BFloat: ctypes.c_uint16,
167
}
168

169

170
class CompiledSDFG:
4✔
171
    def __init__(
4✔
172
        self,
173
        lib_path,
174
        sdfg,
175
        shape_sources=None,
176
        structure_member_info=None,
177
        output_args=None,
178
        output_shapes=None,
179
        output_strides=None,
180
        device_resident=False,
181
        device_backend=None,
182
    ):
183
        self.lib_path = lib_path
4✔
184
        self.sdfg = sdfg
4✔
185
        self.shape_sources = shape_sources or []
4✔
186
        self.structure_member_info = structure_member_info or {}
4✔
187
        self.lib = ctypes.CDLL(lib_path)
4✔
188
        self.func = getattr(self.lib, sdfg.name)
4✔
189

190
        # Check for output args
191
        self.output_args = output_args or []
4✔
192
        if not self.output_args and hasattr(sdfg, "metadata"):
4✔
193
            out_args_str = sdfg.metadata("output_args")
4✔
194
            if out_args_str:
4✔
195
                self.output_args = out_args_str.split(",")
×
196

197
        self.output_shapes = output_shapes or {}
4✔
198
        self.output_strides = output_strides or {}
4✔
199

200
        # Device residency: set by the DeviceResidentArgPromotion pass when all
201
        # pointer arguments were promoted to device-resident storage. When active,
202
        # the compiled function expects device pointers (no host<->device copies at
203
        # the boundary) and produces device-resident outputs. Communicated
204
        # explicitly via the constructor (pass return value), not via metadata.
205
        self.device_resident = bool(device_resident)
4✔
206
        self.device_backend = device_backend or (
4✔
207
            "cuda" if self.device_resident else None
208
        )
209
        # Warn at most once per artifact when host inputs must be copied to device.
210
        self._warned_host_to_device = False
4✔
211

212
        # Cache for ctypes structure definitions
213
        self._ctypes_structures = {}
4✔
214

215
        # Set up argument types
216
        self.arg_names = sdfg.arguments
4✔
217
        self.arg_types = []
4✔
218
        self.arg_sdfg_types = []  # Keep track of original sdfg types
4✔
219
        for arg_name in sdfg.arguments:
4✔
220
            arg_type = sdfg.type(arg_name)
4✔
221
            self.arg_sdfg_types.append(arg_type)
4✔
222
            ct_type = self._get_ctypes_type(arg_type)
4✔
223
            self.arg_types.append(ct_type)
4✔
224

225
        self.func.argtypes = self.arg_types
4✔
226

227
        # Set up return type
228
        self.func.restype = self._get_ctypes_type(sdfg.return_type)
4✔
229

230
        # Pre-compute argument classification for fast __call__
231
        self._precompute_arg_metadata()
4✔
232

233
    def _precompute_arg_metadata(self):
4✔
234
        """Pre-compute argument metadata for fast __call__ dispatch."""
235
        output_args_set = set(self.output_args)
4✔
236

237
        # Build shape source lookup: s_idx -> (u_idx, dim_idx)
238
        # Also pre-compute the shape keys
239
        self._shape_sources_list = []  # [(s_idx, u_idx, dim_idx, key_str), ...]
4✔
240
        for i, (u_idx, dim_idx) in enumerate(self.shape_sources):
4✔
241
            self._shape_sources_list.append((i, u_idx, dim_idx, f"_s{i}"))
4✔
242

243
        # Classify each argument using tuple-based info for faster access
244
        # Each entry is (arg_type, *type_specific_data)
245
        self._arg_info = []
4✔
246
        user_arg_counter = 0
4✔
247

248
        # For output ordering (avoid sorting at runtime)
249
        output_order = []
4✔
250

251
        for i, arg_name in enumerate(self.arg_names):
4✔
252
            if arg_name in output_args_set:
4✔
253
                # Output argument
254
                target_type = self.arg_types[i]
4✔
255
                base_type = target_type._type_
4✔
256
                sdfg_type = self.arg_sdfg_types[i]
4✔
257

258
                # Pre-compute primitive type for return processing
259
                primitive_type = None
4✔
260
                if isinstance(sdfg_type, Pointer) and sdfg_type.has_pointee_type():
4✔
261
                    pointee = sdfg_type.pointee_type
4✔
262
                    if isinstance(pointee, Scalar):
4✔
263
                        primitive_type = pointee.primitive_type
4✔
264

265
                if arg_name in self.output_shapes:
4✔
266
                    dims = self.output_shapes[arg_name]
4✔
267
                    # Always compile shape expressions - they may depend on runtime values
268
                    compiled_dims = []
4✔
269
                    for d in dims:
4✔
270
                        d_str = str(d)
4✔
271
                        expr = self._convert_to_python_syntax(d_str)
4✔
272
                        compiled_dims.append(compile(expr, "<shape>", "eval"))
4✔
273

274
                    # Pre-compile stride expressions if available
275
                    compiled_strides = None
4✔
276
                    if arg_name in self.output_strides:
4✔
277
                        compiled_strides = []
4✔
278
                        for s in self.output_strides[arg_name]:
4✔
279
                            expr = self._convert_to_python_syntax(str(s))
4✔
280
                            compiled_strides.append(compile(expr, "<stride>", "eval"))
4✔
281

282
                    # Get numpy dtype for fast allocation
283
                    np_dtype = (
4✔
284
                        _PRIMITIVE_TO_NP_DTYPE.get(primitive_type, np.float64)
285
                        if primitive_type
286
                        else np.float64
287
                    )
288

289
                    # Tuple: (arg_type, name, base_type, target_type, compiled_dims, compiled_strides, primitive_type, np_dtype)
290
                    info_idx = len(self._arg_info)
4✔
291
                    self._arg_info.append(
4✔
292
                        (
293
                            _ARG_TYPE_OUTPUT_ARRAY,
294
                            arg_name,
295
                            base_type,
296
                            target_type,
297
                            compiled_dims,
298
                            compiled_strides,
299
                            primitive_type,
300
                            np_dtype,
301
                        )
302
                    )
303
                    output_order.append((int(arg_name.split("_")[-1]), info_idx))
4✔
304
                else:
305
                    # Scalar return
306
                    info_idx = len(self._arg_info)
4✔
307
                    self._arg_info.append(
4✔
308
                        (_ARG_TYPE_OUTPUT_SCALAR, arg_name, base_type, primitive_type)
309
                    )
310
                    output_order.append((int(arg_name.split("_")[-1]), info_idx))
4✔
311

312
            elif arg_name.startswith("_s") and arg_name[2:].isdigit():
4✔
313
                # Shape symbol argument - tuple: (arg_type, s_idx, key_str)
314
                s_idx = int(arg_name[2:])
4✔
315
                self._arg_info.append((_ARG_TYPE_SHAPE, s_idx, f"_s{s_idx}"))
4✔
316
            else:
317
                # User argument
318
                sdfg_type = self.arg_sdfg_types[i]
4✔
319
                target_type = self.arg_types[i]
4✔
320
                is_struct_ptr = (
4✔
321
                    sdfg_type
322
                    and isinstance(sdfg_type, Pointer)
323
                    and sdfg_type.has_pointee_type()
324
                    and isinstance(sdfg_type.pointee_type, Structure)
325
                )
326

327
                if is_struct_ptr:
4✔
328
                    struct_name = sdfg_type.pointee_type.name
4✔
329
                    struct_class = self._create_ctypes_structure(struct_name)
4✔
330
                    members = self.structure_member_info[struct_name]
4✔
331
                    sorted_members = tuple(
4✔
332
                        sorted(members.items(), key=lambda x: x[1][0])
333
                    )
334
                    # Tuple: (arg_type, user_idx, name, struct_class, sorted_members)
335
                    self._arg_info.append(
4✔
336
                        (
337
                            _ARG_TYPE_USER_STRUCT,
338
                            user_arg_counter,
339
                            arg_name,
340
                            struct_class,
341
                            sorted_members,
342
                        )
343
                    )
344
                elif hasattr(target_type, "contents"):
4✔
345
                    # Array user arg - tuple: (arg_type, user_idx, name, target_type)
346
                    self._arg_info.append(
4✔
347
                        (_ARG_TYPE_USER_ARRAY, user_arg_counter, arg_name, target_type)
348
                    )
349
                else:
350
                    # Scalar user arg - tuple: (arg_type, user_idx, name, target_type)
351
                    self._arg_info.append(
4✔
352
                        (_ARG_TYPE_USER_SCALAR, user_arg_counter, arg_name, target_type)
353
                    )
354
                user_arg_counter += 1
4✔
355

356
        self._num_user_args = user_arg_counter
4✔
357

358
        # Pre-sort output order and build position map
359
        output_order.sort(key=lambda x: x[0])
4✔
360
        self._output_order = tuple(idx for _, idx in output_order)
4✔
361
        # Map from _arg_info index to result position (for O(1) lookup)
362
        self._output_pos_map = {idx: pos for pos, idx in enumerate(self._output_order)}
4✔
363

364
    def _convert_to_python_syntax(self, expr_str):
4✔
365
        result = expr_str
4✔
366

367
        while True:
4✔
368
            match = _FUNC_CALL_PATTERN.search(result)
4✔
369
            if not match:
4✔
370
                break
4✔
371

372
            name = match.group(1)
4✔
373
            index = match.group(2)
4✔
374

375
            if name.lower() in _KNOWN_FUNCTIONS:
4✔
376
                # Use unique delimiters that won't appear in expressions
377
                placeholder = f"@@@FUNC@@@{name}@@@{index}@@@END@@@"
4✔
378
                result = result[: match.start()] + placeholder + result[match.end() :]
4✔
379
            else:
380
                result = (
×
381
                    result[: match.start()] + f"{name}[{index}]" + result[match.end() :]
382
                )
383

384
        result = _PLACEHOLDER_PATTERN.sub(r"\1(\2)", result)
4✔
385

386
        return result
4✔
387

388
    def _create_ctypes_structure(self, struct_name):
4✔
389
        """Create a ctypes Structure class for the given structure name."""
390
        if struct_name in self._ctypes_structures:
4✔
391
            return self._ctypes_structures[struct_name]
4✔
392

393
        if struct_name not in self.structure_member_info:
4✔
394
            raise ValueError(f"Structure '{struct_name}' not found in member info")
×
395

396
        # Get member info: {member_name: (index, type)}
397
        members = self.structure_member_info[struct_name]
4✔
398
        # Sort by index to get correct order
399
        sorted_members = sorted(members.items(), key=lambda x: x[1][0])
4✔
400

401
        # Build _fields_ for ctypes.Structure
402
        fields = []
4✔
403
        for member_name, (index, member_type) in sorted_members:
4✔
404
            ct_type = self._get_ctypes_type(member_type)
4✔
405
            fields.append((member_name, ct_type))
4✔
406

407
        # Create the ctypes Structure class dynamically
408
        class CStructure(ctypes.Structure):
4✔
409
            _fields_ = fields
4✔
410

411
        self._ctypes_structures[struct_name] = CStructure
4✔
412
        return CStructure
4✔
413

414
    def _get_ctypes_type(self, sdfg_type):
4✔
415
        if isinstance(sdfg_type, Scalar):
4✔
416
            return _CTYPES_MAP.get(sdfg_type.primitive_type, ctypes.c_void_p)
4✔
417
        elif isinstance(sdfg_type, Array):
4✔
418
            # Arrays are passed as pointers
419
            elem_type = _CTYPES_MAP.get(sdfg_type.primitive_type, ctypes.c_void_p)
×
420
            return ctypes.POINTER(elem_type)
×
421
        elif isinstance(sdfg_type, Pointer):
4✔
422
            # Check if pointee is a Structure
423
            # Note: has_pointee_type() is guaranteed to exist on Pointer instances from C++ bindings
424
            if sdfg_type.has_pointee_type():
4✔
425
                pointee = sdfg_type.pointee_type
4✔
426
                if isinstance(pointee, Structure):
4✔
427
                    # Create ctypes structure and return pointer to it
428
                    struct_class = self._create_ctypes_structure(pointee.name)
4✔
429
                    return ctypes.POINTER(struct_class)
4✔
430
                elif isinstance(pointee, Scalar):
4✔
431
                    elem_type = _CTYPES_MAP.get(pointee.primitive_type, ctypes.c_void_p)
4✔
432
                    return ctypes.POINTER(elem_type)
4✔
433
            return ctypes.c_void_p
×
434
        return ctypes.c_void_p
×
435

436
    def _convert_return_value(self, func_result, shape_symbol_values):
4✔
437
        return_type = self.sdfg.return_type
4✔
438

439
        if isinstance(return_type, Pointer):
4✔
440
            if return_type.has_pointee_type():
×
441
                pointee = return_type.pointee_type
×
442
                if isinstance(pointee, Scalar):
×
443
                    # Pointer to scalar element type - need to determine array size
444
                    # Get return shape from metadata if available
445
                    return_shape_str = self.sdfg.metadata("return_shape")
×
446
                    if return_shape_str:
×
447
                        # Strip brackets (metadata may be "[10,10]" format)
448
                        return_shape_str = return_shape_str.strip("[]")
×
449
                        shape = []
×
450
                        for dim_str in return_shape_str.split(","):
×
451
                            try:
×
452
                                eval_str = self._convert_to_python_syntax(str(dim_str))
×
453
                                val = eval(eval_str, _EVAL_GLOBALS, shape_symbol_values)
×
454
                                shape.append(int(val))
×
455
                            except Exception:
×
456
                                # Can't evaluate shape, return raw pointer
457
                                return func_result
×
458

459
                        # Determine numpy dtype from primitive type
460
                        dtype = _NUMPY_DTYPE_MAP.get(pointee.primitive_type, np.float64)
×
461

462
                        # Calculate total size
463
                        total_size = 1
×
464
                        for dim in shape:
×
465
                            total_size *= dim
×
466

467
                        # Create numpy array from pointer
468
                        ct_type = _CTYPES_MAP.get(
×
469
                            pointee.primitive_type, ctypes.c_double
470
                        )
471
                        arr_type = ct_type * total_size
×
472
                        # For Half/BFloat, use np.frombuffer since np.ctypeslib.as_array
473
                        # doesn't support these types (PEP 3118 buffer format limitation)
474
                        if pointee.primitive_type in (
×
475
                            PrimitiveType.Half,
476
                            PrimitiveType.BFloat,
477
                        ):
478
                            byte_size = total_size * 2  # Half and BFloat are 2 bytes
×
479
                            arr = np.frombuffer(
×
480
                                (ctypes.c_char * byte_size).from_address(
481
                                    ctypes.cast(func_result, ctypes.c_void_p).value
482
                                ),
483
                                dtype=dtype,
484
                            ).copy()
485
                        else:
486
                            arr = np.ctypeslib.as_array(
×
487
                                ctypes.cast(
488
                                    func_result, ctypes.POINTER(arr_type)
489
                                ).contents
490
                            )
491
                        return arr.reshape(shape)
×
492
                    else:
493
                        # No shape info - try to infer from input shapes
494
                        # For identity-like operations, the output shape matches input
495
                        if len(self.shape_sources) > 0 and len(shape_symbol_values) > 0:
×
496
                            # Use first input's shape as a fallback
497
                            shape = []
×
498
                            for i in range(len(self.shape_sources)):
×
499
                                if f"_s{i}" in shape_symbol_values:
×
500
                                    shape.append(shape_symbol_values[f"_s{i}"])
×
501

502
                            if shape:
×
503
                                dtype = _NUMPY_DTYPE_MAP.get(
×
504
                                    pointee.primitive_type, np.float64
505
                                )
506

507
                                total_size = 1
×
508
                                for dim in shape:
×
509
                                    total_size *= dim
×
510

511
                                ct_type = _CTYPES_MAP.get(
×
512
                                    pointee.primitive_type, ctypes.c_double
513
                                )
514
                                arr_type = ct_type * total_size
×
515
                                # For Half/BFloat, use np.frombuffer since np.ctypeslib.as_array
516
                                # doesn't support these types (PEP 3118 buffer format limitation)
517
                                if pointee.primitive_type in (
×
518
                                    PrimitiveType.Half,
519
                                    PrimitiveType.BFloat,
520
                                ):
521
                                    byte_size = (
×
522
                                        total_size * 2
523
                                    )  # Half and BFloat are 2 bytes
524
                                    arr = np.frombuffer(
×
525
                                        (ctypes.c_char * byte_size).from_address(
526
                                            ctypes.cast(
527
                                                func_result, ctypes.c_void_p
528
                                            ).value
529
                                        ),
530
                                        dtype=dtype,
531
                                    ).copy()
532
                                else:
533
                                    arr = np.ctypeslib.as_array(
×
534
                                        ctypes.cast(
535
                                            func_result, ctypes.POINTER(arr_type)
536
                                        ).contents
537
                                    )
538
                                return arr.reshape(shape)
×
539

540
                        # Can't determine shape, return raw pointer
541
                        return func_result
×
542
        elif isinstance(return_type, Scalar):
4✔
543
            return func_result
4✔
544

545
        return func_result
×
546

547
    def _ensure_device_array(self, arg, keepalive, writebacks):
4✔
548
        """Return a device array for ``arg``, copying host inputs to the device.
549

550
        Device-resident artifacts consume raw device pointers. For backwards
551
        compatibility, callers may still pass host arrays (numpy arrays or CPU
552
        torch tensors); these are copied to the device here and a one-time
553
        performance warning is emitted. Device arrays are returned unchanged.
554

555
        The copied device array is appended to ``keepalive`` so it outlives the
556
        call, and the ``(host, device)`` pair is recorded in ``writebacks`` so
557
        that in-place writes (e.g. output arguments) are mirrored back to the
558
        original host array after execution.
559
        """
NEW
560
        if _is_device_array(arg):
×
NEW
561
            return arg
×
562

NEW
563
        import cupy
×
564

NEW
565
        host = arg
×
566
        # Convert a CPU torch tensor to numpy first; cupy.asarray handles numpy.
NEW
567
        if hasattr(host, "detach") and hasattr(host, "numpy"):
×
NEW
568
            host = host.detach().numpy()
×
NEW
569
        device_arg = cupy.asarray(host)
×
NEW
570
        keepalive.append(device_arg)
×
NEW
571
        writebacks.append((arg, device_arg))
×
572

NEW
573
        if not self._warned_host_to_device:
×
NEW
574
            self._warned_host_to_device = True
×
NEW
575
            warn_host_to_device(self.device_backend or "cuda")
×
576

NEW
577
        return device_arg
×
578

579
    def _call_device(self, *args):
4✔
580
        """Execute with device-resident arguments (cupy / torch.cuda).
581

582
        Inputs are passed as raw device pointers and output arrays are allocated
583
        on the device, so no host<->device copies happen at the call boundary.
584
        Outputs are returned as cupy arrays (zero-copy interoperable with torch
585
        via DLPack / __cuda_array_interface__).
586
        """
NEW
587
        import cupy
×
588

NEW
589
        _eval = eval
×
NEW
590
        _GLOBALS = _EVAL_GLOBALS
×
591
        # 1. Build shape_symbol_values from shape sources
NEW
592
        shape_symbol_values = {}
×
NEW
593
        for s_idx, u_idx, dim_idx, key in self._shape_sources_list:
×
NEW
594
            if u_idx < len(args):
×
NEW
595
                shape_symbol_values[key] = args[u_idx].shape[dim_idx]
×
596

597
        # 2. Process arguments
NEW
598
        converted_args = []
×
NEW
599
        keepalive = []  # keep device buffers / ctypes scalars alive
×
NEW
600
        writebacks = []  # (host_arg, device_arg) for host inputs copied to device
×
NEW
601
        return_buffers = []  # (buf, size, dims, compiled_strides, primitive_type)
×
602

NEW
603
        for info in self._arg_info:
×
NEW
604
            arg_type = info[0]
×
605

NEW
606
            if arg_type == _ARG_TYPE_OUTPUT_ARRAY:
×
NEW
607
                target_type = info[3]
×
NEW
608
                compiled_dims = info[4]
×
NEW
609
                compiled_strides = info[5]
×
NEW
610
                np_dtype = info[7]
×
611

NEW
612
                size = 1
×
NEW
613
                dims = []
×
NEW
614
                for code in compiled_dims:
×
NEW
615
                    d = int(_eval(code, _GLOBALS, shape_symbol_values))
×
NEW
616
                    dims.append(d)
×
NEW
617
                    size *= d
×
618

NEW
619
                buf = cupy.empty(size, dtype=np_dtype)
×
NEW
620
                keepalive.append(buf)
×
NEW
621
                return_buffers.append((buf, size, dims, compiled_strides, info[6]))
×
NEW
622
                converted_args.append(_ctypes_cast(int(buf.data.ptr), target_type))
×
623

NEW
624
            elif arg_type == _ARG_TYPE_OUTPUT_SCALAR:
×
NEW
625
                base_type = info[2]
×
NEW
626
                primitive_type = info[3]
×
NEW
627
                buf = base_type()
×
NEW
628
                keepalive.append(buf)
×
NEW
629
                return_buffers.append((buf, 1, None, None, primitive_type))
×
NEW
630
                converted_args.append(_ctypes_byref(buf))
×
631

NEW
632
            elif arg_type == _ARG_TYPE_SHAPE:
×
NEW
633
                converted_args.append(_c_int64(shape_symbol_values.get(info[2], 0)))
×
634

NEW
635
            elif arg_type == _ARG_TYPE_USER_ARRAY:
×
NEW
636
                arg = args[info[1]]
×
NEW
637
                shape_symbol_values[info[2]] = arg
×
NEW
638
                arg = self._ensure_device_array(arg, keepalive, writebacks)
×
NEW
639
                converted_args.append(_ctypes_cast(_device_array_ptr(arg), info[3]))
×
640

NEW
641
            elif arg_type == _ARG_TYPE_USER_STRUCT:
×
NEW
642
                raise NotImplementedError(
×
643
                    "Structure arguments are not supported for device-resident "
644
                    "execution."
645
                )
646

647
            else:  # _ARG_TYPE_USER_SCALAR
NEW
648
                arg = args[info[1]]
×
NEW
649
                shape_symbol_values[info[2]] = arg
×
NEW
650
                converted_args.append(info[3](arg))
×
651

652
        # 3. Call the function
NEW
653
        func_result = self.func(*converted_args)
×
654

655
        # 3b. Mirror device results back into host inputs that were copied to the
656
        # device, so in-place writes (e.g. output arguments) are visible to the
657
        # caller. Read-only inputs are unchanged and copy back identically.
NEW
658
        for host_arg, device_arg in writebacks:
×
NEW
659
            host_view = cupy.asnumpy(device_arg)
×
NEW
660
            if isinstance(host_arg, np.ndarray):
×
NEW
661
                np.copyto(host_arg, host_view.reshape(host_arg.shape))
×
NEW
662
            elif hasattr(host_arg, "copy_"):  # torch CPU tensor
×
NEW
663
                import torch
×
664

NEW
665
                host_arg.copy_(
×
666
                    torch.from_numpy(host_view.reshape(tuple(host_arg.shape)))
667
                )
668

669
        # 4. Process returns using pre-sorted order
NEW
670
        if not return_buffers:
×
NEW
671
            return None
×
672

NEW
673
        num_outputs = len(return_buffers)
×
NEW
674
        results = [None] * num_outputs
×
675

NEW
676
        buf_idx = 0
×
NEW
677
        for i, info in enumerate(self._arg_info):
×
NEW
678
            arg_type = info[0]
×
NEW
679
            if arg_type not in (_ARG_TYPE_OUTPUT_ARRAY, _ARG_TYPE_OUTPUT_SCALAR):
×
NEW
680
                continue
×
681

NEW
682
            result_pos = self._output_pos_map[i]
×
NEW
683
            buf, size, dims, compiled_strides, primitive_type = return_buffers[buf_idx]
×
NEW
684
            buf_idx += 1
×
685

NEW
686
            if arg_type == _ARG_TYPE_OUTPUT_SCALAR:
×
NEW
687
                results[result_pos] = buf.value
×
688
            else:
NEW
689
                arr = buf
×
NEW
690
                if dims and len(dims) > 1:
×
NEW
691
                    arr = arr.reshape(dims)
×
NEW
692
                results[result_pos] = arr
×
693

NEW
694
        if len(results) == 1:
×
NEW
695
            return results[0]
×
NEW
696
        return tuple(results) if results else None
×
697

698
    def __call__(self, *args):
4✔
699
        # Device-resident artifacts consume/produce device pointers directly.
700
        if self.device_resident:
4✔
NEW
701
            return self._call_device(*args)
×
702

703
        # Ultra-fast path using pre-computed tuple-based argument info
704
        # Local variable caching for speed
705
        _eval = eval
4✔
706
        _GLOBALS = _EVAL_GLOBALS
4✔
707
        _np_empty = np.empty
4✔
708

709
        # 1. Build shape_symbol_values from shape sources (pre-computed list)
710
        shape_symbol_values = {}
4✔
711
        for s_idx, u_idx, dim_idx, key in self._shape_sources_list:
4✔
712
            if u_idx < len(args):
4✔
713
                shape_symbol_values[key] = args[u_idx].shape[dim_idx]
4✔
714

715
        # 2. Process arguments using tuple-based dispatch
716
        converted_args = []
4✔
717
        structure_refs = (
4✔
718
            []
719
        )  # Keep refs alive (includes numpy arrays for output buffers)
720
        return_buffers = (
4✔
721
            []
722
        )  # List of (np_arr, size, dims, compiled_strides, primitive_type)
723

724
        for info in self._arg_info:
4✔
725
            arg_type = info[0]
4✔
726

727
            if arg_type == _ARG_TYPE_OUTPUT_ARRAY:
4✔
728
                # info = (type, name, base_type, target_type, compiled_dims, compiled_strides, primitive_type, np_dtype)
729
                target_type = info[3]
4✔
730
                compiled_dims = info[4]
4✔
731
                compiled_strides = info[5]
4✔
732
                np_dtype = info[7]
4✔
733

734
                # Evaluate size from compiled code objects
735
                size = 1
4✔
736
                dims = []
4✔
737
                for code in compiled_dims:
4✔
738
                    d = int(_eval(code, _GLOBALS, shape_symbol_values))
4✔
739
                    dims.append(d)
4✔
740
                    size *= d
4✔
741

742
                # Use numpy for fast allocation (much faster than ctypes)
743
                buf_arr = _np_empty(size, dtype=np_dtype)
4✔
744
                structure_refs.append(buf_arr)  # Keep alive
4✔
745
                return_buffers.append((buf_arr, size, dims, compiled_strides, info[6]))
4✔
746
                # Get pointer directly from numpy array interface
747
                ptr = buf_arr.ctypes.data
4✔
748
                converted_args.append(_ctypes_cast(ptr, target_type))
4✔
749

750
            elif arg_type == _ARG_TYPE_OUTPUT_SCALAR:
4✔
751
                # info = (type, name, base_type, primitive_type)
752
                base_type = info[2]
4✔
753
                primitive_type = info[3]
4✔
754
                buf = base_type()
4✔
755
                structure_refs.append(buf)
4✔
756
                return_buffers.append((buf, 1, None, None, primitive_type))
4✔
757
                converted_args.append(_ctypes_byref(buf))
4✔
758

759
            elif arg_type == _ARG_TYPE_SHAPE:
4✔
760
                # info = (type, s_idx, key_str)
761
                converted_args.append(_c_int64(shape_symbol_values.get(info[2], 0)))
4✔
762

763
            elif arg_type == _ARG_TYPE_USER_ARRAY:
4✔
764
                # info = (type, user_idx, name, target_type)
765
                user_idx = info[1]
4✔
766
                arg = args[user_idx]
4✔
767
                shape_symbol_values[info[2]] = arg  # For indirect access
4✔
768
                # Direct pointer access - faster than data_as()
769
                converted_args.append(_ctypes_cast(arg.ctypes.data, info[3]))
4✔
770

771
            elif arg_type == _ARG_TYPE_USER_STRUCT:
4✔
772
                # info = (type, user_idx, name, struct_class, sorted_members)
773
                arg = args[info[1]]
4✔
774
                shape_symbol_values[info[2]] = arg
4✔
775
                struct_class = info[3]
4✔
776
                struct_values = {
4✔
777
                    m[0]: getattr(arg, m[0]) for m in info[4] if hasattr(arg, m[0])
778
                }
779
                c_struct = struct_class(**struct_values)
4✔
780
                structure_refs.append(c_struct)
4✔
781
                converted_args.append(_ctypes_pointer(c_struct))
4✔
782

783
            else:  # _ARG_TYPE_USER_SCALAR
784
                # info = (type, user_idx, name, target_type)
785
                arg = args[info[1]]
4✔
786
                shape_symbol_values[info[2]] = arg
4✔
787
                converted_args.append(info[3](arg))
4✔
788

789
        # 3. Call the function
790
        func_result = self.func(*converted_args)
4✔
791

792
        # 4. Process returns using pre-sorted order
793
        if not return_buffers:
4✔
794
            if func_result is not None:
4✔
795
                return self._convert_return_value(func_result, shape_symbol_values)
4✔
796
            return None
4✔
797

798
        # return_buffers: [(np_arr_or_ctypes_scalar, size, dims, compiled_strides, primitive_type), ...]
799
        num_outputs = len(return_buffers)
4✔
800
        results = [None] * num_outputs
4✔
801

802
        buf_idx = 0
4✔
803
        for i, info in enumerate(self._arg_info):
4✔
804
            arg_type = info[0]
4✔
805
            if arg_type not in (_ARG_TYPE_OUTPUT_ARRAY, _ARG_TYPE_OUTPUT_SCALAR):
4✔
806
                continue
4✔
807

808
            result_pos = self._output_pos_map[i]
4✔
809
            buf, size, dims, compiled_strides, primitive_type = return_buffers[buf_idx]
4✔
810
            buf_idx += 1
4✔
811

812
            if arg_type == _ARG_TYPE_OUTPUT_SCALAR:
4✔
813
                # Scalar - buf is a ctypes scalar
814
                results[result_pos] = buf.value
4✔
815
            else:
816
                # Array - buf is already a numpy array
817
                arr = buf
4✔
818
                if dims and len(dims) > 1:
4✔
819
                    # Need to reshape
820
                    if compiled_strides:
4✔
821
                        try:
4✔
822
                            itemsize = arr.itemsize
4✔
823
                            byte_strides = tuple(
4✔
824
                                int(_eval(s, {}, shape_symbol_values)) * itemsize
825
                                for s in compiled_strides
826
                            )
827
                            arr = np.lib.stride_tricks.as_strided(
4✔
828
                                arr, shape=dims, strides=byte_strides
829
                            )
830
                        except:
4✔
831
                            arr = arr.reshape(dims)
4✔
832
                    else:
833
                        arr = arr.reshape(dims)
×
834
                elif dims and len(dims) == 1:
4✔
835
                    pass  # Already 1D with correct size
4✔
836
                results[result_pos] = arr
4✔
837

838
        if len(results) == 1:
4✔
839
            return results[0]
4✔
840
        return tuple(results) if results else None
4✔
841

842
    def get_return_shape(self, *args):
4✔
843
        shape_str = self.sdfg.metadata("return_shape")
4✔
844
        if not shape_str:
4✔
845
            return None
4✔
846

847
        shape_exprs = shape_str.split(",")
×
848

849
        # Reconstruct shape values
850
        shape_values = {}
×
851
        for i, (arg_idx, dim_idx) in enumerate(self.shape_sources):
×
852
            arg = args[arg_idx]
×
853
            if np is not None and isinstance(arg, np.ndarray):
×
854
                val = arg.shape[dim_idx]
×
855
                shape_values[f"_s{i}"] = val
×
856

857
        # Add scalar arguments to shape_values
858
        # We assume the first len(args) arguments in sdfg.arguments correspond to the user arguments
859
        if hasattr(self.sdfg, "arguments"):
×
860
            for arg_name, arg_val in zip(self.sdfg.arguments, args):
×
861
                if isinstance(arg_val, (int, np.integer)):
×
862
                    shape_values[arg_name] = int(arg_val)
×
863

864
        evaluated_shape = []
×
865
        for expr in shape_exprs:
×
866
            # Simple evaluation using eval with shape_values
867
            # Warning: eval is unsafe, but here expressions come from our compiler
868
            try:
×
869
                val = eval(expr, _EVAL_GLOBALS, shape_values)
×
870
                evaluated_shape.append(int(val))
×
871
            except Exception:
×
872
                return None
×
873

874
        return tuple(evaluated_shape)
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc