• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

daisytuner / docc / 27899971587

21 Jun 2026 09:24AM UTC coverage: 61.753% (-0.09%) from 61.843%
27899971587

Pull #797

github

web-flow
Merge 7f602d37c into 9a8945878
Pull Request #797: Adds Infrastructure to Forward Device Pointer to CompiledSDFG

52 of 176 new or added lines in 4 files covered. (29.55%)

16 existing lines in 1 file now uncovered.

37092 of 60065 relevant lines covered (61.75%)

1015.1 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

60.46
/python/docc/compiler/compiled_sdfg.py
1
import ctypes
4✔
2
import re
4✔
3
import time
4✔
4
import warnings
4✔
5
from docc.sdfg import Scalar, Array, Pointer, Structure, PrimitiveType
4✔
6

7
import numpy as np
4✔
8
import ml_dtypes
4✔
9

10

11
class DoccPerformanceWarning(UserWarning):
4✔
12
    """Warning emitted when a slower-than-necessary execution path is taken."""
13

14

15
def warn_device_residency_failed(backend):
4✔
16
    """Emit a one-time warning that device-residency promotion did not apply.
17

18
    Args:
19
        backend: The GPU backend name (e.g. "cuda", "rocm") for context.
20
    """
NEW
21
    warnings.warn(
×
22
        f"Device residency could not be enabled for this function on backend {backend}. "
23
        f"The function is only partially offloaded; function arguments stay in host memory before "
24
        f"and after execution.",
25
        DoccPerformanceWarning,
26
        stacklevel=3,
27
    )
28

29

30
def warn_host_to_device_copies(backend):
4✔
31
    """Emit a one-time performance warning about unnecessary host-to-device copies.
32

33
    Args:
34
        backend: The GPU backend name (e.g. "cuda", "rocm") for context.
35
    """
NEW
36
    warnings.warn(
×
37
        f"Device residency is enabled for this function on backend {backend}. "
38
        f"Function arguments are passed from host memory slowing down execution. "
39
        f"Provide device-resident arrays to avoid unnecessary copies and improve performance.",
40
        DoccPerformanceWarning,
41
        stacklevel=3,
42
    )
43

44

45
def idiv(a, b):
4✔
46
    """Integer division (floor division for positive numbers)."""
47
    return int(a) // int(b)
4✔
48

49

50
def _is_device_array(arg):
4✔
51
    """Return True if ``arg`` already lives in device memory.
52

53
    cupy arrays and CUDA torch tensors expose ``__cuda_array_interface__``; a
54
    torch tensor reports its location via ``is_cuda``. Host arrays (numpy, CPU
55
    torch tensors) return False.
56
    """
NEW
57
    if getattr(arg, "__cuda_array_interface__", None) is not None:
×
NEW
58
        return True
×
NEW
59
    is_cuda = getattr(arg, "is_cuda", None)
×
NEW
60
    if is_cuda is not None:
×
NEW
61
        return bool(is_cuda)
×
NEW
62
    return False
×
63

64

65
def _device_array_ptr(arg):
4✔
66
    """Extract the raw device pointer from a GPU array (cupy or torch.cuda).
67

68
    Both cupy arrays and CUDA torch tensors expose ``__cuda_array_interface__``;
69
    torch tensors additionally expose ``data_ptr()``. Returns an integer address.
70
    """
NEW
71
    cai = getattr(arg, "__cuda_array_interface__", None)
×
NEW
72
    if cai is not None:
×
NEW
73
        return cai["data"][0]
×
NEW
74
    data_ptr = getattr(arg, "data_ptr", None)
×
NEW
75
    if callable(data_ptr):
×
NEW
76
        return data_ptr()
×
NEW
77
    raise TypeError(
×
78
        f"Device-resident execution requires a GPU array exposing "
79
        f"__cuda_array_interface__ or data_ptr(), got {type(arg).__name__}"
80
    )
81

82

83
# Evaluation context for shape expressions
84
_EVAL_GLOBALS = {"idiv": idiv}
4✔
85

86
# Pre-compiled regex for _convert_to_python_syntax
87
_FUNC_CALL_PATTERN = re.compile(r"([a-zA-Z_][a-zA-Z0-9_]*)\(([^()]+)\)")
4✔
88
_PLACEHOLDER_PATTERN = re.compile(
4✔
89
    r"@@@FUNC@@@([a-zA-Z_][a-zA-Z0-9_]*)@@@(.+?)@@@END@@@"
90
)
91
_KNOWN_FUNCTIONS = frozenset(
4✔
92
    {"int", "float", "abs", "min", "max", "sum", "len", "idiv"}
93
)
94

95
# Argument type constants for fast dispatch
96
_ARG_TYPE_OUTPUT_ARRAY = 0
4✔
97
_ARG_TYPE_OUTPUT_SCALAR = 1
4✔
98
_ARG_TYPE_SHAPE = 2
4✔
99
_ARG_TYPE_USER_ARRAY = 3
4✔
100
_ARG_TYPE_USER_STRUCT = 4
4✔
101
_ARG_TYPE_USER_SCALAR = 5
4✔
102

103
# Call modes. A single call uses exactly one mode (mixing array kinds is not
104
# allowed). The mode, combined with whether the artifact is device-resident,
105
# fully determines the execution path:
106
#
107
#   mode      | non-device-resident        | device-resident
108
#   ----------|----------------------------|-------------------------------
109
#   NumPyCPU  | host execution (default)   | host->device copy + warning
110
#   NumPyGPU  | rejected (TypeError)       | zero-copy device execution
111
#   TorchCPU  | host execution             | host->device copy + warning
112
#   TorchGPU  | rejected (TypeError)       | zero-copy device execution
113
_CALL_MODE_NUMPY_CPU = "NumPyCPU"
4✔
114
_CALL_MODE_NUMPY_GPU = "NumPyGPU"
4✔
115
_CALL_MODE_TORCH_CPU = "TorchCPU"
4✔
116
_CALL_MODE_TORCH_GPU = "TorchGPU"
4✔
117

118
# Modes whose data lives in device memory; these require a device-resident
119
# artifact, otherwise the call is rejected.
120
_GPU_CALL_MODES = frozenset({_CALL_MODE_NUMPY_GPU, _CALL_MODE_TORCH_GPU})
4✔
121

122

123
def _is_torch_tensor(arg):
4✔
124
    """Return True if ``arg`` is a torch tensor (without importing torch)."""
NEW
125
    return type(arg).__module__.split(".", 1)[0] == "torch"
×
126

127

128
def _classify_array_kind(arg):
4✔
129
    """Classify a single argument into one of the four call modes.
130

131
    Returns one of the ``_CALL_MODE_*`` constants for array arguments, or None
132
    for mode-agnostic values (Python/numpy scalars, structures, ...). Neither
133
    torch nor cupy is imported: classification relies on the defining module and
134
    the array's own attributes.
135
    """
136
    root = type(arg).__module__.split(".", 1)[0]
4✔
137
    if root == "torch":
4✔
NEW
138
        return (
×
139
            _CALL_MODE_TORCH_GPU
140
            if getattr(arg, "is_cuda", False)
141
            else _CALL_MODE_TORCH_CPU
142
        )
143
    if root == "cupy":
4✔
NEW
144
        return _CALL_MODE_NUMPY_GPU
×
145
    if isinstance(arg, np.ndarray):
4✔
146
        return _CALL_MODE_NUMPY_CPU
4✔
147
    # Any other object exposing the CUDA array interface is a device array.
148
    if getattr(arg, "__cuda_array_interface__", None) is not None:
4✔
NEW
149
        return _CALL_MODE_NUMPY_GPU
×
150
    return None
4✔
151

152

153
# Pre-cache ctypes.c_int64 for speed
154
_c_int64 = ctypes.c_int64
4✔
155
_ctypes_cast = ctypes.cast
4✔
156
_ctypes_addressof = ctypes.addressof
4✔
157
_ctypes_byref = ctypes.byref
4✔
158
_ctypes_pointer = ctypes.pointer
4✔
159
_ctypes_c_void_p = ctypes.c_void_p
4✔
160

161
# Map primitive types to numpy dtypes for fast buffer allocation
162
_PRIMITIVE_TO_NP_DTYPE = {
4✔
163
    PrimitiveType.Float: np.float32,
164
    PrimitiveType.Double: np.float64,
165
    PrimitiveType.Int8: np.int8,
166
    PrimitiveType.Int16: np.int16,
167
    PrimitiveType.Int32: np.int32,
168
    PrimitiveType.Int64: np.int64,
169
    PrimitiveType.UInt8: np.uint8,
170
    PrimitiveType.UInt16: np.uint16,
171
    PrimitiveType.UInt32: np.uint32,
172
    PrimitiveType.UInt64: np.uint64,
173
    PrimitiveType.Bool: np.bool_,
174
    PrimitiveType.Half: np.float16,
175
    PrimitiveType.BFloat: ml_dtypes.bfloat16,
176
}
177

178
# Pre-computed dtype map for numpy conversion
179
_NUMPY_DTYPE_MAP = {
4✔
180
    PrimitiveType.Float: np.float32,
181
    PrimitiveType.Double: np.float64,
182
    PrimitiveType.Int8: np.int8,
183
    PrimitiveType.Int16: np.int16,
184
    PrimitiveType.Int32: np.int32,
185
    PrimitiveType.Int64: np.int64,
186
    PrimitiveType.UInt8: np.uint8,
187
    PrimitiveType.UInt16: np.uint16,
188
    PrimitiveType.UInt32: np.uint32,
189
    PrimitiveType.UInt64: np.uint64,
190
    PrimitiveType.Bool: np.bool_,
191
    PrimitiveType.Half: np.float16,
192
    PrimitiveType.BFloat: ml_dtypes.bfloat16,
193
}
194

195
_CTYPES_MAP = {
4✔
196
    PrimitiveType.Bool: ctypes.c_bool,
197
    PrimitiveType.Int8: ctypes.c_int8,
198
    PrimitiveType.Int16: ctypes.c_int16,
199
    PrimitiveType.Int32: ctypes.c_int32,
200
    PrimitiveType.Int64: ctypes.c_int64,
201
    PrimitiveType.UInt8: ctypes.c_uint8,
202
    PrimitiveType.UInt16: ctypes.c_uint16,
203
    PrimitiveType.UInt32: ctypes.c_uint32,
204
    PrimitiveType.UInt64: ctypes.c_uint64,
205
    PrimitiveType.Float: ctypes.c_float,
206
    PrimitiveType.Double: ctypes.c_double,
207
    # Half and BFloat are 2 bytes, use c_uint16 for raw storage
208
    PrimitiveType.Half: ctypes.c_uint16,
209
    PrimitiveType.BFloat: ctypes.c_uint16,
210
}
211

212

213
class CompiledSDFG:
4✔
214
    def __init__(
4✔
215
        self,
216
        lib_path,
217
        sdfg,
218
        shape_sources=None,
219
        structure_member_info=None,
220
        output_args=None,
221
        output_shapes=None,
222
        output_strides=None,
223
        device_resident=False,
224
        device_backend=None,
225
        target=None,
226
    ):
227
        self.lib_path = lib_path
4✔
228
        self.sdfg = sdfg
4✔
229
        self.shape_sources = shape_sources or []
4✔
230
        self.structure_member_info = structure_member_info or {}
4✔
231
        self.lib = ctypes.CDLL(lib_path)
4✔
232
        self.func = getattr(self.lib, sdfg.name)
4✔
233

234
        # Check for output args
235
        self.output_args = output_args or []
4✔
236
        if not self.output_args and hasattr(sdfg, "metadata"):
4✔
237
            out_args_str = sdfg.metadata("output_args")
4✔
238
            if out_args_str:
4✔
239
                self.output_args = out_args_str.split(",")
×
240

241
        self.output_shapes = output_shapes or {}
4✔
242
        self.output_strides = output_strides or {}
4✔
243

244
        # Device residency: set by the DeviceResidentArgPromotion pass when all
245
        # pointer arguments were promoted to device-resident storage. When active,
246
        # the compiled function expects device pointers (no host<->device copies at
247
        # the boundary) and produces device-resident outputs. Communicated
248
        # explicitly via the constructor (pass return value), not via metadata.
249
        self.device_resident = bool(device_resident)
4✔
250
        self.device_backend = device_backend or (
4✔
251
            "cuda" if self.device_resident else None
252
        )
253
        # Warn at most once per artifact when host inputs must be copied to device.
254
        self._warned_host_to_device = False
4✔
255

256
        # Compilation target (e.g. "cuda"/"rocm"/"sequential"). Used to inform,
257
        # once, when a GPU target ran on the host because device-residency
258
        # promotion did not apply to this artifact.
259
        self.target = target
4✔
260
        self._warned_residency_failed = False
4✔
261

262
        # Cache for ctypes structure definitions
263
        self._ctypes_structures = {}
4✔
264

265
        # Set up argument types
266
        self.arg_names = sdfg.arguments
4✔
267
        self.arg_types = []
4✔
268
        self.arg_sdfg_types = []  # Keep track of original sdfg types
4✔
269
        for arg_name in sdfg.arguments:
4✔
270
            arg_type = sdfg.type(arg_name)
4✔
271
            self.arg_sdfg_types.append(arg_type)
4✔
272
            ct_type = self._get_ctypes_type(arg_type)
4✔
273
            self.arg_types.append(ct_type)
4✔
274

275
        self.func.argtypes = self.arg_types
4✔
276

277
        # Set up return type
278
        self.func.restype = self._get_ctypes_type(sdfg.return_type)
4✔
279

280
        # Pre-compute argument classification for fast __call__
281
        self._precompute_arg_metadata()
4✔
282

283
    def _precompute_arg_metadata(self):
4✔
284
        """Pre-compute argument metadata for fast __call__ dispatch."""
285
        output_args_set = set(self.output_args)
4✔
286

287
        # Build shape source lookup: s_idx -> (u_idx, dim_idx)
288
        # Also pre-compute the shape keys
289
        self._shape_sources_list = []  # [(s_idx, u_idx, dim_idx, key_str), ...]
4✔
290
        for i, (u_idx, dim_idx) in enumerate(self.shape_sources):
4✔
291
            self._shape_sources_list.append((i, u_idx, dim_idx, f"_s{i}"))
4✔
292

293
        # Classify each argument using tuple-based info for faster access
294
        # Each entry is (arg_type, *type_specific_data)
295
        self._arg_info = []
4✔
296
        user_arg_counter = 0
4✔
297

298
        # For output ordering (avoid sorting at runtime)
299
        output_order = []
4✔
300

301
        for i, arg_name in enumerate(self.arg_names):
4✔
302
            if arg_name in output_args_set:
4✔
303
                # Output argument
304
                target_type = self.arg_types[i]
4✔
305
                base_type = target_type._type_
4✔
306
                sdfg_type = self.arg_sdfg_types[i]
4✔
307

308
                # Pre-compute primitive type for return processing
309
                primitive_type = None
4✔
310
                if isinstance(sdfg_type, Pointer) and sdfg_type.has_pointee_type():
4✔
311
                    pointee = sdfg_type.pointee_type
4✔
312
                    if isinstance(pointee, Scalar):
4✔
313
                        primitive_type = pointee.primitive_type
4✔
314

315
                if arg_name in self.output_shapes:
4✔
316
                    dims = self.output_shapes[arg_name]
4✔
317
                    # Always compile shape expressions - they may depend on runtime values
318
                    compiled_dims = []
4✔
319
                    for d in dims:
4✔
320
                        d_str = str(d)
4✔
321
                        expr = self._convert_to_python_syntax(d_str)
4✔
322
                        compiled_dims.append(compile(expr, "<shape>", "eval"))
4✔
323

324
                    # Pre-compile stride expressions if available
325
                    compiled_strides = None
4✔
326
                    if arg_name in self.output_strides:
4✔
327
                        compiled_strides = []
4✔
328
                        for s in self.output_strides[arg_name]:
4✔
329
                            expr = self._convert_to_python_syntax(str(s))
4✔
330
                            compiled_strides.append(compile(expr, "<stride>", "eval"))
4✔
331

332
                    # Get numpy dtype for fast allocation
333
                    np_dtype = (
4✔
334
                        _PRIMITIVE_TO_NP_DTYPE.get(primitive_type, np.float64)
335
                        if primitive_type
336
                        else np.float64
337
                    )
338

339
                    # Tuple: (arg_type, name, base_type, target_type, compiled_dims, compiled_strides, primitive_type, np_dtype)
340
                    info_idx = len(self._arg_info)
4✔
341
                    self._arg_info.append(
4✔
342
                        (
343
                            _ARG_TYPE_OUTPUT_ARRAY,
344
                            arg_name,
345
                            base_type,
346
                            target_type,
347
                            compiled_dims,
348
                            compiled_strides,
349
                            primitive_type,
350
                            np_dtype,
351
                        )
352
                    )
353
                    output_order.append((int(arg_name.split("_")[-1]), info_idx))
4✔
354
                else:
355
                    # Scalar return
356
                    info_idx = len(self._arg_info)
4✔
357
                    self._arg_info.append(
4✔
358
                        (_ARG_TYPE_OUTPUT_SCALAR, arg_name, base_type, primitive_type)
359
                    )
360
                    output_order.append((int(arg_name.split("_")[-1]), info_idx))
4✔
361

362
            elif arg_name.startswith("_s") and arg_name[2:].isdigit():
4✔
363
                # Shape symbol argument - tuple: (arg_type, s_idx, key_str)
364
                s_idx = int(arg_name[2:])
4✔
365
                self._arg_info.append((_ARG_TYPE_SHAPE, s_idx, f"_s{s_idx}"))
4✔
366
            else:
367
                # User argument
368
                sdfg_type = self.arg_sdfg_types[i]
4✔
369
                target_type = self.arg_types[i]
4✔
370
                is_struct_ptr = (
4✔
371
                    sdfg_type
372
                    and isinstance(sdfg_type, Pointer)
373
                    and sdfg_type.has_pointee_type()
374
                    and isinstance(sdfg_type.pointee_type, Structure)
375
                )
376

377
                if is_struct_ptr:
4✔
378
                    struct_name = sdfg_type.pointee_type.name
4✔
379
                    struct_class = self._create_ctypes_structure(struct_name)
4✔
380
                    members = self.structure_member_info[struct_name]
4✔
381
                    sorted_members = tuple(
4✔
382
                        sorted(members.items(), key=lambda x: x[1][0])
383
                    )
384
                    # Tuple: (arg_type, user_idx, name, struct_class, sorted_members)
385
                    self._arg_info.append(
4✔
386
                        (
387
                            _ARG_TYPE_USER_STRUCT,
388
                            user_arg_counter,
389
                            arg_name,
390
                            struct_class,
391
                            sorted_members,
392
                        )
393
                    )
394
                elif hasattr(target_type, "contents"):
4✔
395
                    # Array user arg - tuple: (arg_type, user_idx, name, target_type)
396
                    self._arg_info.append(
4✔
397
                        (_ARG_TYPE_USER_ARRAY, user_arg_counter, arg_name, target_type)
398
                    )
399
                else:
400
                    # Scalar user arg - tuple: (arg_type, user_idx, name, target_type)
401
                    self._arg_info.append(
4✔
402
                        (_ARG_TYPE_USER_SCALAR, user_arg_counter, arg_name, target_type)
403
                    )
404
                user_arg_counter += 1
4✔
405

406
        self._num_user_args = user_arg_counter
4✔
407

408
        # Pre-sort output order and build position map
409
        output_order.sort(key=lambda x: x[0])
4✔
410
        self._output_order = tuple(idx for _, idx in output_order)
4✔
411
        # Map from _arg_info index to result position (for O(1) lookup)
412
        self._output_pos_map = {idx: pos for pos, idx in enumerate(self._output_order)}
4✔
413

414
    def _convert_to_python_syntax(self, expr_str):
4✔
415
        result = expr_str
4✔
416

417
        while True:
4✔
418
            match = _FUNC_CALL_PATTERN.search(result)
4✔
419
            if not match:
4✔
420
                break
4✔
421

422
            name = match.group(1)
4✔
423
            index = match.group(2)
4✔
424

425
            if name.lower() in _KNOWN_FUNCTIONS:
4✔
426
                # Use unique delimiters that won't appear in expressions
427
                placeholder = f"@@@FUNC@@@{name}@@@{index}@@@END@@@"
4✔
428
                result = result[: match.start()] + placeholder + result[match.end() :]
4✔
429
            else:
430
                result = (
×
431
                    result[: match.start()] + f"{name}[{index}]" + result[match.end() :]
432
                )
433

434
        result = _PLACEHOLDER_PATTERN.sub(r"\1(\2)", result)
4✔
435

436
        return result
4✔
437

438
    def _create_ctypes_structure(self, struct_name):
4✔
439
        """Create a ctypes Structure class for the given structure name."""
440
        if struct_name in self._ctypes_structures:
4✔
441
            return self._ctypes_structures[struct_name]
4✔
442

443
        if struct_name not in self.structure_member_info:
4✔
444
            raise ValueError(f"Structure '{struct_name}' not found in member info")
×
445

446
        # Get member info: {member_name: (index, type)}
447
        members = self.structure_member_info[struct_name]
4✔
448
        # Sort by index to get correct order
449
        sorted_members = sorted(members.items(), key=lambda x: x[1][0])
4✔
450

451
        # Build _fields_ for ctypes.Structure
452
        fields = []
4✔
453
        for member_name, (index, member_type) in sorted_members:
4✔
454
            ct_type = self._get_ctypes_type(member_type)
4✔
455
            fields.append((member_name, ct_type))
4✔
456

457
        # Create the ctypes Structure class dynamically
458
        class CStructure(ctypes.Structure):
4✔
459
            _fields_ = fields
4✔
460

461
        self._ctypes_structures[struct_name] = CStructure
4✔
462
        return CStructure
4✔
463

464
    def _get_ctypes_type(self, sdfg_type):
4✔
465
        if isinstance(sdfg_type, Scalar):
4✔
466
            return _CTYPES_MAP.get(sdfg_type.primitive_type, ctypes.c_void_p)
4✔
467
        elif isinstance(sdfg_type, Array):
4✔
468
            # Arrays are passed as pointers
469
            elem_type = _CTYPES_MAP.get(sdfg_type.primitive_type, ctypes.c_void_p)
×
470
            return ctypes.POINTER(elem_type)
×
471
        elif isinstance(sdfg_type, Pointer):
4✔
472
            # Check if pointee is a Structure
473
            # Note: has_pointee_type() is guaranteed to exist on Pointer instances from C++ bindings
474
            if sdfg_type.has_pointee_type():
4✔
475
                pointee = sdfg_type.pointee_type
4✔
476
                if isinstance(pointee, Structure):
4✔
477
                    # Create ctypes structure and return pointer to it
478
                    struct_class = self._create_ctypes_structure(pointee.name)
4✔
479
                    return ctypes.POINTER(struct_class)
4✔
480
                elif isinstance(pointee, Scalar):
4✔
481
                    elem_type = _CTYPES_MAP.get(pointee.primitive_type, ctypes.c_void_p)
4✔
482
                    return ctypes.POINTER(elem_type)
4✔
483
            return ctypes.c_void_p
×
484
        return ctypes.c_void_p
×
485

486
    def _convert_return_value(self, func_result, shape_symbol_values):
4✔
487
        return_type = self.sdfg.return_type
4✔
488

489
        if isinstance(return_type, Pointer):
4✔
490
            if return_type.has_pointee_type():
×
491
                pointee = return_type.pointee_type
×
492
                if isinstance(pointee, Scalar):
×
493
                    # Pointer to scalar element type - need to determine array size
494
                    # Get return shape from metadata if available
495
                    return_shape_str = self.sdfg.metadata("return_shape")
×
496
                    if return_shape_str:
×
497
                        # Strip brackets (metadata may be "[10,10]" format)
498
                        return_shape_str = return_shape_str.strip("[]")
×
499
                        shape = []
×
500
                        for dim_str in return_shape_str.split(","):
×
501
                            try:
×
502
                                eval_str = self._convert_to_python_syntax(str(dim_str))
×
503
                                val = eval(eval_str, _EVAL_GLOBALS, shape_symbol_values)
×
504
                                shape.append(int(val))
×
505
                            except Exception:
×
506
                                # Can't evaluate shape, return raw pointer
507
                                return func_result
×
508

509
                        # Determine numpy dtype from primitive type
510
                        dtype = _NUMPY_DTYPE_MAP.get(pointee.primitive_type, np.float64)
×
511

512
                        # Calculate total size
513
                        total_size = 1
×
514
                        for dim in shape:
×
515
                            total_size *= dim
×
516

517
                        # Create numpy array from pointer
518
                        ct_type = _CTYPES_MAP.get(
×
519
                            pointee.primitive_type, ctypes.c_double
520
                        )
521
                        arr_type = ct_type * total_size
×
522
                        # For Half/BFloat, use np.frombuffer since np.ctypeslib.as_array
523
                        # doesn't support these types (PEP 3118 buffer format limitation)
524
                        if pointee.primitive_type in (
×
525
                            PrimitiveType.Half,
526
                            PrimitiveType.BFloat,
527
                        ):
528
                            byte_size = total_size * 2  # Half and BFloat are 2 bytes
×
529
                            arr = np.frombuffer(
×
530
                                (ctypes.c_char * byte_size).from_address(
531
                                    ctypes.cast(func_result, ctypes.c_void_p).value
532
                                ),
533
                                dtype=dtype,
534
                            ).copy()
535
                        else:
536
                            arr = np.ctypeslib.as_array(
×
537
                                ctypes.cast(
538
                                    func_result, ctypes.POINTER(arr_type)
539
                                ).contents
540
                            )
541
                        return arr.reshape(shape)
×
542
                    else:
543
                        # No shape info - try to infer from input shapes
544
                        # For identity-like operations, the output shape matches input
545
                        if len(self.shape_sources) > 0 and len(shape_symbol_values) > 0:
×
546
                            # Use first input's shape as a fallback
547
                            shape = []
×
548
                            for i in range(len(self.shape_sources)):
×
549
                                if f"_s{i}" in shape_symbol_values:
×
550
                                    shape.append(shape_symbol_values[f"_s{i}"])
×
551

552
                            if shape:
×
553
                                dtype = _NUMPY_DTYPE_MAP.get(
×
554
                                    pointee.primitive_type, np.float64
555
                                )
556

557
                                total_size = 1
×
558
                                for dim in shape:
×
559
                                    total_size *= dim
×
560

561
                                ct_type = _CTYPES_MAP.get(
×
562
                                    pointee.primitive_type, ctypes.c_double
563
                                )
564
                                arr_type = ct_type * total_size
×
565
                                # For Half/BFloat, use np.frombuffer since np.ctypeslib.as_array
566
                                # doesn't support these types (PEP 3118 buffer format limitation)
567
                                if pointee.primitive_type in (
×
568
                                    PrimitiveType.Half,
569
                                    PrimitiveType.BFloat,
570
                                ):
571
                                    byte_size = (
×
572
                                        total_size * 2
573
                                    )  # Half and BFloat are 2 bytes
574
                                    arr = np.frombuffer(
×
575
                                        (ctypes.c_char * byte_size).from_address(
576
                                            ctypes.cast(
577
                                                func_result, ctypes.c_void_p
578
                                            ).value
579
                                        ),
580
                                        dtype=dtype,
581
                                    ).copy()
582
                                else:
583
                                    arr = np.ctypeslib.as_array(
×
584
                                        ctypes.cast(
585
                                            func_result, ctypes.POINTER(arr_type)
586
                                        ).contents
587
                                    )
588
                                return arr.reshape(shape)
×
589

590
                        # Can't determine shape, return raw pointer
591
                        return func_result
×
592
        elif isinstance(return_type, Scalar):
4✔
593
            return func_result
4✔
594

595
        return func_result
×
596

597
    def _ensure_device_array(self, arg, keepalive, writebacks):
4✔
598
        """Return a device array for ``arg``, copying host inputs to the device.
599

600
        Device-resident artifacts consume raw device pointers. For backwards
601
        compatibility, callers may still pass host arrays (numpy arrays or CPU
602
        torch tensors); these are copied to the device here and a one-time
603
        performance warning is emitted. Device arrays are returned unchanged.
604

605
        The copied device array is appended to ``keepalive`` so it outlives the
606
        call, and the ``(host, device)`` pair is recorded in ``writebacks`` so
607
        that in-place writes (e.g. output arguments) are mirrored back to the
608
        original host array after execution.
609
        """
NEW
610
        if _is_device_array(arg):
×
NEW
611
            return arg
×
612

NEW
613
        import cupy
×
614

NEW
615
        host = arg
×
616
        # Convert a CPU torch tensor to numpy first; cupy.asarray handles numpy.
NEW
617
        if hasattr(host, "detach") and hasattr(host, "numpy"):
×
NEW
618
            host = host.detach().numpy()
×
NEW
619
        device_arg = cupy.asarray(host)
×
NEW
620
        keepalive.append(device_arg)
×
NEW
621
        writebacks.append((arg, device_arg))
×
622

NEW
623
        if not self._warned_host_to_device:
×
NEW
624
            self._warned_host_to_device = True
×
NEW
625
            warn_host_to_device_copies(self.device_backend or "cuda")
×
626

NEW
627
        return device_arg
×
628

629
    def _call_device(self, *args):
4✔
630
        """Execute with device-resident arguments (cupy / torch.cuda).
631

632
        Inputs are passed as raw device pointers and output arrays are allocated
633
        on the device, so no host<->device copies happen at the call boundary.
634
        Outputs are returned as cupy arrays (zero-copy interoperable with torch
635
        via DLPack / __cuda_array_interface__).
636
        """
NEW
637
        import cupy
×
638

NEW
639
        _eval = eval
×
NEW
640
        _GLOBALS = _EVAL_GLOBALS
×
641
        # 1. Build shape_symbol_values from shape sources
NEW
642
        shape_symbol_values = {}
×
NEW
643
        for s_idx, u_idx, dim_idx, key in self._shape_sources_list:
×
NEW
644
            if u_idx < len(args):
×
NEW
645
                shape_symbol_values[key] = args[u_idx].shape[dim_idx]
×
646

647
        # 2. Process arguments
NEW
648
        converted_args = []
×
NEW
649
        keepalive = []  # keep device buffers / ctypes scalars alive
×
NEW
650
        writebacks = []  # (host_arg, device_arg) for host inputs copied to device
×
NEW
651
        return_buffers = []  # (buf, size, dims, compiled_strides, primitive_type)
×
652
        # Track whether the caller supplied device arrays. When all array inputs
653
        # are host arrays (numpy / CPU torch), the caller works in host space and
654
        # expects host outputs, so device output buffers are copied back to host.
NEW
655
        any_device_input = False
×
656

NEW
657
        for info in self._arg_info:
×
NEW
658
            arg_type = info[0]
×
659

NEW
660
            if arg_type == _ARG_TYPE_OUTPUT_ARRAY:
×
NEW
661
                target_type = info[3]
×
NEW
662
                compiled_dims = info[4]
×
NEW
663
                compiled_strides = info[5]
×
NEW
664
                np_dtype = info[7]
×
665

NEW
666
                size = 1
×
NEW
667
                dims = []
×
NEW
668
                for code in compiled_dims:
×
NEW
669
                    d = int(_eval(code, _GLOBALS, shape_symbol_values))
×
NEW
670
                    dims.append(d)
×
NEW
671
                    size *= d
×
672

NEW
673
                buf = cupy.empty(size, dtype=np_dtype)
×
NEW
674
                keepalive.append(buf)
×
NEW
675
                return_buffers.append((buf, size, dims, compiled_strides, info[6]))
×
NEW
676
                converted_args.append(_ctypes_cast(int(buf.data.ptr), target_type))
×
677

NEW
678
            elif arg_type == _ARG_TYPE_OUTPUT_SCALAR:
×
NEW
679
                base_type = info[2]
×
NEW
680
                primitive_type = info[3]
×
NEW
681
                buf = base_type()
×
NEW
682
                keepalive.append(buf)
×
NEW
683
                return_buffers.append((buf, 1, None, None, primitive_type))
×
NEW
684
                converted_args.append(_ctypes_byref(buf))
×
685

NEW
686
            elif arg_type == _ARG_TYPE_SHAPE:
×
NEW
687
                converted_args.append(_c_int64(shape_symbol_values.get(info[2], 0)))
×
688

NEW
689
            elif arg_type == _ARG_TYPE_USER_ARRAY:
×
NEW
690
                arg = args[info[1]]
×
NEW
691
                shape_symbol_values[info[2]] = arg
×
NEW
692
                if _is_device_array(arg):
×
NEW
693
                    any_device_input = True
×
NEW
694
                arg = self._ensure_device_array(arg, keepalive, writebacks)
×
NEW
695
                converted_args.append(_ctypes_cast(_device_array_ptr(arg), info[3]))
×
696

NEW
697
            elif arg_type == _ARG_TYPE_USER_STRUCT:
×
NEW
698
                raise NotImplementedError(
×
699
                    "Structure arguments are not supported for device-resident "
700
                    "execution."
701
                )
702

703
            else:  # _ARG_TYPE_USER_SCALAR
NEW
704
                arg = args[info[1]]
×
NEW
705
                shape_symbol_values[info[2]] = arg
×
NEW
706
                converted_args.append(info[3](arg))
×
707

708
        # 3. Call the function
NEW
709
        func_result = self.func(*converted_args)
×
710

711
        # 3b. Mirror device results back into host inputs that were copied to the
712
        # device, so in-place writes (e.g. output arguments) are visible to the
713
        # caller. Read-only inputs are unchanged and copy back identically.
NEW
714
        for host_arg, device_arg in writebacks:
×
NEW
715
            host_view = cupy.asnumpy(device_arg)
×
NEW
716
            if isinstance(host_arg, np.ndarray):
×
NEW
717
                np.copyto(host_arg, host_view.reshape(host_arg.shape))
×
NEW
718
            elif hasattr(host_arg, "copy_"):  # torch CPU tensor
×
NEW
719
                import torch
×
720

NEW
721
                host_arg.copy_(
×
722
                    torch.from_numpy(host_view.reshape(tuple(host_arg.shape)))
723
                )
724

725
        # 4. Process returns using pre-sorted order
NEW
726
        if not return_buffers:
×
NEW
727
            return None
×
728

729
        # Host callers get host (numpy) outputs; device callers get cupy outputs.
NEW
730
        host_mode = not any_device_input
×
731

NEW
732
        num_outputs = len(return_buffers)
×
NEW
733
        results = [None] * num_outputs
×
734

NEW
735
        buf_idx = 0
×
NEW
736
        for i, info in enumerate(self._arg_info):
×
NEW
737
            arg_type = info[0]
×
NEW
738
            if arg_type not in (_ARG_TYPE_OUTPUT_ARRAY, _ARG_TYPE_OUTPUT_SCALAR):
×
NEW
739
                continue
×
740

NEW
741
            result_pos = self._output_pos_map[i]
×
NEW
742
            buf, size, dims, compiled_strides, primitive_type = return_buffers[buf_idx]
×
NEW
743
            buf_idx += 1
×
744

NEW
745
            if arg_type == _ARG_TYPE_OUTPUT_SCALAR:
×
NEW
746
                results[result_pos] = buf.value
×
747
            else:
NEW
748
                arr = buf
×
NEW
749
                if dims and len(dims) > 1:
×
NEW
750
                    arr = arr.reshape(dims)
×
NEW
751
                if host_mode:
×
NEW
752
                    arr = cupy.asnumpy(arr)
×
NEW
753
                results[result_pos] = arr
×
754

NEW
755
        if len(results) == 1:
×
NEW
756
            return results[0]
×
NEW
757
        return tuple(results) if results else None
×
758

759
    def _resolve_call_mode(self, args):
4✔
760
        """Classify the call's array arguments into exactly one call mode.
761

762
        All array arguments must be the same kind; mixing numpy / cupy / CPU
763
        torch / CUDA torch arrays in a single call is rejected. Calls without any
764
        array argument default to ``NumPyCPU`` (host execution).
765
        """
766
        modes = set()
4✔
767
        for info in self._arg_info:
4✔
768
            arg_type = info[0]
4✔
769
            if arg_type == _ARG_TYPE_USER_ARRAY or arg_type == _ARG_TYPE_USER_STRUCT:
4✔
770
                kind = _classify_array_kind(args[info[1]])
4✔
771
                if kind is not None:
4✔
772
                    modes.add(kind)
4✔
773

774
        if not modes:
4✔
775
            return _CALL_MODE_NUMPY_CPU
4✔
776
        if len(modes) > 1:
4✔
NEW
777
            raise TypeError(
×
778
                "Mixed array kinds are not allowed in a single call; every array "
779
                f"argument must be the same kind, but got {sorted(modes)}. "
780
                "Provide all inputs (and outputs) as one of: numpy arrays, cupy "
781
                "arrays, CPU torch tensors, or CUDA torch tensors."
782
            )
783
        return next(iter(modes))
4✔
784

785
    def __call__(self, *args):
4✔
786
        """Execute the compiled artifact, dispatching by call mode.
787

788
        Exactly one call mode is allowed per call (no mixing of array kinds).
789
        GPU modes (cupy / CUDA torch) require a device-resident artifact; host
790
        modes (numpy / CPU torch) run on the device with a one-time performance
791
        warning when the artifact is device-resident, otherwise on the host.
792
        """
793
        mode = self._resolve_call_mode(args)
4✔
794

795
        if mode in _GPU_CALL_MODES and not self.device_resident:
4✔
NEW
796
            raise TypeError(
×
797
                f"{mode} arguments were provided, but this artifact is not "
798
                "device-resident. GPU arrays can only be used with a "
799
                "device-resident artifact (a fully-offloadable kernel compiled "
800
                "for a GPU target). Provide host arrays (numpy arrays or CPU "
801
                "torch tensors) instead."
802
            )
803

804
        # Device-resident artifacts consume/produce device pointers directly.
805
        # Host inputs are copied to the device inside _call_device (with a
806
        # one-time warning); device inputs are consumed zero-copy.
807
        if self.device_resident:
4✔
NEW
808
            return self._call_device(*args)
×
809

810
        # Host execution path. CPU torch tensors are converted to numpy views so
811
        # the ctypes boundary can take their data pointer.
812
        # Inform once when a GPU target falls back to host execution because the
813
        # device-residency optimization did not apply to this artifact.
814
        if self.target in ("cuda", "rocm") and not self._warned_residency_failed:
4✔
NEW
815
            self._warned_residency_failed = True
×
NEW
816
            warn_device_residency_failed(self.target)
×
817

818
        if mode == _CALL_MODE_TORCH_CPU:
4✔
NEW
819
            args = tuple(
×
820
                a.detach().cpu().contiguous().numpy() if _is_torch_tensor(a) else a
821
                for a in args
822
            )
823
        return self._call_host(*args)
4✔
824

825
    def _call_host(self, *args):
4✔
826
        # Ultra-fast path using pre-computed tuple-based argument info
827
        # Local variable caching for speed
828
        _eval = eval
4✔
829
        _GLOBALS = _EVAL_GLOBALS
4✔
830
        _np_empty = np.empty
4✔
831

832
        # 1. Build shape_symbol_values from shape sources (pre-computed list)
833
        shape_symbol_values = {}
4✔
834
        for s_idx, u_idx, dim_idx, key in self._shape_sources_list:
4✔
835
            if u_idx < len(args):
4✔
836
                shape_symbol_values[key] = args[u_idx].shape[dim_idx]
4✔
837

838
        # 2. Process arguments using tuple-based dispatch
839
        converted_args = []
4✔
840
        structure_refs = (
4✔
841
            []
842
        )  # Keep refs alive (includes numpy arrays for output buffers)
843
        return_buffers = (
4✔
844
            []
845
        )  # List of (np_arr, size, dims, compiled_strides, primitive_type)
846

847
        for info in self._arg_info:
4✔
848
            arg_type = info[0]
4✔
849

850
            if arg_type == _ARG_TYPE_OUTPUT_ARRAY:
4✔
851
                # info = (type, name, base_type, target_type, compiled_dims, compiled_strides, primitive_type, np_dtype)
852
                target_type = info[3]
4✔
853
                compiled_dims = info[4]
4✔
854
                compiled_strides = info[5]
4✔
855
                np_dtype = info[7]
4✔
856

857
                # Evaluate size from compiled code objects
858
                size = 1
4✔
859
                dims = []
4✔
860
                for code in compiled_dims:
4✔
861
                    d = int(_eval(code, _GLOBALS, shape_symbol_values))
4✔
862
                    dims.append(d)
4✔
863
                    size *= d
4✔
864

865
                # Use numpy for fast allocation (much faster than ctypes)
866
                buf_arr = _np_empty(size, dtype=np_dtype)
4✔
867
                structure_refs.append(buf_arr)  # Keep alive
4✔
868
                return_buffers.append((buf_arr, size, dims, compiled_strides, info[6]))
4✔
869
                # Get pointer directly from numpy array interface
870
                ptr = buf_arr.ctypes.data
4✔
871
                converted_args.append(_ctypes_cast(ptr, target_type))
4✔
872

873
            elif arg_type == _ARG_TYPE_OUTPUT_SCALAR:
4✔
874
                # info = (type, name, base_type, primitive_type)
875
                base_type = info[2]
4✔
876
                primitive_type = info[3]
4✔
877
                buf = base_type()
4✔
878
                structure_refs.append(buf)
4✔
879
                return_buffers.append((buf, 1, None, None, primitive_type))
4✔
880
                converted_args.append(_ctypes_byref(buf))
4✔
881

882
            elif arg_type == _ARG_TYPE_SHAPE:
4✔
883
                # info = (type, s_idx, key_str)
884
                converted_args.append(_c_int64(shape_symbol_values.get(info[2], 0)))
4✔
885

886
            elif arg_type == _ARG_TYPE_USER_ARRAY:
4✔
887
                # info = (type, user_idx, name, target_type)
888
                user_idx = info[1]
4✔
889
                arg = args[user_idx]
4✔
890
                shape_symbol_values[info[2]] = arg  # For indirect access
4✔
891
                # Direct pointer access - faster than data_as()
892
                converted_args.append(_ctypes_cast(arg.ctypes.data, info[3]))
4✔
893

894
            elif arg_type == _ARG_TYPE_USER_STRUCT:
4✔
895
                # info = (type, user_idx, name, struct_class, sorted_members)
896
                arg = args[info[1]]
4✔
897
                shape_symbol_values[info[2]] = arg
4✔
898
                struct_class = info[3]
4✔
899
                struct_values = {
4✔
900
                    m[0]: getattr(arg, m[0]) for m in info[4] if hasattr(arg, m[0])
901
                }
902
                c_struct = struct_class(**struct_values)
4✔
903
                structure_refs.append(c_struct)
4✔
904
                converted_args.append(_ctypes_pointer(c_struct))
4✔
905

906
            else:  # _ARG_TYPE_USER_SCALAR
907
                # info = (type, user_idx, name, target_type)
908
                arg = args[info[1]]
4✔
909
                shape_symbol_values[info[2]] = arg
4✔
910
                converted_args.append(info[3](arg))
4✔
911

912
        # 3. Call the function
913
        func_result = self.func(*converted_args)
4✔
914

915
        # 4. Process returns using pre-sorted order
916
        if not return_buffers:
4✔
917
            if func_result is not None:
4✔
918
                return self._convert_return_value(func_result, shape_symbol_values)
4✔
919
            return None
4✔
920

921
        # return_buffers: [(np_arr_or_ctypes_scalar, size, dims, compiled_strides, primitive_type), ...]
922
        num_outputs = len(return_buffers)
4✔
923
        results = [None] * num_outputs
4✔
924

925
        buf_idx = 0
4✔
926
        for i, info in enumerate(self._arg_info):
4✔
927
            arg_type = info[0]
4✔
928
            if arg_type not in (_ARG_TYPE_OUTPUT_ARRAY, _ARG_TYPE_OUTPUT_SCALAR):
4✔
929
                continue
4✔
930

931
            result_pos = self._output_pos_map[i]
4✔
932
            buf, size, dims, compiled_strides, primitive_type = return_buffers[buf_idx]
4✔
933
            buf_idx += 1
4✔
934

935
            if arg_type == _ARG_TYPE_OUTPUT_SCALAR:
4✔
936
                # Scalar - buf is a ctypes scalar
937
                results[result_pos] = buf.value
4✔
938
            else:
939
                # Array - buf is already a numpy array
940
                arr = buf
4✔
941
                if dims and len(dims) > 1:
4✔
942
                    # Need to reshape
943
                    if compiled_strides:
4✔
944
                        try:
4✔
945
                            itemsize = arr.itemsize
4✔
946
                            byte_strides = tuple(
4✔
947
                                int(_eval(s, {}, shape_symbol_values)) * itemsize
948
                                for s in compiled_strides
949
                            )
950
                            arr = np.lib.stride_tricks.as_strided(
4✔
951
                                arr, shape=dims, strides=byte_strides
952
                            )
953
                        except:
4✔
954
                            arr = arr.reshape(dims)
4✔
955
                    else:
956
                        arr = arr.reshape(dims)
×
957
                elif dims and len(dims) == 1:
4✔
958
                    pass  # Already 1D with correct size
4✔
959
                results[result_pos] = arr
4✔
960

961
        if len(results) == 1:
4✔
962
            return results[0]
4✔
963
        return tuple(results) if results else None
4✔
964

965
    def get_return_shape(self, *args):
4✔
966
        shape_str = self.sdfg.metadata("return_shape")
4✔
967
        if not shape_str:
4✔
968
            return None
4✔
969

970
        shape_exprs = shape_str.split(",")
×
971

972
        # Reconstruct shape values
973
        shape_values = {}
×
974
        for i, (arg_idx, dim_idx) in enumerate(self.shape_sources):
×
975
            arg = args[arg_idx]
×
976
            if np is not None and isinstance(arg, np.ndarray):
×
977
                val = arg.shape[dim_idx]
×
978
                shape_values[f"_s{i}"] = val
×
979

980
        # Add scalar arguments to shape_values
981
        # We assume the first len(args) arguments in sdfg.arguments correspond to the user arguments
982
        if hasattr(self.sdfg, "arguments"):
×
983
            for arg_name, arg_val in zip(self.sdfg.arguments, args):
×
984
                if isinstance(arg_val, (int, np.integer)):
×
985
                    shape_values[arg_name] = int(arg_val)
×
986

987
        evaluated_shape = []
×
988
        for expr in shape_exprs:
×
989
            # Simple evaluation using eval with shape_values
990
            # Warning: eval is unsafe, but here expressions come from our compiler
991
            try:
×
992
                val = eval(expr, _EVAL_GLOBALS, shape_values)
×
993
                evaluated_shape.append(int(val))
×
994
            except Exception:
×
995
                return None
×
996

997
        return tuple(evaluated_shape)
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc