• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

daisytuner / docc / 23514597273

24 Mar 2026 10:08PM UTC coverage: 64.344% (+0.05%) from 64.295%
23514597273

Pull #611

github

web-flow
Merge 89f7b18f0 into e56781552
Pull Request #611: Updates rules to handle casts between numpy arrays and scalars

79 of 85 new or added lines in 3 files covered. (92.94%)

2 existing lines in 1 file now uncovered.

26715 of 41519 relevant lines covered (64.34%)

405.72 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

78.57
/python/docc/python/types/__init__.py
1
import ast
4✔
2
import inspect
4✔
3
import numpy as np
4✔
4

5
from typing import get_origin, get_args
4✔
6
from docc.sdfg import (
4✔
7
    PrimitiveType,
8
    Scalar,
9
    Pointer,
10
    Array,
11
    Structure,
12
    Tensor,
13
    Type,
14
)
15

16

17
def sdfg_type_from_type(python_type):
4✔
18
    if isinstance(python_type, Type):
4✔
19
        return python_type
×
20

21
    # Handle numpy.ndarray[Shape, python_type] type annotations
22
    if get_origin(python_type) is np.ndarray:
4✔
23
        args = get_args(python_type)
4✔
24
        if len(args) >= 2:
4✔
25
            elem_type = sdfg_type_from_type(args[1])
4✔
26
            return Pointer(elem_type)
4✔
27
        # Unparameterized ndarray defaults to void pointer
28
        return Pointer(Scalar(PrimitiveType.Void))
×
29

30
    # Handle np.dtype[ScalarType] annotations
31
    if get_origin(python_type) is np.dtype:
4✔
32
        return sdfg_type_from_type(get_args(python_type)[0])
4✔
33

34
    if python_type is float or python_type is np.float64:
4✔
35
        return Scalar(PrimitiveType.Double)
4✔
36
    elif python_type is np.float32:
4✔
37
        return Scalar(PrimitiveType.Float)
4✔
38
    elif python_type is bool or python_type is np.bool_:
4✔
39
        return Scalar(PrimitiveType.Bool)
4✔
40
    elif python_type is int or python_type is np.int64:
4✔
41
        return Scalar(PrimitiveType.Int64)
4✔
42
    elif python_type is np.int32:
4✔
43
        return Scalar(PrimitiveType.Int32)
4✔
44
    elif python_type is np.int16:
4✔
45
        return Scalar(PrimitiveType.Int16)
4✔
46
    elif python_type is np.int8:
4✔
47
        return Scalar(PrimitiveType.Int8)
4✔
48
    elif python_type is np.uint64:
4✔
49
        return Scalar(PrimitiveType.UInt64)
4✔
50
    elif python_type is np.uint32:
4✔
51
        return Scalar(PrimitiveType.UInt32)
4✔
52
    elif python_type is np.uint16:
4✔
53
        return Scalar(PrimitiveType.UInt16)
4✔
54
    elif python_type is np.uint8:
4✔
55
        return Scalar(PrimitiveType.UInt8)
4✔
56

57
    # Handle Python classes - map to Structure type
58
    if inspect.isclass(python_type):
4✔
59
        return Pointer(Structure(python_type.__name__))
4✔
60

61
    raise ValueError(f"Cannot map type to SDFG type: {python_type}")
×
62

63

64
def element_type_from_sdfg_type(sdfg_type: Type):
4✔
65
    if isinstance(sdfg_type, Scalar):
4✔
66
        return sdfg_type
4✔
67
    elif isinstance(sdfg_type, (Pointer, Array, Tensor)):
4✔
68
        return Scalar(sdfg_type.primitive_type)
4✔
69
    else:
70
        raise ValueError(
×
71
            f"Unsupported SDFG type for element type extraction: {sdfg_type}"
72
        )
73

74

75
def element_type_from_ast_node(ast_node, container_table=None):
4✔
76
    # Default to double
77
    if ast_node is None:
4✔
78
        return Scalar(PrimitiveType.Double)
×
79

80
    # Handle python built-in types
81
    if isinstance(ast_node, ast.Name):
4✔
82
        if ast_node.id == "float":
4✔
83
            return Scalar(PrimitiveType.Double)
4✔
84
        if ast_node.id == "int":
4✔
85
            return Scalar(PrimitiveType.Int64)
4✔
86
        if ast_node.id == "bool":
×
87
            return Scalar(PrimitiveType.Bool)
×
88

89
    # Handle complex types
90
    if isinstance(ast_node, ast.Attribute):
4✔
91
        # Handle numpy types like np.float64, np.int32, etc.
92
        if isinstance(ast_node.value, ast.Name) and ast_node.value.id in [
4✔
93
            "numpy",
94
            "np",
95
        ]:
96
            if ast_node.attr == "float64":
4✔
97
                return Scalar(PrimitiveType.Double)
4✔
98
            if ast_node.attr == "float32":
4✔
99
                return Scalar(PrimitiveType.Float)
4✔
100
            if ast_node.attr == "int64":
4✔
101
                return Scalar(PrimitiveType.Int64)
4✔
102
            if ast_node.attr == "int32":
4✔
103
                return Scalar(PrimitiveType.Int32)
4✔
104
            if ast_node.attr == "int16":
×
105
                return Scalar(PrimitiveType.Int16)
×
106
            if ast_node.attr == "int8":
×
107
                return Scalar(PrimitiveType.Int8)
×
108
            if ast_node.attr == "uint64":
×
109
                return Scalar(PrimitiveType.UInt64)
×
110
            if ast_node.attr == "uint32":
×
111
                return Scalar(PrimitiveType.UInt32)
×
112
            if ast_node.attr == "uint16":
×
113
                return Scalar(PrimitiveType.UInt16)
×
114
            if ast_node.attr == "uint8":
×
115
                return Scalar(PrimitiveType.UInt8)
×
116
            if ast_node.attr == "bool_":
×
117
                return Scalar(PrimitiveType.Bool)
×
118

119
        # Handle arr.dtype - get element type from array's type in symbol table
120
        if ast_node.attr == "dtype" and container_table is not None:
4✔
121
            if isinstance(ast_node.value, ast.Name):
4✔
122
                var_name = ast_node.value.id
4✔
123
                if var_name in container_table:
4✔
124
                    var_type = container_table[var_name]
4✔
125
                    return element_type_from_sdfg_type(var_type)
4✔
126

127
    raise ValueError(f"Cannot map AST node to SDFG type: {ast.dump(ast_node)}")
×
128

129

130
def promote_element_types(left_element_type, right_element_type):
4✔
131
    """
132
    Promote two dtypes following NumPy rules for array-array operations.
133

134
    Rules:
135
    - float + float → wider float
136
    - int + int → wider int
137
    - float + int → float that can represent both (float32+int32 → float64)
138
    """
139
    left_pt = left_element_type.primitive_type
4✔
140
    right_pt = right_element_type.primitive_type
4✔
141

142
    # Check if types are floating point (includes half-precision types)
143
    float_types = {
4✔
144
        PrimitiveType.Double,
145
        PrimitiveType.Float,
146
        PrimitiveType.Half,
147
        PrimitiveType.BFloat,
148
    }
149
    int_types = {
4✔
150
        PrimitiveType.Int64,
151
        PrimitiveType.Int32,
152
        PrimitiveType.Int16,
153
        PrimitiveType.Int8,
154
        PrimitiveType.UInt64,
155
        PrimitiveType.UInt32,
156
        PrimitiveType.UInt16,
157
        PrimitiveType.UInt8,
158
    }
159

160
    left_is_float = left_pt in float_types
4✔
161
    right_is_float = right_pt in float_types
4✔
162

163
    # Both floats: return wider
164
    if left_is_float and right_is_float:
4✔
165
        if left_pt == PrimitiveType.Double or right_pt == PrimitiveType.Double:
4✔
166
            return Scalar(PrimitiveType.Double)
4✔
167
        if left_pt == PrimitiveType.Float or right_pt == PrimitiveType.Float:
4✔
168
            return Scalar(PrimitiveType.Float)
4✔
169
        # Half-precision types: same type stays, mixed promotes to Float
170
        if left_pt == right_pt:
4✔
171
            return Scalar(left_pt)  # BFloat+BFloat→BFloat, Half+Half→Half
4✔
NEW
172
        return Scalar(PrimitiveType.Float)  # Mixed half types → float32
×
173

174
    # Both integers: return wider (simplified - always Int64 for now)
175
    if not left_is_float and not right_is_float:
4✔
176
        if left_pt == PrimitiveType.Int64 or right_pt == PrimitiveType.Int64:
4✔
177
            return Scalar(PrimitiveType.Int64)
4✔
178
        if left_pt == PrimitiveType.UInt64 or right_pt == PrimitiveType.UInt64:
4✔
NEW
179
            return Scalar(PrimitiveType.Int64)  # Promote to signed for safety
×
180
        if left_pt == PrimitiveType.Int32 or right_pt == PrimitiveType.Int32:
4✔
181
            return Scalar(PrimitiveType.Int32)
4✔
NEW
182
        return Scalar(PrimitiveType.Int64)  # Default
×
183

184
    # Mixed float + int: need a float that can represent the int
185
    # float32 can represent int16/int8, but not int32
186
    # float64 can represent int32 and smaller
187
    # half types + int → promote to float32 (half can't represent ints well)
188
    float_type = left_pt if left_is_float else right_pt
4✔
189
    int_type = right_pt if left_is_float else left_pt
4✔
190

191
    # If float is already double, use double
192
    if float_type == PrimitiveType.Double:
4✔
193
        return Scalar(PrimitiveType.Double)
4✔
194

195
    # Half-precision + any int → float32 (half types can't represent ints well)
196
    if float_type in {PrimitiveType.Half, PrimitiveType.BFloat}:
4✔
NEW
197
        return Scalar(PrimitiveType.Float)
×
198

199
    # float32 + (int32 or larger) → float64
200
    if int_type in {
4✔
201
        PrimitiveType.Int32,
202
        PrimitiveType.Int64,
203
        PrimitiveType.UInt32,
204
        PrimitiveType.UInt64,
205
    }:
206
        return Scalar(PrimitiveType.Double)
4✔
207

208
    # float32 + (int16 or smaller) → float32
NEW
209
    return Scalar(PrimitiveType.Float)
×
210

211

212
def numpy_promote_types(left_type, left_is_array, right_type, right_is_array):
4✔
213
    """
214
    Implement NumPy's type promotion rules for binary operations.
215

216
    Key rule: Scalars adapt to arrays, not vice versa.
217
    - array + scalar → array's dtype (scalar is cast to array's dtype)
218
    - array + array → standard promotion (wider/float wins)
219
    - scalar + scalar → standard promotion
220

221
    Args:
222
        left_type: Element type of left operand (Scalar)
223
        left_is_array: True if left operand is an array
224
        right_type: Element type of right operand (Scalar)
225
        right_is_array: True if right operand is an array
226

227
    Returns:
228
        Result element type (Scalar)
229
    """
230
    if left_is_array and not right_is_array:
4✔
231
        # Scalar adapts to array
232
        return left_type
4✔
233
    if right_is_array and not left_is_array:
4✔
234
        # Scalar adapts to array
235
        return right_type
4✔
236
    # Both arrays or both scalars: use standard promotion
237
    return promote_element_types(left_type, right_type)
4✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc