• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

IntelPython / dpctl / 10166246008

30 Jul 2024 04:34PM UTC coverage: 86.385% (-1.2%) from 87.609%
10166246008

Pull #1732

github

web-flow
Merge d66971031 into 7b6437491
Pull Request #1732: Implements `dpctl.tensor.count_nonzero` and `dpctl.tensor.diff`

3328 of 3933 branches covered (84.62%)

Branch coverage included in aggregate %.

68 of 229 new or added lines in 5 files covered. (29.69%)

2 existing lines in 1 file now uncovered.

11272 of 12968 relevant lines covered (86.92%)

7273.92 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

92.97
/dpctl/tensor/_type_utils.py
1
#                      Data Parallel Control (dpctl)
2
#
3
# Copyright 2020-2024 Intel Corporation
4
#
5
# Licensed under the Apache License, Version 2.0 (the "License");
6
# you may not use this file except in compliance with the License.
7
# You may obtain a copy of the License at
8
#
9
#    http://www.apache.org/licenses/LICENSE-2.0
10
#
11
# Unless required by applicable law or agreed to in writing, software
12
# distributed under the License is distributed on an "AS IS" BASIS,
13
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
# See the License for the specific language governing permissions and
15
# limitations under the License.
16
from __future__ import annotations
1✔
17

18
import numpy as np
1✔
19

20
import dpctl.tensor as dpt
1✔
21
import dpctl.tensor._tensor_impl as ti
1✔
22

23

24
def _all_data_types(_fp16, _fp64):
1✔
25
    _non_fp_types = [
1✔
26
        dpt.bool,
27
        dpt.int8,
28
        dpt.uint8,
29
        dpt.int16,
30
        dpt.uint16,
31
        dpt.int32,
32
        dpt.uint32,
33
        dpt.int64,
34
        dpt.uint64,
35
    ]
36
    if _fp64:
1✔
37
        if _fp16:
1✔
38
            return _non_fp_types + [
1✔
39
                dpt.float16,
40
                dpt.float32,
41
                dpt.float64,
42
                dpt.complex64,
43
                dpt.complex128,
44
            ]
45
        else:
46
            return _non_fp_types + [
1✔
47
                dpt.float32,
48
                dpt.float64,
49
                dpt.complex64,
50
                dpt.complex128,
51
            ]
52
    else:
53
        if _fp16:
1✔
54
            return _non_fp_types + [
1✔
55
                dpt.float16,
56
                dpt.float32,
57
                dpt.complex64,
58
            ]
59
        else:
60
            return _non_fp_types + [
1✔
61
                dpt.float32,
62
                dpt.complex64,
63
            ]
64

65

66
def _is_maximal_inexact_type(dt: dpt.dtype, _fp16: bool, _fp64: bool):
1✔
67
    """
68
    Return True if data type `dt` is the
69
    maximal size inexact data type
70
    """
71
    if _fp64:
1✔
72
        return dt in [dpt.float64, dpt.complex128]
1✔
73
    return dt in [dpt.float32, dpt.complex64]
1✔
74

75

76
def _dtype_supported_by_device_impl(
1✔
77
    dt: dpt.dtype, has_fp16: bool, has_fp64: bool
78
) -> bool:
79
    if has_fp64:
1✔
80
        if not has_fp16:
1✔
81
            if dt is dpt.float16:
1✔
82
                return False
1✔
83
    else:
84
        if dt is dpt.float64:
1✔
85
            return False
1✔
86
        elif dt is dpt.complex128:
1✔
87
            return False
1✔
88
        if not has_fp16 and dt is dpt.float16:
1✔
89
            return False
1✔
90
    return True
1✔
91

92

93
def _can_cast(
1✔
94
    from_: dpt.dtype, to_: dpt.dtype, _fp16: bool, _fp64: bool, casting="safe"
95
) -> bool:
96
    """
97
    Can `from_` be cast to `to_` safely on a device with
98
    fp16 and fp64 aspects as given?
99
    """
100
    if not _dtype_supported_by_device_impl(to_, _fp16, _fp64):
1✔
101
        return False
1✔
102
    can_cast_v = np.can_cast(from_, to_, casting=casting)  # ask NumPy
1✔
103
    if _fp16 and _fp64:
1✔
104
        return can_cast_v
1✔
105
    if not can_cast_v:
1✔
106
        if (
1✔
107
            from_.kind in "biu"
108
            and to_.kind in "fc"
109
            and _is_maximal_inexact_type(to_, _fp16, _fp64)
110
        ):
111
            return True
1✔
112

113
    return can_cast_v
1✔
114

115

116
def _to_device_supported_dtype_impl(dt, has_fp16, has_fp64):
1✔
117
    if has_fp64:
1✔
118
        if not has_fp16:
1✔
119
            if dt is dpt.float16:
1✔
120
                return dpt.float32
1✔
121
    else:
122
        if dt is dpt.float64:
1✔
123
            return dpt.float32
1✔
124
        elif dt is dpt.complex128:
1✔
125
            return dpt.complex64
1✔
126
        if not has_fp16 and dt is dpt.float16:
1✔
127
            return dpt.float32
1✔
128
    return dt
1✔
129

130

131
def _to_device_supported_dtype(dt, dev):
1✔
132
    has_fp16 = dev.has_aspect_fp16
1✔
133
    has_fp64 = dev.has_aspect_fp64
1✔
134

135
    return _to_device_supported_dtype_impl(dt, has_fp16, has_fp64)
1✔
136

137

138
def _acceptance_fn_default_unary(arg_dtype, ret_buf_dt, res_dt, sycl_dev):
1✔
139
    return True
1✔
140

141

142
def _acceptance_fn_reciprocal(arg_dtype, buf_dt, res_dt, sycl_dev):
1✔
143
    # if the kind of result is different from
144
    # the kind of input, use the default data
145
    # we use default dtype for the resulting kind.
146
    # This guarantees alignment of reciprocal and
147
    # divide output types.
148
    if buf_dt.kind != arg_dtype.kind:
1✔
149
        default_dt = _get_device_default_dtype(res_dt.kind, sycl_dev)
1✔
150
        if res_dt == default_dt:
1✔
151
            return True
1✔
152
        else:
153
            return False
1✔
154
    else:
155
        return True
1✔
156

157

158
def _acceptance_fn_negative(arg_dtype, buf_dt, res_dt, sycl_dev):
1✔
159
    # negative is not defined for boolean data type
160
    if arg_dtype.char == "?":
1✔
161
        raise ValueError(
1✔
162
            "The `negative` function, the `-` operator, is not supported "
163
            "for inputs of data type bool, use the `~` operator or the "
164
            "`logical_not` function instead"
165
        )
166
    else:
167
        return True
1✔
168

169

170
def _acceptance_fn_subtract(
1✔
171
    arg1_dtype, arg2_dtype, buf1_dt, buf2_dt, res_dt, sycl_dev
172
):
173
    # subtract is not defined for boolean data type
174
    if arg1_dtype.char == "?" and arg2_dtype.char == "?":
1✔
175
        raise ValueError(
1✔
176
            "The `subtract` function, the `-` operator, is not supported "
177
            "for inputs of data type bool, use the `^` operator,  the "
178
            "`bitwise_xor`, or the `logical_xor` function instead"
179
        )
180
    else:
181
        return True
1✔
182

183

184
def _find_buf_dtype(arg_dtype, query_fn, sycl_dev, acceptance_fn):
1✔
185
    res_dt = query_fn(arg_dtype)
1✔
186
    if res_dt:
1✔
187
        return None, res_dt
1✔
188

189
    _fp16 = sycl_dev.has_aspect_fp16
1✔
190
    _fp64 = sycl_dev.has_aspect_fp64
1✔
191
    all_dts = _all_data_types(_fp16, _fp64)
1✔
192
    for buf_dt in all_dts:
1✔
193
        if _can_cast(arg_dtype, buf_dt, _fp16, _fp64):
1✔
194
            res_dt = query_fn(buf_dt)
1✔
195
            if res_dt:
1✔
196
                acceptable = acceptance_fn(arg_dtype, buf_dt, res_dt, sycl_dev)
1✔
197
                if acceptable:
1✔
198
                    return buf_dt, res_dt
1✔
199
                else:
200
                    continue
1✔
201

202
    return None, None
1✔
203

204

205
def _get_device_default_dtype(dt_kind, sycl_dev):
1✔
206
    if dt_kind == "b":
1✔
207
        return dpt.dtype(ti.default_device_bool_type(sycl_dev))
1✔
208
    elif dt_kind == "i":
1✔
209
        return dpt.dtype(ti.default_device_int_type(sycl_dev))
1✔
210
    elif dt_kind == "u":
1✔
211
        return dpt.dtype(ti.default_device_int_type(sycl_dev).upper())
1✔
212
    elif dt_kind == "f":
1✔
213
        return dpt.dtype(ti.default_device_fp_type(sycl_dev))
1✔
214
    elif dt_kind == "c":
1✔
215
        return dpt.dtype(ti.default_device_complex_type(sycl_dev))
1✔
216
    raise RuntimeError
1✔
217

218

219
def _acceptance_fn_default_binary(
1✔
220
    arg1_dtype, arg2_dtype, ret_buf1_dt, ret_buf2_dt, res_dt, sycl_dev
221
):
222
    return True
1✔
223

224

225
def _acceptance_fn_divide(
1✔
226
    arg1_dtype, arg2_dtype, ret_buf1_dt, ret_buf2_dt, res_dt, sycl_dev
227
):
228
    # both are being promoted, if the kind of result is
229
    # different than the kind of original input dtypes,
230
    # we use default dtype for the resulting kind.
231
    # This covers, e.g. (array_dtype_i1 / array_dtype_u1)
232
    # result of which in divide is double (in NumPy), but
233
    # regular type promotion rules peg at float16
234
    if (ret_buf1_dt.kind != arg1_dtype.kind) and (
1✔
235
        ret_buf2_dt.kind != arg2_dtype.kind
236
    ):
237
        default_dt = _get_device_default_dtype(res_dt.kind, sycl_dev)
1✔
238
        if res_dt == default_dt:
1✔
239
            return True
1✔
240
        else:
241
            return False
1✔
242
    else:
243
        return True
1✔
244

245

246
def _find_buf_dtype2(arg1_dtype, arg2_dtype, query_fn, sycl_dev, acceptance_fn):
1✔
247
    res_dt = query_fn(arg1_dtype, arg2_dtype)
1✔
248
    if res_dt:
1✔
249
        return None, None, res_dt
1✔
250

251
    _fp16 = sycl_dev.has_aspect_fp16
1✔
252
    _fp64 = sycl_dev.has_aspect_fp64
1✔
253
    all_dts = _all_data_types(_fp16, _fp64)
1✔
254
    for buf1_dt in all_dts:
1✔
255
        for buf2_dt in all_dts:
1✔
256
            if _can_cast(arg1_dtype, buf1_dt, _fp16, _fp64) and _can_cast(
1✔
257
                arg2_dtype, buf2_dt, _fp16, _fp64
258
            ):
259
                res_dt = query_fn(buf1_dt, buf2_dt)
1✔
260
                if res_dt:
1✔
261
                    ret_buf1_dt = None if buf1_dt == arg1_dtype else buf1_dt
1✔
262
                    ret_buf2_dt = None if buf2_dt == arg2_dtype else buf2_dt
1✔
263
                    if ret_buf1_dt is None or ret_buf2_dt is None:
1✔
264
                        return ret_buf1_dt, ret_buf2_dt, res_dt
1✔
265
                    else:
266
                        acceptable = acceptance_fn(
1✔
267
                            arg1_dtype,
268
                            arg2_dtype,
269
                            ret_buf1_dt,
270
                            ret_buf2_dt,
271
                            res_dt,
272
                            sycl_dev,
273
                        )
274
                        if acceptable:
1✔
275
                            return ret_buf1_dt, ret_buf2_dt, res_dt
1✔
276
                        else:
277
                            continue
1✔
278

279
    return None, None, None
1✔
280

281

282
class WeakBooleanType:
1✔
283
    "Python type representing type of Python boolean objects"
284

285
    def __init__(self, o):
1✔
286
        self.o_ = o
1✔
287

288
    def get(self):
1✔
289
        return self.o_
1✔
290

291

292
class WeakIntegralType:
1✔
293
    "Python type representing type of Python integral objects"
294

295
    def __init__(self, o):
1✔
296
        self.o_ = o
1✔
297

298
    def get(self):
1✔
299
        return self.o_
1✔
300

301

302
class WeakFloatingType:
1✔
303
    """Python type representing type of Python floating point objects"""
304

305
    def __init__(self, o):
1✔
306
        self.o_ = o
1✔
307

308
    def get(self):
1✔
309
        return self.o_
1✔
310

311

312
class WeakComplexType:
1✔
313
    """Python type representing type of Python complex floating point objects"""
314

315
    def __init__(self, o):
1✔
316
        self.o_ = o
1✔
317

318
    def get(self):
1✔
319
        return self.o_
1✔
320

321

322
def _weak_type_num_kind(o):
1✔
323
    _map = {"?": 0, "i": 1, "f": 2, "c": 3}
1✔
324
    if isinstance(o, WeakBooleanType):
1✔
325
        return _map["?"]
1✔
326
    if isinstance(o, WeakIntegralType):
1✔
327
        return _map["i"]
1✔
328
    if isinstance(o, WeakFloatingType):
1✔
329
        return _map["f"]
1✔
330
    if isinstance(o, WeakComplexType):
1✔
331
        return _map["c"]
1✔
332
    raise TypeError(
1✔
333
        f"Unexpected type {o} while expecting "
334
        "`WeakBooleanType`, `WeakIntegralType`,"
335
        "`WeakFloatingType`, or `WeakComplexType`."
336
    )
337

338

339
def _strong_dtype_num_kind(o):
1✔
340
    _map = {"b": 0, "i": 1, "u": 1, "f": 2, "c": 3}
1✔
341
    if not isinstance(o, dpt.dtype):
1✔
342
        raise TypeError
1✔
343
    k = o.kind
1✔
344
    if k in _map:
1✔
345
        return _map[k]
1✔
346
    raise ValueError(f"Unrecognized kind {k} for dtype {o}")
1✔
347

348

349
def _is_weak_dtype(dtype):
1✔
350
    return isinstance(
1✔
351
        dtype,
352
        (WeakBooleanType, WeakIntegralType, WeakFloatingType, WeakComplexType),
353
    )
354

355

356
def _resolve_weak_types(o1_dtype, o2_dtype, dev):
1✔
357
    "Resolves weak data type per NEP-0050"
358
    if _is_weak_dtype(o1_dtype):
1✔
359
        if _is_weak_dtype(o2_dtype):
1✔
360
            raise ValueError
1✔
361
        o1_kind_num = _weak_type_num_kind(o1_dtype)
1✔
362
        o2_kind_num = _strong_dtype_num_kind(o2_dtype)
1✔
363
        if o1_kind_num > o2_kind_num:
1✔
364
            if isinstance(o1_dtype, WeakIntegralType):
1✔
365
                return dpt.dtype(ti.default_device_int_type(dev)), o2_dtype
1✔
366
            if isinstance(o1_dtype, WeakComplexType):
1✔
367
                if o2_dtype is dpt.float16 or o2_dtype is dpt.float32:
1✔
368
                    return dpt.complex64, o2_dtype
1✔
369
                return (
1✔
370
                    _to_device_supported_dtype(dpt.complex128, dev),
371
                    o2_dtype,
372
                )
373
            return _to_device_supported_dtype(dpt.float64, dev), o2_dtype
1✔
374
        else:
375
            return o2_dtype, o2_dtype
1✔
376
    elif _is_weak_dtype(o2_dtype):
1✔
377
        o1_kind_num = _strong_dtype_num_kind(o1_dtype)
1✔
378
        o2_kind_num = _weak_type_num_kind(o2_dtype)
1✔
379
        if o2_kind_num > o1_kind_num:
1✔
380
            if isinstance(o2_dtype, WeakIntegralType):
1✔
381
                return o1_dtype, dpt.dtype(ti.default_device_int_type(dev))
1✔
382
            if isinstance(o2_dtype, WeakComplexType):
1✔
383
                if o1_dtype is dpt.float16 or o1_dtype is dpt.float32:
1✔
384
                    return o1_dtype, dpt.complex64
1✔
385
                return o1_dtype, _to_device_supported_dtype(dpt.complex128, dev)
1✔
386
            return (
1✔
387
                o1_dtype,
388
                _to_device_supported_dtype(dpt.float64, dev),
389
            )
390
        else:
391
            return o1_dtype, o1_dtype
1✔
392
    else:
393
        return o1_dtype, o2_dtype
1✔
394

395

396
def _resolve_weak_types_comparisons(o1_dtype, o2_dtype, dev):
1✔
397
    "Resolves weak data type per NEP-0050 for comparisons,"
398
    "where result type is known to be `bool` and special behavior"
1✔
399
    "is needed to handle mixed integer kinds"
1✔
400
    if _is_weak_dtype(o1_dtype):
1✔
401
        if _is_weak_dtype(o2_dtype):
1!
402
            raise ValueError
×
403
        o1_kind_num = _weak_type_num_kind(o1_dtype)
1✔
404
        o2_kind_num = _strong_dtype_num_kind(o2_dtype)
1✔
405
        if o1_kind_num > o2_kind_num:
1✔
406
            if isinstance(o1_dtype, WeakIntegralType):
1✔
407
                return dpt.dtype(ti.default_device_int_type(dev)), o2_dtype
1✔
408
            if isinstance(o1_dtype, WeakComplexType):
1✔
409
                if o2_dtype is dpt.float16 or o2_dtype is dpt.float32:
1✔
410
                    return dpt.complex64, o2_dtype
1✔
411
                return (
1✔
412
                    _to_device_supported_dtype(dpt.complex128, dev),
413
                    o2_dtype,
414
                )
415
            return _to_device_supported_dtype(dpt.float64, dev), o2_dtype
1✔
416
        else:
417
            if isinstance(o1_dtype, WeakIntegralType):
1✔
418
                if o2_dtype.kind == "u":
1✔
419
                    # Python scalar may be negative, assumes mixed int loops
420
                    # exist
421
                    return dpt.dtype(ti.default_device_int_type(dev)), o2_dtype
1✔
422
            return o2_dtype, o2_dtype
1✔
423
    elif _is_weak_dtype(o2_dtype):
1✔
424
        o1_kind_num = _strong_dtype_num_kind(o1_dtype)
1✔
425
        o2_kind_num = _weak_type_num_kind(o2_dtype)
1✔
426
        if o2_kind_num > o1_kind_num:
1✔
427
            if isinstance(o2_dtype, WeakIntegralType):
1✔
428
                return o1_dtype, dpt.dtype(ti.default_device_int_type(dev))
1✔
429
            if isinstance(o2_dtype, WeakComplexType):
1✔
430
                if o1_dtype is dpt.float16 or o1_dtype is dpt.float32:
1✔
431
                    return o1_dtype, dpt.complex64
1✔
432
                return o1_dtype, _to_device_supported_dtype(dpt.complex128, dev)
1✔
433
            return (
1✔
434
                o1_dtype,
435
                _to_device_supported_dtype(dpt.float64, dev),
436
            )
437
        else:
438
            if isinstance(o2_dtype, WeakIntegralType):
1✔
439
                if o1_dtype.kind == "u":
1✔
440
                    # Python scalar may be negative, assumes mixed int loops
441
                    # exist
442
                    return o1_dtype, dpt.dtype(ti.default_device_int_type(dev))
1✔
443
            return o1_dtype, o1_dtype
1✔
444
    else:
445
        return o1_dtype, o2_dtype
1✔
446

447

448
def _resolve_one_strong_two_weak_types(st_dtype, dtype1, dtype2, dev):
1✔
449
    "Resolves weak data types per NEP-0050,"
450
    "where the second and third arguments are"
1✔
451
    "permitted to be weak types"
1✔
452
    if _is_weak_dtype(st_dtype):
1!
NEW
453
        raise ValueError
×
454
    if _is_weak_dtype(dtype1):
1✔
455
        if _is_weak_dtype(dtype2):
1✔
456
            kind_num1 = _weak_type_num_kind(dtype1)
1✔
457
            kind_num2 = _weak_type_num_kind(dtype2)
1✔
458
            st_kind_num = _strong_dtype_num_kind(st_dtype)
1✔
459

460
            if kind_num1 > st_kind_num:
1✔
461
                if isinstance(dtype1, WeakIntegralType):
1✔
462
                    ret_dtype1 = dpt.dtype(ti.default_device_int_type(dev))
1✔
463
                elif isinstance(dtype1, WeakComplexType):
1✔
464
                    if st_dtype is dpt.float16 or st_dtype is dpt.float32:
1✔
465
                        ret_dtype1 = dpt.complex64
1✔
466
                    ret_dtype1 = _to_device_supported_dtype(dpt.complex128, dev)
1✔
467
                else:
468
                    ret_dtype1 = _to_device_supported_dtype(dpt.float64, dev)
1✔
469
            else:
470
                ret_dtype1 = st_dtype
1✔
471

472
            if kind_num2 > st_kind_num:
1✔
473
                if isinstance(dtype2, WeakIntegralType):
1✔
474
                    ret_dtype2 = dpt.dtype(ti.default_device_int_type(dev))
1✔
475
                elif isinstance(dtype2, WeakComplexType):
1✔
476
                    if st_dtype is dpt.float16 or st_dtype is dpt.float32:
1✔
477
                        ret_dtype2 = dpt.complex64
1✔
478
                    ret_dtype2 = _to_device_supported_dtype(dpt.complex128, dev)
1✔
479
                else:
480
                    ret_dtype2 = _to_device_supported_dtype(dpt.float64, dev)
1✔
481
            else:
482
                ret_dtype2 = st_dtype
1✔
483

484
            return ret_dtype1, ret_dtype2
1✔
485

486
        max_dt_num_kind, max_dtype = max(
1✔
487
            [
488
                (_strong_dtype_num_kind(st_dtype), st_dtype),
489
                (_strong_dtype_num_kind(dtype2), dtype2),
490
            ]
491
        )
492
        dt1_kind_num = _weak_type_num_kind(dtype1)
1✔
493
        if dt1_kind_num > max_dt_num_kind:
1✔
494
            if isinstance(dtype1, WeakIntegralType):
1!
NEW
495
                return dpt.dtype(ti.default_device_int_type(dev)), dtype2
×
496
            if isinstance(dtype1, WeakComplexType):
1!
NEW
497
                if max_dtype is dpt.float16 or max_dtype is dpt.float32:
×
NEW
498
                    return dpt.complex64, dtype2
×
NEW
499
                return (
×
500
                    _to_device_supported_dtype(dpt.complex128, dev),
501
                    dtype2,
502
                )
503
            return _to_device_supported_dtype(dpt.float64, dev), dtype2
1✔
504
        else:
505
            return max_dtype, dtype2
1✔
506
    elif _is_weak_dtype(dtype2):
1✔
507
        max_dt_num_kind, max_dtype = max(
1✔
508
            [
509
                (_strong_dtype_num_kind(st_dtype), st_dtype),
510
                (_strong_dtype_num_kind(dtype1), dtype1),
511
            ]
512
        )
513
        dt2_kind_num = _weak_type_num_kind(dtype2)
1✔
514
        if dt2_kind_num > max_dt_num_kind:
1!
515
            if isinstance(dtype2, WeakIntegralType):
1!
NEW
516
                return dtype1, dpt.dtype(ti.default_device_int_type(dev))
×
517
            if isinstance(dtype2, WeakComplexType):
1!
NEW
518
                if max_dtype is dpt.float16 or max_dtype is dpt.float32:
×
NEW
519
                    return dtype1, dpt.complex64
×
NEW
520
                return (
×
521
                    dtype1,
522
                    _to_device_supported_dtype(dpt.complex128, dev),
523
                )
524
            return dtype1, _to_device_supported_dtype(dpt.float64, dev)
1✔
525
        else:
NEW
526
            return dtype1, max_dtype
×
527
    else:
528
        # both are strong dtypes
529
        # return unmodified
530
        return dtype1, dtype2
1✔
531

532

533
def _resolve_one_strong_one_weak_types(st_dtype, dtype, dev):
1✔
534
    "Resolves one weak data type with one strong data type per NEP-0050"
535
    if _is_weak_dtype(st_dtype):
1!
NEW
536
        raise ValueError
×
537
    if _is_weak_dtype(dtype):
1✔
538
        st_kind_num = _strong_dtype_num_kind(st_dtype)
1✔
539
        kind_num = _weak_type_num_kind(dtype)
1✔
540
        if kind_num > st_kind_num:
1✔
541
            if isinstance(dtype, WeakIntegralType):
1✔
542
                return dpt.dtype(ti.default_device_int_type(dev))
1✔
543
            if isinstance(dtype, WeakComplexType):
1✔
544
                if st_dtype is dpt.float16 or st_dtype is dpt.float32:
1✔
545
                    return dpt.complex64
1✔
546
                return _to_device_supported_dtype(dpt.complex128, dev)
1✔
547
            return _to_device_supported_dtype(dpt.float64, dev)
1✔
548
        else:
549
            return st_dtype
1✔
550
    else:
551
        return dtype
1✔
552

553

554
class finfo_object:
1✔
555
    """
556
    `numpy.finfo` subclass which returns Python floating-point scalars for
557
    `eps`, `max`, `min`, and `smallest_normal` attributes.
558
    """
559

560
    def __init__(self, dtype):
1✔
561
        _supported_dtype([dpt.dtype(dtype)])
1✔
562
        self._finfo = np.finfo(dtype)
1✔
563

564
    @property
1✔
565
    def bits(self):
1✔
566
        """
567
        number of bits occupied by the real-valued floating-point data type.
568
        """
569
        return int(self._finfo.bits)
1✔
570

571
    @property
1✔
572
    def smallest_normal(self):
1✔
573
        """
574
        smallest positive real-valued floating-point number with full
575
        precision.
576
        """
577
        return float(self._finfo.smallest_normal)
1✔
578

579
    @property
1✔
580
    def tiny(self):
1✔
581
        """an alias for `smallest_normal`"""
582
        return float(self._finfo.tiny)
1✔
583

584
    @property
1✔
585
    def eps(self):
1✔
586
        """
587
        difference between 1.0 and the next smallest representable real-valued
588
        floating-point number larger than 1.0 according to the IEEE-754
589
        standard.
590
        """
591
        return float(self._finfo.eps)
1✔
592

593
    @property
1✔
594
    def epsneg(self):
1✔
595
        """
596
        difference between 1.0 and the next smallest representable real-valued
597
        floating-point number smaller than 1.0 according to the IEEE-754
598
        standard.
599
        """
600
        return float(self._finfo.epsneg)
1✔
601

602
    @property
1✔
603
    def min(self):
1✔
604
        """smallest representable real-valued number."""
605
        return float(self._finfo.min)
1✔
606

607
    @property
1✔
608
    def max(self):
1✔
609
        "largest representable real-valued number."
610
        return float(self._finfo.max)
1✔
611

612
    @property
1✔
613
    def resolution(self):
1✔
614
        "the approximate decimal resolution of this type."
615
        return float(self._finfo.resolution)
1✔
616

617
    @property
1✔
618
    def precision(self):
1✔
619
        """
620
        the approximate number of decimal digits to which this kind of
621
        floating point type is precise.
622
        """
623
        return float(self._finfo.precision)
1✔
624

625
    @property
1✔
626
    def dtype(self):
1✔
627
        """
628
        the dtype for which finfo returns information. For complex input, the
629
        returned dtype is the associated floating point dtype for its real and
630
        complex components.
631
        """
632
        return self._finfo.dtype
1✔
633

634
    def __str__(self):
1✔
635
        return self._finfo.__str__()
1✔
636

637
    def __repr__(self):
1✔
638
        return self._finfo.__repr__()
1✔
639

640

641
def can_cast(from_, to, /, *, casting="safe") -> bool:
1✔
642
    """ can_cast(from, to, casting="safe")
643

644
    Determines if one data type can be cast to another data type according \
645
    to Type Promotion Rules.
646

647
    Args:
648
       from_ (Union[usm_ndarray, dtype]):
649
           source data type. If `from_` is an array, a device-specific type
650
           promotion rules apply.
651
       to (dtype):
652
           target data type
653
       casting (Optional[str]):
654
            controls what kind of data casting may occur.
655

656
                * "no" means data types should not be cast at all.
657
                * "safe" means only casts that preserve values are allowed.
658
                * "same_kind" means only safe casts and casts within a kind,
659
                  like `float64` to `float32`, are allowed.
660
                * "unsafe" means any data conversion can be done.
661

662
            Default: `"safe"`.
663

664
    Returns:
665
        bool:
666
            Gives `True` if cast can occur according to the casting rule.
667

668
    Device-specific type promotion rules take into account which data type are
669
    and are not supported by a specific device.
670
    """
671
    if isinstance(to, dpt.usm_ndarray):
1✔
672
        raise TypeError(f"Expected `dpt.dtype` type, got {type(to)}.")
1✔
673

674
    dtype_to = dpt.dtype(to)
1✔
675
    _supported_dtype([dtype_to])
1✔
676

677
    if isinstance(from_, dpt.usm_ndarray):
1✔
678
        dtype_from = from_.dtype
1✔
679
        return _can_cast(
1✔
680
            dtype_from,
681
            dtype_to,
682
            from_.sycl_device.has_aspect_fp16,
683
            from_.sycl_device.has_aspect_fp64,
684
            casting=casting,
685
        )
686
    else:
687
        dtype_from = dpt.dtype(from_)
1✔
688
        _supported_dtype([dtype_from])
1✔
689
        # query casting as if all dtypes are supported
690
        return _can_cast(dtype_from, dtype_to, True, True, casting=casting)
1✔
691

692

693
def result_type(*arrays_and_dtypes):
1✔
694
    """
695
    result_type(*arrays_and_dtypes)
696

697
    Returns the dtype that results from applying the Type Promotion Rules to \
698
        the arguments.
699

700
    Args:
701
        arrays_and_dtypes (Union[usm_ndarray, dtype]):
702
            An arbitrary length sequence of usm_ndarray objects or dtypes.
703

704
    Returns:
705
        dtype:
706
            The dtype resulting from an operation involving the
707
            input arrays and dtypes.
708
    """
709
    dtypes = []
1✔
710
    devices = []
1✔
711
    weak_dtypes = []
1✔
712
    for arg_i in arrays_and_dtypes:
1✔
713
        if isinstance(arg_i, dpt.usm_ndarray):
1✔
714
            devices.append(arg_i.sycl_device)
1✔
715
            dtypes.append(arg_i.dtype)
1✔
716
        elif isinstance(arg_i, int):
1✔
717
            weak_dtypes.append(WeakIntegralType(arg_i))
1✔
718
        elif isinstance(arg_i, float):
1✔
719
            weak_dtypes.append(WeakFloatingType(arg_i))
1✔
720
        elif isinstance(arg_i, complex):
1✔
721
            weak_dtypes.append(WeakComplexType(arg_i))
1✔
722
        elif isinstance(arg_i, bool):
1!
723
            weak_dtypes.append(WeakBooleanType(arg_i))
×
724
        else:
725
            dt = dpt.dtype(arg_i)
1✔
726
            _supported_dtype([dt])
1✔
727
            dtypes.append(dt)
1✔
728

729
    has_fp16 = True
1✔
730
    has_fp64 = True
1✔
731
    target_dev = None
1✔
732
    if devices:
1✔
733
        inspected = False
1✔
734
        for d in devices:
1✔
735
            if inspected:
1✔
736
                unsame_fp16_support = d.has_aspect_fp16 != has_fp16
1✔
737
                unsame_fp64_support = d.has_aspect_fp64 != has_fp64
1✔
738
                if unsame_fp16_support or unsame_fp64_support:
1!
739
                    raise ValueError(
×
740
                        "Input arrays reside on devices "
741
                        "with different device supports; "
742
                        "unable to determine which "
743
                        "device-specific type promotion rules "
744
                        "to use."
745
                    )
746
            else:
747
                has_fp16 = d.has_aspect_fp16
1✔
748
                has_fp64 = d.has_aspect_fp64
1✔
749
                target_dev = d
1✔
750
                inspected = True
1✔
751

752
    if not (has_fp16 and has_fp64):
1!
753
        for dt in dtypes:
×
754
            if not _dtype_supported_by_device_impl(dt, has_fp16, has_fp64):
×
755
                raise ValueError(
×
756
                    f"Argument {dt} is not supported by the device"
757
                )
758
        res_dt = np.result_type(*dtypes)
×
759
        res_dt = _to_device_supported_dtype_impl(res_dt, has_fp16, has_fp64)
×
760
        for wdt in weak_dtypes:
×
761
            pair = _resolve_weak_types(wdt, res_dt, target_dev)
×
762
            res_dt = np.result_type(*pair)
×
763
            res_dt = _to_device_supported_dtype_impl(res_dt, has_fp16, has_fp64)
×
764
    else:
765
        res_dt = np.result_type(*dtypes)
1✔
766
        if weak_dtypes:
1✔
767
            weak_dt_obj = [wdt.get() for wdt in weak_dtypes]
1✔
768
            res_dt = np.result_type(res_dt, *weak_dt_obj)
1✔
769

770
    return res_dt
1✔
771

772

773
def iinfo(dtype, /):
1✔
774
    """iinfo(dtype)
775

776
    Returns machine limits for integer data types.
777

778
    Args:
779
        dtype (dtype, usm_ndarray):
780
            integer dtype or
781
            an array with integer dtype.
782

783
    Returns:
784
        iinfo_object:
785
            An object with the following attributes:
786

787
            * bits: int
788
                number of bits occupied by the data type
789
            * max: int
790
                largest representable number.
791
            * min: int
792
                smallest representable number.
793
            * dtype: dtype
794
                integer data type.
795
    """
796
    if isinstance(dtype, dpt.usm_ndarray):
1✔
797
        dtype = dtype.dtype
1✔
798
    _supported_dtype([dpt.dtype(dtype)])
1✔
799
    return np.iinfo(dtype)
1✔
800

801

802
def finfo(dtype, /):
1✔
803
    """finfo(type)
804

805
    Returns machine limits for floating-point data types.
806

807
    Args:
808
        dtype (dtype, usm_ndarray): floating-point dtype or
809
            an array with floating point data type.
810
            If complex, the information is about its component
811
            data type.
812

813
    Returns:
814
        finfo_object:
815
            an object have the following attributes:
816

817
                * bits: int
818
                    number of bits occupied by dtype.
819
                * eps: float
820
                    difference between 1.0 and the next smallest representable
821
                    real-valued floating-point number larger than 1.0 according
822
                    to the IEEE-754 standard.
823
                * max: float
824
                    largest representable real-valued number.
825
                * min: float
826
                    smallest representable real-valued number.
827
                * smallest_normal: float
828
                    smallest positive real-valued floating-point number with
829
                    full precision.
830
                * dtype: dtype
831
                    real-valued floating-point data type.
832

833
    """
834
    if isinstance(dtype, dpt.usm_ndarray):
1✔
835
        dtype = dtype.dtype
1✔
836
    _supported_dtype([dpt.dtype(dtype)])
1✔
837
    return finfo_object(dtype)
1✔
838

839

840
def _supported_dtype(dtypes):
1✔
841
    for dtype in dtypes:
1✔
842
        if dtype.char not in "?bBhHiIlLqQefdFD":
1✔
843
            raise ValueError(f"Dpctl doesn't support dtype {dtype}.")
1✔
844
    return True
1✔
845

846

847
def isdtype(dtype, kind):
1✔
848
    """isdtype(dtype, kind)
849

850
    Returns a boolean indicating whether a provided `dtype` is
851
    of a specified data type `kind`.
852

853
    See [array API](array_api) for more information.
854

855
    [array_api]: https://data-apis.org/array-api/latest/
856
    """
857

858
    if not isinstance(dtype, np.dtype):
1!
859
        raise TypeError(f"Expected instance of `dpt.dtype`, got {dtype}")
×
860

861
    if isinstance(kind, np.dtype):
1✔
862
        return dtype == kind
1✔
863

864
    elif isinstance(kind, str):
1✔
865
        if kind == "bool":
1✔
866
            return dtype == np.dtype("bool")
1✔
867
        elif kind == "signed integer":
1✔
868
            return dtype.kind == "i"
1✔
869
        elif kind == "unsigned integer":
1✔
870
            return dtype.kind == "u"
1✔
871
        elif kind == "integral":
1✔
872
            return dtype.kind in "iu"
1✔
873
        elif kind == "real floating":
1✔
874
            return dtype.kind == "f"
1✔
875
        elif kind == "complex floating":
1✔
876
            return dtype.kind == "c"
1✔
877
        elif kind == "numeric":
1✔
878
            return dtype.kind in "iufc"
1✔
879
        else:
880
            raise ValueError(f"Unrecognized data type kind: {kind}")
1✔
881

882
    elif isinstance(kind, tuple):
1✔
883
        return any(isdtype(dtype, k) for k in kind)
1✔
884

885
    else:
886
        raise TypeError(f"Unsupported data type kind: {kind}")
1✔
887

888

889
def _default_accumulation_dtype(inp_dt, q):
1✔
890
    """Gives default output data type for given input data
891
    type `inp_dt` when accumulation is performed on queue `q`
892
    """
893
    inp_kind = inp_dt.kind
1✔
894
    if inp_kind in "bi":
1✔
895
        res_dt = dpt.dtype(ti.default_device_int_type(q))
1✔
896
        if inp_dt.itemsize > res_dt.itemsize:
1!
897
            res_dt = inp_dt
×
898
    elif inp_kind in "u":
1✔
899
        res_dt = dpt.dtype(ti.default_device_int_type(q).upper())
1✔
900
        res_ii = dpt.iinfo(res_dt)
1✔
901
        inp_ii = dpt.iinfo(inp_dt)
1✔
902
        if inp_ii.min >= res_ii.min and inp_ii.max <= res_ii.max:
1!
903
            pass
1✔
904
        else:
905
            res_dt = inp_dt
×
906
    elif inp_kind in "fc":
1!
907
        res_dt = inp_dt
1✔
908

909
    return res_dt
1✔
910

911

912
def _default_accumulation_dtype_fp_types(inp_dt, q):
1✔
913
    """Gives default output data type for given input data
914
    type `inp_dt` when accumulation is performed on queue `q`
915
    and the accumulation supports only floating-point data types
916
    """
917
    inp_kind = inp_dt.kind
1✔
918
    if inp_kind in "biu":
1✔
919
        res_dt = dpt.dtype(ti.default_device_fp_type(q))
1✔
920
        can_cast_v = dpt.can_cast(inp_dt, res_dt)
1✔
921
        if not can_cast_v:
1!
922
            _fp64 = q.sycl_device.has_aspect_fp64
×
923
            res_dt = dpt.float64 if _fp64 else dpt.float32
×
924
    elif inp_kind in "f":
1✔
925
        res_dt = inp_dt
1✔
926
    elif inp_kind in "c":
1!
927
        raise ValueError("function not defined for complex types")
1✔
928

929
    return res_dt
1✔
930

931

932
__all__ = [
1✔
933
    "_find_buf_dtype",
934
    "_find_buf_dtype2",
935
    "_to_device_supported_dtype",
936
    "_acceptance_fn_default_unary",
937
    "_acceptance_fn_reciprocal",
938
    "_acceptance_fn_default_binary",
939
    "_acceptance_fn_divide",
940
    "_acceptance_fn_negative",
941
    "_acceptance_fn_subtract",
942
    "_resolve_one_strong_one_weak_types",
943
    "_resolve_one_strong_two_weak_types",
944
    "_resolve_weak_types",
945
    "_resolve_weak_types_comparisons",
946
    "_weak_type_num_kind",
947
    "_strong_dtype_num_kind",
948
    "can_cast",
949
    "finfo",
950
    "iinfo",
951
    "isdtype",
952
    "result_type",
953
    "WeakBooleanType",
954
    "WeakIntegralType",
955
    "WeakFloatingType",
956
    "WeakComplexType",
957
    "_default_accumulation_dtype",
958
    "_default_accumulation_dtype_fp_types",
959
]
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc