• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

IntelPython / dpctl / 10836421982

12 Sep 2024 06:15PM UTC coverage: 87.907% (+0.01%) from 87.893%
10836421982

Pull #1829

github

web-flow
Merge 33e6c5a8f into 8b257733c
Pull Request #1829: Specialize copy_from_numpy_into_usm_ndarray for contig case

3409 of 3922 branches covered (86.92%)

Branch coverage included in aggregate %.

11711 of 13278 relevant lines covered (88.2%)

7086.44 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

93.41
/dpctl/tensor/_type_utils.py
1
#                      Data Parallel Control (dpctl)
2
#
3
# Copyright 2020-2024 Intel Corporation
4
#
5
# Licensed under the Apache License, Version 2.0 (the "License");
6
# you may not use this file except in compliance with the License.
7
# You may obtain a copy of the License at
8
#
9
#    http://www.apache.org/licenses/LICENSE-2.0
10
#
11
# Unless required by applicable law or agreed to in writing, software
12
# distributed under the License is distributed on an "AS IS" BASIS,
13
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
# See the License for the specific language governing permissions and
15
# limitations under the License.
16
from __future__ import annotations
1✔
17

18
import numpy as np
1✔
19

20
import dpctl.tensor as dpt
1✔
21
import dpctl.tensor._tensor_impl as ti
1✔
22

23

24
def _all_data_types(_fp16, _fp64):
1✔
25
    _non_fp_types = [
1✔
26
        dpt.bool,
27
        dpt.int8,
28
        dpt.uint8,
29
        dpt.int16,
30
        dpt.uint16,
31
        dpt.int32,
32
        dpt.uint32,
33
        dpt.int64,
34
        dpt.uint64,
35
    ]
36
    if _fp64:
1✔
37
        if _fp16:
1✔
38
            return _non_fp_types + [
1✔
39
                dpt.float16,
40
                dpt.float32,
41
                dpt.float64,
42
                dpt.complex64,
43
                dpt.complex128,
44
            ]
45
        else:
46
            return _non_fp_types + [
1✔
47
                dpt.float32,
48
                dpt.float64,
49
                dpt.complex64,
50
                dpt.complex128,
51
            ]
52
    else:
53
        if _fp16:
1✔
54
            return _non_fp_types + [
1✔
55
                dpt.float16,
56
                dpt.float32,
57
                dpt.complex64,
58
            ]
59
        else:
60
            return _non_fp_types + [
1✔
61
                dpt.float32,
62
                dpt.complex64,
63
            ]
64

65

66
def _is_maximal_inexact_type(dt: dpt.dtype, _fp16: bool, _fp64: bool):
1✔
67
    """
68
    Return True if data type `dt` is the
69
    maximal size inexact data type
70
    """
71
    if _fp64:
1✔
72
        return dt in [dpt.float64, dpt.complex128]
1✔
73
    return dt in [dpt.float32, dpt.complex64]
1✔
74

75

76
def _dtype_supported_by_device_impl(
1✔
77
    dt: dpt.dtype, has_fp16: bool, has_fp64: bool
78
) -> bool:
79
    if has_fp64:
1✔
80
        if not has_fp16:
1✔
81
            if dt is dpt.float16:
1✔
82
                return False
1✔
83
    else:
84
        if dt is dpt.float64:
1✔
85
            return False
1✔
86
        elif dt is dpt.complex128:
1✔
87
            return False
1✔
88
        if not has_fp16 and dt is dpt.float16:
1✔
89
            return False
1✔
90
    return True
1✔
91

92

93
def _can_cast(
1✔
94
    from_: dpt.dtype, to_: dpt.dtype, _fp16: bool, _fp64: bool, casting="safe"
95
) -> bool:
96
    """
97
    Can `from_` be cast to `to_` safely on a device with
98
    fp16 and fp64 aspects as given?
99
    """
100
    if not _dtype_supported_by_device_impl(to_, _fp16, _fp64):
1✔
101
        return False
1✔
102
    can_cast_v = np.can_cast(from_, to_, casting=casting)  # ask NumPy
1✔
103
    if _fp16 and _fp64:
1✔
104
        return can_cast_v
1✔
105
    if not can_cast_v:
1✔
106
        if (
1✔
107
            from_.kind in "biu"
108
            and to_.kind in "fc"
109
            and _is_maximal_inexact_type(to_, _fp16, _fp64)
110
        ):
111
            return True
1✔
112

113
    return can_cast_v
1✔
114

115

116
def _to_device_supported_dtype_impl(dt, has_fp16, has_fp64):
1✔
117
    if has_fp64:
1✔
118
        if not has_fp16:
1✔
119
            if dt is dpt.float16:
1✔
120
                return dpt.float32
1✔
121
    else:
122
        if dt is dpt.float64:
1✔
123
            return dpt.float32
1✔
124
        elif dt is dpt.complex128:
1✔
125
            return dpt.complex64
1✔
126
        if not has_fp16 and dt is dpt.float16:
1✔
127
            return dpt.float32
1✔
128
    return dt
1✔
129

130

131
def _to_device_supported_dtype(dt, dev):
1✔
132
    has_fp16 = dev.has_aspect_fp16
1✔
133
    has_fp64 = dev.has_aspect_fp64
1✔
134

135
    return _to_device_supported_dtype_impl(dt, has_fp16, has_fp64)
1✔
136

137

138
def _acceptance_fn_default_unary(arg_dtype, ret_buf_dt, res_dt, sycl_dev):
1✔
139
    return True
1✔
140

141

142
def _acceptance_fn_reciprocal(arg_dtype, buf_dt, res_dt, sycl_dev):
1✔
143
    # if the kind of result is different from the kind of input, we use the
144
    # default floating-point dtype for the resulting kind. This guarantees
145
    # alignment of reciprocal and divide output types.
146
    if buf_dt.kind != arg_dtype.kind:
1✔
147
        default_dt = _get_device_default_dtype(res_dt.kind, sycl_dev)
1✔
148
        if res_dt == default_dt:
1✔
149
            return True
1✔
150
        else:
151
            return False
1✔
152
    else:
153
        return True
1✔
154

155

156
def _acceptance_fn_negative(arg_dtype, buf_dt, res_dt, sycl_dev):
1✔
157
    # negative is not defined for boolean data type
158
    if arg_dtype.char == "?":
1✔
159
        raise ValueError(
1✔
160
            "The `negative` function, the `-` operator, is not supported "
161
            "for inputs of data type bool, use the `~` operator or the "
162
            "`logical_not` function instead"
163
        )
164
    else:
165
        return True
1✔
166

167

168
def _acceptance_fn_subtract(
1✔
169
    arg1_dtype, arg2_dtype, buf1_dt, buf2_dt, res_dt, sycl_dev
170
):
171
    # subtract is not defined for boolean data type
172
    if arg1_dtype.char == "?" and arg2_dtype.char == "?":
1✔
173
        raise ValueError(
1✔
174
            "The `subtract` function, the `-` operator, is not supported "
175
            "for inputs of data type bool, use the `^` operator,  the "
176
            "`bitwise_xor`, or the `logical_xor` function instead"
177
        )
178
    else:
179
        return True
1✔
180

181

182
def _find_buf_dtype(arg_dtype, query_fn, sycl_dev, acceptance_fn):
1✔
183
    res_dt = query_fn(arg_dtype)
1✔
184
    if res_dt:
1✔
185
        return None, res_dt
1✔
186

187
    _fp16 = sycl_dev.has_aspect_fp16
1✔
188
    _fp64 = sycl_dev.has_aspect_fp64
1✔
189
    all_dts = _all_data_types(_fp16, _fp64)
1✔
190
    for buf_dt in all_dts:
1✔
191
        if _can_cast(arg_dtype, buf_dt, _fp16, _fp64):
1✔
192
            res_dt = query_fn(buf_dt)
1✔
193
            if res_dt:
1✔
194
                acceptable = acceptance_fn(arg_dtype, buf_dt, res_dt, sycl_dev)
1✔
195
                if acceptable:
1✔
196
                    return buf_dt, res_dt
1✔
197
                else:
198
                    continue
1✔
199

200
    return None, None
1✔
201

202

203
def _get_device_default_dtype(dt_kind, sycl_dev):
1✔
204
    if dt_kind == "b":
1✔
205
        return dpt.dtype(ti.default_device_bool_type(sycl_dev))
1✔
206
    elif dt_kind == "i":
1✔
207
        return dpt.dtype(ti.default_device_int_type(sycl_dev))
1✔
208
    elif dt_kind == "u":
1✔
209
        return dpt.dtype(ti.default_device_uint_type(sycl_dev))
1✔
210
    elif dt_kind == "f":
1✔
211
        return dpt.dtype(ti.default_device_fp_type(sycl_dev))
1✔
212
    elif dt_kind == "c":
1✔
213
        return dpt.dtype(ti.default_device_complex_type(sycl_dev))
1✔
214
    raise RuntimeError
1✔
215

216

217
def _acceptance_fn_default_binary(
1✔
218
    arg1_dtype, arg2_dtype, ret_buf1_dt, ret_buf2_dt, res_dt, sycl_dev
219
):
220
    return True
1✔
221

222

223
def _acceptance_fn_divide(
1✔
224
    arg1_dtype, arg2_dtype, ret_buf1_dt, ret_buf2_dt, res_dt, sycl_dev
225
):
226
    # both are being promoted, if the kind of result is
227
    # different than the kind of original input dtypes,
228
    # we use default dtype for the resulting kind.
229
    # This covers, e.g. (array_dtype_i1 / array_dtype_u1)
230
    # result of which in divide is double (in NumPy), but
231
    # regular type promotion rules peg at float16
232
    if (ret_buf1_dt.kind != arg1_dtype.kind) and (
1✔
233
        ret_buf2_dt.kind != arg2_dtype.kind
234
    ):
235
        default_dt = _get_device_default_dtype(res_dt.kind, sycl_dev)
1✔
236
        if res_dt == default_dt:
1✔
237
            return True
1✔
238
        else:
239
            return False
1✔
240
    else:
241
        return True
1✔
242

243

244
def _find_buf_dtype2(arg1_dtype, arg2_dtype, query_fn, sycl_dev, acceptance_fn):
1✔
245
    res_dt = query_fn(arg1_dtype, arg2_dtype)
1✔
246
    if res_dt:
1✔
247
        return None, None, res_dt
1✔
248

249
    _fp16 = sycl_dev.has_aspect_fp16
1✔
250
    _fp64 = sycl_dev.has_aspect_fp64
1✔
251
    all_dts = _all_data_types(_fp16, _fp64)
1✔
252
    for buf1_dt in all_dts:
1✔
253
        for buf2_dt in all_dts:
1✔
254
            if _can_cast(arg1_dtype, buf1_dt, _fp16, _fp64) and _can_cast(
1✔
255
                arg2_dtype, buf2_dt, _fp16, _fp64
256
            ):
257
                res_dt = query_fn(buf1_dt, buf2_dt)
1✔
258
                if res_dt:
1✔
259
                    ret_buf1_dt = None if buf1_dt == arg1_dtype else buf1_dt
1✔
260
                    ret_buf2_dt = None if buf2_dt == arg2_dtype else buf2_dt
1✔
261
                    if ret_buf1_dt is None or ret_buf2_dt is None:
1✔
262
                        return ret_buf1_dt, ret_buf2_dt, res_dt
1✔
263
                    else:
264
                        acceptable = acceptance_fn(
1✔
265
                            arg1_dtype,
266
                            arg2_dtype,
267
                            ret_buf1_dt,
268
                            ret_buf2_dt,
269
                            res_dt,
270
                            sycl_dev,
271
                        )
272
                        if acceptable:
1✔
273
                            return ret_buf1_dt, ret_buf2_dt, res_dt
1✔
274
                        else:
275
                            continue
1✔
276

277
    return None, None, None
1✔
278

279

280
def _find_buf_dtype_in_place_op(arg1_dtype, arg2_dtype, query_fn, sycl_dev):
1✔
281
    res_dt = query_fn(arg1_dtype, arg2_dtype)
1✔
282
    if res_dt:
1✔
283
        return None, res_dt
1✔
284

285
    _fp16 = sycl_dev.has_aspect_fp16
1✔
286
    _fp64 = sycl_dev.has_aspect_fp64
1✔
287
    if _can_cast(arg2_dtype, arg1_dtype, _fp16, _fp64, casting="same_kind"):
1✔
288
        res_dt = query_fn(arg1_dtype, arg1_dtype)
1✔
289
        if res_dt:
1✔
290
            return arg1_dtype, res_dt
1✔
291

292
    return None, None
1✔
293

294

295
class WeakBooleanType:
1✔
296
    "Python type representing type of Python boolean objects"
297

298
    def __init__(self, o):
1✔
299
        self.o_ = o
1✔
300

301
    def get(self):
1✔
302
        return self.o_
1✔
303

304

305
class WeakIntegralType:
1✔
306
    "Python type representing type of Python integral objects"
307

308
    def __init__(self, o):
1✔
309
        self.o_ = o
1✔
310

311
    def get(self):
1✔
312
        return self.o_
1✔
313

314

315
class WeakFloatingType:
1✔
316
    """Python type representing type of Python floating point objects"""
317

318
    def __init__(self, o):
1✔
319
        self.o_ = o
1✔
320

321
    def get(self):
1✔
322
        return self.o_
1✔
323

324

325
class WeakComplexType:
1✔
326
    """Python type representing type of Python complex floating point objects"""
327

328
    def __init__(self, o):
1✔
329
        self.o_ = o
1✔
330

331
    def get(self):
1✔
332
        return self.o_
1✔
333

334

335
def _weak_type_num_kind(o):
1✔
336
    _map = {"?": 0, "i": 1, "f": 2, "c": 3}
1✔
337
    if isinstance(o, WeakBooleanType):
1✔
338
        return _map["?"]
1✔
339
    if isinstance(o, WeakIntegralType):
1✔
340
        return _map["i"]
1✔
341
    if isinstance(o, WeakFloatingType):
1✔
342
        return _map["f"]
1✔
343
    if isinstance(o, WeakComplexType):
1✔
344
        return _map["c"]
1✔
345
    raise TypeError(
1✔
346
        f"Unexpected type {o} while expecting "
347
        "`WeakBooleanType`, `WeakIntegralType`,"
348
        "`WeakFloatingType`, or `WeakComplexType`."
349
    )
350

351

352
def _strong_dtype_num_kind(o):
1✔
353
    _map = {"b": 0, "i": 1, "u": 1, "f": 2, "c": 3}
1✔
354
    if not isinstance(o, dpt.dtype):
1✔
355
        raise TypeError
1✔
356
    k = o.kind
1✔
357
    if k in _map:
1✔
358
        return _map[k]
1✔
359
    raise ValueError(f"Unrecognized kind {k} for dtype {o}")
1✔
360

361

362
def _is_weak_dtype(dtype):
1✔
363
    return isinstance(
1✔
364
        dtype,
365
        (WeakBooleanType, WeakIntegralType, WeakFloatingType, WeakComplexType),
366
    )
367

368

369
def _resolve_weak_types(o1_dtype, o2_dtype, dev):
1✔
370
    "Resolves weak data type per NEP-0050"
371
    if _is_weak_dtype(o1_dtype):
1✔
372
        if _is_weak_dtype(o2_dtype):
1✔
373
            raise ValueError
1✔
374
        o1_kind_num = _weak_type_num_kind(o1_dtype)
1✔
375
        o2_kind_num = _strong_dtype_num_kind(o2_dtype)
1✔
376
        if o1_kind_num > o2_kind_num:
1✔
377
            if isinstance(o1_dtype, WeakIntegralType):
1✔
378
                return dpt.dtype(ti.default_device_int_type(dev)), o2_dtype
1✔
379
            if isinstance(o1_dtype, WeakComplexType):
1✔
380
                if o2_dtype is dpt.float16 or o2_dtype is dpt.float32:
1✔
381
                    return dpt.complex64, o2_dtype
1✔
382
                return (
1✔
383
                    _to_device_supported_dtype(dpt.complex128, dev),
384
                    o2_dtype,
385
                )
386
            return _to_device_supported_dtype(dpt.float64, dev), o2_dtype
1✔
387
        else:
388
            return o2_dtype, o2_dtype
1✔
389
    elif _is_weak_dtype(o2_dtype):
1✔
390
        o1_kind_num = _strong_dtype_num_kind(o1_dtype)
1✔
391
        o2_kind_num = _weak_type_num_kind(o2_dtype)
1✔
392
        if o2_kind_num > o1_kind_num:
1✔
393
            if isinstance(o2_dtype, WeakIntegralType):
1✔
394
                return o1_dtype, dpt.dtype(ti.default_device_int_type(dev))
1✔
395
            if isinstance(o2_dtype, WeakComplexType):
1✔
396
                if o1_dtype is dpt.float16 or o1_dtype is dpt.float32:
1✔
397
                    return o1_dtype, dpt.complex64
1✔
398
                return o1_dtype, _to_device_supported_dtype(dpt.complex128, dev)
1✔
399
            return (
1✔
400
                o1_dtype,
401
                _to_device_supported_dtype(dpt.float64, dev),
402
            )
403
        else:
404
            return o1_dtype, o1_dtype
1✔
405
    else:
406
        return o1_dtype, o2_dtype
1✔
407

408

409
def _resolve_weak_types_all_py_ints(o1_dtype, o2_dtype, dev):
1✔
410
    "Resolves weak data type per NEP-0050 for comparisons and"
411
    " divide, where result type is known and special behavior"
1✔
412
    "is needed to handle mixed integer kinds and Python integers"
1✔
413
    "without overflow"
1✔
414
    if _is_weak_dtype(o1_dtype):
1✔
415
        if _is_weak_dtype(o2_dtype):
1!
416
            raise ValueError
×
417
        o1_kind_num = _weak_type_num_kind(o1_dtype)
1✔
418
        o2_kind_num = _strong_dtype_num_kind(o2_dtype)
1✔
419
        if o1_kind_num > o2_kind_num:
1✔
420
            if isinstance(o1_dtype, WeakIntegralType):
1✔
421
                return dpt.dtype(ti.default_device_int_type(dev)), o2_dtype
1✔
422
            if isinstance(o1_dtype, WeakComplexType):
1✔
423
                if o2_dtype is dpt.float16 or o2_dtype is dpt.float32:
1✔
424
                    return dpt.complex64, o2_dtype
1✔
425
                return (
1✔
426
                    _to_device_supported_dtype(dpt.complex128, dev),
427
                    o2_dtype,
428
                )
429
            return _to_device_supported_dtype(dpt.float64, dev), o2_dtype
1✔
430
        else:
431
            if o1_kind_num == o2_kind_num and isinstance(
1✔
432
                o1_dtype, WeakIntegralType
433
            ):
434
                o1_val = o1_dtype.get()
1✔
435
                o2_iinfo = dpt.iinfo(o2_dtype)
1✔
436
                if (o1_val < o2_iinfo.min) or (o1_val > o2_iinfo.max):
1✔
437
                    return dpt.dtype(np.min_scalar_type(o1_val)), o2_dtype
1✔
438
            return o2_dtype, o2_dtype
1✔
439
    elif _is_weak_dtype(o2_dtype):
1✔
440
        o1_kind_num = _strong_dtype_num_kind(o1_dtype)
1✔
441
        o2_kind_num = _weak_type_num_kind(o2_dtype)
1✔
442
        if o2_kind_num > o1_kind_num:
1✔
443
            if isinstance(o2_dtype, WeakIntegralType):
1✔
444
                return o1_dtype, dpt.dtype(ti.default_device_int_type(dev))
1✔
445
            if isinstance(o2_dtype, WeakComplexType):
1✔
446
                if o1_dtype is dpt.float16 or o1_dtype is dpt.float32:
1✔
447
                    return o1_dtype, dpt.complex64
1✔
448
                return o1_dtype, _to_device_supported_dtype(dpt.complex128, dev)
1✔
449
            return (
1✔
450
                o1_dtype,
451
                _to_device_supported_dtype(dpt.float64, dev),
452
            )
453
        else:
454
            if o1_kind_num == o2_kind_num and isinstance(
1✔
455
                o2_dtype, WeakIntegralType
456
            ):
457
                o2_val = o2_dtype.get()
1✔
458
                o1_iinfo = dpt.iinfo(o1_dtype)
1✔
459
                if (o2_val < o1_iinfo.min) or (o2_val > o1_iinfo.max):
1✔
460
                    return o1_dtype, dpt.dtype(np.min_scalar_type(o2_val))
1✔
461
            return o1_dtype, o1_dtype
1✔
462
    else:
463
        return o1_dtype, o2_dtype
1✔
464

465

466
def _resolve_one_strong_two_weak_types(st_dtype, dtype1, dtype2, dev):
1✔
467
    "Resolves weak data types per NEP-0050,"
468
    "where the second and third arguments are"
1✔
469
    "permitted to be weak types"
1✔
470
    if _is_weak_dtype(st_dtype):
1!
471
        raise ValueError
×
472
    if _is_weak_dtype(dtype1):
1✔
473
        if _is_weak_dtype(dtype2):
1✔
474
            kind_num1 = _weak_type_num_kind(dtype1)
1✔
475
            kind_num2 = _weak_type_num_kind(dtype2)
1✔
476
            st_kind_num = _strong_dtype_num_kind(st_dtype)
1✔
477

478
            if kind_num1 > st_kind_num:
1✔
479
                if isinstance(dtype1, WeakIntegralType):
1✔
480
                    ret_dtype1 = dpt.dtype(ti.default_device_int_type(dev))
1✔
481
                elif isinstance(dtype1, WeakComplexType):
1✔
482
                    if st_dtype is dpt.float16 or st_dtype is dpt.float32:
1✔
483
                        ret_dtype1 = dpt.complex64
1✔
484
                    ret_dtype1 = _to_device_supported_dtype(dpt.complex128, dev)
1✔
485
                else:
486
                    ret_dtype1 = _to_device_supported_dtype(dpt.float64, dev)
1✔
487
            else:
488
                ret_dtype1 = st_dtype
1✔
489

490
            if kind_num2 > st_kind_num:
1✔
491
                if isinstance(dtype2, WeakIntegralType):
1✔
492
                    ret_dtype2 = dpt.dtype(ti.default_device_int_type(dev))
1✔
493
                elif isinstance(dtype2, WeakComplexType):
1✔
494
                    if st_dtype is dpt.float16 or st_dtype is dpt.float32:
1✔
495
                        ret_dtype2 = dpt.complex64
1✔
496
                    ret_dtype2 = _to_device_supported_dtype(dpt.complex128, dev)
1✔
497
                else:
498
                    ret_dtype2 = _to_device_supported_dtype(dpt.float64, dev)
1✔
499
            else:
500
                ret_dtype2 = st_dtype
1✔
501

502
            return ret_dtype1, ret_dtype2
1✔
503

504
        max_dt_num_kind, max_dtype = max(
1✔
505
            [
506
                (_strong_dtype_num_kind(st_dtype), st_dtype),
507
                (_strong_dtype_num_kind(dtype2), dtype2),
508
            ]
509
        )
510
        dt1_kind_num = _weak_type_num_kind(dtype1)
1✔
511
        if dt1_kind_num > max_dt_num_kind:
1✔
512
            if isinstance(dtype1, WeakIntegralType):
1!
513
                return dpt.dtype(ti.default_device_int_type(dev)), dtype2
×
514
            if isinstance(dtype1, WeakComplexType):
1!
515
                if max_dtype is dpt.float16 or max_dtype is dpt.float32:
×
516
                    return dpt.complex64, dtype2
×
517
                return (
×
518
                    _to_device_supported_dtype(dpt.complex128, dev),
519
                    dtype2,
520
                )
521
            return _to_device_supported_dtype(dpt.float64, dev), dtype2
1✔
522
        else:
523
            return max_dtype, dtype2
1✔
524
    elif _is_weak_dtype(dtype2):
1✔
525
        max_dt_num_kind, max_dtype = max(
1✔
526
            [
527
                (_strong_dtype_num_kind(st_dtype), st_dtype),
528
                (_strong_dtype_num_kind(dtype1), dtype1),
529
            ]
530
        )
531
        dt2_kind_num = _weak_type_num_kind(dtype2)
1✔
532
        if dt2_kind_num > max_dt_num_kind:
1✔
533
            if isinstance(dtype2, WeakIntegralType):
1!
534
                return dtype1, dpt.dtype(ti.default_device_int_type(dev))
×
535
            if isinstance(dtype2, WeakComplexType):
1!
536
                if max_dtype is dpt.float16 or max_dtype is dpt.float32:
×
537
                    return dtype1, dpt.complex64
×
538
                return (
×
539
                    dtype1,
540
                    _to_device_supported_dtype(dpt.complex128, dev),
541
                )
542
            return dtype1, _to_device_supported_dtype(dpt.float64, dev)
1✔
543
        else:
544
            return dtype1, max_dtype
1✔
545
    else:
546
        # both are strong dtypes
547
        # return unmodified
548
        return dtype1, dtype2
1✔
549

550

551
def _resolve_one_strong_one_weak_types(st_dtype, dtype, dev):
1✔
552
    "Resolves one weak data type with one strong data type per NEP-0050"
553
    if _is_weak_dtype(st_dtype):
1!
554
        raise ValueError
×
555
    if _is_weak_dtype(dtype):
1✔
556
        st_kind_num = _strong_dtype_num_kind(st_dtype)
1✔
557
        kind_num = _weak_type_num_kind(dtype)
1✔
558
        if kind_num > st_kind_num:
1✔
559
            if isinstance(dtype, WeakIntegralType):
1✔
560
                return dpt.dtype(ti.default_device_int_type(dev))
1✔
561
            if isinstance(dtype, WeakComplexType):
1✔
562
                if st_dtype is dpt.float16 or st_dtype is dpt.float32:
1✔
563
                    return dpt.complex64
1✔
564
                return _to_device_supported_dtype(dpt.complex128, dev)
1✔
565
            return _to_device_supported_dtype(dpt.float64, dev)
1✔
566
        else:
567
            return st_dtype
1✔
568
    else:
569
        return dtype
1✔
570

571

572
class finfo_object:
1✔
573
    """
574
    `numpy.finfo` subclass which returns Python floating-point scalars for
575
    `eps`, `max`, `min`, and `smallest_normal` attributes.
576
    """
577

578
    def __init__(self, dtype):
1✔
579
        _supported_dtype([dpt.dtype(dtype)])
1✔
580
        self._finfo = np.finfo(dtype)
1✔
581

582
    @property
1✔
583
    def bits(self):
1✔
584
        """
585
        number of bits occupied by the real-valued floating-point data type.
586
        """
587
        return int(self._finfo.bits)
1✔
588

589
    @property
1✔
590
    def smallest_normal(self):
1✔
591
        """
592
        smallest positive real-valued floating-point number with full
593
        precision.
594
        """
595
        return float(self._finfo.smallest_normal)
1✔
596

597
    @property
1✔
598
    def tiny(self):
1✔
599
        """an alias for `smallest_normal`"""
600
        return float(self._finfo.tiny)
1✔
601

602
    @property
1✔
603
    def eps(self):
1✔
604
        """
605
        difference between 1.0 and the next smallest representable real-valued
606
        floating-point number larger than 1.0 according to the IEEE-754
607
        standard.
608
        """
609
        return float(self._finfo.eps)
1✔
610

611
    @property
1✔
612
    def epsneg(self):
1✔
613
        """
614
        difference between 1.0 and the next smallest representable real-valued
615
        floating-point number smaller than 1.0 according to the IEEE-754
616
        standard.
617
        """
618
        return float(self._finfo.epsneg)
1✔
619

620
    @property
1✔
621
    def min(self):
1✔
622
        """smallest representable real-valued number."""
623
        return float(self._finfo.min)
1✔
624

625
    @property
1✔
626
    def max(self):
1✔
627
        "largest representable real-valued number."
628
        return float(self._finfo.max)
1✔
629

630
    @property
1✔
631
    def resolution(self):
1✔
632
        "the approximate decimal resolution of this type."
633
        return float(self._finfo.resolution)
1✔
634

635
    @property
1✔
636
    def precision(self):
1✔
637
        """
638
        the approximate number of decimal digits to which this kind of
639
        floating point type is precise.
640
        """
641
        return float(self._finfo.precision)
1✔
642

643
    @property
1✔
644
    def dtype(self):
1✔
645
        """
646
        the dtype for which finfo returns information. For complex input, the
647
        returned dtype is the associated floating point dtype for its real and
648
        complex components.
649
        """
650
        return self._finfo.dtype
1✔
651

652
    def __str__(self):
1✔
653
        return self._finfo.__str__()
1✔
654

655
    def __repr__(self):
1✔
656
        return self._finfo.__repr__()
1✔
657

658

659
def can_cast(from_, to, /, *, casting="safe") -> bool:
1✔
660
    """ can_cast(from, to, casting="safe")
661

662
    Determines if one data type can be cast to another data type according \
663
    to Type Promotion Rules.
664

665
    Args:
666
       from_ (Union[usm_ndarray, dtype]):
667
           source data type. If `from_` is an array, a device-specific type
668
           promotion rules apply.
669
       to (dtype):
670
           target data type
671
       casting (Optional[str]):
672
            controls what kind of data casting may occur.
673

674
                * "no" means data types should not be cast at all.
675
                * "safe" means only casts that preserve values are allowed.
676
                * "same_kind" means only safe casts and casts within a kind,
677
                  like `float64` to `float32`, are allowed.
678
                * "unsafe" means any data conversion can be done.
679

680
            Default: `"safe"`.
681

682
    Returns:
683
        bool:
684
            Gives `True` if cast can occur according to the casting rule.
685

686
    Device-specific type promotion rules take into account which data type are
687
    and are not supported by a specific device.
688
    """
689
    if isinstance(to, dpt.usm_ndarray):
1✔
690
        raise TypeError(f"Expected `dpt.dtype` type, got {type(to)}.")
1✔
691

692
    dtype_to = dpt.dtype(to)
1✔
693
    _supported_dtype([dtype_to])
1✔
694

695
    if isinstance(from_, dpt.usm_ndarray):
1✔
696
        dtype_from = from_.dtype
1✔
697
        return _can_cast(
1✔
698
            dtype_from,
699
            dtype_to,
700
            from_.sycl_device.has_aspect_fp16,
701
            from_.sycl_device.has_aspect_fp64,
702
            casting=casting,
703
        )
704
    else:
705
        dtype_from = dpt.dtype(from_)
1✔
706
        _supported_dtype([dtype_from])
1✔
707
        # query casting as if all dtypes are supported
708
        return _can_cast(dtype_from, dtype_to, True, True, casting=casting)
1✔
709

710

711
def result_type(*arrays_and_dtypes):
1✔
712
    """
713
    result_type(*arrays_and_dtypes)
714

715
    Returns the dtype that results from applying the Type Promotion Rules to \
716
        the arguments.
717

718
    Args:
719
        arrays_and_dtypes (Union[usm_ndarray, dtype]):
720
            An arbitrary length sequence of usm_ndarray objects or dtypes.
721

722
    Returns:
723
        dtype:
724
            The dtype resulting from an operation involving the
725
            input arrays and dtypes.
726
    """
727
    dtypes = []
1✔
728
    devices = []
1✔
729
    weak_dtypes = []
1✔
730
    for arg_i in arrays_and_dtypes:
1✔
731
        if isinstance(arg_i, dpt.usm_ndarray):
1✔
732
            devices.append(arg_i.sycl_device)
1✔
733
            dtypes.append(arg_i.dtype)
1✔
734
        elif isinstance(arg_i, int):
1✔
735
            weak_dtypes.append(WeakIntegralType(arg_i))
1✔
736
        elif isinstance(arg_i, float):
1✔
737
            weak_dtypes.append(WeakFloatingType(arg_i))
1✔
738
        elif isinstance(arg_i, complex):
1✔
739
            weak_dtypes.append(WeakComplexType(arg_i))
1✔
740
        elif isinstance(arg_i, bool):
1!
741
            weak_dtypes.append(WeakBooleanType(arg_i))
×
742
        else:
743
            dt = dpt.dtype(arg_i)
1✔
744
            _supported_dtype([dt])
1✔
745
            dtypes.append(dt)
1✔
746

747
    has_fp16 = True
1✔
748
    has_fp64 = True
1✔
749
    target_dev = None
1✔
750
    if devices:
1✔
751
        inspected = False
1✔
752
        for d in devices:
1✔
753
            if inspected:
1✔
754
                unsame_fp16_support = d.has_aspect_fp16 != has_fp16
1✔
755
                unsame_fp64_support = d.has_aspect_fp64 != has_fp64
1✔
756
                if unsame_fp16_support or unsame_fp64_support:
1!
757
                    raise ValueError(
×
758
                        "Input arrays reside on devices "
759
                        "with different device supports; "
760
                        "unable to determine which "
761
                        "device-specific type promotion rules "
762
                        "to use."
763
                    )
764
            else:
765
                has_fp16 = d.has_aspect_fp16
1✔
766
                has_fp64 = d.has_aspect_fp64
1✔
767
                target_dev = d
1✔
768
                inspected = True
1✔
769

770
    if not (has_fp16 and has_fp64):
1!
771
        for dt in dtypes:
×
772
            if not _dtype_supported_by_device_impl(dt, has_fp16, has_fp64):
×
773
                raise ValueError(
×
774
                    f"Argument {dt} is not supported by the device"
775
                )
776
        res_dt = np.result_type(*dtypes)
×
777
        res_dt = _to_device_supported_dtype_impl(res_dt, has_fp16, has_fp64)
×
778
        for wdt in weak_dtypes:
×
779
            pair = _resolve_weak_types(wdt, res_dt, target_dev)
×
780
            res_dt = np.result_type(*pair)
×
781
            res_dt = _to_device_supported_dtype_impl(res_dt, has_fp16, has_fp64)
×
782
    else:
783
        res_dt = np.result_type(*dtypes)
1✔
784
        if weak_dtypes:
1✔
785
            weak_dt_obj = [wdt.get() for wdt in weak_dtypes]
1✔
786
            res_dt = np.result_type(res_dt, *weak_dt_obj)
1✔
787

788
    return res_dt
1✔
789

790

791
def iinfo(dtype, /):
1✔
792
    """iinfo(dtype)
793

794
    Returns machine limits for integer data types.
795

796
    Args:
797
        dtype (dtype, usm_ndarray):
798
            integer dtype or
799
            an array with integer dtype.
800

801
    Returns:
802
        iinfo_object:
803
            An object with the following attributes:
804

805
            * bits: int
806
                number of bits occupied by the data type
807
            * max: int
808
                largest representable number.
809
            * min: int
810
                smallest representable number.
811
            * dtype: dtype
812
                integer data type.
813
    """
814
    if isinstance(dtype, dpt.usm_ndarray):
1✔
815
        dtype = dtype.dtype
1✔
816
    _supported_dtype([dpt.dtype(dtype)])
1✔
817
    return np.iinfo(dtype)
1✔
818

819

820
def finfo(dtype, /):
1✔
821
    """finfo(type)
822

823
    Returns machine limits for floating-point data types.
824

825
    Args:
826
        dtype (dtype, usm_ndarray): floating-point dtype or
827
            an array with floating point data type.
828
            If complex, the information is about its component
829
            data type.
830

831
    Returns:
832
        finfo_object:
833
            an object have the following attributes:
834

835
                * bits: int
836
                    number of bits occupied by dtype.
837
                * eps: float
838
                    difference between 1.0 and the next smallest representable
839
                    real-valued floating-point number larger than 1.0 according
840
                    to the IEEE-754 standard.
841
                * max: float
842
                    largest representable real-valued number.
843
                * min: float
844
                    smallest representable real-valued number.
845
                * smallest_normal: float
846
                    smallest positive real-valued floating-point number with
847
                    full precision.
848
                * dtype: dtype
849
                    real-valued floating-point data type.
850

851
    """
852
    if isinstance(dtype, dpt.usm_ndarray):
1✔
853
        dtype = dtype.dtype
1✔
854
    _supported_dtype([dpt.dtype(dtype)])
1✔
855
    return finfo_object(dtype)
1✔
856

857

858
def _supported_dtype(dtypes):
1✔
859
    for dtype in dtypes:
1✔
860
        if dtype.char not in "?bBhHiIlLqQefdFD":
1✔
861
            raise ValueError(f"Dpctl doesn't support dtype {dtype}.")
1✔
862
    return True
1✔
863

864

865
def isdtype(dtype, kind):
1✔
866
    """isdtype(dtype, kind)
867

868
    Returns a boolean indicating whether a provided `dtype` is
869
    of a specified data type `kind`.
870

871
    See [array API](array_api) for more information.
872

873
    [array_api]: https://data-apis.org/array-api/latest/
874
    """
875

876
    if not isinstance(dtype, np.dtype):
1!
877
        raise TypeError(f"Expected instance of `dpt.dtype`, got {dtype}")
×
878

879
    if isinstance(kind, np.dtype):
1✔
880
        return dtype == kind
1✔
881

882
    elif isinstance(kind, str):
1✔
883
        if kind == "bool":
1✔
884
            return dtype == np.dtype("bool")
1✔
885
        elif kind == "signed integer":
1✔
886
            return dtype.kind == "i"
1✔
887
        elif kind == "unsigned integer":
1✔
888
            return dtype.kind == "u"
1✔
889
        elif kind == "integral":
1✔
890
            return dtype.kind in "iu"
1✔
891
        elif kind == "real floating":
1✔
892
            return dtype.kind == "f"
1✔
893
        elif kind == "complex floating":
1✔
894
            return dtype.kind == "c"
1✔
895
        elif kind == "numeric":
1✔
896
            return dtype.kind in "iufc"
1✔
897
        else:
898
            raise ValueError(f"Unrecognized data type kind: {kind}")
1✔
899

900
    elif isinstance(kind, tuple):
1✔
901
        return any(isdtype(dtype, k) for k in kind)
1✔
902

903
    else:
904
        raise TypeError(f"Unsupported data type kind: {kind}")
1✔
905

906

907
def _default_accumulation_dtype(inp_dt, q):
1✔
908
    """Gives default output data type for given input data
909
    type `inp_dt` when accumulation is performed on queue `q`
910
    """
911
    inp_kind = inp_dt.kind
1✔
912
    if inp_kind in "bi":
1✔
913
        res_dt = dpt.dtype(ti.default_device_int_type(q))
1✔
914
        if inp_dt.itemsize > res_dt.itemsize:
1!
915
            res_dt = inp_dt
×
916
    elif inp_kind in "u":
1✔
917
        res_dt = dpt.dtype(ti.default_device_uint_type(q))
1✔
918
        res_ii = dpt.iinfo(res_dt)
1✔
919
        inp_ii = dpt.iinfo(inp_dt)
1✔
920
        if inp_ii.min >= res_ii.min and inp_ii.max <= res_ii.max:
1!
921
            pass
1✔
922
        else:
923
            res_dt = inp_dt
×
924
    elif inp_kind in "fc":
1!
925
        res_dt = inp_dt
1✔
926

927
    return res_dt
1✔
928

929

930
def _default_accumulation_dtype_fp_types(inp_dt, q):
1✔
931
    """Gives default output data type for given input data
932
    type `inp_dt` when accumulation is performed on queue `q`
933
    and the accumulation supports only floating-point data types
934
    """
935
    inp_kind = inp_dt.kind
1✔
936
    if inp_kind in "biu":
1✔
937
        res_dt = dpt.dtype(ti.default_device_fp_type(q))
1✔
938
        can_cast_v = dpt.can_cast(inp_dt, res_dt)
1✔
939
        if not can_cast_v:
1!
940
            _fp64 = q.sycl_device.has_aspect_fp64
×
941
            res_dt = dpt.float64 if _fp64 else dpt.float32
×
942
    elif inp_kind in "f":
1✔
943
        res_dt = inp_dt
1✔
944
    elif inp_kind in "c":
1!
945
        raise ValueError("function not defined for complex types")
1✔
946

947
    return res_dt
1✔
948

949

950
__all__ = [
1✔
951
    "_find_buf_dtype",
952
    "_find_buf_dtype2",
953
    "_to_device_supported_dtype",
954
    "_acceptance_fn_default_unary",
955
    "_acceptance_fn_reciprocal",
956
    "_acceptance_fn_default_binary",
957
    "_acceptance_fn_divide",
958
    "_acceptance_fn_negative",
959
    "_acceptance_fn_subtract",
960
    "_resolve_one_strong_one_weak_types",
961
    "_resolve_one_strong_two_weak_types",
962
    "_resolve_weak_types",
963
    "_resolve_weak_types_all_py_ints",
964
    "_weak_type_num_kind",
965
    "_strong_dtype_num_kind",
966
    "can_cast",
967
    "finfo",
968
    "iinfo",
969
    "isdtype",
970
    "result_type",
971
    "WeakBooleanType",
972
    "WeakIntegralType",
973
    "WeakFloatingType",
974
    "WeakComplexType",
975
    "_default_accumulation_dtype",
976
    "_default_accumulation_dtype_fp_types",
977
    "_find_buf_dtype_in_place_op",
978
]
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc