• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

daisytuner / sdfglib / 21113623600

18 Jan 2026 02:50PM UTC coverage: 64.425% (+0.3%) from 64.154%
21113623600

Pull #462

github

web-flow
Merge d503e5691 into 92e9cbdc3
Pull Request #462: adds syntax support for multi-assignments and np.empty_like

221 of 258 new or added lines in 5 files covered. (85.66%)

21 existing lines in 4 files now uncovered.

19678 of 30544 relevant lines covered (64.43%)

385.56 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

82.59
/python/docc/__init__.py
1
import inspect
3✔
2
import shutil
3✔
3
import textwrap
3✔
4
import ast
3✔
5
import os
3✔
6
import getpass
3✔
7
import hashlib
3✔
8
import numpy as np
3✔
9
from typing import Annotated, get_origin, get_args
3✔
10
from ._sdfg import *
3✔
11
from .compiled_sdfg import CompiledSDFG
3✔
12
from .ast_parser import ASTParser
3✔
13

14

15
def _compile_wrapper(self, output_folder=None):
3✔
16
    lib_path = self._compile(output_folder)
×
17
    return CompiledSDFG(lib_path, self)
×
18

19

20
StructuredSDFG.compile = _compile_wrapper
3✔
21

22

23
def _map_python_type(dtype):
3✔
24
    # If it is already a sdfg Type, return it
25
    if isinstance(dtype, Type):
3✔
26
        return dtype
×
27

28
    # Handle Annotated for Arrays
29
    if get_origin(dtype) is Annotated:
3✔
30
        args = get_args(dtype)
3✔
31
        base_type = args[0]
3✔
32
        metadata = args[1:]
3✔
33

34
        if base_type is np.ndarray:
3✔
35
            # Convention: Annotated[np.ndarray, shape, dtype]
36
            shape = metadata[0]
3✔
37
            elem_type = Scalar(PrimitiveType.Double)  # Default
3✔
38

39
            if len(metadata) > 1:
3✔
40
                possible_dtype = metadata[1]
3✔
41
                elem_type = _map_python_type(possible_dtype)
3✔
42

43
            return Pointer(elem_type)
3✔
44

45
    # Handle numpy.ndarray[Shape, DType]
46
    if get_origin(dtype) is np.ndarray:
3✔
47
        args = get_args(dtype)
×
48
        # args[0] is shape, args[1] is dtype
49
        if len(args) >= 2:
×
50
            elem_type = _map_python_type(args[1])
×
51
            return Pointer(elem_type)
×
52

53
    # Simple mapping for python types
54
    if dtype is float or dtype is np.float64:
3✔
55
        return Scalar(PrimitiveType.Double)
3✔
56
    elif dtype is int or dtype is np.int64:
3✔
57
        return Scalar(PrimitiveType.Int64)
3✔
58
    elif dtype is bool or dtype is np.bool_:
3✔
59
        return Scalar(PrimitiveType.Bool)
3✔
60
    elif dtype is np.float32:
3✔
61
        return Scalar(PrimitiveType.Float)
×
62
    elif dtype is np.int32:
3✔
63
        return Scalar(PrimitiveType.Int32)
×
64

65
    # Handle Python classes - map to Structure type
66
    if inspect.isclass(dtype):
3✔
67
        # Use the class name as the structure name
68
        return Pointer(Structure(dtype.__name__))
3✔
69

70
    return dtype
×
71

72

73
class DoccProgram:
3✔
74
    def __init__(
3✔
75
        self,
76
        func,
77
        target="none",
78
        category="desktop",
79
        instrumentation_mode=None,
80
        capture_args=None,
81
    ):
82
        self.func = func
3✔
83
        self.name = func.__name__
3✔
84
        self.target = target
3✔
85
        self.category = category
3✔
86
        self.instrumentation_mode = instrumentation_mode
3✔
87
        self.capture_args = capture_args
3✔
88
        self.last_sdfg = None
3✔
89
        self.cache = {}
3✔
90

91
    def __call__(self, *args):
3✔
92
        # JIT compile and run
93
        compiled = self.compile(*args)
3✔
94
        res = compiled(*args)
3✔
95

96
        # Handle return value conversion based on annotation
97
        sig = inspect.signature(self.func)
3✔
98
        ret_annotation = sig.return_annotation
3✔
99

100
        if ret_annotation is not inspect.Signature.empty:
3✔
101
            if get_origin(ret_annotation) is Annotated:
3✔
102
                type_args = get_args(ret_annotation)
3✔
103
                if len(type_args) >= 1 and type_args[0] is np.ndarray:
3✔
104
                    shape = None
3✔
105
                    if len(type_args) >= 2:
3✔
106
                        shape = type_args[1]
3✔
107

108
                    if shape is not None:
3✔
109
                        try:
3✔
110
                            return np.ctypeslib.as_array(res, shape=shape)
3✔
111
                        except Exception:
×
112
                            pass
×
113

114
        # Try to infer return shape from metadata
115
        if hasattr(compiled, "get_return_shape"):
3✔
116
            shape = compiled.get_return_shape(*args)
3✔
117
            if shape is not None:
3✔
UNCOV
118
                try:
×
UNCOV
119
                    return np.ctypeslib.as_array(res, shape=shape)
×
120
                except Exception:
×
121
                    pass
×
122

123
        return res
3✔
124

125
    def compile(
3✔
126
        self, *args, output_folder=None, instrumentation_mode=None, capture_args=None
127
    ):
128
        original_output_folder = output_folder
3✔
129

130
        # Resolve options
131
        if instrumentation_mode is None:
3✔
132
            instrumentation_mode = self.instrumentation_mode
3✔
133
        if capture_args is None:
3✔
134
            capture_args = self.capture_args
3✔
135

136
        # Check environment variable DOCC_CI
137
        docc_ci = os.environ.get("DOCC_CI", "")
3✔
138
        if docc_ci:
3✔
139
            if docc_ci == "regions":
×
140
                if instrumentation_mode is None:
×
141
                    instrumentation_mode = "ols"
×
142
            elif docc_ci == "arg-capture":
×
143
                if capture_args is None:
×
144
                    capture_args = True
×
145
            else:
146
                # Full mode (or unknown value treated as full)
147
                if instrumentation_mode is None:
×
148
                    instrumentation_mode = "ols"
×
149
                if capture_args is None:
×
150
                    capture_args = True
×
151

152
        # Defaults
153
        if instrumentation_mode is None:
3✔
154
            instrumentation_mode = ""
3✔
155
        if capture_args is None:
3✔
156
            capture_args = False
3✔
157

158
        # 1. Analyze arguments and shapes
159
        arg_types = []
3✔
160
        shape_values = []  # List of unique shape values found
3✔
161
        shape_sources = []  # List of (arg_idx, dim_idx) for each unique shape value
3✔
162

163
        # Mapping from (arg_idx, dim_idx) -> unique_shape_idx
164
        arg_shape_mapping = {}
3✔
165

166
        for i, arg in enumerate(args):
3✔
167
            t = self._infer_type(arg)
3✔
168
            arg_types.append(t)
3✔
169

170
            if isinstance(arg, np.ndarray):
3✔
171
                for dim_idx, dim_val in enumerate(arg.shape):
3✔
172
                    # Check if we've seen this value
173
                    if dim_val in shape_values:
3✔
174
                        # Reuse
175
                        u_idx = shape_values.index(dim_val)
3✔
176
                    else:
177
                        # New
178
                        u_idx = len(shape_values)
3✔
179
                        shape_values.append(dim_val)
3✔
180
                        shape_sources.append((i, dim_idx))
3✔
181

182
                    arg_shape_mapping[(i, dim_idx)] = u_idx
3✔
183

184
        # 2. Signature
185
        mapping_sig = sorted(arg_shape_mapping.items())
3✔
186
        type_sig = ", ".join(self._type_to_str(t) for t in arg_types)
3✔
187
        signature = f"{type_sig}|{mapping_sig}"
3✔
188

189
        if output_folder is None:
3✔
190
            filename = inspect.getsourcefile(self.func)
3✔
191
            hash_input = f"{filename}|{self.name}|{self.target}|{self.category}|{self.capture_args}|{self.instrumentation_mode}|{signature}".encode(
3✔
192
                "utf-8"
193
            )
194
            stable_id = hashlib.sha256(hash_input).hexdigest()[:16]
3✔
195

196
            docc_tmp = os.environ.get("DOCC_TMP")
3✔
197
            if docc_tmp:
3✔
198
                output_folder = f"{docc_tmp}/{self.name}-{stable_id}"
×
199
            else:
200
                user = getpass.getuser()
3✔
201
                output_folder = f"/tmp/{user}/DOCC/{self.name}-{stable_id}"
3✔
202

203
        if original_output_folder is None and signature in self.cache:
3✔
204
            return self.cache[signature]
3✔
205

206
        # 3. Build SDFG
207
        if os.path.exists(output_folder):
3✔
208
            # Multiple python processes running the same code?
209
            shutil.rmtree(output_folder)
×
210
        sdfg, out_args, out_shapes = self._build_sdfg(
3✔
211
            arg_types, args, arg_shape_mapping, len(shape_values)
212
        )
213
        sdfg.expand()
3✔
214
        sdfg.simplify()
3✔
215

216
        if self.target != "none":
3✔
217
            sdfg.normalize()
3✔
218

219
        sdfg.dump(output_folder)
3✔
220

221
        # Schedule if target is specified
222
        if self.target != "none":
3✔
223
            sdfg.schedule(self.target, self.category)
3✔
224

225
        self.last_sdfg = sdfg
3✔
226

227
        lib_path = sdfg._compile(
3✔
228
            output_folder=output_folder,
229
            instrumentation_mode=instrumentation_mode,
230
            capture_args=capture_args,
231
        )
232

233
        # 5. Create CompiledSDFG
234
        compiled = CompiledSDFG(
3✔
235
            lib_path,
236
            sdfg,
237
            shape_sources,
238
            self._last_structure_member_info,
239
            out_args,
240
            out_shapes,
241
        )
242

243
        # Cache if using default output folder
244
        if original_output_folder is None:
3✔
245
            self.cache[signature] = compiled
3✔
246

247
        return compiled
3✔
248

249
    def _get_signature(self, arg_types):
3✔
250
        return ", ".join(self._type_to_str(t) for t in arg_types)
×
251

252
    def _type_to_str(self, t):
3✔
253
        if isinstance(t, Scalar):
3✔
254
            return f"Scalar({t.primitive_type})"
3✔
255
        elif isinstance(t, Array):
3✔
256
            return f"Array({self._type_to_str(t.element_type)}, {t.num_elements})"
×
257
        elif isinstance(t, Pointer):
3✔
258
            return f"Pointer({self._type_to_str(t.pointee_type)})"
3✔
259
        elif isinstance(t, Structure):
3✔
260
            return f"Structure({t.name})"
3✔
261
        return str(t)
×
262

263
    def _infer_type(self, arg):
3✔
264
        if isinstance(arg, (bool, np.bool_)):
3✔
265
            return Scalar(PrimitiveType.Bool)
3✔
266
        elif isinstance(arg, (int, np.int64)):
3✔
267
            return Scalar(PrimitiveType.Int64)
3✔
268
        elif isinstance(arg, (float, np.float64)):
3✔
269
            return Scalar(PrimitiveType.Double)
3✔
270
        elif isinstance(arg, np.int32):
3✔
271
            return Scalar(PrimitiveType.Int32)
3✔
272
        elif isinstance(arg, np.float32):
3✔
273
            return Scalar(PrimitiveType.Float)
3✔
274
        elif isinstance(arg, np.ndarray):
3✔
275
            # Map dtype
276
            if arg.dtype == np.float64:
3✔
277
                elem_type = Scalar(PrimitiveType.Double)
3✔
278
            elif arg.dtype == np.float32:
3✔
279
                elem_type = Scalar(PrimitiveType.Float)
3✔
280
            elif arg.dtype == np.int64:
3✔
281
                elem_type = Scalar(PrimitiveType.Int64)
3✔
282
            elif arg.dtype == np.int32:
3✔
283
                elem_type = Scalar(PrimitiveType.Int32)
3✔
284
            elif arg.dtype == np.bool_:
×
285
                elem_type = Scalar(PrimitiveType.Bool)
×
286
            else:
287
                raise ValueError(f"Unsupported numpy dtype: {arg.dtype}")
×
288

289
            return Pointer(elem_type)
3✔
290
        elif isinstance(arg, str):
3✔
291
            # Explicitly reject strings - they are not supported
292
            raise ValueError(f"Unsupported argument type: {type(arg)}")
3✔
293
        else:
294
            # Check if it's a class instance
295
            if hasattr(arg, "__class__") and not isinstance(arg, type):
3✔
296
                # It's an instance of a class, return pointer to Structure
297
                return Pointer(Structure(arg.__class__.__name__))
3✔
298
            raise ValueError(f"Unsupported argument type: {type(arg)}")
×
299

300
    def _build_sdfg(self, arg_types, args, arg_shape_mapping, num_unique_shapes):
3✔
301
        sig = inspect.signature(self.func)
3✔
302

303
        # Handle return type - always void for SDFG, output args used for returns
304
        return_type = Scalar(PrimitiveType.Void)
3✔
305
        infer_return_type = True
3✔
306

307
        # Parse return annotation to determine output arguments if possible
308
        explicit_returns = []
3✔
309
        if sig.return_annotation is not inspect.Signature.empty:
3✔
310
            infer_return_type = False
3✔
311

312
            # Helper to normalize annotation to list of types
313
            def normalize_annotation(ann):
3✔
314
                # Handle Tuple[type, ...]
315
                origin = get_origin(ann)
3✔
316
                if origin is tuple:
3✔
NEW
317
                    type_args = get_args(ann)
×
318
                    # Tuple[()] or Tuple w/o args
NEW
319
                    if not type_args:
×
NEW
320
                        return []
×
321
                    # Tuple[int, float]
NEW
322
                    if len(type_args) > 0 and type_args[-1] is not Ellipsis:
×
NEW
323
                        return [_map_python_type(t) for t in type_args]
×
324
                    # Tuple[int, ...] - not supported for fixed number of returns yet?
325
                    # For now assume fixed tuple
NEW
326
                    return [_map_python_type(t) for t in type_args]
×
327
                else:
328
                    return [_map_python_type(ann)]
3✔
329

330
            explicit_returns = normalize_annotation(sig.return_annotation)
3✔
331
            for rt in explicit_returns:
3✔
332
                if not isinstance(rt, Type):
3✔
333
                    # Fallback if map failed (e.g. invalid annotation)
NEW
334
                    infer_return_type = True
×
NEW
335
                    explicit_returns = []
×
NEW
336
                    break
×
337

338
        builder = StructuredSDFGBuilder(f"{self.name}_sdfg", return_type)
3✔
339

340
        # Add pre-defined return arguments if we know them
341
        if not infer_return_type:
3✔
342
            for i, dtype in enumerate(explicit_returns):
3✔
343
                # Scalar -> Pointer(Scalar)
344
                # Array -> Already Pointer(Scalar). Keep it.
345
                arg_type = dtype
3✔
346
                if isinstance(dtype, Scalar):
3✔
347
                    arg_type = Pointer(dtype)
3✔
348

349
                builder.add_container(f"_docc_ret_{i}", arg_type, is_argument=True)
3✔
350

351
        # Register structure types for any class arguments
352
        # Also track member name to index mapping for each structure
353
        structures_to_register = {}
3✔
354
        structure_member_info = {}  # Maps struct_name -> {member_name: (index, type)}
3✔
355
        for i, (arg, dtype) in enumerate(zip(args, arg_types)):
3✔
356
            if isinstance(dtype, Pointer) and dtype.has_pointee_type():
3✔
357
                pointee = dtype.pointee_type
3✔
358
                if isinstance(pointee, Structure):
3✔
359
                    struct_name = pointee.name
3✔
360
                    if struct_name not in structures_to_register:
3✔
361
                        # Get class from arg to introspect members
362
                        if hasattr(arg, "__dict__"):
3✔
363
                            # Use __dict__ to get only instance attributes
364
                            # Sort by name to ensure consistent ordering
365
                            # Note: This alphabetical ordering is used to define the
366
                            # structure layout and must match the order expected by
367
                            # the backend code generation
368
                            member_types = []
3✔
369
                            member_names = []
3✔
370
                            for attr_name, attr_value in sorted(arg.__dict__.items()):
3✔
371
                                if not attr_name.startswith("_"):
3✔
372
                                    # Infer member type from instance attribute
373
                                    # Check bool before int since bool is subclass of int
374
                                    member_type = None
3✔
375
                                    if isinstance(attr_value, bool):
3✔
376
                                        member_type = Scalar(PrimitiveType.Bool)
×
377
                                    elif isinstance(attr_value, (int, np.int64)):
3✔
378
                                        member_type = Scalar(PrimitiveType.Int64)
×
379
                                    elif isinstance(attr_value, (float, np.float64)):
3✔
380
                                        member_type = Scalar(PrimitiveType.Double)
3✔
381
                                    elif isinstance(attr_value, np.int32):
×
382
                                        member_type = Scalar(PrimitiveType.Int32)
×
383
                                    elif isinstance(attr_value, np.float32):
×
384
                                        member_type = Scalar(PrimitiveType.Float)
×
385
                                    # TODO: Consider using np.integer and np.floating abstract types
386
                                    # for more comprehensive numpy type coverage
387
                                    # TODO: Add support for nested structures and arrays
388

389
                                    if member_type is not None:
3✔
390
                                        member_types.append(member_type)
3✔
391
                                        member_names.append(attr_name)
3✔
392

393
                            if member_types:
3✔
394
                                structures_to_register[struct_name] = member_types
3✔
395
                                # Build member name to (index, type) mapping
396
                                structure_member_info[struct_name] = {
3✔
397
                                    name: (idx, mtype)
398
                                    for idx, (name, mtype) in enumerate(
399
                                        zip(member_names, member_types)
400
                                    )
401
                                }
402

403
        # Store structure_member_info for later use in CompiledSDFG
404
        self._last_structure_member_info = structure_member_info
3✔
405

406
        # Register all discovered structures with the builder
407
        for struct_name, member_types in structures_to_register.items():
3✔
408
            builder.add_structure(struct_name, member_types)
3✔
409

410
        # Register arguments
411
        params = list(sig.parameters.items())
3✔
412
        if len(params) != len(arg_types):
3✔
413
            raise ValueError(
×
414
                f"Argument count mismatch: expected {len(params)}, got {len(arg_types)}"
415
            )
416

417
        array_info = {}
3✔
418

419
        # Add regular arguments
420
        for i, ((name, param), dtype, arg) in enumerate(zip(params, arg_types, args)):
3✔
421
            builder.add_container(name, dtype, is_argument=True)
3✔
422

423
            # If it's an array, prepare shape info
424
            if isinstance(arg, np.ndarray):
3✔
425
                shapes = []
3✔
426
                for dim_idx in range(arg.ndim):
3✔
427
                    u_idx = arg_shape_mapping[(i, dim_idx)]
3✔
428
                    shapes.append(f"_s{u_idx}")
3✔
429

430
                array_info[name] = {"ndim": arg.ndim, "shapes": shapes}
3✔
431

432
        # Add unified shape arguments
433
        for i in range(num_unique_shapes):
3✔
434
            builder.add_container(
3✔
435
                f"_s{i}", Scalar(PrimitiveType.Int64), is_argument=True
436
            )
437

438
        # Create symbol table for parser
439
        symbol_table = {}
3✔
440
        for i, ((name, param), dtype, arg) in enumerate(zip(params, arg_types, args)):
3✔
441
            symbol_table[name] = dtype
3✔
442

443
        for i in range(num_unique_shapes):
3✔
444
            symbol_table[f"_s{i}"] = Scalar(PrimitiveType.Int64)
3✔
445

446
        # Parse AST
447
        source_lines, start_line = inspect.getsourcelines(self.func)
3✔
448
        source = textwrap.dedent("".join(source_lines))
3✔
449
        tree = ast.parse(source)
3✔
450
        ast.increment_lineno(tree, start_line - 1)
3✔
451
        func_def = tree.body[0]
3✔
452

453
        filename = inspect.getsourcefile(self.func)
3✔
454
        function_name = self.func.__name__
3✔
455

456
        parser = ASTParser(
3✔
457
            builder,
458
            array_info,
459
            symbol_table,
460
            filename,
461
            function_name,
462
            infer_return_type=infer_return_type,
463
            globals_dict=self.func.__globals__,
464
            structure_member_info=structure_member_info,
465
        )
466
        for node in func_def.body:
3✔
467
            parser.visit(node)
3✔
468

469
        sdfg = builder.move()
3✔
470
        # Mark return arguments metadata
471
        out_args = []
3✔
472
        for name in sdfg.arguments:
3✔
473
            if name.startswith("_docc_ret_"):
3✔
474
                out_args.append(name)
3✔
475

476
        return sdfg, out_args, parser.captured_return_shapes
3✔
477

478

479
def program(
3✔
480
    func=None,
481
    *,
482
    target="none",
483
    category="desktop",
484
    instrumentation_mode=None,
485
    capture_args=None,
486
):
487
    if func is None:
3✔
488
        return lambda f: DoccProgram(
3✔
489
            f,
490
            target=target,
491
            category=category,
492
            instrumentation_mode=instrumentation_mode,
493
            capture_args=capture_args,
494
        )
495
    return DoccProgram(
3✔
496
        func,
497
        target=target,
498
        category=category,
499
        instrumentation_mode=instrumentation_mode,
500
        capture_args=capture_args,
501
    )
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc