• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

daisytuner / docc / 27856839861

20 Jun 2026 02:00AM UTC coverage: 61.736%. First build
27856839861

Pull #789

github

web-flow
Merge 22229aa8f into b7103c21a
Pull Request #789:

53 of 174 new or added lines in 7 files covered. (30.46%)

36995 of 59925 relevant lines covered (61.74%)

1016.05 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

72.93
/python/docc/python/python_program.py
1
import inspect
4✔
2
import shutil
4✔
3
import textwrap
4✔
4
import ast
4✔
5
import os
4✔
6
import getpass
4✔
7
import hashlib
4✔
8
import ml_dtypes
4✔
9
import numpy as np
4✔
10
from typing import Annotated, get_origin, get_args, Any, Optional
4✔
11

12
from docc.sdfg import (
4✔
13
    Scalar,
14
    PrimitiveType,
15
    Pointer,
16
    Structure,
17
    Array,
18
    Type,
19
    Tensor,
20
    StructuredSDFG,
21
    StructuredSDFGBuilder,
22
)
23
from docc.compiler.docc_program import DoccProgram
4✔
24
from docc.compiler.compiled_sdfg import CompiledSDFG
4✔
25
from docc.python.ast_parser import ASTParser
4✔
26
from docc.python.types import element_type_from_sdfg_type
4✔
27

28

29
def _compile_wrapper(self, output_folder=None):
4✔
30
    """Wrapper to allow StructuredSDFG.compile() to return a CompiledSDFG."""
31
    lib_path = self._compile(output_folder)
×
32
    return CompiledSDFG(lib_path, self)
×
33

34

35
# Monkey-patch StructuredSDFG to add compile method
36
StructuredSDFG.compile = _compile_wrapper
4✔
37

38

39
def _map_python_type(dtype):
4✔
40
    """Map Python/numpy types to SDFG types."""
41
    # If it is already a sdfg Type, return it
42
    if isinstance(dtype, Type):
4✔
43
        return dtype
×
44

45
    # Handle Annotated for Arrays
46
    if get_origin(dtype) is Annotated:
4✔
47
        args = get_args(dtype)
4✔
48
        base_type = args[0]
4✔
49
        metadata = args[1:]
4✔
50

51
        if base_type is np.ndarray:
4✔
52
            # Convention: Annotated[np.ndarray, shape, dtype]
53
            shape = metadata[0]
4✔
54
            elem_type = Scalar(PrimitiveType.Double)  # Default
4✔
55

56
            if len(metadata) > 1:
4✔
57
                possible_dtype = metadata[1]
4✔
58
                elem_type = _map_python_type(possible_dtype)
4✔
59

60
            return Pointer(elem_type)
4✔
61

62
    # Handle numpy.ndarray[Shape, DType]
63
    if get_origin(dtype) is np.ndarray:
4✔
64
        args = get_args(dtype)
×
65
        # args[0] is shape, args[1] is dtype
66
        if len(args) >= 2:
×
67
            elem_type = _map_python_type(args[1])
×
68
            return Pointer(elem_type)
×
69

70
    # Simple mapping for python types
71
    if dtype is float or dtype is np.float64:
4✔
72
        return Scalar(PrimitiveType.Double)
4✔
73
    elif dtype is np.float32:
4✔
74
        return Scalar(PrimitiveType.Float)
×
75
    elif dtype is bool or dtype is np.bool_:
4✔
76
        return Scalar(PrimitiveType.Bool)
4✔
77
    elif dtype is int or dtype is np.int64:
4✔
78
        return Scalar(PrimitiveType.Int64)
4✔
79
    elif dtype is np.int32:
4✔
80
        return Scalar(PrimitiveType.Int32)
4✔
81
    elif dtype is np.int16:
×
82
        return Scalar(PrimitiveType.Int16)
×
83
    elif dtype is np.int8:
×
84
        return Scalar(PrimitiveType.Int8)
×
85
    elif dtype is np.uint64:
×
86
        return Scalar(PrimitiveType.UInt64)
×
87
    elif dtype is np.uint32:
×
88
        return Scalar(PrimitiveType.UInt32)
×
89
    elif dtype is np.uint16:
×
90
        return Scalar(PrimitiveType.UInt16)
×
91
    elif dtype is np.uint8:
×
92
        return Scalar(PrimitiveType.UInt8)
×
93

94
    # Handle Python classes - map to Structure type
95
    if inspect.isclass(dtype):
×
96
        # Use the class name as the structure name
97
        return Pointer(Structure(dtype.__name__))
×
98

99
    return dtype
×
100

101

102
class PythonProgram(DoccProgram):
4✔
103

104
    def __init__(
4✔
105
        self,
106
        func,
107
        target: str = "none",
108
        category: str = "server",
109
        instrumentation_mode: Optional[str] = None,
110
        capture_args: Optional[bool] = None,
111
        remote_tuning: bool = False,
112
    ):
113
        super().__init__(
4✔
114
            name=func.__name__,
115
            target=target,
116
            category=category,
117
            instrumentation_mode=instrumentation_mode,
118
            capture_args=capture_args,
119
            remote_tuning=remote_tuning,
120
        )
121
        self.func = func
4✔
122
        self._last_structure_member_info = {}
4✔
123
        self._warned_cpu_fallback = False
4✔
124

125
    def _prepare_device_args(self, compiled, args):
4✔
126
        """Handle GPU (cupy) inputs for the numpy frontend.
127

128
        When the artifact is device-resident, pass GPU arrays through unchanged
129
        (executed directly on device). Otherwise copy them to host, warning once
130
        about the transfer overhead for GPU targets.
131
        """
132
        has_gpu = any(
4✔
133
            hasattr(a, "__cuda_array_interface__") and not isinstance(a, np.ndarray)
134
            for a in args
135
        )
136
        if not has_gpu or compiled.device_resident:
4✔
137
            return args
4✔
138

NEW
139
        if self.target in ("cuda", "rocm") and not self._warned_cpu_fallback:
×
NEW
140
            from docc.compiler import warn_cpu_fallback
×
141

NEW
142
            warn_cpu_fallback(self.target)
×
NEW
143
            self._warned_cpu_fallback = True
×
144

NEW
145
        converted = []
×
NEW
146
        for a in args:
×
NEW
147
            if hasattr(a, "__cuda_array_interface__") and not isinstance(a, np.ndarray):
×
NEW
148
                get = getattr(a, "get", None)  # cupy
×
NEW
149
                if callable(get):
×
NEW
150
                    converted.append(get())
×
151
                else:
NEW
152
                    converted.append(np.asarray(a))
×
153
            else:
NEW
154
                converted.append(a)
×
NEW
155
        return tuple(converted)
×
156

157
    def __call__(self, *args: Any) -> Any:
4✔
158
        # JIT compile and run
159
        compiled = self.compile(*args)
4✔
160
        args = self._prepare_device_args(compiled, args)
4✔
161
        res = compiled(*args)
4✔
162

163
        # Handle return value conversion based on annotation
164
        sig = inspect.signature(self.func)
4✔
165
        ret_annotation = sig.return_annotation
4✔
166

167
        if ret_annotation is not inspect.Signature.empty:
4✔
168
            if get_origin(ret_annotation) is Annotated:
4✔
169
                type_args = get_args(ret_annotation)
4✔
170
                if len(type_args) >= 1 and type_args[0] is np.ndarray:
4✔
171
                    shape = None
4✔
172
                    if len(type_args) >= 2:
4✔
173
                        shape = type_args[1]
4✔
174

175
                    if shape is not None:
4✔
176
                        try:
4✔
177
                            return np.ctypeslib.as_array(res, shape=shape)
4✔
178
                        except Exception:
×
179
                            pass
×
180

181
        # Try to infer return shape from metadata
182
        if hasattr(compiled, "get_return_shape"):
4✔
183
            shape = compiled.get_return_shape(*args)
4✔
184
            if shape is not None:
4✔
185
                try:
×
186
                    return np.ctypeslib.as_array(res, shape=shape)
×
187
                except Exception:
×
188
                    pass
×
189

190
        return res
4✔
191

192
    def compile(
4✔
193
        self,
194
        *args: Any,
195
        output_folder: Optional[str] = None,
196
        instrumentation_mode: Optional[str] = None,
197
        capture_args: Optional[bool] = None,
198
        remote_tuning: Optional[bool] = None,
199
    ) -> CompiledSDFG:
200
        original_output_folder = output_folder
4✔
201

202
        # Resolve options
203
        instrumentation_mode, capture_args, remote_tuning = (
4✔
204
            self._resolve_compile_options(
205
                instrumentation_mode, capture_args, remote_tuning
206
            )
207
        )
208

209
        # 1. Analyze arguments and shapes
210
        arg_types = []
4✔
211
        shape_values = []  # List of unique shape values found
4✔
212
        shape_sources = []  # List of (arg_idx, dim_idx) for each unique shape value
4✔
213

214
        # Mapping from (arg_idx, dim_idx) -> unique_shape_idx
215
        arg_shape_mapping = {}
4✔
216

217
        # First pass: collect scalar integer arguments and their values
218
        sig = inspect.signature(self.func)
4✔
219
        params = list(sig.parameters.items())
4✔
220
        scalar_int_params = {}  # Maps value -> parameter name (first one wins)
4✔
221
        for i, ((name, param), arg) in enumerate(zip(params, args)):
4✔
222
            if isinstance(arg, (int, np.integer)) and not isinstance(
4✔
223
                arg, (bool, np.bool_)
224
            ):
225
                val = int(arg)
4✔
226
                if val not in scalar_int_params:
4✔
227
                    scalar_int_params[val] = name
4✔
228

229
        for i, arg in enumerate(args):
4✔
230
            t = self._infer_type(arg)
4✔
231
            arg_types.append(t)
4✔
232

233
            if isinstance(arg, np.ndarray):
4✔
234
                for dim_idx, dim_val in enumerate(arg.shape):
4✔
235
                    # Check if we've seen this value
236
                    if dim_val in shape_values:
4✔
237
                        # Reuse
238
                        u_idx = shape_values.index(dim_val)
4✔
239
                    else:
240
                        # New
241
                        u_idx = len(shape_values)
4✔
242
                        shape_values.append(dim_val)
4✔
243
                        shape_sources.append((i, dim_idx))
4✔
244

245
                    arg_shape_mapping[(i, dim_idx)] = u_idx
4✔
246

247
        # 2. Signature - include scalar-shape equivalences for correct caching
248
        mapping_sig = sorted(arg_shape_mapping.items())
4✔
249
        type_sig = ", ".join(self._type_to_str(t) for t in arg_types)
4✔
250
        signature = f"{type_sig}|{mapping_sig}"
4✔
251

252
        if output_folder is None:
4✔
253
            source_path = inspect.getsourcefile(self.func)
4✔
254
            hash_input = f"{source_path}|{self.name}|{self.target}|{self.category}|{self.capture_args}|{self.instrumentation_mode}|{signature}".encode(
4✔
255
                "utf-8"
256
            )
257
            stable_id = hashlib.sha256(hash_input).hexdigest()[:16]
4✔
258
            filename = os.path.basename(inspect.getsourcefile(self.func))
4✔
259

260
            docc_tmp = os.environ.get("DOCC_TMP")
4✔
261
            if docc_tmp:
4✔
262
                output_folder = (
×
263
                    f"{docc_tmp}/{filename}-{self.name}-{self.target}-{stable_id}"
264
                )
265
            else:
266
                user = os.getenv("USER")
4✔
267
                if not user:
4✔
268
                    user = getpass.getuser()
4✔
269
                output_folder = f"/tmp/{user}/DOCC/{self.name}-{stable_id}"
4✔
270

271
        if original_output_folder is None and signature in self.cache:
4✔
272
            return self.cache[signature]
4✔
273

274
        # 3. Build SDFG
275
        if os.path.exists(output_folder):
4✔
276
            # Multiple python processes running the same code?
277
            shutil.rmtree(output_folder)
4✔
278
        sdfg, out_args, out_shapes, out_strides = self._build_sdfg(
4✔
279
            arg_types, args, arg_shape_mapping, shape_values
280
        )
281

282
        lib_path = self.sdfg_pipe(
4✔
283
            sdfg, output_folder, instrumentation_mode, capture_args, remote_tuning
284
        )
285

286
        # 4. Create CompiledSDFG
287
        compiled = CompiledSDFG(
4✔
288
            lib_path,
289
            sdfg,
290
            shape_sources,
291
            self._last_structure_member_info,
292
            out_args,
293
            out_shapes,
294
            out_strides,
295
            device_resident=self._device_resident,
296
            device_backend=self._device_backend,
297
        )
298

299
        # Cache if using default output folder
300
        if original_output_folder is None:
4✔
301
            self.cache[signature] = compiled
4✔
302

303
        return compiled
4✔
304

305
    def to_sdfg(self, *args: Any) -> StructuredSDFG:
4✔
306
        arg_types = [self._infer_type(arg) for arg in args]
×
307

308
        # Build shape mapping
309
        shape_values = []
×
310
        shape_sources = []
×
311
        arg_shape_mapping = {}
×
312

313
        sig = inspect.signature(self.func)
×
314
        params = list(sig.parameters.items())
×
315
        scalar_int_params = {}
×
316
        for i, ((name, param), arg) in enumerate(zip(params, args)):
×
317
            if isinstance(arg, (int, np.integer)) and not isinstance(
×
318
                arg, (bool, np.bool_)
319
            ):
320
                val = int(arg)
×
321
                if val not in scalar_int_params:
×
322
                    scalar_int_params[val] = name
×
323

324
        for i, arg in enumerate(args):
×
325
            if isinstance(arg, np.ndarray):
×
326
                for dim_idx, dim_val in enumerate(arg.shape):
×
327
                    if dim_val in shape_values:
×
328
                        u_idx = shape_values.index(dim_val)
×
329
                    else:
330
                        u_idx = len(shape_values)
×
331
                        shape_values.append(dim_val)
×
332
                        shape_sources.append((i, dim_idx))
×
333
                    arg_shape_mapping[(i, dim_idx)] = u_idx
×
334

335
        sdfg, _, _, _ = self._build_sdfg(
×
336
            arg_types, args, arg_shape_mapping, shape_values
337
        )
338
        return sdfg
×
339

340
    def _convert_inputs(self, args: tuple) -> tuple:
4✔
341
        return args
×
342

343
    def _convert_outputs(self, result: Any, original_args: tuple) -> Any:
4✔
344
        return result
×
345

346
    def _get_signature(self, arg_types):
4✔
347
        return ", ".join(self._type_to_str(t) for t in arg_types)
×
348

349
    def _type_to_str(self, t):
4✔
350
        if isinstance(t, Scalar):
4✔
351
            return f"Scalar({t.primitive_type})"
4✔
352
        elif isinstance(t, Array):
4✔
353
            return f"Array({self._type_to_str(t.element_type)}, {t.num_elements})"
×
354
        elif isinstance(t, Pointer):
4✔
355
            return f"Pointer({self._type_to_str(t.pointee_type)})"
4✔
356
        elif isinstance(t, Structure):
4✔
357
            return f"Structure({t.name})"
4✔
358
        return str(t)
×
359

360
    def _infer_type(self, arg):
4✔
361
        if isinstance(arg, (float, np.float64)):
4✔
362
            return Scalar(PrimitiveType.Double)
4✔
363
        elif isinstance(arg, np.float32):
4✔
364
            return Scalar(PrimitiveType.Float)
4✔
365
        elif isinstance(arg, (bool, np.bool_)):
4✔
366
            return Scalar(PrimitiveType.Bool)
4✔
367
        elif isinstance(arg, (int, np.int64)):
4✔
368
            return Scalar(PrimitiveType.Int64)
4✔
369
        elif isinstance(arg, np.int32):
4✔
370
            return Scalar(PrimitiveType.Int32)
4✔
371
        elif isinstance(arg, np.int16):
4✔
372
            return Scalar(PrimitiveType.Int16)
×
373
        elif isinstance(arg, np.int8):
4✔
374
            return Scalar(PrimitiveType.Int8)
×
375
        elif isinstance(arg, np.uint64):
4✔
376
            return Scalar(PrimitiveType.UInt64)
×
377
        elif isinstance(arg, np.uint32):
4✔
378
            return Scalar(PrimitiveType.UInt32)
×
379
        elif isinstance(arg, np.uint16):
4✔
380
            return Scalar(PrimitiveType.UInt16)
×
381
        elif isinstance(arg, np.uint8):
4✔
382
            return Scalar(PrimitiveType.UInt8)
×
383
        elif isinstance(arg, np.ndarray):
4✔
384
            # Map dtype
385
            if arg.dtype == np.float64:
4✔
386
                elem_type = Scalar(PrimitiveType.Double)
4✔
387
            elif arg.dtype == np.float32:
4✔
388
                elem_type = Scalar(PrimitiveType.Float)
4✔
389
            elif arg.dtype == np.float16:
4✔
390
                elem_type = Scalar(PrimitiveType.Half)
×
391
            elif arg.dtype == ml_dtypes.bfloat16:
4✔
392
                elem_type = Scalar(PrimitiveType.BFloat)
4✔
393
            elif arg.dtype == np.bool_:
4✔
394
                elem_type = Scalar(PrimitiveType.Bool)
4✔
395
            elif arg.dtype == np.int64:
4✔
396
                elem_type = Scalar(PrimitiveType.Int64)
4✔
397
            elif arg.dtype == np.int32:
4✔
398
                elem_type = Scalar(PrimitiveType.Int32)
4✔
399
            elif arg.dtype == np.int16:
×
400
                elem_type = Scalar(PrimitiveType.Int16)
×
401
            elif arg.dtype == np.int8:
×
402
                elem_type = Scalar(PrimitiveType.Int8)
×
403
            elif arg.dtype == np.uint64:
×
404
                elem_type = Scalar(PrimitiveType.UInt64)
×
405
            elif arg.dtype == np.uint32:
×
406
                elem_type = Scalar(PrimitiveType.UInt32)
×
407
            elif arg.dtype == np.uint16:
×
408
                elem_type = Scalar(PrimitiveType.UInt16)
×
409
            elif arg.dtype == np.uint8:
×
410
                elem_type = Scalar(PrimitiveType.UInt8)
×
411
            else:
412
                raise ValueError(f"Unsupported numpy dtype: {arg.dtype}")
×
413

414
            return Pointer(elem_type)
4✔
415
        elif isinstance(arg, str):
4✔
416
            # Explicitly reject strings - they are not supported
417
            raise ValueError(f"Unsupported argument type: {type(arg)}")
4✔
418
        else:
419
            # Check if it's a class instance
420
            if hasattr(arg, "__class__") and not isinstance(arg, type):
4✔
421
                # It's an instance of a class, return pointer to Structure
422
                return Pointer(Structure(arg.__class__.__name__))
4✔
423
            raise ValueError(f"Unsupported argument type: {type(arg)}")
×
424

425
    def _build_sdfg(
4✔
426
        self,
427
        arg_types,
428
        args,
429
        arg_shape_mapping,
430
        shape_values,
431
    ):
432
        sig = inspect.signature(self.func)
4✔
433

434
        # Handle return type - always void for SDFG, output args used for returns
435
        return_type = Scalar(PrimitiveType.Void)
4✔
436
        infer_return_type = True
4✔
437

438
        # Parse return annotation to determine output arguments if possible
439
        explicit_returns = []
4✔
440
        if sig.return_annotation is not inspect.Signature.empty:
4✔
441
            infer_return_type = False
4✔
442

443
            # Helper to normalize annotation to list of types
444
            def normalize_annotation(ann):
4✔
445
                # Handle Tuple[type, ...]
446
                origin = get_origin(ann)
4✔
447
                if origin is tuple:
4✔
448
                    type_args = get_args(ann)
×
449
                    # Tuple[()] or Tuple w/o args
450
                    if not type_args:
×
451
                        return []
×
452
                    # Tuple[int, float]
453
                    if len(type_args) > 0 and type_args[-1] is not Ellipsis:
×
454
                        return [_map_python_type(t) for t in type_args]
×
455
                    # Tuple[int, ...] - not supported for fixed number of returns yet?
456
                    # For now assume fixed tuple
457
                    return [_map_python_type(t) for t in type_args]
×
458
                else:
459
                    return [_map_python_type(ann)]
4✔
460

461
            explicit_returns = normalize_annotation(sig.return_annotation)
4✔
462
            for rt in explicit_returns:
4✔
463
                if not isinstance(rt, Type):
4✔
464
                    # Fallback if map failed (e.g. invalid annotation)
465
                    infer_return_type = True
×
466
                    explicit_returns = []
×
467
                    break
×
468

469
        builder = StructuredSDFGBuilder(f"{self.name}_sdfg", return_type)
4✔
470

471
        # Add pre-defined return arguments if we know them
472
        if not infer_return_type:
4✔
473
            for i, dtype in enumerate(explicit_returns):
4✔
474
                # Scalar -> Pointer(Scalar)
475
                # Array -> Already Pointer(Scalar). Keep it.
476
                arg_type = dtype
4✔
477
                if isinstance(dtype, Scalar):
4✔
478
                    arg_type = Pointer(dtype)
4✔
479

480
                builder.add_container(f"_docc_ret_{i}", arg_type, is_argument=True)
4✔
481

482
        # Register structure types for any class arguments
483
        # Also track member name to index mapping for each structure
484
        structures_to_register = {}
4✔
485
        structure_member_info = {}  # Maps struct_name -> {member_name: (index, type)}
4✔
486
        for i, (arg, dtype) in enumerate(zip(args, arg_types)):
4✔
487
            if isinstance(dtype, Pointer) and dtype.has_pointee_type():
4✔
488
                pointee = dtype.pointee_type
4✔
489
                if isinstance(pointee, Structure):
4✔
490
                    struct_name = pointee.name
4✔
491
                    if struct_name not in structures_to_register:
4✔
492
                        # Get class from arg to introspect members
493
                        if hasattr(arg, "__dict__"):
4✔
494
                            # Use __dict__ to get only instance attributes
495
                            # Sort by name to ensure consistent ordering
496
                            # Note: This alphabetical ordering is used to define the
497
                            # structure layout and must match the order expected by
498
                            # the backend code generation
499
                            member_types = []
4✔
500
                            member_names = []
4✔
501
                            for attr_name, attr_value in sorted(arg.__dict__.items()):
4✔
502
                                if not attr_name.startswith("_"):
4✔
503
                                    # Infer member type from instance attribute
504
                                    # Check bool before int since bool is subclass of int
505
                                    member_type = None
4✔
506
                                    if isinstance(attr_value, bool):
4✔
507
                                        member_type = Scalar(PrimitiveType.Bool)
×
508
                                    elif isinstance(attr_value, (int, np.int64)):
4✔
509
                                        member_type = Scalar(PrimitiveType.Int64)
×
510
                                    elif isinstance(attr_value, (float, np.float64)):
4✔
511
                                        member_type = Scalar(PrimitiveType.Double)
4✔
512
                                    elif isinstance(attr_value, np.int32):
×
513
                                        member_type = Scalar(PrimitiveType.Int32)
×
514
                                    elif isinstance(attr_value, np.float32):
×
515
                                        member_type = Scalar(PrimitiveType.Float)
×
516
                                    # TODO: Consider using np.integer and np.floating abstract types
517
                                    # for more comprehensive numpy type coverage
518
                                    # TODO: Add support for nested structures and arrays
519

520
                                    if member_type is not None:
4✔
521
                                        member_types.append(member_type)
4✔
522
                                        member_names.append(attr_name)
4✔
523

524
                            if member_types:
4✔
525
                                structures_to_register[struct_name] = member_types
4✔
526
                                # Build member name to (index, type) mapping
527
                                structure_member_info[struct_name] = {
4✔
528
                                    name: (idx, mtype)
529
                                    for idx, (name, mtype) in enumerate(
530
                                        zip(member_names, member_types)
531
                                    )
532
                                }
533

534
        # Store structure_member_info for later use in CompiledSDFG
535
        self._last_structure_member_info = structure_member_info
4✔
536

537
        # Register all discovered structures with the builder
538
        for struct_name, member_types in structures_to_register.items():
4✔
539
            builder.add_structure(struct_name, member_types)
4✔
540

541
        # Register arguments
542
        params = list(sig.parameters.items())
4✔
543
        if len(params) != len(arg_types):
4✔
544
            raise ValueError(
×
545
                f"Argument count mismatch: expected {len(params)}, got {len(arg_types)}"
546
            )
547

548
        # Add regular arguments
549
        tensor_table = {}
4✔
550
        for i, ((name, param), dtype, arg) in enumerate(zip(params, arg_types, args)):
4✔
551
            builder.add_container(name, dtype, is_argument=True)
4✔
552

553
            # Store layout information for arrays
554
            if isinstance(arg, np.ndarray):
4✔
555
                element_type = element_type_from_sdfg_type(dtype)
4✔
556

557
                shapes = []
4✔
558
                for dim_idx in range(arg.ndim):
4✔
559
                    dim_val = arg.shape[dim_idx]
4✔
560
                    if dim_val == 1:
4✔
561
                        # Always use literal "1" for size-1 dimensions to enable
562
                        # proper broadcasting detection
563
                        shapes.append("1")
4✔
564
                    else:
565
                        u_idx = arg_shape_mapping[(i, dim_idx)]
4✔
566
                        shapes.append(f"_s{u_idx}")
4✔
567

568
                strides = []
4✔
569
                if arg.flags["C_CONTIGUOUS"]:
4✔
570
                    # Row-major: stride[i] = product of shapes[i+1:]
571
                    for dim_idx in range(arg.ndim):
4✔
572
                        if dim_idx == arg.ndim - 1:
4✔
573
                            strides.append("1")
4✔
574
                        else:
575
                            suffix_shapes = shapes[dim_idx + 1 :]
4✔
576
                            if len(suffix_shapes) == 1:
4✔
577
                                strides.append(suffix_shapes[0])
4✔
578
                            else:
579
                                strides.append("(" + " * ".join(suffix_shapes) + ")")
4✔
580
                elif arg.flags["F_CONTIGUOUS"]:
4✔
581
                    # Column-major: stride[i] = product of shapes[:i]
582
                    for dim_idx in range(arg.ndim):
4✔
583
                        if dim_idx == 0:
4✔
584
                            strides.append("1")
4✔
585
                        else:
586
                            prefix_shapes = shapes[:dim_idx]
4✔
587
                            if len(prefix_shapes) == 1:
4✔
588
                                strides.append(prefix_shapes[0])
4✔
589
                            else:
590
                                strides.append("(" + " * ".join(prefix_shapes) + ")")
4✔
591
                else:
592
                    # Non-contiguous: use actual stride values
593
                    for dim_idx in range(arg.ndim):
4✔
594
                        stride_val = arg.strides[dim_idx] // arg.itemsize
4✔
595
                        strides.append(f"{stride_val}")
4✔
596

597
                offset = "0"
4✔
598
                tensor_table[name] = Tensor(element_type, shapes, strides, offset)
4✔
599

600
            elif isinstance(arg, np.generic):
4✔
601
                # NumPy scalar types (np.float64, np.int32, etc.) should be treated
602
                # as 0-d arrays for type promotion purposes - they trigger full
603
                # promotion, unlike Python literals which adapt to the array dtype
604
                element_type = element_type_from_sdfg_type(dtype)
4✔
605
                tensor_table[name] = Tensor(element_type, [], [], "0")
4✔
606

607
        # Add unified shape arguments only for shapes without scalar equivalents
608
        # and skip size-1 dimensions (they use literal "1" instead)
609
        for i in range(len(shape_values)):
4✔
610
            if shape_values[i] != 1:
4✔
611
                builder.add_container(
4✔
612
                    f"_s{i}", Scalar(PrimitiveType.Int64), is_argument=True
613
                )
614
                builder.add_assumption_lb(f"_s{i}", "1")  # Shapes must be positive
4✔
615
                builder.add_assumption_const(f"_s{i}", True)  # Shapes are constant
4✔
616

617
        # Create symbol table for parser
618
        container_table = {}
4✔
619
        for i, ((name, param), dtype, arg) in enumerate(zip(params, arg_types, args)):
4✔
620
            container_table[name] = dtype
4✔
621

622
        for i in range(len(shape_values)):
4✔
623
            if shape_values[i] != 1:
4✔
624
                container_table[f"_s{i}"] = Scalar(PrimitiveType.Int64)
4✔
625

626
        # Parse AST
627
        source_lines, start_line = inspect.getsourcelines(self.func)
4✔
628
        source = textwrap.dedent("".join(source_lines))
4✔
629
        tree = ast.parse(source)
4✔
630
        ast.increment_lineno(tree, start_line - 1)
4✔
631
        func_def = tree.body[0]
4✔
632

633
        filename = inspect.getsourcefile(self.func)
4✔
634
        function_name = self.func.__name__
4✔
635

636
        # Combine globals with closure variables (closure takes precedence)
637
        combined_globals = dict(self.func.__globals__)
4✔
638
        if self.func.__closure__ is not None and self.func.__code__.co_freevars:
4✔
639
            for name, cell in zip(
4✔
640
                self.func.__code__.co_freevars, self.func.__closure__
641
            ):
642
                combined_globals[name] = cell.cell_contents
4✔
643

644
        parser = ASTParser(
4✔
645
            builder,
646
            tensor_table,
647
            container_table,
648
            filename,
649
            function_name,
650
            infer_return_type=infer_return_type,
651
            globals_dict=combined_globals,
652
            structure_member_info=structure_member_info,
653
        )
654
        for node in func_def.body:
4✔
655
            parser.visit(node)
4✔
656

657
        # Emit hoisted allocations at function entry
658
        parser.memory_handler.emit_allocations()
4✔
659

660
        sdfg = builder.move()
4✔
661
        # Mark return arguments metadata
662
        out_args = []
4✔
663
        for name in sdfg.arguments:
4✔
664
            if name.startswith("_docc_ret_"):
4✔
665
                out_args.append(name)
4✔
666

667
        return (
4✔
668
            sdfg,
669
            out_args,
670
            parser.captured_return_shapes,
671
            parser.captured_return_strides,
672
        )
673

674

675
def native(
4✔
676
    func=None,
677
    *,
678
    target="none",
679
    category="server",
680
    instrumentation_mode=None,
681
    capture_args=None,
682
    remote_tuning=False,
683
):
684
    """Decorator to create a PythonProgram from a Python function.
685

686
    Example:
687
        @native
688
        def my_function(x: np.ndarray) -> np.ndarray:
689
            return x * 2
690

691
        result = my_function(np.array([1.0, 2.0, 3.0]))
692
    """
693
    if func is None:
4✔
694
        return lambda f: PythonProgram(
4✔
695
            f,
696
            target=target,
697
            category=category,
698
            instrumentation_mode=instrumentation_mode,
699
            capture_args=capture_args,
700
            remote_tuning=remote_tuning,
701
        )
702
    return PythonProgram(
4✔
703
        func,
704
        target=target,
705
        category=category,
706
        instrumentation_mode=instrumentation_mode,
707
        capture_args=capture_args,
708
        remote_tuning=remote_tuning,
709
    )
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc