• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

daisytuner / docc / 23201411123

17 Mar 2026 03:13PM UTC coverage: 63.982% (+0.009%) from 63.973%
23201411123

push

github

web-flow
Mlir target plugins (#593)

* Made dumping of SDFGs for python-based frontends optional (DOCC_DEBUG and DOCC_CI).

 ~ ensure correct json dump is designated for arg_capture & instrumentation
 + enable mlir front end to use python target-overrides and make the same debug outputs as regular python front end

* Dump MLIR code and translated SDFG string on parsing-failure

* Removing dead code that assumed presence of an optional metadata entry pointing to an SDFG Json dump that does not exist.

12 of 17 new or added lines in 1 file covered. (70.59%)

1 existing line in 1 file now uncovered.

26052 of 40718 relevant lines covered (63.98%)

403.73 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

73.24
/python/docc/python/python_program.py
1
import inspect
4✔
2
import shutil
4✔
3
import textwrap
4✔
4
import ast
4✔
5
import os
4✔
6
import getpass
4✔
7
import hashlib
4✔
8
import ml_dtypes
4✔
9
import numpy as np
4✔
10
from typing import Annotated, get_origin, get_args, Any, Optional
4✔
11

12
from docc.sdfg import (
4✔
13
    Scalar,
14
    PrimitiveType,
15
    Pointer,
16
    Structure,
17
    Array,
18
    Type,
19
    Tensor,
20
    StructuredSDFG,
21
    StructuredSDFGBuilder,
22
)
23
from docc.compiler.docc_program import DoccProgram
4✔
24
from docc.compiler.compiled_sdfg import CompiledSDFG
4✔
25
from docc.python.ast_parser import ASTParser
4✔
26
from docc.python.types import element_type_from_sdfg_type
4✔
27
from docc.python.target_registry import get_target_schedule_fn, get_target_compile_fn
4✔
28

29

30
def _compile_wrapper(self, output_folder=None):
4✔
31
    """Wrapper to allow StructuredSDFG.compile() to return a CompiledSDFG."""
32
    lib_path = self._compile(output_folder)
×
33
    return CompiledSDFG(lib_path, self)
×
34

35

36
# Monkey-patch StructuredSDFG to add compile method
37
StructuredSDFG.compile = _compile_wrapper
4✔
38

39

40
def _map_python_type(dtype):
4✔
41
    """Map Python/numpy types to SDFG types."""
42
    # If it is already a sdfg Type, return it
43
    if isinstance(dtype, Type):
4✔
44
        return dtype
×
45

46
    # Handle Annotated for Arrays
47
    if get_origin(dtype) is Annotated:
4✔
48
        args = get_args(dtype)
4✔
49
        base_type = args[0]
4✔
50
        metadata = args[1:]
4✔
51

52
        if base_type is np.ndarray:
4✔
53
            # Convention: Annotated[np.ndarray, shape, dtype]
54
            shape = metadata[0]
4✔
55
            elem_type = Scalar(PrimitiveType.Double)  # Default
4✔
56

57
            if len(metadata) > 1:
4✔
58
                possible_dtype = metadata[1]
4✔
59
                elem_type = _map_python_type(possible_dtype)
4✔
60

61
            return Pointer(elem_type)
4✔
62

63
    # Handle numpy.ndarray[Shape, DType]
64
    if get_origin(dtype) is np.ndarray:
4✔
65
        args = get_args(dtype)
×
66
        # args[0] is shape, args[1] is dtype
67
        if len(args) >= 2:
×
68
            elem_type = _map_python_type(args[1])
×
69
            return Pointer(elem_type)
×
70

71
    # Simple mapping for python types
72
    if dtype is float or dtype is np.float64:
4✔
73
        return Scalar(PrimitiveType.Double)
4✔
74
    elif dtype is np.float32:
4✔
75
        return Scalar(PrimitiveType.Float)
×
76
    elif dtype is bool or dtype is np.bool_:
4✔
77
        return Scalar(PrimitiveType.Bool)
4✔
78
    elif dtype is int or dtype is np.int64:
4✔
79
        return Scalar(PrimitiveType.Int64)
4✔
80
    elif dtype is np.int32:
4✔
81
        return Scalar(PrimitiveType.Int32)
4✔
82
    elif dtype is np.int16:
×
83
        return Scalar(PrimitiveType.Int16)
×
84
    elif dtype is np.int8:
×
85
        return Scalar(PrimitiveType.Int8)
×
86
    elif dtype is np.uint64:
×
87
        return Scalar(PrimitiveType.UInt64)
×
88
    elif dtype is np.uint32:
×
89
        return Scalar(PrimitiveType.UInt32)
×
90
    elif dtype is np.uint16:
×
91
        return Scalar(PrimitiveType.UInt16)
×
92
    elif dtype is np.uint8:
×
93
        return Scalar(PrimitiveType.UInt8)
×
94

95
    # Handle Python classes - map to Structure type
96
    if inspect.isclass(dtype):
×
97
        # Use the class name as the structure name
98
        return Pointer(Structure(dtype.__name__))
×
99

100
    return dtype
×
101

102

103
def _is_debug_dump() -> bool:
4✔
104
    return bool(os.environ.get("DOCC_DEBUG"))
4✔
105

106

107
class PythonProgram(DoccProgram):
4✔
108

109
    def __init__(
4✔
110
        self,
111
        func,
112
        target: str = "none",
113
        category: str = "server",
114
        instrumentation_mode: Optional[str] = None,
115
        capture_args: Optional[bool] = None,
116
        remote_tuning: bool = False,
117
    ):
118
        super().__init__(
4✔
119
            name=func.__name__,
120
            target=target,
121
            category=category,
122
            instrumentation_mode=instrumentation_mode,
123
            capture_args=capture_args,
124
            remote_tuning=remote_tuning,
125
        )
126
        self.func = func
4✔
127
        self._last_structure_member_info = {}
4✔
128

129
    def __call__(self, *args: Any) -> Any:
4✔
130
        # JIT compile and run
131
        compiled = self.compile(*args)
4✔
132
        res = compiled(*args)
4✔
133

134
        # Handle return value conversion based on annotation
135
        sig = inspect.signature(self.func)
4✔
136
        ret_annotation = sig.return_annotation
4✔
137

138
        if ret_annotation is not inspect.Signature.empty:
4✔
139
            if get_origin(ret_annotation) is Annotated:
4✔
140
                type_args = get_args(ret_annotation)
4✔
141
                if len(type_args) >= 1 and type_args[0] is np.ndarray:
4✔
142
                    shape = None
4✔
143
                    if len(type_args) >= 2:
4✔
144
                        shape = type_args[1]
4✔
145

146
                    if shape is not None:
4✔
147
                        try:
4✔
148
                            return np.ctypeslib.as_array(res, shape=shape)
4✔
149
                        except Exception:
×
150
                            pass
×
151

152
        # Try to infer return shape from metadata
153
        if hasattr(compiled, "get_return_shape"):
4✔
154
            shape = compiled.get_return_shape(*args)
4✔
155
            if shape is not None:
4✔
156
                try:
×
157
                    return np.ctypeslib.as_array(res, shape=shape)
×
158
                except Exception:
×
159
                    pass
×
160

161
        return res
4✔
162

163
    def compile(
4✔
164
        self,
165
        *args: Any,
166
        output_folder: Optional[str] = None,
167
        instrumentation_mode: Optional[str] = None,
168
        capture_args: Optional[bool] = None,
169
    ) -> CompiledSDFG:
170
        original_output_folder = output_folder
4✔
171

172
        # Resolve options
173
        if instrumentation_mode is None:
4✔
174
            instrumentation_mode = self.instrumentation_mode
4✔
175
        if capture_args is None:
4✔
176
            capture_args = self.capture_args
4✔
177

178
        # Check environment variable DOCC_CI
179
        docc_ci = os.environ.get("DOCC_CI", "")
4✔
180
        if docc_ci:
4✔
181
            if docc_ci == "regions":
×
182
                if instrumentation_mode is None:
×
183
                    instrumentation_mode = "ols"
×
184
            elif docc_ci == "arg-capture":
×
185
                if capture_args is None:
×
186
                    capture_args = True
×
187
            else:
188
                # Full mode (or unknown value treated as full)
189
                if instrumentation_mode is None:
×
190
                    instrumentation_mode = "ols"
×
191
                if capture_args is None:
×
192
                    capture_args = True
×
193

194
        # Defaults
195
        if instrumentation_mode is None:
4✔
196
            instrumentation_mode = ""
4✔
197
        if capture_args is None:
4✔
198
            capture_args = False
4✔
199

200
        # 1. Analyze arguments and shapes
201
        arg_types = []
4✔
202
        shape_values = []  # List of unique shape values found
4✔
203
        shape_sources = []  # List of (arg_idx, dim_idx) for each unique shape value
4✔
204

205
        # Mapping from (arg_idx, dim_idx) -> unique_shape_idx
206
        arg_shape_mapping = {}
4✔
207

208
        # First pass: collect scalar integer arguments and their values
209
        sig = inspect.signature(self.func)
4✔
210
        params = list(sig.parameters.items())
4✔
211
        scalar_int_params = {}  # Maps value -> parameter name (first one wins)
4✔
212
        for i, ((name, param), arg) in enumerate(zip(params, args)):
4✔
213
            if isinstance(arg, (int, np.integer)) and not isinstance(
4✔
214
                arg, (bool, np.bool_)
215
            ):
216
                val = int(arg)
4✔
217
                if val not in scalar_int_params:
4✔
218
                    scalar_int_params[val] = name
4✔
219

220
        for i, arg in enumerate(args):
4✔
221
            t = self._infer_type(arg)
4✔
222
            arg_types.append(t)
4✔
223

224
            if isinstance(arg, np.ndarray):
4✔
225
                for dim_idx, dim_val in enumerate(arg.shape):
4✔
226
                    # Check if we've seen this value
227
                    if dim_val in shape_values:
4✔
228
                        # Reuse
229
                        u_idx = shape_values.index(dim_val)
4✔
230
                    else:
231
                        # New
232
                        u_idx = len(shape_values)
4✔
233
                        shape_values.append(dim_val)
4✔
234
                        shape_sources.append((i, dim_idx))
4✔
235

236
                    arg_shape_mapping[(i, dim_idx)] = u_idx
4✔
237

238
        # Detect scalar-shape equivalences: which shape indices have a matching scalar param
239
        # Maps unique_shape_idx -> scalar parameter name
240
        shape_to_scalar = {}
4✔
241
        for s_idx, s_val in enumerate(shape_values):
4✔
242
            if s_val in scalar_int_params:
4✔
243
                shape_to_scalar[s_idx] = scalar_int_params[s_val]
4✔
244

245
        # 2. Signature - include scalar-shape equivalences for correct caching
246
        mapping_sig = sorted(arg_shape_mapping.items())
4✔
247
        equiv_sig = sorted(shape_to_scalar.items())
4✔
248
        type_sig = ", ".join(self._type_to_str(t) for t in arg_types)
4✔
249
        signature = f"{type_sig}|{mapping_sig}|{equiv_sig}"
4✔
250

251
        if output_folder is None:
4✔
252
            source_path = inspect.getsourcefile(self.func)
4✔
253
            hash_input = f"{source_path}|{self.name}|{self.target}|{self.category}|{self.capture_args}|{self.instrumentation_mode}|{signature}".encode(
4✔
254
                "utf-8"
255
            )
256
            stable_id = hashlib.sha256(hash_input).hexdigest()[:16]
4✔
257
            filename = os.path.basename(inspect.getsourcefile(self.func))
4✔
258

259
            docc_tmp = os.environ.get("DOCC_TMP")
4✔
260
            if docc_tmp:
4✔
NEW
261
                output_folder = (
×
262
                    f"{docc_tmp}/{filename}-{self.name}-{self.target}-{stable_id}"
263
                )
264
            else:
265
                user = os.getenv("USER")
4✔
266
                if not user:
4✔
267
                    user = getpass.getuser()
4✔
268
                output_folder = f"/tmp/{user}/DOCC/{self.name}-{stable_id}"
4✔
269

270
        if original_output_folder is None and signature in self.cache:
4✔
271
            return self.cache[signature]
4✔
272

273
        # 3. Build SDFG
274
        if os.path.exists(output_folder):
4✔
275
            # Multiple python processes running the same code?
276
            shutil.rmtree(output_folder)
4✔
277
        sdfg, out_args, out_shapes, out_strides = self._build_sdfg(
4✔
278
            arg_types, args, arg_shape_mapping, shape_values, shape_to_scalar
279
        )
280
        sdfg.validate()
4✔
281

282
        debug_dump = _is_debug_dump()
4✔
283

284
        if debug_dump:
4✔
NEW
285
            sdfg.dump(output_folder, "py0.parsed", dump_dot=True)
×
286

287
        # Tensor targets keep tensor nodes
288
        if self.target != "onnx":
4✔
289
            sdfg.expand()
4✔
290
            if debug_dump:
4✔
NEW
291
                sdfg.dump(output_folder, "py1.expanded", dump_dot=True)
×
292

293
        # Simplify pipelines
294
        sdfg.simplify()
4✔
295
        if debug_dump:
4✔
NEW
296
            sdfg.dump(output_folder, "py2.opt", dump_dot=True)
×
297

298
        # Normalization for scheduling
299
        if self.target != "none":
4✔
300
            sdfg.normalize()
4✔
301

302
        if debug_dump or instrumentation_mode or capture_args:
4✔
303
            sdfg.dump(
4✔
304
                output_folder,
305
                "py3.norm",
306
                dump_dot=debug_dump,
307
                dump_json=True,
308
                record_for_instrumentation=True,
309
            )
310

311
        # Schedule if target is specified
312
        if self.target != "none":
4✔
313
            # Check for custom registered target first
314
            custom_schedule_fn = get_target_schedule_fn(self.target)
4✔
315
            if custom_schedule_fn is not None:
4✔
316
                custom_schedule_fn(
4✔
317
                    sdfg, self.category, {"remote_tuning": self.remote_tuning}
318
                )
319
            else:
320
                sdfg.schedule(self.target, self.category, self.remote_tuning)
4✔
321

322
        self.last_sdfg = sdfg
4✔
323

324
        if debug_dump:
4✔
NEW
325
            sdfg.dump(output_folder, "py4.post_sched", dump_dot=True)
×
326

327
        custom_compile_fn = get_target_compile_fn(self.target)
4✔
328
        if custom_compile_fn is not None:
4✔
329
            lib_path = custom_compile_fn(
4✔
330
                sdfg, output_folder, instrumentation_mode, capture_args, {}
331
            )
332
        else:
333
            lib_path = sdfg._compile(
4✔
334
                output_folder=output_folder,
335
                target=self.target,
336
                instrumentation_mode=instrumentation_mode,
337
                capture_args=capture_args,
338
            )
339

340
        # Build ONNX model from JSON if target is onnx (after _compile creates the JSON)
341
        if self.target == "onnx":
4✔
342
            from docc.python.targets.onnx_model_builder import convert_json_to_onnx
×
343

344
            onnx_model_path = convert_json_to_onnx(output_folder)
×
345
            if onnx_model_path:
×
346
                print(f"Generated ONNX models: {onnx_model_path}")
×
347

348
        # 5. Create CompiledSDFG
349
        compiled = CompiledSDFG(
4✔
350
            lib_path,
351
            sdfg,
352
            shape_sources,
353
            self._last_structure_member_info,
354
            out_args,
355
            out_shapes,
356
            out_strides,
357
        )
358

359
        # Cache if using default output folder
360
        if original_output_folder is None:
4✔
361
            self.cache[signature] = compiled
4✔
362

363
        return compiled
4✔
364

365
    def to_sdfg(self, *args: Any) -> StructuredSDFG:
4✔
366
        arg_types = [self._infer_type(arg) for arg in args]
×
367

368
        # Build shape mapping
369
        shape_values = []
×
370
        shape_sources = []
×
371
        arg_shape_mapping = {}
×
372

373
        sig = inspect.signature(self.func)
×
374
        params = list(sig.parameters.items())
×
375
        scalar_int_params = {}
×
376
        for i, ((name, param), arg) in enumerate(zip(params, args)):
×
377
            if isinstance(arg, (int, np.integer)) and not isinstance(
×
378
                arg, (bool, np.bool_)
379
            ):
380
                val = int(arg)
×
381
                if val not in scalar_int_params:
×
382
                    scalar_int_params[val] = name
×
383

384
        for i, arg in enumerate(args):
×
385
            if isinstance(arg, np.ndarray):
×
386
                for dim_idx, dim_val in enumerate(arg.shape):
×
387
                    if dim_val in shape_values:
×
388
                        u_idx = shape_values.index(dim_val)
×
389
                    else:
390
                        u_idx = len(shape_values)
×
391
                        shape_values.append(dim_val)
×
392
                        shape_sources.append((i, dim_idx))
×
393
                    arg_shape_mapping[(i, dim_idx)] = u_idx
×
394

395
        shape_to_scalar = {}
×
396
        for s_idx, s_val in enumerate(shape_values):
×
397
            if s_val in scalar_int_params:
×
398
                shape_to_scalar[s_idx] = scalar_int_params[s_val]
×
399

400
        sdfg, _, _, _ = self._build_sdfg(
×
401
            arg_types, args, arg_shape_mapping, shape_values, shape_to_scalar
402
        )
403
        return sdfg
×
404

405
    def _convert_inputs(self, args: tuple) -> tuple:
4✔
406
        return args
×
407

408
    def _convert_outputs(self, result: Any, original_args: tuple) -> Any:
4✔
409
        return result
×
410

411
    def _get_signature(self, arg_types):
4✔
412
        return ", ".join(self._type_to_str(t) for t in arg_types)
×
413

414
    def _type_to_str(self, t):
4✔
415
        if isinstance(t, Scalar):
4✔
416
            return f"Scalar({t.primitive_type})"
4✔
417
        elif isinstance(t, Array):
4✔
418
            return f"Array({self._type_to_str(t.element_type)}, {t.num_elements})"
×
419
        elif isinstance(t, Pointer):
4✔
420
            return f"Pointer({self._type_to_str(t.pointee_type)})"
4✔
421
        elif isinstance(t, Structure):
4✔
422
            return f"Structure({t.name})"
4✔
423
        return str(t)
×
424

425
    def _infer_type(self, arg):
4✔
426
        if isinstance(arg, (float, np.float64)):
4✔
427
            return Scalar(PrimitiveType.Double)
4✔
428
        elif isinstance(arg, np.float32):
4✔
429
            return Scalar(PrimitiveType.Float)
4✔
430
        elif isinstance(arg, (bool, np.bool_)):
4✔
431
            return Scalar(PrimitiveType.Bool)
4✔
432
        elif isinstance(arg, (int, np.int64)):
4✔
433
            return Scalar(PrimitiveType.Int64)
4✔
434
        elif isinstance(arg, np.int32):
4✔
435
            return Scalar(PrimitiveType.Int32)
4✔
436
        elif isinstance(arg, np.int16):
4✔
437
            return Scalar(PrimitiveType.Int16)
×
438
        elif isinstance(arg, np.int8):
4✔
439
            return Scalar(PrimitiveType.Int8)
×
440
        elif isinstance(arg, np.uint64):
4✔
441
            return Scalar(PrimitiveType.UInt64)
×
442
        elif isinstance(arg, np.uint32):
4✔
443
            return Scalar(PrimitiveType.UInt32)
×
444
        elif isinstance(arg, np.uint16):
4✔
445
            return Scalar(PrimitiveType.UInt16)
×
446
        elif isinstance(arg, np.uint8):
4✔
447
            return Scalar(PrimitiveType.UInt8)
×
448
        elif isinstance(arg, np.ndarray):
4✔
449
            # Map dtype
450
            if arg.dtype == np.float64:
4✔
451
                elem_type = Scalar(PrimitiveType.Double)
4✔
452
            elif arg.dtype == np.float32:
4✔
453
                elem_type = Scalar(PrimitiveType.Float)
4✔
454
            elif arg.dtype == np.float16:
4✔
455
                elem_type = Scalar(PrimitiveType.Half)
×
456
            elif arg.dtype == ml_dtypes.bfloat16:
4✔
457
                elem_type = Scalar(PrimitiveType.BFloat)
4✔
458
            elif arg.dtype == np.bool_:
4✔
459
                elem_type = Scalar(PrimitiveType.Bool)
4✔
460
            elif arg.dtype == np.int64:
4✔
461
                elem_type = Scalar(PrimitiveType.Int64)
4✔
462
            elif arg.dtype == np.int32:
4✔
463
                elem_type = Scalar(PrimitiveType.Int32)
4✔
464
            elif arg.dtype == np.int16:
×
465
                elem_type = Scalar(PrimitiveType.Int16)
×
466
            elif arg.dtype == np.int8:
×
467
                elem_type = Scalar(PrimitiveType.Int8)
×
468
            elif arg.dtype == np.uint64:
×
469
                elem_type = Scalar(PrimitiveType.UInt64)
×
470
            elif arg.dtype == np.uint32:
×
471
                elem_type = Scalar(PrimitiveType.UInt32)
×
472
            elif arg.dtype == np.uint16:
×
473
                elem_type = Scalar(PrimitiveType.UInt16)
×
474
            elif arg.dtype == np.uint8:
×
475
                elem_type = Scalar(PrimitiveType.UInt8)
×
476
            else:
477
                raise ValueError(f"Unsupported numpy dtype: {arg.dtype}")
×
478

479
            return Pointer(elem_type)
4✔
480
        elif isinstance(arg, str):
4✔
481
            # Explicitly reject strings - they are not supported
482
            raise ValueError(f"Unsupported argument type: {type(arg)}")
4✔
483
        else:
484
            # Check if it's a class instance
485
            if hasattr(arg, "__class__") and not isinstance(arg, type):
4✔
486
                # It's an instance of a class, return pointer to Structure
487
                return Pointer(Structure(arg.__class__.__name__))
4✔
488
            raise ValueError(f"Unsupported argument type: {type(arg)}")
×
489

490
    def _build_sdfg(
4✔
491
        self,
492
        arg_types,
493
        args,
494
        arg_shape_mapping,
495
        shape_values,
496
        shape_to_scalar=None,
497
    ):
498
        if shape_to_scalar is None:
4✔
499
            shape_to_scalar = {}
×
500
        sig = inspect.signature(self.func)
4✔
501

502
        # Handle return type - always void for SDFG, output args used for returns
503
        return_type = Scalar(PrimitiveType.Void)
4✔
504
        infer_return_type = True
4✔
505

506
        # Parse return annotation to determine output arguments if possible
507
        explicit_returns = []
4✔
508
        if sig.return_annotation is not inspect.Signature.empty:
4✔
509
            infer_return_type = False
4✔
510

511
            # Helper to normalize annotation to list of types
512
            def normalize_annotation(ann):
4✔
513
                # Handle Tuple[type, ...]
514
                origin = get_origin(ann)
4✔
515
                if origin is tuple:
4✔
516
                    type_args = get_args(ann)
×
517
                    # Tuple[()] or Tuple w/o args
518
                    if not type_args:
×
519
                        return []
×
520
                    # Tuple[int, float]
521
                    if len(type_args) > 0 and type_args[-1] is not Ellipsis:
×
522
                        return [_map_python_type(t) for t in type_args]
×
523
                    # Tuple[int, ...] - not supported for fixed number of returns yet?
524
                    # For now assume fixed tuple
525
                    return [_map_python_type(t) for t in type_args]
×
526
                else:
527
                    return [_map_python_type(ann)]
4✔
528

529
            explicit_returns = normalize_annotation(sig.return_annotation)
4✔
530
            for rt in explicit_returns:
4✔
531
                if not isinstance(rt, Type):
4✔
532
                    # Fallback if map failed (e.g. invalid annotation)
533
                    infer_return_type = True
×
534
                    explicit_returns = []
×
535
                    break
×
536

537
        builder = StructuredSDFGBuilder(f"{self.name}_sdfg", return_type)
4✔
538

539
        # Add pre-defined return arguments if we know them
540
        if not infer_return_type:
4✔
541
            for i, dtype in enumerate(explicit_returns):
4✔
542
                # Scalar -> Pointer(Scalar)
543
                # Array -> Already Pointer(Scalar). Keep it.
544
                arg_type = dtype
4✔
545
                if isinstance(dtype, Scalar):
4✔
546
                    arg_type = Pointer(dtype)
4✔
547

548
                builder.add_container(f"_docc_ret_{i}", arg_type, is_argument=True)
4✔
549

550
        # Register structure types for any class arguments
551
        # Also track member name to index mapping for each structure
552
        structures_to_register = {}
4✔
553
        structure_member_info = {}  # Maps struct_name -> {member_name: (index, type)}
4✔
554
        for i, (arg, dtype) in enumerate(zip(args, arg_types)):
4✔
555
            if isinstance(dtype, Pointer) and dtype.has_pointee_type():
4✔
556
                pointee = dtype.pointee_type
4✔
557
                if isinstance(pointee, Structure):
4✔
558
                    struct_name = pointee.name
4✔
559
                    if struct_name not in structures_to_register:
4✔
560
                        # Get class from arg to introspect members
561
                        if hasattr(arg, "__dict__"):
4✔
562
                            # Use __dict__ to get only instance attributes
563
                            # Sort by name to ensure consistent ordering
564
                            # Note: This alphabetical ordering is used to define the
565
                            # structure layout and must match the order expected by
566
                            # the backend code generation
567
                            member_types = []
4✔
568
                            member_names = []
4✔
569
                            for attr_name, attr_value in sorted(arg.__dict__.items()):
4✔
570
                                if not attr_name.startswith("_"):
4✔
571
                                    # Infer member type from instance attribute
572
                                    # Check bool before int since bool is subclass of int
573
                                    member_type = None
4✔
574
                                    if isinstance(attr_value, bool):
4✔
575
                                        member_type = Scalar(PrimitiveType.Bool)
×
576
                                    elif isinstance(attr_value, (int, np.int64)):
4✔
577
                                        member_type = Scalar(PrimitiveType.Int64)
×
578
                                    elif isinstance(attr_value, (float, np.float64)):
4✔
579
                                        member_type = Scalar(PrimitiveType.Double)
4✔
580
                                    elif isinstance(attr_value, np.int32):
×
581
                                        member_type = Scalar(PrimitiveType.Int32)
×
582
                                    elif isinstance(attr_value, np.float32):
×
583
                                        member_type = Scalar(PrimitiveType.Float)
×
584
                                    # TODO: Consider using np.integer and np.floating abstract types
585
                                    # for more comprehensive numpy type coverage
586
                                    # TODO: Add support for nested structures and arrays
587

588
                                    if member_type is not None:
4✔
589
                                        member_types.append(member_type)
4✔
590
                                        member_names.append(attr_name)
4✔
591

592
                            if member_types:
4✔
593
                                structures_to_register[struct_name] = member_types
4✔
594
                                # Build member name to (index, type) mapping
595
                                structure_member_info[struct_name] = {
4✔
596
                                    name: (idx, mtype)
597
                                    for idx, (name, mtype) in enumerate(
598
                                        zip(member_names, member_types)
599
                                    )
600
                                }
601

602
        # Store structure_member_info for later use in CompiledSDFG
603
        self._last_structure_member_info = structure_member_info
4✔
604

605
        # Register all discovered structures with the builder
606
        for struct_name, member_types in structures_to_register.items():
4✔
607
            builder.add_structure(struct_name, member_types)
4✔
608

609
        # Register arguments
610
        params = list(sig.parameters.items())
4✔
611
        if len(params) != len(arg_types):
4✔
612
            raise ValueError(
×
613
                f"Argument count mismatch: expected {len(params)}, got {len(arg_types)}"
614
            )
615

616
        # Add regular arguments
617
        tensor_table = {}
4✔
618
        for i, ((name, param), dtype, arg) in enumerate(zip(params, arg_types, args)):
4✔
619
            builder.add_container(name, dtype, is_argument=True)
4✔
620

621
            # Store layout information for arrays
622
            if isinstance(arg, np.ndarray):
4✔
623
                element_type = element_type_from_sdfg_type(dtype)
4✔
624

625
                shapes = []
4✔
626
                for dim_idx in range(arg.ndim):
4✔
627
                    dim_val = arg.shape[dim_idx]
4✔
628
                    if dim_val == 1:
4✔
629
                        # Always use literal "1" for size-1 dimensions to enable
630
                        # proper broadcasting detection
631
                        shapes.append("1")
4✔
632
                    else:
633
                        u_idx = arg_shape_mapping[(i, dim_idx)]
4✔
634
                        if u_idx in shape_to_scalar:
4✔
635
                            shapes.append(shape_to_scalar[u_idx])
4✔
636
                        else:
637
                            shapes.append(f"_s{u_idx}")
4✔
638

639
                strides = []
4✔
640
                if arg.flags["C_CONTIGUOUS"]:
4✔
641
                    # Row-major: stride[i] = product of shapes[i+1:]
642
                    for dim_idx in range(arg.ndim):
4✔
643
                        if dim_idx == arg.ndim - 1:
4✔
644
                            strides.append("1")
4✔
645
                        else:
646
                            suffix_shapes = shapes[dim_idx + 1 :]
4✔
647
                            if len(suffix_shapes) == 1:
4✔
648
                                strides.append(suffix_shapes[0])
4✔
649
                            else:
650
                                strides.append("(" + " * ".join(suffix_shapes) + ")")
4✔
651
                elif arg.flags["F_CONTIGUOUS"]:
4✔
652
                    # Column-major: stride[i] = product of shapes[:i]
653
                    for dim_idx in range(arg.ndim):
4✔
654
                        if dim_idx == 0:
4✔
655
                            strides.append("1")
4✔
656
                        else:
657
                            prefix_shapes = shapes[:dim_idx]
4✔
658
                            if len(prefix_shapes) == 1:
4✔
659
                                strides.append(prefix_shapes[0])
4✔
660
                            else:
661
                                strides.append("(" + " * ".join(prefix_shapes) + ")")
4✔
662
                else:
663
                    # Non-contiguous: use actual stride values
664
                    for dim_idx in range(arg.ndim):
4✔
665
                        stride_val = arg.strides[dim_idx] // arg.itemsize
4✔
666
                        strides.append(f"{stride_val}")
4✔
667

668
                offset = "0"
4✔
669
                tensor_table[name] = Tensor(element_type, shapes, strides, offset)
4✔
670

671
        # Add unified shape arguments only for shapes without scalar equivalents
672
        # and skip size-1 dimensions (they use literal "1" instead)
673
        for i in range(len(shape_values)):
4✔
674
            if i not in shape_to_scalar and shape_values[i] != 1:
4✔
675
                builder.add_container(
4✔
676
                    f"_s{i}", Scalar(PrimitiveType.Int64), is_argument=True
677
                )
678

679
        # Create symbol table for parser
680
        container_table = {}
4✔
681
        for i, ((name, param), dtype, arg) in enumerate(zip(params, arg_types, args)):
4✔
682
            container_table[name] = dtype
4✔
683

684
        for i in range(len(shape_values)):
4✔
685
            if i not in shape_to_scalar and shape_values[i] != 1:
4✔
686
                container_table[f"_s{i}"] = Scalar(PrimitiveType.Int64)
4✔
687

688
        # Parse AST
689
        source_lines, start_line = inspect.getsourcelines(self.func)
4✔
690
        source = textwrap.dedent("".join(source_lines))
4✔
691
        tree = ast.parse(source)
4✔
692
        ast.increment_lineno(tree, start_line - 1)
4✔
693
        func_def = tree.body[0]
4✔
694

695
        filename = inspect.getsourcefile(self.func)
4✔
696
        function_name = self.func.__name__
4✔
697

698
        # Combine globals with closure variables (closure takes precedence)
699
        combined_globals = dict(self.func.__globals__)
4✔
700
        if self.func.__closure__ is not None and self.func.__code__.co_freevars:
4✔
701
            for name, cell in zip(
4✔
702
                self.func.__code__.co_freevars, self.func.__closure__
703
            ):
704
                combined_globals[name] = cell.cell_contents
4✔
705

706
        parser = ASTParser(
4✔
707
            builder,
708
            tensor_table,
709
            container_table,
710
            filename,
711
            function_name,
712
            infer_return_type=infer_return_type,
713
            globals_dict=combined_globals,
714
            structure_member_info=structure_member_info,
715
        )
716
        for node in func_def.body:
4✔
717
            parser.visit(node)
4✔
718

719
        # Emit hoisted allocations at function entry
720
        parser.memory_handler.emit_allocations()
4✔
721

722
        sdfg = builder.move()
4✔
723
        # Mark return arguments metadata
724
        out_args = []
4✔
725
        for name in sdfg.arguments:
4✔
726
            if name.startswith("_docc_ret_"):
4✔
727
                out_args.append(name)
4✔
728

729
        return (
4✔
730
            sdfg,
731
            out_args,
732
            parser.captured_return_shapes,
733
            parser.captured_return_strides,
734
        )
735

736

737
def native(
4✔
738
    func=None,
739
    *,
740
    target="none",
741
    category="server",
742
    instrumentation_mode=None,
743
    capture_args=None,
744
    remote_tuning=False,
745
):
746
    """Decorator to create a PythonProgram from a Python function.
747

748
    Example:
749
        @native
750
        def my_function(x: np.ndarray) -> np.ndarray:
751
            return x * 2
752

753
        result = my_function(np.array([1.0, 2.0, 3.0]))
754
    """
755
    if func is None:
4✔
756
        return lambda f: PythonProgram(
4✔
757
            f,
758
            target=target,
759
            category=category,
760
            instrumentation_mode=instrumentation_mode,
761
            capture_args=capture_args,
762
            remote_tuning=remote_tuning,
763
        )
764
    return PythonProgram(
4✔
765
        func,
766
        target=target,
767
        category=category,
768
        instrumentation_mode=instrumentation_mode,
769
        capture_args=capture_args,
770
        remote_tuning=remote_tuning,
771
    )
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc