• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

daisytuner / docc / 21533102315

30 Jan 2026 10:38PM UTC coverage: 66.576% (+0.007%) from 66.569%
21533102315

Pull #478

github

web-flow
Merge 20c453b3e into 77627c047
Pull Request #478: add initial workflow npbench

16 of 16 new or added lines in 3 files covered. (100.0%)

14 existing lines in 1 file now uncovered.

23058 of 34634 relevant lines covered (66.58%)

381.94 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

85.08
/python/docc/__init__.py
1
import inspect
4✔
2
import shutil
4✔
3
import textwrap
4✔
4
import ast
4✔
5
import os
4✔
6
import getpass
4✔
7
import hashlib
4✔
8
import numpy as np
4✔
9
from typing import Annotated, get_origin, get_args
4✔
10
from ._sdfg import *
4✔
11
from .compiled_sdfg import CompiledSDFG
4✔
12
from .ast_parser import ASTParser
4✔
13

14

15
def _compile_wrapper(self, output_folder=None):
4✔
16
    lib_path = self._compile(output_folder)
×
17
    return CompiledSDFG(lib_path, self)
×
18

19

20
StructuredSDFG.compile = _compile_wrapper
4✔
21

22
# Global RPC context for scheduling SDFGs
23
sdfg_rpc_context = None
4✔
24

25

26
def _map_python_type(dtype):
4✔
27
    # If it is already a sdfg Type, return it
28
    if isinstance(dtype, Type):
4✔
29
        return dtype
×
30

31
    # Handle Annotated for Arrays
32
    if get_origin(dtype) is Annotated:
4✔
33
        args = get_args(dtype)
4✔
34
        base_type = args[0]
4✔
35
        metadata = args[1:]
4✔
36

37
        if base_type is np.ndarray:
4✔
38
            # Convention: Annotated[np.ndarray, shape, dtype]
39
            shape = metadata[0]
4✔
40
            elem_type = Scalar(PrimitiveType.Double)  # Default
4✔
41

42
            if len(metadata) > 1:
4✔
43
                possible_dtype = metadata[1]
4✔
44
                elem_type = _map_python_type(possible_dtype)
4✔
45

46
            return Pointer(elem_type)
4✔
47

48
    # Handle numpy.ndarray[Shape, DType]
49
    if get_origin(dtype) is np.ndarray:
4✔
50
        args = get_args(dtype)
×
51
        # args[0] is shape, args[1] is dtype
52
        if len(args) >= 2:
×
53
            elem_type = _map_python_type(args[1])
×
54
            return Pointer(elem_type)
×
55

56
    # Simple mapping for python types
57
    if dtype is float or dtype is np.float64:
4✔
58
        return Scalar(PrimitiveType.Double)
4✔
59
    elif dtype is int or dtype is np.int64:
4✔
60
        return Scalar(PrimitiveType.Int64)
4✔
61
    elif dtype is bool or dtype is np.bool_:
4✔
62
        return Scalar(PrimitiveType.Bool)
4✔
63
    elif dtype is np.float32:
4✔
64
        return Scalar(PrimitiveType.Float)
×
65
    elif dtype is np.int32:
4✔
66
        return Scalar(PrimitiveType.Int32)
4✔
67

68
    # Handle Python classes - map to Structure type
69
    if inspect.isclass(dtype):
4✔
70
        # Use the class name as the structure name
71
        return Pointer(Structure(dtype.__name__))
4✔
72

73
    return dtype
×
74

75

76
class DoccProgram:
4✔
77
    def __init__(
4✔
78
        self,
79
        func,
80
        target="none",
81
        category="server",
82
        instrumentation_mode=None,
83
        capture_args=None,
84
    ):
85
        self.func = func
4✔
86
        self.name = func.__name__
4✔
87
        self.target = target
4✔
88
        self.category = category
4✔
89
        self.instrumentation_mode = instrumentation_mode
4✔
90
        self.capture_args = capture_args
4✔
91
        self.last_sdfg = None
4✔
92
        self.cache = {}
4✔
93
        print(
4✔
94
            f"Created DoccProgram with category '{self.category}' and name '{self.name}'"
95
        )
96

97
    def __call__(self, *args):
4✔
98
        # JIT compile and run
99
        compiled = self.compile(*args)
4✔
100
        res = compiled(*args)
4✔
101

102
        # Handle return value conversion based on annotation
103
        sig = inspect.signature(self.func)
4✔
104
        ret_annotation = sig.return_annotation
4✔
105

106
        if ret_annotation is not inspect.Signature.empty:
4✔
107
            if get_origin(ret_annotation) is Annotated:
4✔
108
                type_args = get_args(ret_annotation)
4✔
109
                if len(type_args) >= 1 and type_args[0] is np.ndarray:
4✔
110
                    shape = None
4✔
111
                    if len(type_args) >= 2:
4✔
112
                        shape = type_args[1]
4✔
113

114
                    if shape is not None:
4✔
115
                        try:
4✔
116
                            return np.ctypeslib.as_array(res, shape=shape)
4✔
117
                        except Exception:
×
118
                            pass
×
119

120
        # Try to infer return shape from metadata
121
        if hasattr(compiled, "get_return_shape"):
4✔
122
            shape = compiled.get_return_shape(*args)
4✔
123
            if shape is not None:
4✔
124
                try:
×
125
                    return np.ctypeslib.as_array(res, shape=shape)
×
126
                except Exception:
×
127
                    pass
×
128

129
        return res
4✔
130

131
    def compile(
4✔
132
        self, *args, output_folder=None, instrumentation_mode=None, capture_args=None
133
    ):
134
        original_output_folder = output_folder
4✔
135

136
        # Resolve options
137
        if instrumentation_mode is None:
4✔
138
            instrumentation_mode = self.instrumentation_mode
4✔
139
        if capture_args is None:
4✔
140
            capture_args = self.capture_args
4✔
141

142
        # Check environment variable DOCC_CI
143
        docc_ci = os.environ.get("DOCC_CI", "")
4✔
144
        if docc_ci:
4✔
145
            if docc_ci == "regions":
×
146
                if instrumentation_mode is None:
×
147
                    instrumentation_mode = "ols"
×
148
            elif docc_ci == "arg-capture":
×
149
                if capture_args is None:
×
150
                    capture_args = True
×
151
            else:
152
                # Full mode (or unknown value treated as full)
153
                if instrumentation_mode is None:
×
154
                    instrumentation_mode = "ols"
×
155
                if capture_args is None:
×
156
                    capture_args = True
×
157

158
        # Defaults
159
        if instrumentation_mode is None:
4✔
160
            instrumentation_mode = ""
4✔
161
        if capture_args is None:
4✔
162
            capture_args = False
4✔
163

164
        # 1. Analyze arguments and shapes
165
        arg_types = []
4✔
166
        shape_values = []  # List of unique shape values found
4✔
167
        shape_sources = []  # List of (arg_idx, dim_idx) for each unique shape value
4✔
168

169
        # Mapping from (arg_idx, dim_idx) -> unique_shape_idx
170
        arg_shape_mapping = {}
4✔
171

172
        # First pass: collect scalar integer arguments and their values
173
        sig = inspect.signature(self.func)
4✔
174
        params = list(sig.parameters.items())
4✔
175
        scalar_int_params = {}  # Maps value -> parameter name (first one wins)
4✔
176
        for i, ((name, param), arg) in enumerate(zip(params, args)):
4✔
177
            if isinstance(arg, (int, np.integer)) and not isinstance(
4✔
178
                arg, (bool, np.bool_)
179
            ):
180
                val = int(arg)
4✔
181
                if val not in scalar_int_params:
4✔
182
                    scalar_int_params[val] = name
4✔
183

184
        for i, arg in enumerate(args):
4✔
185
            t = self._infer_type(arg)
4✔
186
            arg_types.append(t)
4✔
187

188
            if isinstance(arg, np.ndarray):
4✔
189
                for dim_idx, dim_val in enumerate(arg.shape):
4✔
190
                    # Check if we've seen this value
191
                    if dim_val in shape_values:
4✔
192
                        # Reuse
193
                        u_idx = shape_values.index(dim_val)
4✔
194
                    else:
195
                        # New
196
                        u_idx = len(shape_values)
4✔
197
                        shape_values.append(dim_val)
4✔
198
                        shape_sources.append((i, dim_idx))
4✔
199

200
                    arg_shape_mapping[(i, dim_idx)] = u_idx
4✔
201

202
        # Detect scalar-shape equivalences: which shape indices have a matching scalar param
203
        # Maps unique_shape_idx -> scalar parameter name
204
        shape_to_scalar = {}
4✔
205
        for s_idx, s_val in enumerate(shape_values):
4✔
206
            if s_val in scalar_int_params:
4✔
207
                shape_to_scalar[s_idx] = scalar_int_params[s_val]
4✔
208

209
        # 2. Signature - include scalar-shape equivalences for correct caching
210
        mapping_sig = sorted(arg_shape_mapping.items())
4✔
211
        equiv_sig = sorted(shape_to_scalar.items())
4✔
212
        type_sig = ", ".join(self._type_to_str(t) for t in arg_types)
4✔
213
        signature = f"{type_sig}|{mapping_sig}|{equiv_sig}"
4✔
214

215
        if output_folder is None:
4✔
216
            filename = inspect.getsourcefile(self.func)
4✔
217
            hash_input = f"{filename}|{self.name}|{self.target}|{self.category}|{self.capture_args}|{self.instrumentation_mode}|{signature}".encode(
4✔
218
                "utf-8"
219
            )
220
            stable_id = hashlib.sha256(hash_input).hexdigest()[:16]
4✔
221

222
            docc_tmp = os.environ.get("DOCC_TMP")
4✔
223
            if docc_tmp:
4✔
224
                output_folder = f"{docc_tmp}/{self.name}-{stable_id}"
×
225
            else:
226
                user = getpass.getuser()
4✔
227
                output_folder = f"/tmp/{user}/DOCC/{self.name}-{stable_id}"
4✔
228

229
        if original_output_folder is None and signature in self.cache:
4✔
230
            return self.cache[signature]
4✔
231

232
        # 3. Build SDFG
233
        if os.path.exists(output_folder):
4✔
234
            # Multiple python processes running the same code?
235
            shutil.rmtree(output_folder)
4✔
236
        sdfg, out_args, out_shapes = self._build_sdfg(
4✔
237
            arg_types, args, arg_shape_mapping, len(shape_values), shape_to_scalar
238
        )
239
        sdfg.validate()
4✔
240
        sdfg.expand()
4✔
241
        sdfg.simplify()
4✔
242

243
        if self.target != "none":
4✔
244
            sdfg.normalize()
4✔
245

246
        sdfg.dump(output_folder)
4✔
247

248
        # Schedule if target is specified
249
        if self.target != "none":
4✔
250
            sdfg.schedule(self.target, self.category, sdfg_rpc_context)
4✔
251

252
        self.last_sdfg = sdfg
4✔
253

254
        lib_path = sdfg._compile(
4✔
255
            output_folder=output_folder,
256
            target=self.target,
257
            instrumentation_mode=instrumentation_mode,
258
            capture_args=capture_args,
259
        )
260

261
        # 5. Create CompiledSDFG
262
        compiled = CompiledSDFG(
4✔
263
            lib_path,
264
            sdfg,
265
            shape_sources,
266
            self._last_structure_member_info,
267
            out_args,
268
            out_shapes,
269
        )
270

271
        # Cache if using default output folder
272
        if original_output_folder is None:
4✔
273
            self.cache[signature] = compiled
4✔
274

275
        return compiled
4✔
276

277
    def _get_signature(self, arg_types):
4✔
UNCOV
278
        return ", ".join(self._type_to_str(t) for t in arg_types)
×
279

280
    def _type_to_str(self, t):
4✔
281
        if isinstance(t, Scalar):
4✔
282
            return f"Scalar({t.primitive_type})"
4✔
283
        elif isinstance(t, Array):
4✔
UNCOV
284
            return f"Array({self._type_to_str(t.element_type)}, {t.num_elements})"
×
285
        elif isinstance(t, Pointer):
4✔
286
            return f"Pointer({self._type_to_str(t.pointee_type)})"
4✔
287
        elif isinstance(t, Structure):
4✔
288
            return f"Structure({t.name})"
4✔
UNCOV
289
        return str(t)
×
290

291
    def _infer_type(self, arg):
4✔
292
        if isinstance(arg, (bool, np.bool_)):
4✔
293
            return Scalar(PrimitiveType.Bool)
4✔
294
        elif isinstance(arg, (int, np.int64)):
4✔
295
            return Scalar(PrimitiveType.Int64)
4✔
296
        elif isinstance(arg, (float, np.float64)):
4✔
297
            return Scalar(PrimitiveType.Double)
4✔
298
        elif isinstance(arg, np.int32):
4✔
299
            return Scalar(PrimitiveType.Int32)
4✔
300
        elif isinstance(arg, np.float32):
4✔
301
            return Scalar(PrimitiveType.Float)
4✔
302
        elif isinstance(arg, np.ndarray):
4✔
303
            # Map dtype
304
            if arg.dtype == np.float64:
4✔
305
                elem_type = Scalar(PrimitiveType.Double)
4✔
306
            elif arg.dtype == np.float32:
4✔
307
                elem_type = Scalar(PrimitiveType.Float)
4✔
308
            elif arg.dtype == np.int64:
4✔
309
                elem_type = Scalar(PrimitiveType.Int64)
4✔
310
            elif arg.dtype == np.int32:
4✔
311
                elem_type = Scalar(PrimitiveType.Int32)
4✔
312
            elif arg.dtype == np.bool_:
4✔
313
                elem_type = Scalar(PrimitiveType.Bool)
4✔
314
            else:
UNCOV
315
                raise ValueError(f"Unsupported numpy dtype: {arg.dtype}")
×
316

317
            return Pointer(elem_type)
4✔
318
        elif isinstance(arg, str):
4✔
319
            # Explicitly reject strings - they are not supported
320
            raise ValueError(f"Unsupported argument type: {type(arg)}")
4✔
321
        else:
322
            # Check if it's a class instance
323
            if hasattr(arg, "__class__") and not isinstance(arg, type):
4✔
324
                # It's an instance of a class, return pointer to Structure
325
                return Pointer(Structure(arg.__class__.__name__))
4✔
UNCOV
326
            raise ValueError(f"Unsupported argument type: {type(arg)}")
×
327

328
    def _build_sdfg(
4✔
329
        self,
330
        arg_types,
331
        args,
332
        arg_shape_mapping,
333
        num_unique_shapes,
334
        shape_to_scalar=None,
335
    ):
336
        if shape_to_scalar is None:
4✔
337
            shape_to_scalar = {}
4✔
338
        sig = inspect.signature(self.func)
4✔
339

340
        # Handle return type - always void for SDFG, output args used for returns
341
        return_type = Scalar(PrimitiveType.Void)
4✔
342
        infer_return_type = True
4✔
343

344
        # Parse return annotation to determine output arguments if possible
345
        explicit_returns = []
4✔
346
        if sig.return_annotation is not inspect.Signature.empty:
4✔
347
            infer_return_type = False
4✔
348

349
            # Helper to normalize annotation to list of types
350
            def normalize_annotation(ann):
4✔
351
                # Handle Tuple[type, ...]
352
                origin = get_origin(ann)
4✔
353
                if origin is tuple:
4✔
UNCOV
354
                    type_args = get_args(ann)
×
355
                    # Tuple[()] or Tuple w/o args
356
                    if not type_args:
×
UNCOV
357
                        return []
×
358
                    # Tuple[int, float]
359
                    if len(type_args) > 0 and type_args[-1] is not Ellipsis:
×
UNCOV
360
                        return [_map_python_type(t) for t in type_args]
×
361
                    # Tuple[int, ...] - not supported for fixed number of returns yet?
362
                    # For now assume fixed tuple
UNCOV
363
                    return [_map_python_type(t) for t in type_args]
×
364
                else:
365
                    return [_map_python_type(ann)]
4✔
366

367
            explicit_returns = normalize_annotation(sig.return_annotation)
4✔
368
            for rt in explicit_returns:
4✔
369
                if not isinstance(rt, Type):
4✔
370
                    # Fallback if map failed (e.g. invalid annotation)
371
                    infer_return_type = True
×
372
                    explicit_returns = []
×
UNCOV
373
                    break
×
374

375
        builder = StructuredSDFGBuilder(f"{self.name}_sdfg", return_type)
4✔
376

377
        # Add pre-defined return arguments if we know them
378
        if not infer_return_type:
4✔
379
            for i, dtype in enumerate(explicit_returns):
4✔
380
                # Scalar -> Pointer(Scalar)
381
                # Array -> Already Pointer(Scalar). Keep it.
382
                arg_type = dtype
4✔
383
                if isinstance(dtype, Scalar):
4✔
384
                    arg_type = Pointer(dtype)
4✔
385

386
                builder.add_container(f"_docc_ret_{i}", arg_type, is_argument=True)
4✔
387

388
        # Register structure types for any class arguments
389
        # Also track member name to index mapping for each structure
390
        structures_to_register = {}
4✔
391
        structure_member_info = {}  # Maps struct_name -> {member_name: (index, type)}
4✔
392
        for i, (arg, dtype) in enumerate(zip(args, arg_types)):
4✔
393
            if isinstance(dtype, Pointer) and dtype.has_pointee_type():
4✔
394
                pointee = dtype.pointee_type
4✔
395
                if isinstance(pointee, Structure):
4✔
396
                    struct_name = pointee.name
4✔
397
                    if struct_name not in structures_to_register:
4✔
398
                        # Get class from arg to introspect members
399
                        if hasattr(arg, "__dict__"):
4✔
400
                            # Use __dict__ to get only instance attributes
401
                            # Sort by name to ensure consistent ordering
402
                            # Note: This alphabetical ordering is used to define the
403
                            # structure layout and must match the order expected by
404
                            # the backend code generation
405
                            member_types = []
4✔
406
                            member_names = []
4✔
407
                            for attr_name, attr_value in sorted(arg.__dict__.items()):
4✔
408
                                if not attr_name.startswith("_"):
4✔
409
                                    # Infer member type from instance attribute
410
                                    # Check bool before int since bool is subclass of int
411
                                    member_type = None
4✔
412
                                    if isinstance(attr_value, bool):
4✔
UNCOV
413
                                        member_type = Scalar(PrimitiveType.Bool)
×
414
                                    elif isinstance(attr_value, (int, np.int64)):
4✔
UNCOV
415
                                        member_type = Scalar(PrimitiveType.Int64)
×
416
                                    elif isinstance(attr_value, (float, np.float64)):
4✔
417
                                        member_type = Scalar(PrimitiveType.Double)
4✔
418
                                    elif isinstance(attr_value, np.int32):
×
419
                                        member_type = Scalar(PrimitiveType.Int32)
×
420
                                    elif isinstance(attr_value, np.float32):
×
UNCOV
421
                                        member_type = Scalar(PrimitiveType.Float)
×
422
                                    # TODO: Consider using np.integer and np.floating abstract types
423
                                    # for more comprehensive numpy type coverage
424
                                    # TODO: Add support for nested structures and arrays
425

426
                                    if member_type is not None:
4✔
427
                                        member_types.append(member_type)
4✔
428
                                        member_names.append(attr_name)
4✔
429

430
                            if member_types:
4✔
431
                                structures_to_register[struct_name] = member_types
4✔
432
                                # Build member name to (index, type) mapping
433
                                structure_member_info[struct_name] = {
4✔
434
                                    name: (idx, mtype)
435
                                    for idx, (name, mtype) in enumerate(
436
                                        zip(member_names, member_types)
437
                                    )
438
                                }
439

440
        # Store structure_member_info for later use in CompiledSDFG
441
        self._last_structure_member_info = structure_member_info
4✔
442

443
        # Register all discovered structures with the builder
444
        for struct_name, member_types in structures_to_register.items():
4✔
445
            builder.add_structure(struct_name, member_types)
4✔
446

447
        # Register arguments
448
        params = list(sig.parameters.items())
4✔
449
        if len(params) != len(arg_types):
4✔
UNCOV
450
            raise ValueError(
×
451
                f"Argument count mismatch: expected {len(params)}, got {len(arg_types)}"
452
            )
453

454
        array_info = {}
4✔
455

456
        # Add regular arguments
457
        for i, ((name, param), dtype, arg) in enumerate(zip(params, arg_types, args)):
4✔
458
            builder.add_container(name, dtype, is_argument=True)
4✔
459

460
            # If it's an array, prepare shape info
461
            if isinstance(arg, np.ndarray):
4✔
462
                shapes = []
4✔
463
                for dim_idx in range(arg.ndim):
4✔
464
                    u_idx = arg_shape_mapping[(i, dim_idx)]
4✔
465
                    # Use scalar parameter name if there's an equivalence, otherwise _sX
466
                    if u_idx in shape_to_scalar:
4✔
467
                        shapes.append(shape_to_scalar[u_idx])
4✔
468
                    else:
469
                        shapes.append(f"_s{u_idx}")
4✔
470

471
                array_info[name] = {"ndim": arg.ndim, "shapes": shapes}
4✔
472

473
        # Add unified shape arguments only for shapes without scalar equivalents
474
        for i in range(num_unique_shapes):
4✔
475
            if i not in shape_to_scalar:
4✔
476
                builder.add_container(
4✔
477
                    f"_s{i}", Scalar(PrimitiveType.Int64), is_argument=True
478
                )
479

480
        # Create symbol table for parser
481
        symbol_table = {}
4✔
482
        for i, ((name, param), dtype, arg) in enumerate(zip(params, arg_types, args)):
4✔
483
            symbol_table[name] = dtype
4✔
484

485
        for i in range(num_unique_shapes):
4✔
486
            if i not in shape_to_scalar:
4✔
487
                symbol_table[f"_s{i}"] = Scalar(PrimitiveType.Int64)
4✔
488

489
        # Parse AST
490
        source_lines, start_line = inspect.getsourcelines(self.func)
4✔
491
        source = textwrap.dedent("".join(source_lines))
4✔
492
        tree = ast.parse(source)
4✔
493
        ast.increment_lineno(tree, start_line - 1)
4✔
494
        func_def = tree.body[0]
4✔
495

496
        filename = inspect.getsourcefile(self.func)
4✔
497
        function_name = self.func.__name__
4✔
498

499
        parser = ASTParser(
4✔
500
            builder,
501
            array_info,
502
            symbol_table,
503
            filename,
504
            function_name,
505
            infer_return_type=infer_return_type,
506
            globals_dict=self.func.__globals__,
507
            structure_member_info=structure_member_info,
508
        )
509
        for node in func_def.body:
4✔
510
            parser.visit(node)
4✔
511

512
        sdfg = builder.move()
4✔
513
        # Mark return arguments metadata
514
        out_args = []
4✔
515
        for name in sdfg.arguments:
4✔
516
            if name.startswith("_docc_ret_"):
4✔
517
                out_args.append(name)
4✔
518

519
        return sdfg, out_args, parser.captured_return_shapes
4✔
520

521

522
def program(
4✔
523
    func=None,
524
    *,
525
    target="none",
526
    category="desktop",
527
    instrumentation_mode=None,
528
    capture_args=None,
529
):
530
    if func is None:
4✔
531
        return lambda f: DoccProgram(
4✔
532
            f,
533
            target=target,
534
            category=category,
535
            instrumentation_mode=instrumentation_mode,
536
            capture_args=capture_args,
537
        )
538
    return DoccProgram(
4✔
539
        func,
540
        target=target,
541
        category=category,
542
        instrumentation_mode=instrumentation_mode,
543
        capture_args=capture_args,
544
    )
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc