• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

daisytuner / docc / 22023884668

14 Feb 2026 08:36PM UTC coverage: 64.903% (-1.4%) from 66.315%
22023884668

Pull #525

github

web-flow
Merge 1d47f8bf2 into 9d01cacd5
Pull Request #525: Step 3 (Native Tensor Support): Refactor Python Frontend

2522 of 3435 new or added lines in 32 files covered. (73.42%)

320 existing lines in 15 files now uncovered.

23204 of 35752 relevant lines covered (64.9%)

370.03 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

76.84
/python/docc/python/types/__init__.py
1
import ast
4✔
2
import inspect
4✔
3
import numpy as np
4✔
4

5
from typing import get_origin, get_args
4✔
6
from docc.sdfg import (
4✔
7
    PrimitiveType,
8
    Scalar,
9
    Pointer,
10
    Array,
11
    Structure,
12
    Tensor,
13
    Type,
14
)
15

16

17
def sdfg_type_from_type(python_type):
4✔
18
    if isinstance(python_type, Type):
4✔
NEW
19
        return python_type
×
20

21
    # Handle numpy.ndarray[Shape, python_type] type annotations
22
    if get_origin(python_type) is np.ndarray:
4✔
23
        args = get_args(python_type)
4✔
24
        if len(args) >= 2:
4✔
25
            elem_type = sdfg_type_from_type(args[1])
4✔
26
            return Pointer(elem_type)
4✔
27
        # Unparameterized ndarray defaults to void pointer
NEW
28
        return Pointer(Scalar(PrimitiveType.Void))
×
29

30
    # Handle np.dtype[ScalarType] annotations
31
    if get_origin(python_type) is np.dtype:
4✔
32
        return sdfg_type_from_type(get_args(python_type)[0])
4✔
33

34
    if python_type is float or python_type is np.float64:
4✔
35
        return Scalar(PrimitiveType.Double)
4✔
36
    elif python_type is np.float32:
4✔
37
        return Scalar(PrimitiveType.Float)
4✔
38
    elif python_type is bool or python_type is np.bool_:
4✔
39
        return Scalar(PrimitiveType.Bool)
4✔
40
    elif python_type is int or python_type is np.int64:
4✔
41
        return Scalar(PrimitiveType.Int64)
4✔
42
    elif python_type is np.int32:
4✔
43
        return Scalar(PrimitiveType.Int32)
4✔
44
    elif python_type is np.int16:
4✔
45
        return Scalar(PrimitiveType.Int16)
4✔
46
    elif python_type is np.int8:
4✔
47
        return Scalar(PrimitiveType.Int8)
4✔
48
    elif python_type is np.uint64:
4✔
49
        return Scalar(PrimitiveType.UInt64)
4✔
50
    elif python_type is np.uint32:
4✔
51
        return Scalar(PrimitiveType.UInt32)
4✔
52
    elif python_type is np.uint16:
4✔
53
        return Scalar(PrimitiveType.UInt16)
4✔
54
    elif python_type is np.uint8:
4✔
55
        return Scalar(PrimitiveType.UInt8)
4✔
56

57
    # Handle Python classes - map to Structure type
58
    if inspect.isclass(python_type):
4✔
59
        return Pointer(Structure(python_type.__name__))
4✔
60

NEW
61
    raise ValueError(f"Cannot map type to SDFG type: {python_type}")
×
62

63

64
def element_type_from_sdfg_type(sdfg_type: Type):
4✔
65
    if isinstance(sdfg_type, Scalar):
4✔
66
        return sdfg_type
4✔
67
    elif isinstance(sdfg_type, (Pointer, Array, Tensor)):
4✔
68
        return Scalar(sdfg_type.primitive_type)
4✔
69
    else:
NEW
70
        raise ValueError(
×
71
            f"Unsupported SDFG type for element type extraction: {sdfg_type}"
72
        )
73

74

75
def element_type_from_ast_node(ast_node, symbol_table=None):
4✔
76
    # Default to double
77
    if ast_node is None:
4✔
NEW
78
        return Scalar(PrimitiveType.Double)
×
79

80
    # Handle python built-in types
81
    if isinstance(ast_node, ast.Name):
4✔
82
        if ast_node.id == "float":
4✔
83
            return Scalar(PrimitiveType.Double)
4✔
84
        if ast_node.id == "int":
4✔
85
            return Scalar(PrimitiveType.Int64)
4✔
NEW
86
        if ast_node.id == "bool":
×
NEW
87
            return Scalar(PrimitiveType.Bool)
×
88

89
    # Handle complex types
90
    if isinstance(ast_node, ast.Attribute):
4✔
91
        # Handle numpy types like np.float64, np.int32, etc.
92
        if isinstance(ast_node.value, ast.Name) and ast_node.value.id in [
4✔
93
            "numpy",
94
            "np",
95
        ]:
96
            if ast_node.attr == "float64":
4✔
97
                return Scalar(PrimitiveType.Double)
4✔
98
            if ast_node.attr == "float32":
4✔
99
                return Scalar(PrimitiveType.Float)
4✔
100
            if ast_node.attr == "int64":
4✔
101
                return Scalar(PrimitiveType.Int64)
4✔
102
            if ast_node.attr == "int32":
4✔
103
                return Scalar(PrimitiveType.Int32)
4✔
NEW
104
            if ast_node.attr == "int16":
×
NEW
105
                return Scalar(PrimitiveType.Int16)
×
NEW
106
            if ast_node.attr == "int8":
×
NEW
107
                return Scalar(PrimitiveType.Int8)
×
NEW
108
            if ast_node.attr == "uint64":
×
NEW
109
                return Scalar(PrimitiveType.UInt64)
×
NEW
110
            if ast_node.attr == "uint32":
×
NEW
111
                return Scalar(PrimitiveType.UInt32)
×
NEW
112
            if ast_node.attr == "uint16":
×
NEW
113
                return Scalar(PrimitiveType.UInt16)
×
NEW
114
            if ast_node.attr == "uint8":
×
NEW
115
                return Scalar(PrimitiveType.UInt8)
×
NEW
116
            if ast_node.attr == "bool_":
×
NEW
117
                return Scalar(PrimitiveType.Bool)
×
118

119
        # Handle arr.dtype - get element type from array's type in symbol table
120
        if ast_node.attr == "dtype" and symbol_table is not None:
4✔
121
            if isinstance(ast_node.value, ast.Name):
4✔
122
                var_name = ast_node.value.id
4✔
123
                if var_name in symbol_table:
4✔
124
                    var_type = symbol_table[var_name]
4✔
125
                    return element_type_from_sdfg_type(var_type)
4✔
126

NEW
127
    raise ValueError(f"Cannot map AST node to SDFG type: {ast.dump(ast_node)}")
×
128

129

130
def promote_element_types(left_element_type, right_element_type):
4✔
131
    """Promote two dtypes following NumPy rules: float > int, wider > narrower."""
132
    priority = {
4✔
133
        PrimitiveType.Double: 4,
134
        PrimitiveType.Float: 3,
135
        PrimitiveType.Int64: 2,
136
        PrimitiveType.Int32: 1,
137
    }
138
    left_prio = priority.get(left_element_type.primitive_type, 0)
4✔
139
    right_prio = priority.get(right_element_type.primitive_type, 0)
4✔
140
    if left_prio >= right_prio:
4✔
141
        return left_element_type
4✔
142
    else:
143
        return right_element_type
4✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc