• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

IntelPython / dpctl / 17628132898

10 Sep 2025 10:11PM UTC coverage: 85.13% (-0.8%) from 85.908%
17628132898

Pull #2144

github

web-flow
Merge 16361be1b into 52ab45228
Pull Request #2144: Technical debt clean-up in test suite

3190 of 3880 branches covered (82.22%)

Branch coverage included in aggregate %.

12142 of 14130 relevant lines covered (85.93%)

3671.33 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

91.17
/dpctl/tensor/_type_utils.py
1
#                      Data Parallel Control (dpctl)
2
#
3
# Copyright 2020-2025 Intel Corporation
4
#
5
# Licensed under the Apache License, Version 2.0 (the "License");
6
# you may not use this file except in compliance with the License.
7
# You may obtain a copy of the License at
8
#
9
#    http://www.apache.org/licenses/LICENSE-2.0
10
#
11
# Unless required by applicable law or agreed to in writing, software
12
# distributed under the License is distributed on an "AS IS" BASIS,
13
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
# See the License for the specific language governing permissions and
15
# limitations under the License.
16
from __future__ import annotations
1✔
17

18
import numpy as np
1✔
19

20
import dpctl.tensor as dpt
1✔
21
import dpctl.tensor._tensor_impl as ti
1✔
22

23

24
def _all_data_types(_fp16, _fp64):
1✔
25
    _non_fp_types = [
1✔
26
        dpt.bool,
27
        dpt.int8,
28
        dpt.uint8,
29
        dpt.int16,
30
        dpt.uint16,
31
        dpt.int32,
32
        dpt.uint32,
33
        dpt.int64,
34
        dpt.uint64,
35
    ]
36
    if _fp64:
1✔
37
        if _fp16:
1✔
38
            return _non_fp_types + [
1✔
39
                dpt.float16,
40
                dpt.float32,
41
                dpt.float64,
42
                dpt.complex64,
43
                dpt.complex128,
44
            ]
45
        else:
46
            return _non_fp_types + [
1✔
47
                dpt.float32,
48
                dpt.float64,
49
                dpt.complex64,
50
                dpt.complex128,
51
            ]
52
    else:
53
        if _fp16:
1✔
54
            return _non_fp_types + [
1✔
55
                dpt.float16,
56
                dpt.float32,
57
                dpt.complex64,
58
            ]
59
        else:
60
            return _non_fp_types + [
1✔
61
                dpt.float32,
62
                dpt.complex64,
63
            ]
64

65

66
def _is_maximal_inexact_type(dt: dpt.dtype, _fp16: bool, _fp64: bool):
1✔
67
    """
68
    Return True if data type `dt` is the
69
    maximal size inexact data type
70
    """
71
    if _fp64:
1✔
72
        return dt in [dpt.float64, dpt.complex128]
1✔
73
    return dt in [dpt.float32, dpt.complex64]
1✔
74

75

76
def _dtype_supported_by_device_impl(
1✔
77
    dt: dpt.dtype, has_fp16: bool, has_fp64: bool
78
) -> bool:
79
    if has_fp64:
1✔
80
        if not has_fp16:
1✔
81
            if dt is dpt.float16:
1✔
82
                return False
1✔
83
    else:
84
        if dt is dpt.float64:
1✔
85
            return False
1✔
86
        elif dt is dpt.complex128:
1✔
87
            return False
1✔
88
        if not has_fp16 and dt is dpt.float16:
1✔
89
            return False
1✔
90
    return True
1✔
91

92

93
def _can_cast(
1✔
94
    from_: dpt.dtype, to_: dpt.dtype, _fp16: bool, _fp64: bool, casting="safe"
95
) -> bool:
96
    """
97
    Can `from_` be cast to `to_` safely on a device with
98
    fp16 and fp64 aspects as given?
99
    """
100
    if not _dtype_supported_by_device_impl(to_, _fp16, _fp64):
1✔
101
        return False
1✔
102
    can_cast_v = np.can_cast(from_, to_, casting=casting)  # ask NumPy
1✔
103
    if _fp16 and _fp64:
1✔
104
        return can_cast_v
1✔
105
    if not can_cast_v:
1✔
106
        if (
1✔
107
            from_.kind in "biu"
108
            and to_.kind in "fc"
109
            and _is_maximal_inexact_type(to_, _fp16, _fp64)
110
        ):
111
            return True
1✔
112

113
    return can_cast_v
1✔
114

115

116
def _to_device_supported_dtype_impl(dt, has_fp16, has_fp64):
1✔
117
    if has_fp64:
1✔
118
        if not has_fp16:
1✔
119
            if dt is dpt.float16:
1✔
120
                return dpt.float32
1✔
121
    else:
122
        if dt is dpt.float64:
1✔
123
            return dpt.float32
1✔
124
        elif dt is dpt.complex128:
1✔
125
            return dpt.complex64
1✔
126
        if not has_fp16 and dt is dpt.float16:
1✔
127
            return dpt.float32
1✔
128
    return dt
1✔
129

130

131
def _to_device_supported_dtype(dt, dev):
1✔
132
    has_fp16 = dev.has_aspect_fp16
1✔
133
    has_fp64 = dev.has_aspect_fp64
1✔
134

135
    return _to_device_supported_dtype_impl(dt, has_fp16, has_fp64)
1✔
136

137

138
def _acceptance_fn_default_unary(arg_dtype, ret_buf_dt, res_dt, sycl_dev):
1✔
139
    return True
1✔
140

141

142
def _acceptance_fn_reciprocal(arg_dtype, buf_dt, res_dt, sycl_dev):
1✔
143
    # if the kind of result is different from the kind of input, we use the
144
    # default floating-point dtype for the resulting kind. This guarantees
145
    # alignment of reciprocal and divide output types.
146
    if buf_dt.kind != arg_dtype.kind:
1✔
147
        default_dt = _get_device_default_dtype(res_dt.kind, sycl_dev)
1✔
148
        if res_dt == default_dt:
1✔
149
            return True
1✔
150
        else:
151
            return False
1✔
152
    else:
153
        return True
1✔
154

155

156
def _acceptance_fn_negative(arg_dtype, buf_dt, res_dt, sycl_dev):
1✔
157
    # negative is not defined for boolean data type
158
    if arg_dtype.char == "?":
1✔
159
        raise ValueError(
1✔
160
            "The `negative` function, the `-` operator, is not supported "
161
            "for inputs of data type bool, use the `~` operator or the "
162
            "`logical_not` function instead"
163
        )
164
    else:
165
        return True
1✔
166

167

168
def _acceptance_fn_subtract(
1✔
169
    arg1_dtype, arg2_dtype, buf1_dt, buf2_dt, res_dt, sycl_dev
170
):
171
    # subtract is not defined for boolean data type
172
    if arg1_dtype.char == "?" and arg2_dtype.char == "?":
1✔
173
        raise ValueError(
1✔
174
            "The `subtract` function, the `-` operator, is not supported "
175
            "for inputs of data type bool, use the `^` operator,  the "
176
            "`bitwise_xor`, or the `logical_xor` function instead"
177
        )
178
    else:
179
        return True
1✔
180

181

182
def _find_buf_dtype(arg_dtype, query_fn, sycl_dev, acceptance_fn):
1✔
183
    res_dt = query_fn(arg_dtype)
1✔
184
    if res_dt:
1✔
185
        return None, res_dt
1✔
186

187
    _fp16 = sycl_dev.has_aspect_fp16
1✔
188
    _fp64 = sycl_dev.has_aspect_fp64
1✔
189
    all_dts = _all_data_types(_fp16, _fp64)
1✔
190
    for buf_dt in all_dts:
1✔
191
        if _can_cast(arg_dtype, buf_dt, _fp16, _fp64):
1✔
192
            res_dt = query_fn(buf_dt)
1✔
193
            if res_dt:
1✔
194
                acceptable = acceptance_fn(arg_dtype, buf_dt, res_dt, sycl_dev)
1✔
195
                if acceptable:
1✔
196
                    return buf_dt, res_dt
1✔
197
                else:
198
                    continue
1✔
199

200
    return None, None
1✔
201

202

203
def _get_device_default_dtype(dt_kind, sycl_dev):
1✔
204
    if dt_kind == "b":
1✔
205
        return dpt.dtype(ti.default_device_bool_type(sycl_dev))
1✔
206
    elif dt_kind == "i":
1✔
207
        return dpt.dtype(ti.default_device_int_type(sycl_dev))
1✔
208
    elif dt_kind == "u":
1✔
209
        return dpt.dtype(ti.default_device_uint_type(sycl_dev))
1✔
210
    elif dt_kind == "f":
1✔
211
        return dpt.dtype(ti.default_device_fp_type(sycl_dev))
1✔
212
    elif dt_kind == "c":
1✔
213
        return dpt.dtype(ti.default_device_complex_type(sycl_dev))
1✔
214
    raise RuntimeError
1✔
215

216

217
def _acceptance_fn_default_binary(
1✔
218
    arg1_dtype, arg2_dtype, ret_buf1_dt, ret_buf2_dt, res_dt, sycl_dev
219
):
220
    return True
1✔
221

222

223
def _acceptance_fn_divide(
1✔
224
    arg1_dtype, arg2_dtype, ret_buf1_dt, ret_buf2_dt, res_dt, sycl_dev
225
):
226
    # both are being promoted, if the kind of result is
227
    # different than the kind of original input dtypes,
228
    # we use default dtype for the resulting kind.
229
    # This covers, e.g. (array_dtype_i1 / array_dtype_u1)
230
    # result of which in divide is double (in NumPy), but
231
    # regular type promotion rules peg at float16
232
    if (ret_buf1_dt.kind != arg1_dtype.kind) and (
1✔
233
        ret_buf2_dt.kind != arg2_dtype.kind
234
    ):
235
        default_dt = _get_device_default_dtype(res_dt.kind, sycl_dev)
1✔
236
        if res_dt == default_dt:
1✔
237
            return True
1✔
238
        else:
239
            return False
1✔
240
    else:
241
        return True
1✔
242

243

244
def _find_buf_dtype2(arg1_dtype, arg2_dtype, query_fn, sycl_dev, acceptance_fn):
1✔
245
    res_dt = query_fn(arg1_dtype, arg2_dtype)
1✔
246
    if res_dt:
1✔
247
        return None, None, res_dt
1✔
248

249
    _fp16 = sycl_dev.has_aspect_fp16
1✔
250
    _fp64 = sycl_dev.has_aspect_fp64
1✔
251
    all_dts = _all_data_types(_fp16, _fp64)
1✔
252
    for buf1_dt in all_dts:
1✔
253
        for buf2_dt in all_dts:
1✔
254
            if _can_cast(arg1_dtype, buf1_dt, _fp16, _fp64) and _can_cast(
1✔
255
                arg2_dtype, buf2_dt, _fp16, _fp64
256
            ):
257
                res_dt = query_fn(buf1_dt, buf2_dt)
1✔
258
                if res_dt:
1✔
259
                    ret_buf1_dt = None if buf1_dt == arg1_dtype else buf1_dt
1✔
260
                    ret_buf2_dt = None if buf2_dt == arg2_dtype else buf2_dt
1✔
261
                    if ret_buf1_dt is None or ret_buf2_dt is None:
1✔
262
                        return ret_buf1_dt, ret_buf2_dt, res_dt
1✔
263
                    else:
264
                        acceptable = acceptance_fn(
1✔
265
                            arg1_dtype,
266
                            arg2_dtype,
267
                            ret_buf1_dt,
268
                            ret_buf2_dt,
269
                            res_dt,
270
                            sycl_dev,
271
                        )
272
                        if acceptable:
1✔
273
                            return ret_buf1_dt, ret_buf2_dt, res_dt
1✔
274
                        else:
275
                            continue
1✔
276

277
    return None, None, None
1✔
278

279

280
def _find_buf_dtype_in_place_op(arg1_dtype, arg2_dtype, query_fn, sycl_dev):
1✔
281
    res_dt = query_fn(arg1_dtype, arg2_dtype)
1✔
282
    if res_dt:
1✔
283
        return None, res_dt
1✔
284

285
    _fp16 = sycl_dev.has_aspect_fp16
1✔
286
    _fp64 = sycl_dev.has_aspect_fp64
1✔
287
    if _can_cast(arg2_dtype, arg1_dtype, _fp16, _fp64, casting="same_kind"):
1✔
288
        res_dt = query_fn(arg1_dtype, arg1_dtype)
1✔
289
        if res_dt:
1✔
290
            return arg1_dtype, res_dt
1✔
291

292
    return None, None
1✔
293

294

295
class WeakBooleanType:
1✔
296
    "Python type representing type of Python boolean objects"
297

298
    def __init__(self, o):
1✔
299
        self.o_ = o
1✔
300

301
    def get(self):
1✔
302
        return self.o_
1✔
303

304

305
class WeakIntegralType:
1✔
306
    "Python type representing type of Python integral objects"
307

308
    def __init__(self, o):
1✔
309
        self.o_ = o
1✔
310

311
    def get(self):
1✔
312
        return self.o_
1✔
313

314

315
class WeakFloatingType:
1✔
316
    """Python type representing type of Python floating point objects"""
317

318
    def __init__(self, o):
1✔
319
        self.o_ = o
1✔
320

321
    def get(self):
1✔
322
        return self.o_
1✔
323

324

325
class WeakComplexType:
1✔
326
    """Python type representing type of Python complex floating point objects"""
327

328
    def __init__(self, o):
1✔
329
        self.o_ = o
1✔
330

331
    def get(self):
1✔
332
        return self.o_
1✔
333

334

335
def _weak_type_num_kind(o):
1✔
336
    _map = {"?": 0, "i": 1, "f": 2, "c": 3}
1✔
337
    if isinstance(o, WeakBooleanType):
1✔
338
        return _map["?"]
1✔
339
    if isinstance(o, WeakIntegralType):
1✔
340
        return _map["i"]
1✔
341
    if isinstance(o, WeakFloatingType):
1✔
342
        return _map["f"]
1✔
343
    if isinstance(o, WeakComplexType):
1✔
344
        return _map["c"]
1✔
345
    raise TypeError(
1✔
346
        f"Unexpected type {o} while expecting "
347
        "`WeakBooleanType`, `WeakIntegralType`,"
348
        "`WeakFloatingType`, or `WeakComplexType`."
349
    )
350

351

352
def _strong_dtype_num_kind(o):
1✔
353
    _map = {"b": 0, "i": 1, "u": 1, "f": 2, "c": 3}
1✔
354
    if not isinstance(o, dpt.dtype):
1✔
355
        raise TypeError
1✔
356
    k = o.kind
1✔
357
    if k in _map:
1✔
358
        return _map[k]
1✔
359
    raise ValueError(f"Unrecognized kind {k} for dtype {o}")
1✔
360

361

362
def _is_weak_dtype(dtype):
1✔
363
    return isinstance(
1✔
364
        dtype,
365
        (WeakBooleanType, WeakIntegralType, WeakFloatingType, WeakComplexType),
366
    )
367

368

369
def _resolve_weak_types(o1_dtype, o2_dtype, dev):
1✔
370
    "Resolves weak data type per NEP-0050"
371
    if _is_weak_dtype(o1_dtype):
1✔
372
        if _is_weak_dtype(o2_dtype):
1✔
373
            raise ValueError
1✔
374
        o1_kind_num = _weak_type_num_kind(o1_dtype)
1✔
375
        o2_kind_num = _strong_dtype_num_kind(o2_dtype)
1✔
376
        if o1_kind_num > o2_kind_num:
1✔
377
            if isinstance(o1_dtype, WeakIntegralType):
1✔
378
                return dpt.dtype(ti.default_device_int_type(dev)), o2_dtype
1✔
379
            if isinstance(o1_dtype, WeakComplexType):
1✔
380
                if o2_dtype is dpt.float16 or o2_dtype is dpt.float32:
1✔
381
                    return dpt.complex64, o2_dtype
1✔
382
                return (
1✔
383
                    _to_device_supported_dtype(dpt.complex128, dev),
384
                    o2_dtype,
385
                )
386
            return _to_device_supported_dtype(dpt.float64, dev), o2_dtype
1✔
387
        else:
388
            return o2_dtype, o2_dtype
1✔
389
    elif _is_weak_dtype(o2_dtype):
1✔
390
        o1_kind_num = _strong_dtype_num_kind(o1_dtype)
1✔
391
        o2_kind_num = _weak_type_num_kind(o2_dtype)
1✔
392
        if o2_kind_num > o1_kind_num:
1✔
393
            if isinstance(o2_dtype, WeakIntegralType):
1✔
394
                return o1_dtype, dpt.dtype(ti.default_device_int_type(dev))
1✔
395
            if isinstance(o2_dtype, WeakComplexType):
1✔
396
                if o1_dtype is dpt.float16 or o1_dtype is dpt.float32:
1✔
397
                    return o1_dtype, dpt.complex64
1✔
398
                return o1_dtype, _to_device_supported_dtype(dpt.complex128, dev)
1✔
399
            return (
1✔
400
                o1_dtype,
401
                _to_device_supported_dtype(dpt.float64, dev),
402
            )
403
        else:
404
            return o1_dtype, o1_dtype
1✔
405
    else:
406
        return o1_dtype, o2_dtype
1✔
407

408

409
def _resolve_weak_types_all_py_ints(o1_dtype, o2_dtype, dev):
1✔
410
    "Resolves weak data type per NEP-0050 for comparisons and"
411
    " divide, where result type is known and special behavior"
1✔
412
    "is needed to handle mixed integer kinds and Python integers"
1✔
413
    "without overflow"
1✔
414
    if _is_weak_dtype(o1_dtype):
1✔
415
        if _is_weak_dtype(o2_dtype):
1!
416
            raise ValueError
×
417
        o1_kind_num = _weak_type_num_kind(o1_dtype)
1✔
418
        o2_kind_num = _strong_dtype_num_kind(o2_dtype)
1✔
419
        if o1_kind_num > o2_kind_num:
1✔
420
            if isinstance(o1_dtype, WeakIntegralType):
1!
421
                return dpt.dtype(ti.default_device_int_type(dev)), o2_dtype
1✔
422
            if isinstance(o1_dtype, WeakComplexType):
×
423
                if o2_dtype is dpt.float16 or o2_dtype is dpt.float32:
×
424
                    return dpt.complex64, o2_dtype
×
425
                return (
×
426
                    _to_device_supported_dtype(dpt.complex128, dev),
427
                    o2_dtype,
428
                )
429
            return _to_device_supported_dtype(dpt.float64, dev), o2_dtype
×
430
        else:
431
            if o1_kind_num == o2_kind_num and isinstance(
1✔
432
                o1_dtype, WeakIntegralType
433
            ):
434
                o1_val = o1_dtype.get()
1✔
435
                o2_iinfo = dpt.iinfo(o2_dtype)
1✔
436
                if (o1_val < o2_iinfo.min) or (o1_val > o2_iinfo.max):
1✔
437
                    return dpt.dtype(np.min_scalar_type(o1_val)), o2_dtype
1✔
438
            return o2_dtype, o2_dtype
1✔
439
    elif _is_weak_dtype(o2_dtype):
1✔
440
        o1_kind_num = _strong_dtype_num_kind(o1_dtype)
1✔
441
        o2_kind_num = _weak_type_num_kind(o2_dtype)
1✔
442
        if o2_kind_num > o1_kind_num:
1✔
443
            if isinstance(o2_dtype, WeakIntegralType):
1✔
444
                return o1_dtype, dpt.dtype(ti.default_device_int_type(dev))
1✔
445
            if isinstance(o2_dtype, WeakComplexType):
1!
446
                if o1_dtype is dpt.float16 or o1_dtype is dpt.float32:
×
447
                    return o1_dtype, dpt.complex64
×
448
                return o1_dtype, _to_device_supported_dtype(dpt.complex128, dev)
×
449
            return (
1✔
450
                o1_dtype,
451
                _to_device_supported_dtype(dpt.float64, dev),
452
            )
453
        else:
454
            if o1_kind_num == o2_kind_num and isinstance(
1✔
455
                o2_dtype, WeakIntegralType
456
            ):
457
                o2_val = o2_dtype.get()
1✔
458
                o1_iinfo = dpt.iinfo(o1_dtype)
1✔
459
                if (o2_val < o1_iinfo.min) or (o2_val > o1_iinfo.max):
1✔
460
                    return o1_dtype, dpt.dtype(np.min_scalar_type(o2_val))
1✔
461
            return o1_dtype, o1_dtype
1✔
462
    else:
463
        return o1_dtype, o2_dtype
1✔
464

465

466
def _resolve_one_strong_two_weak_types(st_dtype, dtype1, dtype2, dev):
1✔
467
    "Resolves weak data types per NEP-0050,"
468
    "where the second and third arguments are"
1✔
469
    "permitted to be weak types"
1✔
470
    if _is_weak_dtype(st_dtype):
1!
471
        raise ValueError
×
472
    if _is_weak_dtype(dtype1):
1✔
473
        if _is_weak_dtype(dtype2):
1✔
474
            kind_num1 = _weak_type_num_kind(dtype1)
1✔
475
            kind_num2 = _weak_type_num_kind(dtype2)
1✔
476
            st_kind_num = _strong_dtype_num_kind(st_dtype)
1✔
477

478
            if kind_num1 > st_kind_num:
1✔
479
                if isinstance(dtype1, WeakIntegralType):
1✔
480
                    ret_dtype1 = dpt.dtype(ti.default_device_int_type(dev))
1✔
481
                elif isinstance(dtype1, WeakComplexType):
1✔
482
                    if st_dtype is dpt.float16 or st_dtype is dpt.float32:
1✔
483
                        ret_dtype1 = dpt.complex64
1✔
484
                    ret_dtype1 = _to_device_supported_dtype(dpt.complex128, dev)
1✔
485
                else:
486
                    ret_dtype1 = _to_device_supported_dtype(dpt.float64, dev)
1✔
487
            else:
488
                ret_dtype1 = st_dtype
1✔
489

490
            if kind_num2 > st_kind_num:
1✔
491
                if isinstance(dtype2, WeakIntegralType):
1✔
492
                    ret_dtype2 = dpt.dtype(ti.default_device_int_type(dev))
1✔
493
                elif isinstance(dtype2, WeakComplexType):
1✔
494
                    if st_dtype is dpt.float16 or st_dtype is dpt.float32:
1✔
495
                        ret_dtype2 = dpt.complex64
1✔
496
                    ret_dtype2 = _to_device_supported_dtype(dpt.complex128, dev)
1✔
497
                else:
498
                    ret_dtype2 = _to_device_supported_dtype(dpt.float64, dev)
1✔
499
            else:
500
                ret_dtype2 = st_dtype
1✔
501

502
            return ret_dtype1, ret_dtype2
1✔
503

504
        max_dt_num_kind, max_dtype = max(
1✔
505
            [
506
                (_strong_dtype_num_kind(st_dtype), st_dtype),
507
                (_strong_dtype_num_kind(dtype2), dtype2),
508
            ]
509
        )
510
        dt1_kind_num = _weak_type_num_kind(dtype1)
1✔
511
        if dt1_kind_num > max_dt_num_kind:
1✔
512
            if isinstance(dtype1, WeakIntegralType):
1!
513
                return dpt.dtype(ti.default_device_int_type(dev)), dtype2
×
514
            if isinstance(dtype1, WeakComplexType):
1!
515
                if max_dtype is dpt.float16 or max_dtype is dpt.float32:
×
516
                    return dpt.complex64, dtype2
×
517
                return (
×
518
                    _to_device_supported_dtype(dpt.complex128, dev),
519
                    dtype2,
520
                )
521
            return _to_device_supported_dtype(dpt.float64, dev), dtype2
1✔
522
        else:
523
            return max_dtype, dtype2
1✔
524
    elif _is_weak_dtype(dtype2):
1✔
525
        max_dt_num_kind, max_dtype = max(
1✔
526
            [
527
                (_strong_dtype_num_kind(st_dtype), st_dtype),
528
                (_strong_dtype_num_kind(dtype1), dtype1),
529
            ]
530
        )
531
        dt2_kind_num = _weak_type_num_kind(dtype2)
1✔
532
        if dt2_kind_num > max_dt_num_kind:
1✔
533
            if isinstance(dtype2, WeakIntegralType):
1!
534
                return dtype1, dpt.dtype(ti.default_device_int_type(dev))
×
535
            if isinstance(dtype2, WeakComplexType):
1!
536
                if max_dtype is dpt.float16 or max_dtype is dpt.float32:
×
537
                    return dtype1, dpt.complex64
×
538
                return (
×
539
                    dtype1,
540
                    _to_device_supported_dtype(dpt.complex128, dev),
541
                )
542
            return dtype1, _to_device_supported_dtype(dpt.float64, dev)
1✔
543
        else:
544
            return dtype1, max_dtype
1✔
545
    else:
546
        # both are strong dtypes
547
        # return unmodified
548
        return dtype1, dtype2
1✔
549

550

551
def _resolve_one_strong_one_weak_types(st_dtype, dtype, dev):
1✔
552
    "Resolves one weak data type with one strong data type per NEP-0050"
553
    if _is_weak_dtype(st_dtype):
1!
554
        raise ValueError
×
555
    if _is_weak_dtype(dtype):
1✔
556
        st_kind_num = _strong_dtype_num_kind(st_dtype)
1✔
557
        kind_num = _weak_type_num_kind(dtype)
1✔
558
        if kind_num > st_kind_num:
1✔
559
            if isinstance(dtype, WeakIntegralType):
1✔
560
                return dpt.dtype(ti.default_device_int_type(dev))
1✔
561
            if isinstance(dtype, WeakComplexType):
1✔
562
                if st_dtype is dpt.float16 or st_dtype is dpt.float32:
1✔
563
                    return dpt.complex64
1✔
564
                return _to_device_supported_dtype(dpt.complex128, dev)
1✔
565
            return _to_device_supported_dtype(dpt.float64, dev)
1✔
566
        else:
567
            return st_dtype
1✔
568
    else:
569
        return dtype
1✔
570

571

572
class finfo_object:
1✔
573
    """
574
    `numpy.finfo` subclass which returns Python floating-point scalars for
575
    `eps`, `max`, `min`, and `smallest_normal` attributes.
576
    """
577

578
    def __init__(self, dtype):
1✔
579
        _supported_dtype([dpt.dtype(dtype)])
1✔
580
        self._finfo = np.finfo(dtype)
1✔
581

582
    @property
1✔
583
    def bits(self):
1✔
584
        """
585
        number of bits occupied by the real-valued floating-point data type.
586
        """
587
        return int(self._finfo.bits)
1✔
588

589
    @property
1✔
590
    def smallest_normal(self):
1✔
591
        """
592
        smallest positive real-valued floating-point number with full
593
        precision.
594
        """
595
        return float(self._finfo.smallest_normal)
1✔
596

597
    @property
1✔
598
    def tiny(self):
1✔
599
        """an alias for `smallest_normal`"""
600
        return float(self._finfo.tiny)
1✔
601

602
    @property
1✔
603
    def eps(self):
1✔
604
        """
605
        difference between 1.0 and the next smallest representable real-valued
606
        floating-point number larger than 1.0 according to the IEEE-754
607
        standard.
608
        """
609
        return float(self._finfo.eps)
1✔
610

611
    @property
1✔
612
    def epsneg(self):
1✔
613
        """
614
        difference between 1.0 and the next smallest representable real-valued
615
        floating-point number smaller than 1.0 according to the IEEE-754
616
        standard.
617
        """
618
        return float(self._finfo.epsneg)
1✔
619

620
    @property
1✔
621
    def min(self):
1✔
622
        """smallest representable real-valued number."""
623
        return float(self._finfo.min)
1✔
624

625
    @property
1✔
626
    def max(self):
1✔
627
        "largest representable real-valued number."
628
        return float(self._finfo.max)
1✔
629

630
    @property
1✔
631
    def resolution(self):
1✔
632
        "the approximate decimal resolution of this type."
633
        return float(self._finfo.resolution)
1✔
634

635
    @property
1✔
636
    def precision(self):
1✔
637
        """
638
        the approximate number of decimal digits to which this kind of
639
        floating point type is precise.
640
        """
641
        return float(self._finfo.precision)
1✔
642

643
    @property
1✔
644
    def dtype(self):
1✔
645
        """
646
        the dtype for which finfo returns information. For complex input, the
647
        returned dtype is the associated floating point dtype for its real and
648
        complex components.
649
        """
650
        return self._finfo.dtype
1✔
651

652
    def __str__(self):
1✔
653
        return self._finfo.__str__()
1✔
654

655
    def __repr__(self):
1✔
656
        return self._finfo.__repr__()
1✔
657

658

659
def can_cast(from_, to, /, *, casting="safe") -> bool:
1✔
660
    """ can_cast(from, to, casting="safe")
661

662
    Determines if one data type can be cast to another data type according \
663
    to Type Promotion Rules.
664

665
    Args:
666
       from_ (Union[usm_ndarray, dtype]):
667
           source data type. If `from_` is an array, a device-specific type
668
           promotion rules apply.
669
       to (dtype):
670
           target data type
671
       casting (Optional[str]):
672
            controls what kind of data casting may occur.
673

674
                * "no" means data types should not be cast at all.
675
                * "safe" means only casts that preserve values are allowed.
676
                * "same_kind" means only safe casts and casts within a kind,
677
                  like `float64` to `float32`, are allowed.
678
                * "unsafe" means any data conversion can be done.
679

680
            Default: `"safe"`.
681

682
    Returns:
683
        bool:
684
            Gives `True` if cast can occur according to the casting rule.
685

686
    Device-specific type promotion rules take into account which data type are
687
    and are not supported by a specific device.
688
    """
689
    if isinstance(to, dpt.usm_ndarray):
1✔
690
        raise TypeError(f"Expected `dpt.dtype` type, got {type(to)}.")
1✔
691

692
    dtype_to = dpt.dtype(to)
1✔
693
    _supported_dtype([dtype_to])
1✔
694

695
    if isinstance(from_, dpt.usm_ndarray):
1✔
696
        dtype_from = from_.dtype
1✔
697
        return _can_cast(
1✔
698
            dtype_from,
699
            dtype_to,
700
            from_.sycl_device.has_aspect_fp16,
701
            from_.sycl_device.has_aspect_fp64,
702
            casting=casting,
703
        )
704
    else:
705
        dtype_from = dpt.dtype(from_)
1✔
706
        _supported_dtype([dtype_from])
1✔
707
        # query casting as if all dtypes are supported
708
        return _can_cast(dtype_from, dtype_to, True, True, casting=casting)
1✔
709

710

711
def result_type(*arrays_and_dtypes):
1✔
712
    """
713
    result_type(*arrays_and_dtypes)
714

715
    Returns the dtype that results from applying the Type Promotion Rules to \
716
        the arguments.
717

718
    Args:
719
        arrays_and_dtypes (Union[usm_ndarray, dtype]):
720
            An arbitrary length sequence of usm_ndarray objects or dtypes.
721

722
    Returns:
723
        dtype:
724
            The dtype resulting from an operation involving the
725
            input arrays and dtypes.
726
    """
727
    dtypes = []
1✔
728
    devices = []
1✔
729
    weak_dtypes = []
1✔
730
    for arg_i in arrays_and_dtypes:
1✔
731
        if isinstance(arg_i, dpt.usm_ndarray):
1✔
732
            devices.append(arg_i.sycl_device)
1✔
733
            dtypes.append(arg_i.dtype)
1✔
734
        elif isinstance(arg_i, int):
1✔
735
            weak_dtypes.append(WeakIntegralType(arg_i))
1✔
736
        elif isinstance(arg_i, float):
1✔
737
            weak_dtypes.append(WeakFloatingType(arg_i))
1✔
738
        elif isinstance(arg_i, complex):
1✔
739
            weak_dtypes.append(WeakComplexType(arg_i))
1✔
740
        elif isinstance(arg_i, bool):
1!
741
            weak_dtypes.append(WeakBooleanType(arg_i))
×
742
        else:
743
            dt = dpt.dtype(arg_i)
1✔
744
            _supported_dtype([dt])
1✔
745
            dtypes.append(dt)
1✔
746

747
    has_fp16 = True
1✔
748
    has_fp64 = True
1✔
749
    target_dev = None
1✔
750
    if devices:
1✔
751
        inspected = False
1✔
752
        for d in devices:
1✔
753
            if inspected:
1✔
754
                unsame_fp16_support = d.has_aspect_fp16 != has_fp16
1✔
755
                unsame_fp64_support = d.has_aspect_fp64 != has_fp64
1✔
756
                if unsame_fp16_support or unsame_fp64_support:
1!
757
                    raise ValueError(
×
758
                        "Input arrays reside on devices "
759
                        "with different device supports; "
760
                        "unable to determine which "
761
                        "device-specific type promotion rules "
762
                        "to use."
763
                    )
764
            else:
765
                has_fp16 = d.has_aspect_fp16
1✔
766
                has_fp64 = d.has_aspect_fp64
1✔
767
                target_dev = d
1✔
768
                inspected = True
1✔
769

770
    if not dtypes and weak_dtypes:
1✔
771
        dtypes.append(weak_dtypes[0].get())
1✔
772

773
    if not (has_fp16 and has_fp64):
1!
774
        for dt in dtypes:
×
775
            if not _dtype_supported_by_device_impl(dt, has_fp16, has_fp64):
×
776
                raise ValueError(
×
777
                    f"Argument {dt} is not supported by the device"
778
                )
779
        res_dt = np.result_type(*dtypes)
×
780
        res_dt = _to_device_supported_dtype_impl(res_dt, has_fp16, has_fp64)
×
781
        for wdt in weak_dtypes:
×
782
            pair = _resolve_weak_types(wdt, res_dt, target_dev)
×
783
            res_dt = np.result_type(*pair)
×
784
            res_dt = _to_device_supported_dtype_impl(res_dt, has_fp16, has_fp64)
×
785
    else:
786
        res_dt = np.result_type(*dtypes)
1✔
787
        if weak_dtypes:
1✔
788
            weak_dt_obj = [wdt.get() for wdt in weak_dtypes]
1✔
789
            res_dt = np.result_type(res_dt, *weak_dt_obj)
1✔
790

791
    return res_dt
1✔
792

793

794
def iinfo(dtype, /):
1✔
795
    """iinfo(dtype)
796

797
    Returns machine limits for integer data types.
798

799
    Args:
800
        dtype (dtype, usm_ndarray):
801
            integer dtype or
802
            an array with integer dtype.
803

804
    Returns:
805
        iinfo_object:
806
            An object with the following attributes:
807

808
            * bits: int
809
                number of bits occupied by the data type
810
            * max: int
811
                largest representable number.
812
            * min: int
813
                smallest representable number.
814
            * dtype: dtype
815
                integer data type.
816
    """
817
    if isinstance(dtype, dpt.usm_ndarray):
1✔
818
        dtype = dtype.dtype
1✔
819
    _supported_dtype([dpt.dtype(dtype)])
1✔
820
    return np.iinfo(dtype)
1✔
821

822

823
def finfo(dtype, /):
1✔
824
    """finfo(type)
825

826
    Returns machine limits for floating-point data types.
827

828
    Args:
829
        dtype (dtype, usm_ndarray): floating-point dtype or
830
            an array with floating point data type.
831
            If complex, the information is about its component
832
            data type.
833

834
    Returns:
835
        finfo_object:
836
            an object have the following attributes:
837

838
                * bits: int
839
                    number of bits occupied by dtype.
840
                * eps: float
841
                    difference between 1.0 and the next smallest representable
842
                    real-valued floating-point number larger than 1.0 according
843
                    to the IEEE-754 standard.
844
                * max: float
845
                    largest representable real-valued number.
846
                * min: float
847
                    smallest representable real-valued number.
848
                * smallest_normal: float
849
                    smallest positive real-valued floating-point number with
850
                    full precision.
851
                * dtype: dtype
852
                    real-valued floating-point data type.
853

854
    """
855
    if isinstance(dtype, dpt.usm_ndarray):
1✔
856
        dtype = dtype.dtype
1✔
857
    _supported_dtype([dpt.dtype(dtype)])
1✔
858
    return finfo_object(dtype)
1✔
859

860

861
def _supported_dtype(dtypes):
1✔
862
    for dtype in dtypes:
1✔
863
        if dtype.char not in "?bBhHiIlLqQefdFD":
1✔
864
            raise ValueError(f"Dpctl doesn't support dtype {dtype}.")
1✔
865
    return True
1✔
866

867

868
def isdtype(dtype, kind):
1✔
869
    """isdtype(dtype, kind)
870

871
    Returns a boolean indicating whether a provided `dtype` is
872
    of a specified data type `kind`.
873

874
    See [array API](array_api) for more information.
875

876
    [array_api]: https://data-apis.org/array-api/latest/
877
    """
878

879
    if not isinstance(dtype, np.dtype):
1!
880
        raise TypeError(f"Expected instance of `dpt.dtype`, got {dtype}")
×
881

882
    if isinstance(kind, np.dtype):
1✔
883
        return dtype == kind
1✔
884

885
    elif isinstance(kind, str):
1✔
886
        if kind == "bool":
1✔
887
            return dtype == np.dtype("bool")
1✔
888
        elif kind == "signed integer":
1✔
889
            return dtype.kind == "i"
1✔
890
        elif kind == "unsigned integer":
1✔
891
            return dtype.kind == "u"
1✔
892
        elif kind == "integral":
1✔
893
            return dtype.kind in "iu"
1✔
894
        elif kind == "real floating":
1✔
895
            return dtype.kind == "f"
1✔
896
        elif kind == "complex floating":
1✔
897
            return dtype.kind == "c"
1✔
898
        elif kind == "numeric":
1✔
899
            return dtype.kind in "iufc"
1✔
900
        else:
901
            raise ValueError(f"Unrecognized data type kind: {kind}")
1✔
902

903
    elif isinstance(kind, tuple):
1✔
904
        return any(isdtype(dtype, k) for k in kind)
1✔
905

906
    else:
907
        raise TypeError(f"Unsupported data type kind: {kind}")
1✔
908

909

910
def _default_accumulation_dtype(inp_dt, q):
1✔
911
    """Gives default output data type for given input data
912
    type `inp_dt` when accumulation is performed on queue `q`
913
    """
914
    inp_kind = inp_dt.kind
1✔
915
    if inp_kind in "bi":
1✔
916
        res_dt = dpt.dtype(ti.default_device_int_type(q))
1✔
917
        if inp_dt.itemsize > res_dt.itemsize:
1!
918
            res_dt = inp_dt
×
919
    elif inp_kind in "u":
1✔
920
        res_dt = dpt.dtype(ti.default_device_uint_type(q))
1✔
921
        res_ii = dpt.iinfo(res_dt)
1✔
922
        inp_ii = dpt.iinfo(inp_dt)
1✔
923
        if inp_ii.min >= res_ii.min and inp_ii.max <= res_ii.max:
1!
924
            pass
1✔
925
        else:
926
            res_dt = inp_dt
×
927
    elif inp_kind in "fc":
1!
928
        res_dt = inp_dt
1✔
929

930
    return res_dt
1✔
931

932

933
def _default_accumulation_dtype_fp_types(inp_dt, q):
1✔
934
    """Gives default output data type for given input data
935
    type `inp_dt` when accumulation is performed on queue `q`
936
    and the accumulation supports only floating-point data types
937
    """
938
    inp_kind = inp_dt.kind
1✔
939
    if inp_kind in "biu":
1✔
940
        res_dt = dpt.dtype(ti.default_device_fp_type(q))
1✔
941
        can_cast_v = dpt.can_cast(inp_dt, res_dt)
1✔
942
        if not can_cast_v:
1!
943
            _fp64 = q.sycl_device.has_aspect_fp64
×
944
            res_dt = dpt.float64 if _fp64 else dpt.float32
×
945
    elif inp_kind in "f":
1✔
946
        res_dt = inp_dt
1✔
947
    elif inp_kind in "c":
1!
948
        raise ValueError("function not defined for complex types")
1✔
949

950
    return res_dt
1✔
951

952

953
__all__ = [
1✔
954
    "_find_buf_dtype",
955
    "_find_buf_dtype2",
956
    "_to_device_supported_dtype",
957
    "_acceptance_fn_default_unary",
958
    "_acceptance_fn_reciprocal",
959
    "_acceptance_fn_default_binary",
960
    "_acceptance_fn_divide",
961
    "_acceptance_fn_negative",
962
    "_acceptance_fn_subtract",
963
    "_resolve_one_strong_one_weak_types",
964
    "_resolve_one_strong_two_weak_types",
965
    "_resolve_weak_types",
966
    "_resolve_weak_types_all_py_ints",
967
    "_weak_type_num_kind",
968
    "_strong_dtype_num_kind",
969
    "can_cast",
970
    "finfo",
971
    "iinfo",
972
    "isdtype",
973
    "result_type",
974
    "WeakBooleanType",
975
    "WeakIntegralType",
976
    "WeakFloatingType",
977
    "WeakComplexType",
978
    "_default_accumulation_dtype",
979
    "_default_accumulation_dtype_fp_types",
980
    "_find_buf_dtype_in_place_op",
981
]
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc