• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

IntelPython / dpctl / 5617494979

pending completion
5617494979

Pull #1282

github

web-flow
Merge f3d9276ad into e86ba871e
Pull Request #1282: Add python 3.11 support

2276 of 2790 branches covered (81.58%)

Branch coverage included in aggregate %.

8291 of 9908 relevant lines covered (83.68%)

6195.96 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

87.54
/dpctl/tensor/_elementwise_common.py
1
#                       Data Parallel Control (dpctl)
2
#
3
#  Copyright 2020-2023 Intel Corporation
4
#
5
#  Licensed under the Apache License, Version 2.0 (the "License");
6
#  you may not use this file except in compliance with the License.
7
#  You may obtain a copy of the License at
8
#
9
#     http://www.apache.org/licenses/LICENSE-2.0
10
#
11
#  Unless required by applicable law or agreed to in writing, software
12
#  distributed under the License is distributed on an "AS IS" BASIS,
13
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
#  See the License for the specific language governing permissions and
15
#  limitations under the License.
16

17
import numbers
1✔
18

19
import numpy as np
1✔
20

21
import dpctl
1✔
22
import dpctl.memory as dpm
1✔
23
import dpctl.tensor as dpt
1✔
24
import dpctl.tensor._tensor_impl as ti
1✔
25
from dpctl.tensor._manipulation_functions import _broadcast_shape_impl
1✔
26
from dpctl.tensor._usmarray import _is_object_with_buffer_protocol as _is_buffer
1✔
27
from dpctl.utils import ExecutionPlacementError
1✔
28

29
from ._type_utils import (
1✔
30
    _empty_like_orderK,
31
    _empty_like_pair_orderK,
32
    _find_buf_dtype,
33
    _find_buf_dtype2,
34
    _find_inplace_dtype,
35
    _to_device_supported_dtype,
36
)
37

38

39
class UnaryElementwiseFunc:
1✔
40
    """
41
    Class that implements unary element-wise functions.
42
    """
43

44
    def __init__(self, name, result_type_resolver_fn, unary_dp_impl_fn, docs):
1✔
45
        self.__name__ = "UnaryElementwiseFunc"
1✔
46
        self.name_ = name
1✔
47
        self.result_type_resolver_fn_ = result_type_resolver_fn
1✔
48
        self.unary_fn_ = unary_dp_impl_fn
1✔
49
        self.__doc__ = docs
1✔
50

51
    def __call__(self, x, out=None, order="K"):
1✔
52
        if not isinstance(x, dpt.usm_ndarray):
1✔
53
            raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x)}")
1✔
54

55
        if order not in ["C", "F", "K", "A"]:
1✔
56
            order = "K"
1✔
57
        buf_dt, res_dt = _find_buf_dtype(
1✔
58
            x.dtype, self.result_type_resolver_fn_, x.sycl_device
59
        )
60
        if res_dt is None:
1!
61
            raise TypeError(
×
62
                f"function '{self.name_}' does not support input type "
63
                f"({x.dtype}), "
64
                "and the input could not be safely coerced to any "
65
                "supported types according to the casting rule ''safe''."
66
            )
67

68
        orig_out = out
1✔
69
        if out is not None:
1✔
70
            if not isinstance(out, dpt.usm_ndarray):
1!
71
                raise TypeError(
×
72
                    f"output array must be of usm_ndarray type, got {type(out)}"
73
                )
74

75
            if out.shape != x.shape:
1!
76
                raise TypeError(
×
77
                    "The shape of input and output arrays are inconsistent."
78
                    f"Expected output shape is {x.shape}, got {out.shape}"
79
                )
80

81
            if res_dt != out.dtype:
1✔
82
                raise TypeError(
1✔
83
                    f"Output array of type {res_dt} is needed,"
84
                    f" got {out.dtype}"
85
                )
86

87
            if (
1✔
88
                buf_dt is None
89
                and ti._array_overlap(x, out)
90
                and not ti._same_logical_tensors(x, out)
91
            ):
92
                # Allocate a temporary buffer to avoid memory overlapping.
93
                # Note if `buf_dt` is not None, a temporary copy of `x` will be
94
                # created, so the array overlap check isn't needed.
95
                out = dpt.empty_like(out)
1✔
96

97
            if (
1!
98
                dpctl.utils.get_execution_queue((x.sycl_queue, out.sycl_queue))
99
                is None
100
            ):
101
                raise TypeError(
×
102
                    "Input and output allocation queues are not compatible"
103
                )
104

105
        exec_q = x.sycl_queue
1✔
106
        if buf_dt is None:
1✔
107
            if out is None:
1✔
108
                if order == "K":
1✔
109
                    out = _empty_like_orderK(x, res_dt)
1✔
110
                else:
111
                    if order == "A":
1✔
112
                        order = "F" if x.flags.f_contiguous else "C"
1✔
113
                    out = dpt.empty_like(x, dtype=res_dt, order=order)
1✔
114

115
            ht_unary_ev, unary_ev = self.unary_fn_(x, out, sycl_queue=exec_q)
1✔
116

117
            if not (orig_out is None or orig_out is out):
1✔
118
                # Copy the out data from temporary buffer to original memory
119
                ht_copy_ev, _ = ti._copy_usm_ndarray_into_usm_ndarray(
1✔
120
                    src=out, dst=orig_out, sycl_queue=exec_q, depends=[unary_ev]
121
                )
122
                ht_copy_ev.wait()
1✔
123
                out = orig_out
1✔
124

125
            ht_unary_ev.wait()
1✔
126
            return out
1✔
127

128
        if order == "K":
1✔
129
            buf = _empty_like_orderK(x, buf_dt)
1✔
130
        else:
131
            if order == "A":
1✔
132
                order = "F" if x.flags.f_contiguous else "C"
1✔
133
            buf = dpt.empty_like(x, dtype=buf_dt, order=order)
1✔
134

135
        ht_copy_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
1✔
136
            src=x, dst=buf, sycl_queue=exec_q
137
        )
138
        if out is None:
1✔
139
            if order == "K":
1✔
140
                out = _empty_like_orderK(buf, res_dt)
1✔
141
            else:
142
                out = dpt.empty_like(buf, dtype=res_dt, order=order)
1✔
143

144
        ht, _ = self.unary_fn_(buf, out, sycl_queue=exec_q, depends=[copy_ev])
1✔
145
        ht_copy_ev.wait()
1✔
146
        ht.wait()
1✔
147

148
        return out
1✔
149

150

151
def _get_queue_usm_type(o):
1✔
152
    """Return SYCL device where object `o` allocated memory, or None."""
153
    if isinstance(o, dpt.usm_ndarray):
1✔
154
        return o.sycl_queue, o.usm_type
1✔
155
    elif hasattr(o, "__sycl_usm_array_interface__"):
1✔
156
        try:
1✔
157
            m = dpm.as_usm_memory(o)
1✔
158
            return m.sycl_queue, m.get_usm_type()
1✔
159
        except Exception:
1✔
160
            return None, None
1✔
161
    return None, None
1✔
162

163

164
class WeakBooleanType:
1✔
165
    "Python type representing type of Python boolean objects"
166

167
    def __init__(self, o):
1✔
168
        self.o_ = o
1✔
169

170
    def get(self):
1✔
171
        return self.o_
×
172

173

174
class WeakIntegralType:
1✔
175
    "Python type representing type of Python integral objects"
176

177
    def __init__(self, o):
1✔
178
        self.o_ = o
1✔
179

180
    def get(self):
1✔
181
        return self.o_
×
182

183

184
class WeakFloatingType:
1✔
185
    """Python type representing type of Python floating point objects"""
186

187
    def __init__(self, o):
1✔
188
        self.o_ = o
1✔
189

190
    def get(self):
1✔
191
        return self.o_
×
192

193

194
class WeakComplexType:
1✔
195
    """Python type representing type of Python complex floating point objects"""
196

197
    def __init__(self, o):
1✔
198
        self.o_ = o
1✔
199

200
    def get(self):
1✔
201
        return self.o_
×
202

203

204
def _get_dtype(o, dev):
1✔
205
    if isinstance(o, dpt.usm_ndarray):
1✔
206
        return o.dtype
1✔
207
    if hasattr(o, "__sycl_usm_array_interface__"):
1✔
208
        return dpt.asarray(o).dtype
1✔
209
    if _is_buffer(o):
1✔
210
        host_dt = np.array(o).dtype
1✔
211
        dev_dt = _to_device_supported_dtype(host_dt, dev)
1✔
212
        return dev_dt
1✔
213
    if hasattr(o, "dtype"):
1!
214
        dev_dt = _to_device_supported_dtype(o.dtype, dev)
×
215
        return dev_dt
×
216
    if isinstance(o, bool):
1✔
217
        return WeakBooleanType(o)
1✔
218
    if isinstance(o, int):
1✔
219
        return WeakIntegralType(o)
1✔
220
    if isinstance(o, float):
1✔
221
        return WeakFloatingType(o)
1✔
222
    if isinstance(o, complex):
1✔
223
        return WeakComplexType(o)
1✔
224
    return np.object_
1✔
225

226

227
def _validate_dtype(dt) -> bool:
1✔
228
    return isinstance(
1✔
229
        dt,
230
        (WeakBooleanType, WeakIntegralType, WeakFloatingType, WeakComplexType),
231
    ) or (
232
        isinstance(dt, dpt.dtype)
233
        and dt
234
        in [
235
            dpt.bool,
236
            dpt.int8,
237
            dpt.uint8,
238
            dpt.int16,
239
            dpt.uint16,
240
            dpt.int32,
241
            dpt.uint32,
242
            dpt.int64,
243
            dpt.uint64,
244
            dpt.float16,
245
            dpt.float32,
246
            dpt.float64,
247
            dpt.complex64,
248
            dpt.complex128,
249
        ]
250
    )
251

252

253
def _weak_type_num_kind(o):
1✔
254
    _map = {"?": 0, "i": 1, "f": 2, "c": 3}
1✔
255
    if isinstance(o, WeakBooleanType):
1✔
256
        return _map["?"]
1✔
257
    if isinstance(o, WeakIntegralType):
1✔
258
        return _map["i"]
1✔
259
    if isinstance(o, WeakFloatingType):
1✔
260
        return _map["f"]
1✔
261
    if isinstance(o, WeakComplexType):
1!
262
        return _map["c"]
1✔
263
    raise TypeError(
×
264
        f"Unexpected type {o} while expecting "
265
        "`WeakBooleanType`, `WeakIntegralType`,"
266
        "`WeakFloatingType`, or `WeakComplexType`."
267
    )
268

269

270
def _strong_dtype_num_kind(o):
1✔
271
    _map = {"b": 0, "i": 1, "u": 1, "f": 2, "c": 3}
1✔
272
    if not isinstance(o, dpt.dtype):
1!
273
        raise TypeError
×
274
    k = o.kind
1✔
275
    if k in _map:
1!
276
        return _map[k]
1✔
277
    raise ValueError(f"Unrecognized kind {k} for dtype {o}")
×
278

279

280
def _resolve_weak_types(o1_dtype, o2_dtype, dev):
1✔
281
    "Resolves weak data type per NEP-0050"
282
    if isinstance(
1✔
283
        o1_dtype,
284
        (WeakBooleanType, WeakIntegralType, WeakFloatingType, WeakComplexType),
285
    ):
286
        if isinstance(
1!
287
            o2_dtype,
288
            (
289
                WeakBooleanType,
290
                WeakIntegralType,
291
                WeakFloatingType,
292
                WeakComplexType,
293
            ),
294
        ):
295
            raise ValueError
×
296
        o1_kind_num = _weak_type_num_kind(o1_dtype)
1✔
297
        o2_kind_num = _strong_dtype_num_kind(o2_dtype)
1✔
298
        if o1_kind_num > o2_kind_num:
1✔
299
            if isinstance(o1_dtype, WeakBooleanType):
1!
300
                return dpt.bool, o2_dtype
×
301
            if isinstance(o1_dtype, WeakIntegralType):
1✔
302
                return dpt.int64, o2_dtype
1✔
303
            if isinstance(o1_dtype, WeakComplexType):
1✔
304
                if o2_dtype is dpt.float16 or o2_dtype is dpt.float32:
1✔
305
                    return dpt.complex64, o2_dtype
1✔
306
                return (
1✔
307
                    _to_device_supported_dtype(dpt.complex128, dev),
308
                    o2_dtype,
309
                )
310
            return _to_device_supported_dtype(dpt.float64, dev), o2_dtype
1✔
311
        else:
312
            return o2_dtype, o2_dtype
1✔
313
    elif isinstance(
1✔
314
        o2_dtype,
315
        (WeakBooleanType, WeakIntegralType, WeakFloatingType, WeakComplexType),
316
    ):
317
        o1_kind_num = _strong_dtype_num_kind(o1_dtype)
1✔
318
        o2_kind_num = _weak_type_num_kind(o2_dtype)
1✔
319
        if o2_kind_num > o1_kind_num:
1✔
320
            if isinstance(o2_dtype, WeakBooleanType):
1!
321
                return o1_dtype, dpt.bool
×
322
            if isinstance(o2_dtype, WeakIntegralType):
1✔
323
                return o1_dtype, dpt.int64
1✔
324
            if isinstance(o2_dtype, WeakComplexType):
1✔
325
                if o1_dtype is dpt.float16 or o1_dtype is dpt.float32:
1✔
326
                    return o1_dtype, dpt.complex64
1✔
327
                return o1_dtype, _to_device_supported_dtype(dpt.complex128, dev)
1✔
328
            return (
1✔
329
                o1_dtype,
330
                _to_device_supported_dtype(dpt.float64, dev),
331
            )
332
        else:
333
            return o1_dtype, o1_dtype
1✔
334
    else:
335
        return o1_dtype, o2_dtype
1✔
336

337

338
def _get_shape(o):
1✔
339
    if isinstance(o, dpt.usm_ndarray):
1✔
340
        return o.shape
1✔
341
    if _is_buffer(o):
1✔
342
        return memoryview(o).shape
1✔
343
    if isinstance(o, numbers.Number):
1✔
344
        return tuple()
1✔
345
    return getattr(o, "shape", tuple())
1✔
346

347

348
class BinaryElementwiseFunc:
1✔
349
    """
350
    Class that implements binary element-wise functions.
351
    """
352

353
    def __init__(
1✔
354
        self,
355
        name,
356
        result_type_resolver_fn,
357
        binary_dp_impl_fn,
358
        docs,
359
        binary_inplace_fn=None,
360
    ):
361
        self.__name__ = "BinaryElementwiseFunc"
1✔
362
        self.name_ = name
1✔
363
        self.result_type_resolver_fn_ = result_type_resolver_fn
1✔
364
        self.binary_fn_ = binary_dp_impl_fn
1✔
365
        self.binary_inplace_fn_ = binary_inplace_fn
1✔
366
        self.__doc__ = docs
1✔
367

368
    def __str__(self):
1✔
369
        return f"<BinaryElementwiseFunc '{self.name_}'>"
×
370

371
    def __repr__(self):
1✔
372
        return f"<BinaryElementwiseFunc '{self.name_}'>"
×
373

374
    def __call__(self, o1, o2, out=None, order="K"):
1✔
375
        # FIXME: replace with check against base array
376
        # when views can be identified
377
        if o1 is out:
1✔
378
            return self._inplace(o1, o2)
1✔
379
        elif o2 is out:
1!
380
            return self._inplace(o2, o1)
×
381

382
        if order not in ["K", "C", "F", "A"]:
1✔
383
            order = "K"
1✔
384
        q1, o1_usm_type = _get_queue_usm_type(o1)
1✔
385
        q2, o2_usm_type = _get_queue_usm_type(o2)
1✔
386
        if q1 is None and q2 is None:
1✔
387
            raise ExecutionPlacementError(
1✔
388
                "Execution placement can not be unambiguously inferred "
389
                "from input arguments. "
390
                "One of the arguments must represent USM allocation and "
391
                "expose `__sycl_usm_array_interface__` property"
392
            )
393
        if q1 is None:
1✔
394
            exec_q = q2
1✔
395
            res_usm_type = o2_usm_type
1✔
396
        elif q2 is None:
1✔
397
            exec_q = q1
1✔
398
            res_usm_type = o1_usm_type
1✔
399
        else:
400
            exec_q = dpctl.utils.get_execution_queue((q1, q2))
1✔
401
            if exec_q is None:
1!
402
                raise ExecutionPlacementError(
×
403
                    "Execution placement can not be unambiguously inferred "
404
                    "from input arguments."
405
                )
406
            res_usm_type = dpctl.utils.get_coerced_usm_type(
1✔
407
                (
408
                    o1_usm_type,
409
                    o2_usm_type,
410
                )
411
            )
412
        dpctl.utils.validate_usm_type(res_usm_type, allow_none=False)
1✔
413
        o1_shape = _get_shape(o1)
1✔
414
        o2_shape = _get_shape(o2)
1✔
415
        if not all(
1!
416
            isinstance(s, (tuple, list))
417
            for s in (
418
                o1_shape,
419
                o2_shape,
420
            )
421
        ):
422
            raise TypeError(
×
423
                "Shape of arguments can not be inferred. "
424
                "Arguments are expected to be "
425
                "lists, tuples, or both"
426
            )
427
        try:
1✔
428
            res_shape = _broadcast_shape_impl(
1✔
429
                [
430
                    o1_shape,
431
                    o2_shape,
432
                ]
433
            )
434
        except ValueError:
1✔
435
            raise ValueError(
1✔
436
                "operands could not be broadcast together with shapes "
437
                f"{o1_shape} and {o2_shape}"
438
            )
439
        sycl_dev = exec_q.sycl_device
1✔
440
        o1_dtype = _get_dtype(o1, sycl_dev)
1✔
441
        o2_dtype = _get_dtype(o2, sycl_dev)
1✔
442
        if not all(_validate_dtype(o) for o in (o1_dtype, o2_dtype)):
1✔
443
            raise ValueError("Operands of unsupported types")
1✔
444

445
        o1_dtype, o2_dtype = _resolve_weak_types(o1_dtype, o2_dtype, sycl_dev)
1✔
446

447
        buf1_dt, buf2_dt, res_dt = _find_buf_dtype2(
1✔
448
            o1_dtype, o2_dtype, self.result_type_resolver_fn_, sycl_dev
449
        )
450

451
        if res_dt is None:
1!
452
            raise TypeError(
×
453
                f"function '{self.name_}' does not support input types "
454
                f"({o1_dtype}, {o2_dtype}), "
455
                "and the inputs could not be safely coerced to any "
456
                "supported types according to the casting rule ''safe''."
457
            )
458

459
        if out is not None:
1✔
460
            if not isinstance(out, dpt.usm_ndarray):
1!
461
                raise TypeError(
×
462
                    f"output array must be of usm_ndarray type, got {type(out)}"
463
                )
464

465
            if out.shape != res_shape:
1✔
466
                raise TypeError(
1✔
467
                    "The shape of input and output arrays are inconsistent."
468
                    f"Expected output shape is {o1_shape}, got {out.shape}"
469
                )
470

471
            if ti._array_overlap(o1, out) or ti._array_overlap(o2, out):
1!
472
                raise TypeError("Input and output arrays have memory overlap")
×
473

474
            if (
1!
475
                dpctl.utils.get_execution_queue(
476
                    (o1.sycl_queue, o2.sycl_queue, out.sycl_queue)
477
                )
478
                is None
479
            ):
480
                raise TypeError(
×
481
                    "Input and output allocation queues are not compatible"
482
                )
483

484
        if isinstance(o1, dpt.usm_ndarray):
1✔
485
            src1 = o1
1✔
486
        else:
487
            src1 = dpt.asarray(o1, dtype=o1_dtype, sycl_queue=exec_q)
1✔
488
        if isinstance(o2, dpt.usm_ndarray):
1✔
489
            src2 = o2
1✔
490
        else:
491
            src2 = dpt.asarray(o2, dtype=o2_dtype, sycl_queue=exec_q)
1✔
492

493
        if buf1_dt is None and buf2_dt is None:
1✔
494
            if out is None:
1✔
495
                if order == "K":
1✔
496
                    out = _empty_like_pair_orderK(
1✔
497
                        src1, src2, res_dt, res_shape, res_usm_type, exec_q
498
                    )
499
                else:
500
                    if order == "A":
1✔
501
                        order = (
1✔
502
                            "F"
503
                            if all(
504
                                arr.flags.f_contiguous
505
                                for arr in (
506
                                    src1,
507
                                    src2,
508
                                )
509
                            )
510
                            else "C"
511
                        )
512
                    out = dpt.empty(
1✔
513
                        res_shape,
514
                        dtype=res_dt,
515
                        usm_type=res_usm_type,
516
                        sycl_queue=exec_q,
517
                        order=order,
518
                    )
519
            else:
520
                if res_dt != out.dtype:
1✔
521
                    raise TypeError(
1✔
522
                        f"Output array of type {res_dt} is needed,"
523
                        f"got {out.dtype}"
524
                    )
525

526
            src1 = dpt.broadcast_to(src1, res_shape)
1✔
527
            src2 = dpt.broadcast_to(src2, res_shape)
1✔
528
            ht_, _ = self.binary_fn_(
1✔
529
                src1=src1, src2=src2, dst=out, sycl_queue=exec_q
530
            )
531
            ht_.wait()
1✔
532
            return out
1✔
533
        elif buf1_dt is None:
1✔
534
            if order == "K":
1!
535
                buf2 = _empty_like_orderK(src2, buf2_dt)
1✔
536
            else:
537
                if order == "A":
×
538
                    order = "F" if src1.flags.f_contiguous else "C"
×
539
                buf2 = dpt.empty_like(src2, dtype=buf2_dt, order=order)
×
540
            ht_copy_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
1✔
541
                src=src2, dst=buf2, sycl_queue=exec_q
542
            )
543
            if out is None:
1✔
544
                if order == "K":
1!
545
                    out = _empty_like_pair_orderK(
1✔
546
                        src1, buf2, res_dt, res_shape, res_usm_type, exec_q
547
                    )
548
                else:
549
                    out = dpt.empty(
×
550
                        res_shape,
551
                        dtype=res_dt,
552
                        usm_type=res_usm_type,
553
                        sycl_queue=exec_q,
554
                        order=order,
555
                    )
556
            else:
557
                if res_dt != out.dtype:
1✔
558
                    raise TypeError(
1✔
559
                        f"Output array of type {res_dt} is needed,"
560
                        f"got {out.dtype}"
561
                    )
562

563
            src1 = dpt.broadcast_to(src1, res_shape)
1✔
564
            buf2 = dpt.broadcast_to(buf2, res_shape)
1✔
565
            ht_, _ = self.binary_fn_(
1✔
566
                src1=src1,
567
                src2=buf2,
568
                dst=out,
569
                sycl_queue=exec_q,
570
                depends=[copy_ev],
571
            )
572
            ht_copy_ev.wait()
1✔
573
            ht_.wait()
1✔
574
            return out
1✔
575
        elif buf2_dt is None:
1✔
576
            if order == "K":
1!
577
                buf1 = _empty_like_orderK(src1, buf1_dt)
1✔
578
            else:
579
                if order == "A":
×
580
                    order = "F" if src1.flags.f_contiguous else "C"
×
581
                buf1 = dpt.empty_like(src1, dtype=buf1_dt, order=order)
×
582
            ht_copy_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
1✔
583
                src=src1, dst=buf1, sycl_queue=exec_q
584
            )
585
            if out is None:
1✔
586
                if order == "K":
1!
587
                    out = _empty_like_pair_orderK(
1✔
588
                        buf1, src2, res_dt, res_shape, res_usm_type, exec_q
589
                    )
590
                else:
591
                    out = dpt.empty(
×
592
                        res_shape,
593
                        dtype=res_dt,
594
                        usm_type=res_usm_type,
595
                        sycl_queue=exec_q,
596
                        order=order,
597
                    )
598
            else:
599
                if res_dt != out.dtype:
1✔
600
                    raise TypeError(
1✔
601
                        f"Output array of type {res_dt} is needed,"
602
                        f"got {out.dtype}"
603
                    )
604

605
            buf1 = dpt.broadcast_to(buf1, res_shape)
1✔
606
            src2 = dpt.broadcast_to(src2, res_shape)
1✔
607
            ht_, _ = self.binary_fn_(
1✔
608
                src1=buf1,
609
                src2=src2,
610
                dst=out,
611
                sycl_queue=exec_q,
612
                depends=[copy_ev],
613
            )
614
            ht_copy_ev.wait()
1✔
615
            ht_.wait()
1✔
616
            return out
1✔
617

618
        if order in ["K", "A"]:
1✔
619
            if src1.flags.f_contiguous and src2.flags.f_contiguous:
1✔
620
                order = "F"
1✔
621
            elif src1.flags.c_contiguous and src2.flags.c_contiguous:
1✔
622
                order = "C"
1✔
623
            else:
624
                order = "C" if order == "A" else "K"
1✔
625
        if order == "K":
1✔
626
            buf1 = _empty_like_orderK(src1, buf1_dt)
1✔
627
        else:
628
            buf1 = dpt.empty_like(src1, dtype=buf1_dt, order=order)
1✔
629
        ht_copy1_ev, copy1_ev = ti._copy_usm_ndarray_into_usm_ndarray(
1✔
630
            src=src1, dst=buf1, sycl_queue=exec_q
631
        )
632
        if order == "K":
1✔
633
            buf2 = _empty_like_orderK(src2, buf2_dt)
1✔
634
        else:
635
            buf2 = dpt.empty_like(src2, dtype=buf2_dt, order=order)
1✔
636
        ht_copy2_ev, copy2_ev = ti._copy_usm_ndarray_into_usm_ndarray(
1✔
637
            src=src2, dst=buf2, sycl_queue=exec_q
638
        )
639
        if out is None:
1✔
640
            if order == "K":
1✔
641
                out = _empty_like_pair_orderK(
1✔
642
                    buf1, buf2, res_dt, res_shape, res_usm_type, exec_q
643
                )
644
            else:
645
                out = dpt.empty(
1✔
646
                    res_shape,
647
                    dtype=res_dt,
648
                    usm_type=res_usm_type,
649
                    sycl_queue=exec_q,
650
                    order=order,
651
                )
652
        else:
653
            if res_dt != out.dtype:
1✔
654
                raise TypeError(
1✔
655
                    f"Output array of type {res_dt} is needed, got {out.dtype}"
656
                )
657

658
        buf1 = dpt.broadcast_to(buf1, res_shape)
1✔
659
        buf2 = dpt.broadcast_to(buf2, res_shape)
1✔
660
        ht_, _ = self.binary_fn_(
1✔
661
            src1=buf1,
662
            src2=buf2,
663
            dst=out,
664
            sycl_queue=exec_q,
665
            depends=[copy1_ev, copy2_ev],
666
        )
667
        dpctl.SyclEvent.wait_for([ht_copy1_ev, ht_copy2_ev, ht_])
1✔
668
        return out
1✔
669

670
    def _inplace(self, lhs, val):
1✔
671
        if self.binary_inplace_fn_ is None:
1!
672
            raise ValueError(
×
673
                f"In-place operation not supported for ufunc '{self.name_}'"
674
            )
675
        if not isinstance(lhs, dpt.usm_ndarray):
1!
676
            raise TypeError(
×
677
                f"Expected dpctl.tensor.usm_ndarray, got {type(lhs)}"
678
            )
679
        q1, lhs_usm_type = _get_queue_usm_type(lhs)
1✔
680
        q2, val_usm_type = _get_queue_usm_type(val)
1✔
681
        if q2 is None:
1✔
682
            exec_q = q1
1✔
683
            usm_type = lhs_usm_type
1✔
684
        else:
685
            exec_q = dpctl.utils.get_execution_queue((q1, q2))
1✔
686
            if exec_q is None:
1!
687
                raise ExecutionPlacementError(
×
688
                    "Execution placement can not be unambiguously inferred "
689
                    "from input arguments."
690
                )
691
            usm_type = dpctl.utils.get_coerced_usm_type(
1✔
692
                (
693
                    lhs_usm_type,
694
                    val_usm_type,
695
                )
696
            )
697
        dpctl.utils.validate_usm_type(usm_type, allow_none=False)
1✔
698
        lhs_shape = _get_shape(lhs)
1✔
699
        val_shape = _get_shape(val)
1✔
700
        if not all(
1!
701
            isinstance(s, (tuple, list))
702
            for s in (
703
                lhs_shape,
704
                val_shape,
705
            )
706
        ):
707
            raise TypeError(
×
708
                "Shape of arguments can not be inferred. "
709
                "Arguments are expected to be "
710
                "lists, tuples, or both"
711
            )
712
        try:
1✔
713
            res_shape = _broadcast_shape_impl(
1✔
714
                [
715
                    lhs_shape,
716
                    val_shape,
717
                ]
718
            )
719
        except ValueError:
×
720
            raise ValueError(
×
721
                "operands could not be broadcast together with shapes "
722
                f"{lhs_shape} and {val_shape}"
723
            )
724
        if res_shape != lhs_shape:
1!
725
            raise ValueError(
×
726
                f"output shape {lhs_shape} does not match "
727
                f"broadcast shape {res_shape}"
728
            )
729
        sycl_dev = exec_q.sycl_device
1✔
730
        lhs_dtype = lhs.dtype
1✔
731
        val_dtype = _get_dtype(val, sycl_dev)
1✔
732
        if not _validate_dtype(val_dtype):
1!
733
            raise ValueError("Input operand of unsupported type")
×
734

735
        lhs_dtype, val_dtype = _resolve_weak_types(
1✔
736
            lhs_dtype, val_dtype, sycl_dev
737
        )
738

739
        buf_dt = _find_inplace_dtype(
1✔
740
            lhs_dtype, val_dtype, self.result_type_resolver_fn_, sycl_dev
741
        )
742

743
        if buf_dt is None:
1✔
744
            raise TypeError(
1✔
745
                f"In-place '{self.name_}' does not support input types "
746
                f"({lhs_dtype}, {val_dtype}), "
747
                "and the inputs could not be safely coerced to any "
748
                "supported types according to the casting rule ''safe''."
749
            )
750

751
        if isinstance(val, dpt.usm_ndarray):
1✔
752
            rhs = val
1✔
753
            overlap = ti._array_overlap(lhs, rhs)
1✔
754
        else:
755
            rhs = dpt.asarray(val, dtype=val_dtype, sycl_queue=exec_q)
1✔
756
            overlap = False
1✔
757

758
        if buf_dt == val_dtype and overlap is False:
1✔
759
            rhs = dpt.broadcast_to(rhs, res_shape)
1✔
760
            ht_, _ = self.binary_inplace_fn_(
1✔
761
                lhs=lhs, rhs=rhs, sycl_queue=exec_q
762
            )
763
            ht_.wait()
1✔
764

765
        else:
766
            buf = dpt.empty_like(rhs, dtype=buf_dt)
1✔
767
            ht_copy_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
1✔
768
                src=rhs, dst=buf, sycl_queue=exec_q
769
            )
770

771
            buf = dpt.broadcast_to(buf, res_shape)
1✔
772
            ht_, _ = self.binary_inplace_fn_(
1✔
773
                lhs=lhs,
774
                rhs=buf,
775
                sycl_queue=exec_q,
776
                depends=[copy_ev],
777
            )
778
            ht_copy_ev.wait()
1✔
779
            ht_.wait()
1✔
780

781
        return lhs
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc