• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

daisytuner / docc / 22023884668

14 Feb 2026 08:36PM UTC coverage: 64.903% (-1.4%) from 66.315%
22023884668

Pull #525

github

web-flow
Merge 1d47f8bf2 into 9d01cacd5
Pull Request #525: Step 3 (Native Tensor Support): Refactor Python Frontend

2522 of 3435 new or added lines in 32 files covered. (73.42%)

320 existing lines in 15 files now uncovered.

23204 of 35752 relevant lines covered (64.9%)

370.03 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

70.8
/python/docc/python/python_program.py
1
import inspect
4✔
2
import shutil
4✔
3
import textwrap
4✔
4
import ast
4✔
5
import os
4✔
6
import getpass
4✔
7
import hashlib
4✔
8
import numpy as np
4✔
9
from typing import Annotated, get_origin, get_args, Any, Optional
4✔
10

11
from docc.sdfg import (
4✔
12
    Scalar,
13
    PrimitiveType,
14
    Pointer,
15
    Structure,
16
    Array,
17
    Type,
18
    StructuredSDFG,
19
    StructuredSDFGBuilder,
20
)
21
from docc.compiler.docc_program import DoccProgram
4✔
22
from docc.compiler.compiled_sdfg import CompiledSDFG
4✔
23
from docc.python.ast_parser import ASTParser
4✔
24

25

26
def _compile_wrapper(self, output_folder=None):
4✔
27
    """Wrapper to allow StructuredSDFG.compile() to return a CompiledSDFG."""
28
    lib_path = self._compile(output_folder)
×
29
    return CompiledSDFG(lib_path, self)
×
30

31

32
# Monkey-patch StructuredSDFG to add compile method
33
StructuredSDFG.compile = _compile_wrapper
4✔
34

35

36
def _map_python_type(dtype):
4✔
37
    """Map Python/numpy types to SDFG types."""
38
    # If it is already a sdfg Type, return it
39
    if isinstance(dtype, Type):
4✔
40
        return dtype
×
41

42
    # Handle Annotated for Arrays
43
    if get_origin(dtype) is Annotated:
4✔
44
        args = get_args(dtype)
4✔
45
        base_type = args[0]
4✔
46
        metadata = args[1:]
4✔
47

48
        if base_type is np.ndarray:
4✔
49
            # Convention: Annotated[np.ndarray, shape, dtype]
50
            shape = metadata[0]
4✔
51
            elem_type = Scalar(PrimitiveType.Double)  # Default
4✔
52

53
            if len(metadata) > 1:
4✔
54
                possible_dtype = metadata[1]
4✔
55
                elem_type = _map_python_type(possible_dtype)
4✔
56

57
            return Pointer(elem_type)
4✔
58

59
    # Handle numpy.ndarray[Shape, DType]
60
    if get_origin(dtype) is np.ndarray:
4✔
61
        args = get_args(dtype)
×
62
        # args[0] is shape, args[1] is dtype
63
        if len(args) >= 2:
×
64
            elem_type = _map_python_type(args[1])
×
65
            return Pointer(elem_type)
×
66

67
    # Simple mapping for python types
68
    if dtype is float or dtype is np.float64:
4✔
69
        return Scalar(PrimitiveType.Double)
4✔
70
    elif dtype is np.float32:
4✔
71
        return Scalar(PrimitiveType.Float)
×
72
    elif dtype is bool or dtype is np.bool_:
4✔
73
        return Scalar(PrimitiveType.Bool)
4✔
74
    elif dtype is int or dtype is np.int64:
4✔
75
        return Scalar(PrimitiveType.Int64)
4✔
76
    elif dtype is np.int32:
4✔
77
        return Scalar(PrimitiveType.Int32)
4✔
UNCOV
78
    elif dtype is np.int16:
×
79
        return Scalar(PrimitiveType.Int16)
×
UNCOV
80
    elif dtype is np.int8:
×
81
        return Scalar(PrimitiveType.Int8)
×
UNCOV
82
    elif dtype is np.uint64:
×
83
        return Scalar(PrimitiveType.UInt64)
×
UNCOV
84
    elif dtype is np.uint32:
×
85
        return Scalar(PrimitiveType.UInt32)
×
UNCOV
86
    elif dtype is np.uint16:
×
87
        return Scalar(PrimitiveType.UInt16)
×
UNCOV
88
    elif dtype is np.uint8:
×
89
        return Scalar(PrimitiveType.UInt8)
×
90

91
    # Handle Python classes - map to Structure type
UNCOV
92
    if inspect.isclass(dtype):
×
93
        # Use the class name as the structure name
UNCOV
94
        return Pointer(Structure(dtype.__name__))
×
95

96
    return dtype
×
97

98

99
class PythonProgram(DoccProgram):
4✔
100

101
    def __init__(
4✔
102
        self,
103
        func,
104
        target: str = "none",
105
        category: str = "server",
106
        instrumentation_mode: Optional[str] = None,
107
        capture_args: Optional[bool] = None,
108
        remote_tuning: bool = False,
109
    ):
110
        super().__init__(
4✔
111
            name=func.__name__,
112
            target=target,
113
            category=category,
114
            instrumentation_mode=instrumentation_mode,
115
            capture_args=capture_args,
116
            remote_tuning=remote_tuning,
117
        )
118
        self.func = func
4✔
119
        self._last_structure_member_info = {}
4✔
120

121
    def __call__(self, *args: Any) -> Any:
4✔
122
        # JIT compile and run
123
        compiled = self.compile(*args)
4✔
124
        res = compiled(*args)
4✔
125

126
        # Handle return value conversion based on annotation
127
        sig = inspect.signature(self.func)
4✔
128
        ret_annotation = sig.return_annotation
4✔
129

130
        if ret_annotation is not inspect.Signature.empty:
4✔
131
            if get_origin(ret_annotation) is Annotated:
4✔
132
                type_args = get_args(ret_annotation)
4✔
133
                if len(type_args) >= 1 and type_args[0] is np.ndarray:
4✔
134
                    shape = None
4✔
135
                    if len(type_args) >= 2:
4✔
136
                        shape = type_args[1]
4✔
137

138
                    if shape is not None:
4✔
139
                        try:
4✔
140
                            return np.ctypeslib.as_array(res, shape=shape)
4✔
141
                        except Exception:
×
142
                            pass
×
143

144
        # Try to infer return shape from metadata
145
        if hasattr(compiled, "get_return_shape"):
4✔
146
            shape = compiled.get_return_shape(*args)
4✔
147
            if shape is not None:
4✔
148
                try:
×
149
                    return np.ctypeslib.as_array(res, shape=shape)
×
150
                except Exception:
×
151
                    pass
×
152

153
        return res
4✔
154

155
    def compile(
4✔
156
        self,
157
        *args: Any,
158
        output_folder: Optional[str] = None,
159
        instrumentation_mode: Optional[str] = None,
160
        capture_args: Optional[bool] = None,
161
    ) -> CompiledSDFG:
162
        original_output_folder = output_folder
4✔
163

164
        # Resolve options
165
        if instrumentation_mode is None:
4✔
166
            instrumentation_mode = self.instrumentation_mode
4✔
167
        if capture_args is None:
4✔
168
            capture_args = self.capture_args
4✔
169

170
        # Check environment variable DOCC_CI
171
        docc_ci = os.environ.get("DOCC_CI", "")
4✔
172
        if docc_ci:
4✔
173
            if docc_ci == "regions":
×
174
                if instrumentation_mode is None:
×
175
                    instrumentation_mode = "ols"
×
176
            elif docc_ci == "arg-capture":
×
177
                if capture_args is None:
×
178
                    capture_args = True
×
179
            else:
180
                # Full mode (or unknown value treated as full)
181
                if instrumentation_mode is None:
×
182
                    instrumentation_mode = "ols"
×
183
                if capture_args is None:
×
184
                    capture_args = True
×
185

186
        # Defaults
187
        if instrumentation_mode is None:
4✔
188
            instrumentation_mode = ""
4✔
189
        if capture_args is None:
4✔
190
            capture_args = False
4✔
191

192
        # 1. Analyze arguments and shapes
193
        arg_types = []
4✔
194
        shape_values = []  # List of unique shape values found
4✔
195
        shape_sources = []  # List of (arg_idx, dim_idx) for each unique shape value
4✔
196

197
        # Mapping from (arg_idx, dim_idx) -> unique_shape_idx
198
        arg_shape_mapping = {}
4✔
199

200
        # First pass: collect scalar integer arguments and their values
201
        sig = inspect.signature(self.func)
4✔
202
        params = list(sig.parameters.items())
4✔
203
        scalar_int_params = {}  # Maps value -> parameter name (first one wins)
4✔
204
        for i, ((name, param), arg) in enumerate(zip(params, args)):
4✔
205
            if isinstance(arg, (int, np.integer)) and not isinstance(
4✔
206
                arg, (bool, np.bool_)
207
            ):
208
                val = int(arg)
4✔
209
                if val not in scalar_int_params:
4✔
210
                    scalar_int_params[val] = name
4✔
211

212
        for i, arg in enumerate(args):
4✔
213
            t = self._infer_type(arg)
4✔
214
            arg_types.append(t)
4✔
215

216
            if isinstance(arg, np.ndarray):
4✔
217
                for dim_idx, dim_val in enumerate(arg.shape):
4✔
218
                    # Check if we've seen this value
219
                    if dim_val in shape_values:
4✔
220
                        # Reuse
221
                        u_idx = shape_values.index(dim_val)
4✔
222
                    else:
223
                        # New
224
                        u_idx = len(shape_values)
4✔
225
                        shape_values.append(dim_val)
4✔
226
                        shape_sources.append((i, dim_idx))
4✔
227

228
                    arg_shape_mapping[(i, dim_idx)] = u_idx
4✔
229

230
        # Detect scalar-shape equivalences: which shape indices have a matching scalar param
231
        # Maps unique_shape_idx -> scalar parameter name
232
        shape_to_scalar = {}
4✔
233
        for s_idx, s_val in enumerate(shape_values):
4✔
234
            if s_val in scalar_int_params:
4✔
235
                shape_to_scalar[s_idx] = scalar_int_params[s_val]
4✔
236

237
        # 2. Signature - include scalar-shape equivalences for correct caching
238
        mapping_sig = sorted(arg_shape_mapping.items())
4✔
239
        equiv_sig = sorted(shape_to_scalar.items())
4✔
240
        type_sig = ", ".join(self._type_to_str(t) for t in arg_types)
4✔
241
        signature = f"{type_sig}|{mapping_sig}|{equiv_sig}"
4✔
242

243
        if output_folder is None:
4✔
244
            filename = inspect.getsourcefile(self.func)
4✔
245
            hash_input = f"{filename}|{self.name}|{self.target}|{self.category}|{self.capture_args}|{self.instrumentation_mode}|{signature}".encode(
4✔
246
                "utf-8"
247
            )
248
            stable_id = hashlib.sha256(hash_input).hexdigest()[:16]
4✔
249

250
            docc_tmp = os.environ.get("DOCC_TMP")
4✔
251
            if docc_tmp:
4✔
252
                output_folder = f"{docc_tmp}/{self.name}-{stable_id}"
×
253
            else:
254
                user = getpass.getuser()
4✔
255
                output_folder = f"/tmp/{user}/DOCC/{self.name}-{stable_id}"
4✔
256

257
        if original_output_folder is None and signature in self.cache:
4✔
258
            return self.cache[signature]
4✔
259

260
        # 3. Build SDFG
261
        if os.path.exists(output_folder):
4✔
262
            # Multiple python processes running the same code?
263
            shutil.rmtree(output_folder)
4✔
264
        sdfg, out_args, out_shapes = self._build_sdfg(
4✔
265
            arg_types, args, arg_shape_mapping, len(shape_values), shape_to_scalar
266
        )
267
        sdfg.validate()
4✔
268

269
        # Tensor targets keep tensor nodes
270
        if self.target != "onnx":
4✔
271
            sdfg.expand()
4✔
272

273
        # Simplify pipelines
274
        sdfg.simplify()
4✔
275

276
        # Normalization for scheduling
277
        if self.target != "none":
4✔
278
            sdfg.normalize()
4✔
279

280
        sdfg.dump(output_folder)
4✔
281

282
        # Schedule if target is specified
283
        if self.target != "none":
4✔
284
            sdfg.schedule(self.target, self.category, self.remote_tuning)
4✔
285

286
        self.last_sdfg = sdfg
4✔
287

288
        lib_path = sdfg._compile(
4✔
289
            output_folder=output_folder,
290
            target=self.target,
291
            instrumentation_mode=instrumentation_mode,
292
            capture_args=capture_args,
293
        )
294

295
        # Build ONNX model from JSON if target is onnx (after _compile creates the JSON)
296
        if self.target == "onnx":
4✔
NEW
297
            from docc.python.targets.onnx_model_builder import convert_json_to_onnx
×
298

UNCOV
299
            onnx_model_path = convert_json_to_onnx(output_folder)
×
UNCOV
300
            if onnx_model_path:
×
UNCOV
301
                print(f"Generated ONNX models: {onnx_model_path}")
×
302

303
        # 5. Create CompiledSDFG
304
        compiled = CompiledSDFG(
4✔
305
            lib_path,
306
            sdfg,
307
            shape_sources,
308
            self._last_structure_member_info,
309
            out_args,
310
            out_shapes,
311
        )
312

313
        # Cache if using default output folder
314
        if original_output_folder is None:
4✔
315
            self.cache[signature] = compiled
4✔
316

317
        return compiled
4✔
318

319
    def to_sdfg(self, *args: Any) -> StructuredSDFG:
4✔
320
        arg_types = [self._infer_type(arg) for arg in args]
×
321

322
        # Build shape mapping
323
        shape_values = []
×
324
        shape_sources = []
×
325
        arg_shape_mapping = {}
×
326

327
        sig = inspect.signature(self.func)
×
328
        params = list(sig.parameters.items())
×
329
        scalar_int_params = {}
×
330
        for i, ((name, param), arg) in enumerate(zip(params, args)):
×
331
            if isinstance(arg, (int, np.integer)) and not isinstance(
×
332
                arg, (bool, np.bool_)
333
            ):
334
                val = int(arg)
×
335
                if val not in scalar_int_params:
×
336
                    scalar_int_params[val] = name
×
337

338
        for i, arg in enumerate(args):
×
339
            if isinstance(arg, np.ndarray):
×
340
                for dim_idx, dim_val in enumerate(arg.shape):
×
341
                    if dim_val in shape_values:
×
342
                        u_idx = shape_values.index(dim_val)
×
343
                    else:
344
                        u_idx = len(shape_values)
×
345
                        shape_values.append(dim_val)
×
346
                        shape_sources.append((i, dim_idx))
×
347
                    arg_shape_mapping[(i, dim_idx)] = u_idx
×
348

349
        shape_to_scalar = {}
×
350
        for s_idx, s_val in enumerate(shape_values):
×
351
            if s_val in scalar_int_params:
×
352
                shape_to_scalar[s_idx] = scalar_int_params[s_val]
×
353

354
        sdfg, _, _ = self._build_sdfg(
×
355
            arg_types, args, arg_shape_mapping, len(shape_values), shape_to_scalar
356
        )
357
        return sdfg
×
358

359
    def _convert_inputs(self, args: tuple) -> tuple:
4✔
360
        return args
×
361

362
    def _convert_outputs(self, result: Any, original_args: tuple) -> Any:
4✔
363
        return result
×
364

365
    def _get_signature(self, arg_types):
4✔
366
        return ", ".join(self._type_to_str(t) for t in arg_types)
×
367

368
    def _type_to_str(self, t):
4✔
369
        if isinstance(t, Scalar):
4✔
370
            return f"Scalar({t.primitive_type})"
4✔
371
        elif isinstance(t, Array):
4✔
372
            return f"Array({self._type_to_str(t.element_type)}, {t.num_elements})"
×
373
        elif isinstance(t, Pointer):
4✔
374
            return f"Pointer({self._type_to_str(t.pointee_type)})"
4✔
375
        elif isinstance(t, Structure):
4✔
376
            return f"Structure({t.name})"
4✔
377
        return str(t)
×
378

379
    def _infer_type(self, arg):
4✔
380
        if isinstance(arg, (float, np.float64)):
4✔
381
            return Scalar(PrimitiveType.Double)
4✔
382
        elif isinstance(arg, np.float32):
4✔
383
            return Scalar(PrimitiveType.Float)
4✔
384
        elif isinstance(arg, (bool, np.bool_)):
4✔
385
            return Scalar(PrimitiveType.Bool)
4✔
386
        elif isinstance(arg, (int, np.int64)):
4✔
387
            return Scalar(PrimitiveType.Int64)
4✔
388
        elif isinstance(arg, np.int32):
4✔
389
            return Scalar(PrimitiveType.Int32)
4✔
390
        elif isinstance(arg, np.int16):
4✔
391
            return Scalar(PrimitiveType.Int16)
×
392
        elif isinstance(arg, np.int8):
4✔
393
            return Scalar(PrimitiveType.Int8)
×
394
        elif isinstance(arg, np.uint64):
4✔
395
            return Scalar(PrimitiveType.UInt64)
×
396
        elif isinstance(arg, np.uint32):
4✔
397
            return Scalar(PrimitiveType.UInt32)
×
398
        elif isinstance(arg, np.uint16):
4✔
399
            return Scalar(PrimitiveType.UInt16)
×
400
        elif isinstance(arg, np.uint8):
4✔
401
            return Scalar(PrimitiveType.UInt8)
×
402
        elif isinstance(arg, np.ndarray):
4✔
403
            # Map dtype
404
            if arg.dtype == np.float64:
4✔
405
                elem_type = Scalar(PrimitiveType.Double)
4✔
406
            elif arg.dtype == np.float32:
4✔
407
                elem_type = Scalar(PrimitiveType.Float)
4✔
408
            elif arg.dtype == np.bool_:
4✔
409
                elem_type = Scalar(PrimitiveType.Bool)
4✔
410
            elif arg.dtype == np.int64:
4✔
411
                elem_type = Scalar(PrimitiveType.Int64)
4✔
412
            elif arg.dtype == np.int32:
4✔
413
                elem_type = Scalar(PrimitiveType.Int32)
4✔
414
            elif arg.dtype == np.int16:
×
415
                elem_type = Scalar(PrimitiveType.Int16)
×
416
            elif arg.dtype == np.int8:
×
417
                elem_type = Scalar(PrimitiveType.Int8)
×
418
            elif arg.dtype == np.uint64:
×
419
                elem_type = Scalar(PrimitiveType.UInt64)
×
420
            elif arg.dtype == np.uint32:
×
421
                elem_type = Scalar(PrimitiveType.UInt32)
×
422
            elif arg.dtype == np.uint16:
×
423
                elem_type = Scalar(PrimitiveType.UInt16)
×
424
            elif arg.dtype == np.uint8:
×
425
                elem_type = Scalar(PrimitiveType.UInt8)
×
426
            else:
427
                raise ValueError(f"Unsupported numpy dtype: {arg.dtype}")
×
428

429
            return Pointer(elem_type)
4✔
430
        elif isinstance(arg, str):
4✔
431
            # Explicitly reject strings - they are not supported
432
            raise ValueError(f"Unsupported argument type: {type(arg)}")
4✔
433
        else:
434
            # Check if it's a class instance
435
            if hasattr(arg, "__class__") and not isinstance(arg, type):
4✔
436
                # It's an instance of a class, return pointer to Structure
437
                return Pointer(Structure(arg.__class__.__name__))
4✔
438
            raise ValueError(f"Unsupported argument type: {type(arg)}")
×
439

440
    def _build_sdfg(
4✔
441
        self,
442
        arg_types,
443
        args,
444
        arg_shape_mapping,
445
        num_unique_shapes,
446
        shape_to_scalar=None,
447
    ):
448
        if shape_to_scalar is None:
4✔
UNCOV
449
            shape_to_scalar = {}
×
450
        sig = inspect.signature(self.func)
4✔
451

452
        # Handle return type - always void for SDFG, output args used for returns
453
        return_type = Scalar(PrimitiveType.Void)
4✔
454
        infer_return_type = True
4✔
455

456
        # Parse return annotation to determine output arguments if possible
457
        explicit_returns = []
4✔
458
        if sig.return_annotation is not inspect.Signature.empty:
4✔
459
            infer_return_type = False
4✔
460

461
            # Helper to normalize annotation to list of types
462
            def normalize_annotation(ann):
4✔
463
                # Handle Tuple[type, ...]
464
                origin = get_origin(ann)
4✔
465
                if origin is tuple:
4✔
466
                    type_args = get_args(ann)
×
467
                    # Tuple[()] or Tuple w/o args
468
                    if not type_args:
×
469
                        return []
×
470
                    # Tuple[int, float]
471
                    if len(type_args) > 0 and type_args[-1] is not Ellipsis:
×
472
                        return [_map_python_type(t) for t in type_args]
×
473
                    # Tuple[int, ...] - not supported for fixed number of returns yet?
474
                    # For now assume fixed tuple
475
                    return [_map_python_type(t) for t in type_args]
×
476
                else:
477
                    return [_map_python_type(ann)]
4✔
478

479
            explicit_returns = normalize_annotation(sig.return_annotation)
4✔
480
            for rt in explicit_returns:
4✔
481
                if not isinstance(rt, Type):
4✔
482
                    # Fallback if map failed (e.g. invalid annotation)
483
                    infer_return_type = True
×
484
                    explicit_returns = []
×
485
                    break
×
486

487
        builder = StructuredSDFGBuilder(f"{self.name}_sdfg", return_type)
4✔
488

489
        # Add pre-defined return arguments if we know them
490
        if not infer_return_type:
4✔
491
            for i, dtype in enumerate(explicit_returns):
4✔
492
                # Scalar -> Pointer(Scalar)
493
                # Array -> Already Pointer(Scalar). Keep it.
494
                arg_type = dtype
4✔
495
                if isinstance(dtype, Scalar):
4✔
496
                    arg_type = Pointer(dtype)
4✔
497

498
                builder.add_container(f"_docc_ret_{i}", arg_type, is_argument=True)
4✔
499

500
        # Register structure types for any class arguments
501
        # Also track member name to index mapping for each structure
502
        structures_to_register = {}
4✔
503
        structure_member_info = {}  # Maps struct_name -> {member_name: (index, type)}
4✔
504
        for i, (arg, dtype) in enumerate(zip(args, arg_types)):
4✔
505
            if isinstance(dtype, Pointer) and dtype.has_pointee_type():
4✔
506
                pointee = dtype.pointee_type
4✔
507
                if isinstance(pointee, Structure):
4✔
508
                    struct_name = pointee.name
4✔
509
                    if struct_name not in structures_to_register:
4✔
510
                        # Get class from arg to introspect members
511
                        if hasattr(arg, "__dict__"):
4✔
512
                            # Use __dict__ to get only instance attributes
513
                            # Sort by name to ensure consistent ordering
514
                            # Note: This alphabetical ordering is used to define the
515
                            # structure layout and must match the order expected by
516
                            # the backend code generation
517
                            member_types = []
4✔
518
                            member_names = []
4✔
519
                            for attr_name, attr_value in sorted(arg.__dict__.items()):
4✔
520
                                if not attr_name.startswith("_"):
4✔
521
                                    # Infer member type from instance attribute
522
                                    # Check bool before int since bool is subclass of int
523
                                    member_type = None
4✔
524
                                    if isinstance(attr_value, bool):
4✔
525
                                        member_type = Scalar(PrimitiveType.Bool)
×
526
                                    elif isinstance(attr_value, (int, np.int64)):
4✔
527
                                        member_type = Scalar(PrimitiveType.Int64)
×
528
                                    elif isinstance(attr_value, (float, np.float64)):
4✔
529
                                        member_type = Scalar(PrimitiveType.Double)
4✔
530
                                    elif isinstance(attr_value, np.int32):
×
531
                                        member_type = Scalar(PrimitiveType.Int32)
×
532
                                    elif isinstance(attr_value, np.float32):
×
533
                                        member_type = Scalar(PrimitiveType.Float)
×
534
                                    # TODO: Consider using np.integer and np.floating abstract types
535
                                    # for more comprehensive numpy type coverage
536
                                    # TODO: Add support for nested structures and arrays
537

538
                                    if member_type is not None:
4✔
539
                                        member_types.append(member_type)
4✔
540
                                        member_names.append(attr_name)
4✔
541

542
                            if member_types:
4✔
543
                                structures_to_register[struct_name] = member_types
4✔
544
                                # Build member name to (index, type) mapping
545
                                structure_member_info[struct_name] = {
4✔
546
                                    name: (idx, mtype)
547
                                    for idx, (name, mtype) in enumerate(
548
                                        zip(member_names, member_types)
549
                                    )
550
                                }
551

552
        # Store structure_member_info for later use in CompiledSDFG
553
        self._last_structure_member_info = structure_member_info
4✔
554

555
        # Register all discovered structures with the builder
556
        for struct_name, member_types in structures_to_register.items():
4✔
557
            builder.add_structure(struct_name, member_types)
4✔
558

559
        # Register arguments
560
        params = list(sig.parameters.items())
4✔
561
        if len(params) != len(arg_types):
4✔
562
            raise ValueError(
×
563
                f"Argument count mismatch: expected {len(params)}, got {len(arg_types)}"
564
            )
565

566
        array_info = {}
4✔
567

568
        # Add regular arguments
569
        for i, ((name, param), dtype, arg) in enumerate(zip(params, arg_types, args)):
4✔
570
            builder.add_container(name, dtype, is_argument=True)
4✔
571

572
            # If it's an array, prepare shape info
573
            if isinstance(arg, np.ndarray):
4✔
574
                shapes = []
4✔
575
                for dim_idx in range(arg.ndim):
4✔
576
                    u_idx = arg_shape_mapping[(i, dim_idx)]
4✔
577
                    # Use scalar parameter name if there's an equivalence, otherwise _sX
578
                    if u_idx in shape_to_scalar:
4✔
579
                        shapes.append(shape_to_scalar[u_idx])
4✔
580
                    else:
581
                        shapes.append(f"_s{u_idx}")
4✔
582

583
                array_info[name] = {"ndim": arg.ndim, "shapes": shapes}
4✔
584

585
        # Add unified shape arguments only for shapes without scalar equivalents
586
        for i in range(num_unique_shapes):
4✔
587
            if i not in shape_to_scalar:
4✔
588
                builder.add_container(
4✔
589
                    f"_s{i}", Scalar(PrimitiveType.Int64), is_argument=True
590
                )
591

592
        # Create symbol table for parser
593
        symbol_table = {}
4✔
594
        for i, ((name, param), dtype, arg) in enumerate(zip(params, arg_types, args)):
4✔
595
            symbol_table[name] = dtype
4✔
596

597
        for i in range(num_unique_shapes):
4✔
598
            if i not in shape_to_scalar:
4✔
599
                symbol_table[f"_s{i}"] = Scalar(PrimitiveType.Int64)
4✔
600

601
        # Parse AST
602
        source_lines, start_line = inspect.getsourcelines(self.func)
4✔
603
        source = textwrap.dedent("".join(source_lines))
4✔
604
        tree = ast.parse(source)
4✔
605
        ast.increment_lineno(tree, start_line - 1)
4✔
606
        func_def = tree.body[0]
4✔
607

608
        filename = inspect.getsourcefile(self.func)
4✔
609
        function_name = self.func.__name__
4✔
610

611
        # Combine globals with closure variables (closure takes precedence)
612
        combined_globals = dict(self.func.__globals__)
4✔
613
        if self.func.__closure__ is not None and self.func.__code__.co_freevars:
4✔
614
            for name, cell in zip(
4✔
615
                self.func.__code__.co_freevars, self.func.__closure__
616
            ):
617
                combined_globals[name] = cell.cell_contents
4✔
618

619
        parser = ASTParser(
4✔
620
            builder,
621
            array_info,
622
            symbol_table,
623
            filename,
624
            function_name,
625
            infer_return_type=infer_return_type,
626
            globals_dict=combined_globals,
627
            structure_member_info=structure_member_info,
628
        )
629
        for node in func_def.body:
4✔
630
            parser.visit(node)
4✔
631

632
        sdfg = builder.move()
4✔
633
        # Mark return arguments metadata
634
        out_args = []
4✔
635
        for name in sdfg.arguments:
4✔
636
            if name.startswith("_docc_ret_"):
4✔
637
                out_args.append(name)
4✔
638

639
        return sdfg, out_args, parser.captured_return_shapes
4✔
640

641

642
def native(
4✔
643
    func=None,
644
    *,
645
    target="none",
646
    category="desktop",
647
    instrumentation_mode=None,
648
    capture_args=None,
649
):
650
    """Decorator to create a PythonProgram from a Python function.
651

652
    Example:
653
        @native
654
        def my_function(x: np.ndarray) -> np.ndarray:
655
            return x * 2
656

657
        result = my_function(np.array([1.0, 2.0, 3.0]))
658
    """
659
    if func is None:
4✔
660
        return lambda f: PythonProgram(
4✔
661
            f,
662
            target=target,
663
            category=category,
664
            instrumentation_mode=instrumentation_mode,
665
            capture_args=capture_args,
666
        )
667
    return PythonProgram(
4✔
668
        func,
669
        target=target,
670
        category=category,
671
        instrumentation_mode=instrumentation_mode,
672
        capture_args=capture_args,
673
    )
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc