• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

IntelPython / dpctl / 5579681118

pending completion
5579681118

push

github

web-flow
Merge pull request #1281 from IntelPython/unary_out_overlap

Created a temporary copy in case of overlap for unary function

2269 of 2783 branches covered (81.53%)

Branch coverage included in aggregate %.

15 of 16 new or added lines in 1 file covered. (93.75%)

34 existing lines in 1 file now uncovered.

8281 of 9898 relevant lines covered (83.66%)

5828.35 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

87.54
/dpctl/tensor/_elementwise_common.py
1
#                       Data Parallel Control (dpctl)
2
#
3
#  Copyright 2020-2023 Intel Corporation
4
#
5
#  Licensed under the Apache License, Version 2.0 (the "License");
6
#  you may not use this file except in compliance with the License.
7
#  You may obtain a copy of the License at
8
#
9
#     http://www.apache.org/licenses/LICENSE-2.0
10
#
11
#  Unless required by applicable law or agreed to in writing, software
12
#  distributed under the License is distributed on an "AS IS" BASIS,
13
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
#  See the License for the specific language governing permissions and
15
#  limitations under the License.
16

17
import numbers
1✔
18

19
import numpy as np
1✔
20

21
import dpctl
1✔
22
import dpctl.memory as dpm
1✔
23
import dpctl.tensor as dpt
1✔
24
import dpctl.tensor._tensor_impl as ti
1✔
25
from dpctl.tensor._manipulation_functions import _broadcast_shape_impl
1✔
26
from dpctl.tensor._usmarray import _is_object_with_buffer_protocol as _is_buffer
1✔
27
from dpctl.utils import ExecutionPlacementError
1✔
28

29
from ._type_utils import (
1✔
30
    _empty_like_orderK,
31
    _empty_like_pair_orderK,
32
    _find_buf_dtype,
33
    _find_buf_dtype2,
34
    _find_inplace_dtype,
35
    _to_device_supported_dtype,
36
)
37

38

39
class UnaryElementwiseFunc:
1✔
40
    """
41
    Class that implements unary element-wise functions.
42
    """
43

44
    def __init__(self, name, result_type_resolver_fn, unary_dp_impl_fn, docs):
1✔
45
        self.__name__ = "UnaryElementwiseFunc"
1✔
46
        self.name_ = name
1✔
47
        self.result_type_resolver_fn_ = result_type_resolver_fn
1✔
48
        self.unary_fn_ = unary_dp_impl_fn
1✔
49
        self.__doc__ = docs
1✔
50

51
    def __call__(self, x, out=None, order="K"):
1✔
52
        if not isinstance(x, dpt.usm_ndarray):
1✔
53
            raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x)}")
1✔
54

55
        if order not in ["C", "F", "K", "A"]:
1✔
56
            order = "K"
1✔
57
        buf_dt, res_dt = _find_buf_dtype(
1✔
58
            x.dtype, self.result_type_resolver_fn_, x.sycl_device
59
        )
60
        if res_dt is None:
1!
NEW
61
            raise RuntimeError
×
62

63
        orig_out = out
1✔
64
        if out is not None:
1✔
65
            if not isinstance(out, dpt.usm_ndarray):
1!
66
                raise TypeError(
×
67
                    f"output array must be of usm_ndarray type, got {type(out)}"
68
                )
69

70
            if out.shape != x.shape:
1!
71
                raise TypeError(
×
72
                    "The shape of input and output arrays are inconsistent."
73
                    f"Expected output shape is {x.shape}, got {out.shape}"
74
                )
75

76
            if res_dt != out.dtype:
1✔
77
                raise TypeError(
1✔
78
                    f"Output array of type {res_dt} is needed,"
79
                    f" got {out.dtype}"
80
                )
81

82
            if (
1✔
83
                buf_dt is None
84
                and ti._array_overlap(x, out)
85
                and not ti._same_logical_tensors(x, out)
86
            ):
87
                # Allocate a temporary buffer to avoid memory overlapping.
88
                # Note if `buf_dt` is not None, a temporary copy of `x` will be
89
                # created, so the array overlap check isn't needed.
90
                out = dpt.empty_like(out)
1✔
91

92
            if (
1!
93
                dpctl.utils.get_execution_queue((x.sycl_queue, out.sycl_queue))
94
                is None
95
            ):
96
                raise TypeError(
×
97
                    "Input and output allocation queues are not compatible"
98
                )
99

100
        exec_q = x.sycl_queue
1✔
101
        if buf_dt is None:
1✔
102
            if out is None:
1✔
103
                if order == "K":
1✔
104
                    out = _empty_like_orderK(x, res_dt)
1✔
105
                else:
106
                    if order == "A":
1✔
107
                        order = "F" if x.flags.f_contiguous else "C"
1✔
108
                    out = dpt.empty_like(x, dtype=res_dt, order=order)
1✔
109

110
            ht_unary_ev, unary_ev = self.unary_fn_(x, out, sycl_queue=exec_q)
1✔
111

112
            if not (orig_out is None or orig_out is out):
1✔
113
                # Copy the out data from temporary buffer to original memory
114
                ht_copy_ev, _ = ti._copy_usm_ndarray_into_usm_ndarray(
1✔
115
                    src=out, dst=orig_out, sycl_queue=exec_q, depends=[unary_ev]
116
                )
117
                ht_copy_ev.wait()
1✔
118
                out = orig_out
1✔
119

120
            ht_unary_ev.wait()
1✔
121
            return out
1✔
122

123
        if order == "K":
1✔
124
            buf = _empty_like_orderK(x, buf_dt)
1✔
125
        else:
126
            if order == "A":
1✔
127
                order = "F" if x.flags.f_contiguous else "C"
1✔
128
            buf = dpt.empty_like(x, dtype=buf_dt, order=order)
1✔
129

130
        ht_copy_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
1✔
131
            src=x, dst=buf, sycl_queue=exec_q
132
        )
133
        if out is None:
1✔
134
            if order == "K":
1✔
135
                out = _empty_like_orderK(buf, res_dt)
1✔
136
            else:
137
                out = dpt.empty_like(buf, dtype=res_dt, order=order)
1✔
138

139
        ht, _ = self.unary_fn_(buf, out, sycl_queue=exec_q, depends=[copy_ev])
1✔
140
        ht_copy_ev.wait()
1✔
141
        ht.wait()
1✔
142

143
        return out
1✔
144

145

146
def _get_queue_usm_type(o):
1✔
147
    """Return SYCL device where object `o` allocated memory, or None."""
148
    if isinstance(o, dpt.usm_ndarray):
1✔
149
        return o.sycl_queue, o.usm_type
1✔
150
    elif hasattr(o, "__sycl_usm_array_interface__"):
1✔
151
        try:
1✔
152
            m = dpm.as_usm_memory(o)
1✔
153
            return m.sycl_queue, m.get_usm_type()
1✔
154
        except Exception:
1✔
155
            return None, None
1✔
156
    return None, None
1✔
157

158

159
class WeakBooleanType:
1✔
160
    "Python type representing type of Python boolean objects"
161

162
    def __init__(self, o):
1✔
163
        self.o_ = o
1✔
164

165
    def get(self):
1✔
166
        return self.o_
×
167

168

169
class WeakIntegralType:
1✔
170
    "Python type representing type of Python integral objects"
171

172
    def __init__(self, o):
1✔
173
        self.o_ = o
1✔
174

175
    def get(self):
1✔
176
        return self.o_
×
177

178

179
class WeakFloatingType:
1✔
180
    """Python type representing type of Python floating point objects"""
181

182
    def __init__(self, o):
1✔
183
        self.o_ = o
1✔
184

185
    def get(self):
1✔
186
        return self.o_
×
187

188

189
class WeakComplexType:
1✔
190
    """Python type representing type of Python complex floating point objects"""
191

192
    def __init__(self, o):
1✔
193
        self.o_ = o
1✔
194

195
    def get(self):
1✔
196
        return self.o_
×
197

198

199
def _get_dtype(o, dev):
1✔
200
    if isinstance(o, dpt.usm_ndarray):
1✔
201
        return o.dtype
1✔
202
    if hasattr(o, "__sycl_usm_array_interface__"):
1✔
203
        return dpt.asarray(o).dtype
1✔
204
    if _is_buffer(o):
1✔
205
        host_dt = np.array(o).dtype
1✔
206
        dev_dt = _to_device_supported_dtype(host_dt, dev)
1✔
207
        return dev_dt
1✔
208
    if hasattr(o, "dtype"):
1!
209
        dev_dt = _to_device_supported_dtype(o.dtype, dev)
×
210
        return dev_dt
×
211
    if isinstance(o, bool):
1✔
212
        return WeakBooleanType(o)
1✔
213
    if isinstance(o, int):
1✔
214
        return WeakIntegralType(o)
1✔
215
    if isinstance(o, float):
1✔
216
        return WeakFloatingType(o)
1✔
217
    if isinstance(o, complex):
1✔
218
        return WeakComplexType(o)
1✔
219
    return np.object_
1✔
220

221

222
def _validate_dtype(dt) -> bool:
1✔
223
    return isinstance(
1✔
224
        dt,
225
        (WeakBooleanType, WeakIntegralType, WeakFloatingType, WeakComplexType),
226
    ) or (
227
        isinstance(dt, dpt.dtype)
228
        and dt
229
        in [
230
            dpt.bool,
231
            dpt.int8,
232
            dpt.uint8,
233
            dpt.int16,
234
            dpt.uint16,
235
            dpt.int32,
236
            dpt.uint32,
237
            dpt.int64,
238
            dpt.uint64,
239
            dpt.float16,
240
            dpt.float32,
241
            dpt.float64,
242
            dpt.complex64,
243
            dpt.complex128,
244
        ]
245
    )
246

247

248
def _weak_type_num_kind(o):
1✔
249
    _map = {"?": 0, "i": 1, "f": 2, "c": 3}
1✔
250
    if isinstance(o, WeakBooleanType):
1✔
251
        return _map["?"]
1✔
252
    if isinstance(o, WeakIntegralType):
1✔
253
        return _map["i"]
1✔
254
    if isinstance(o, WeakFloatingType):
1✔
255
        return _map["f"]
1✔
256
    if isinstance(o, WeakComplexType):
1!
257
        return _map["c"]
1✔
258
    raise TypeError(
×
259
        f"Unexpected type {o} while expecting "
260
        "`WeakBooleanType`, `WeakIntegralType`,"
261
        "`WeakFloatingType`, or `WeakComplexType`."
262
    )
263

264

265
def _strong_dtype_num_kind(o):
1✔
266
    _map = {"b": 0, "i": 1, "u": 1, "f": 2, "c": 3}
1✔
267
    if not isinstance(o, dpt.dtype):
1!
268
        raise TypeError
×
269
    k = o.kind
1✔
270
    if k in _map:
1!
271
        return _map[k]
1✔
272
    raise ValueError(f"Unrecognized kind {k} for dtype {o}")
×
273

274

275
def _resolve_weak_types(o1_dtype, o2_dtype, dev):
1✔
276
    "Resolves weak data type per NEP-0050"
277
    if isinstance(
1✔
278
        o1_dtype,
279
        (WeakBooleanType, WeakIntegralType, WeakFloatingType, WeakComplexType),
280
    ):
281
        if isinstance(
1!
282
            o2_dtype,
283
            (
284
                WeakBooleanType,
285
                WeakIntegralType,
286
                WeakFloatingType,
287
                WeakComplexType,
288
            ),
289
        ):
290
            raise ValueError
×
291
        o1_kind_num = _weak_type_num_kind(o1_dtype)
1✔
292
        o2_kind_num = _strong_dtype_num_kind(o2_dtype)
1✔
293
        if o1_kind_num > o2_kind_num:
1✔
294
            if isinstance(o1_dtype, WeakBooleanType):
1!
295
                return dpt.bool, o2_dtype
×
296
            if isinstance(o1_dtype, WeakIntegralType):
1✔
297
                return dpt.int64, o2_dtype
1✔
298
            if isinstance(o1_dtype, WeakComplexType):
1✔
299
                if o2_dtype is dpt.float16 or o2_dtype is dpt.float32:
1✔
300
                    return dpt.complex64, o2_dtype
1✔
301
                return (
1✔
302
                    _to_device_supported_dtype(dpt.complex128, dev),
303
                    o2_dtype,
304
                )
305
            return _to_device_supported_dtype(dpt.float64, dev), o2_dtype
1✔
306
        else:
307
            return o2_dtype, o2_dtype
1✔
308
    elif isinstance(
1✔
309
        o2_dtype,
310
        (WeakBooleanType, WeakIntegralType, WeakFloatingType, WeakComplexType),
311
    ):
312
        o1_kind_num = _strong_dtype_num_kind(o1_dtype)
1✔
313
        o2_kind_num = _weak_type_num_kind(o2_dtype)
1✔
314
        if o2_kind_num > o1_kind_num:
1✔
315
            if isinstance(o2_dtype, WeakBooleanType):
1!
316
                return o1_dtype, dpt.bool
×
317
            if isinstance(o2_dtype, WeakIntegralType):
1✔
318
                return o1_dtype, dpt.int64
1✔
319
            if isinstance(o2_dtype, WeakComplexType):
1✔
320
                if o1_dtype is dpt.float16 or o1_dtype is dpt.float32:
1✔
321
                    return o1_dtype, dpt.complex64
1✔
322
                return o1_dtype, _to_device_supported_dtype(dpt.complex128, dev)
1✔
323
            return (
1✔
324
                o1_dtype,
325
                _to_device_supported_dtype(dpt.float64, dev),
326
            )
327
        else:
328
            return o1_dtype, o1_dtype
1✔
329
    else:
330
        return o1_dtype, o2_dtype
1✔
331

332

333
def _get_shape(o):
1✔
334
    if isinstance(o, dpt.usm_ndarray):
1✔
335
        return o.shape
1✔
336
    if _is_buffer(o):
1✔
337
        return memoryview(o).shape
1✔
338
    if isinstance(o, numbers.Number):
1✔
339
        return tuple()
1✔
340
    return getattr(o, "shape", tuple())
1✔
341

342

343
class BinaryElementwiseFunc:
1✔
344
    """
345
    Class that implements binary element-wise functions.
346
    """
347

348
    def __init__(
1✔
349
        self,
350
        name,
351
        result_type_resolver_fn,
352
        binary_dp_impl_fn,
353
        docs,
354
        binary_inplace_fn=None,
355
    ):
356
        self.__name__ = "BinaryElementwiseFunc"
1✔
357
        self.name_ = name
1✔
358
        self.result_type_resolver_fn_ = result_type_resolver_fn
1✔
359
        self.binary_fn_ = binary_dp_impl_fn
1✔
360
        self.binary_inplace_fn_ = binary_inplace_fn
1✔
361
        self.__doc__ = docs
1✔
362

363
    def __str__(self):
1✔
364
        return f"<BinaryElementwiseFunc '{self.name_}'>"
×
365

366
    def __repr__(self):
1✔
367
        return f"<BinaryElementwiseFunc '{self.name_}'>"
×
368

369
    def __call__(self, o1, o2, out=None, order="K"):
1✔
370
        # FIXME: replace with check against base array
371
        # when views can be identified
372
        if o1 is out:
1✔
373
            return self._inplace(o1, o2)
1✔
374
        elif o2 is out:
1!
375
            return self._inplace(o2, o1)
×
376

377
        if order not in ["K", "C", "F", "A"]:
1✔
378
            order = "K"
1✔
379
        q1, o1_usm_type = _get_queue_usm_type(o1)
1✔
380
        q2, o2_usm_type = _get_queue_usm_type(o2)
1✔
381
        if q1 is None and q2 is None:
1✔
382
            raise ExecutionPlacementError(
1✔
383
                "Execution placement can not be unambiguously inferred "
384
                "from input arguments. "
385
                "One of the arguments must represent USM allocation and "
386
                "expose `__sycl_usm_array_interface__` property"
387
            )
388
        if q1 is None:
1✔
389
            exec_q = q2
1✔
390
            res_usm_type = o2_usm_type
1✔
391
        elif q2 is None:
1✔
392
            exec_q = q1
1✔
393
            res_usm_type = o1_usm_type
1✔
394
        else:
395
            exec_q = dpctl.utils.get_execution_queue((q1, q2))
1✔
396
            if exec_q is None:
1!
397
                raise ExecutionPlacementError(
×
398
                    "Execution placement can not be unambiguously inferred "
399
                    "from input arguments."
400
                )
401
            res_usm_type = dpctl.utils.get_coerced_usm_type(
1✔
402
                (
403
                    o1_usm_type,
404
                    o2_usm_type,
405
                )
406
            )
407
        dpctl.utils.validate_usm_type(res_usm_type, allow_none=False)
1✔
408
        o1_shape = _get_shape(o1)
1✔
409
        o2_shape = _get_shape(o2)
1✔
410
        if not all(
1!
411
            isinstance(s, (tuple, list))
412
            for s in (
413
                o1_shape,
414
                o2_shape,
415
            )
416
        ):
417
            raise TypeError(
×
418
                "Shape of arguments can not be inferred. "
419
                "Arguments are expected to be "
420
                "lists, tuples, or both"
421
            )
422
        try:
1✔
423
            res_shape = _broadcast_shape_impl(
1✔
424
                [
425
                    o1_shape,
426
                    o2_shape,
427
                ]
428
            )
429
        except ValueError:
1✔
430
            raise ValueError(
1✔
431
                "operands could not be broadcast together with shapes "
432
                f"{o1_shape} and {o2_shape}"
433
            )
434
        sycl_dev = exec_q.sycl_device
1✔
435
        o1_dtype = _get_dtype(o1, sycl_dev)
1✔
436
        o2_dtype = _get_dtype(o2, sycl_dev)
1✔
437
        if not all(_validate_dtype(o) for o in (o1_dtype, o2_dtype)):
1✔
438
            raise ValueError("Operands of unsupported types")
1✔
439

440
        o1_dtype, o2_dtype = _resolve_weak_types(o1_dtype, o2_dtype, sycl_dev)
1✔
441

442
        buf1_dt, buf2_dt, res_dt = _find_buf_dtype2(
1✔
443
            o1_dtype, o2_dtype, self.result_type_resolver_fn_, sycl_dev
444
        )
445

446
        if res_dt is None:
1!
447
            raise TypeError(
×
448
                f"function '{self.name_}' does not support input types "
449
                f"({o1_dtype}, {o2_dtype}), "
450
                "and the inputs could not be safely coerced to any "
451
                "supported types according to the casting rule ''safe''."
452
            )
453

454
        if out is not None:
1✔
455
            if not isinstance(out, dpt.usm_ndarray):
1!
456
                raise TypeError(
×
457
                    f"output array must be of usm_ndarray type, got {type(out)}"
458
                )
459

460
            if out.shape != res_shape:
1✔
461
                raise TypeError(
1✔
462
                    "The shape of input and output arrays are inconsistent."
463
                    f"Expected output shape is {o1_shape}, got {out.shape}"
464
                )
465

466
            if ti._array_overlap(o1, out) or ti._array_overlap(o2, out):
1!
467
                raise TypeError("Input and output arrays have memory overlap")
×
468

469
            if (
1!
470
                dpctl.utils.get_execution_queue(
471
                    (o1.sycl_queue, o2.sycl_queue, out.sycl_queue)
472
                )
473
                is None
474
            ):
475
                raise TypeError(
×
476
                    "Input and output allocation queues are not compatible"
477
                )
478

479
        if isinstance(o1, dpt.usm_ndarray):
1✔
480
            src1 = o1
1✔
481
        else:
482
            src1 = dpt.asarray(o1, dtype=o1_dtype, sycl_queue=exec_q)
1✔
483
        if isinstance(o2, dpt.usm_ndarray):
1✔
484
            src2 = o2
1✔
485
        else:
486
            src2 = dpt.asarray(o2, dtype=o2_dtype, sycl_queue=exec_q)
1✔
487

488
        if buf1_dt is None and buf2_dt is None:
1✔
489
            if out is None:
1✔
490
                if order == "K":
1✔
491
                    out = _empty_like_pair_orderK(
1✔
492
                        src1, src2, res_dt, res_shape, res_usm_type, exec_q
493
                    )
494
                else:
495
                    if order == "A":
1✔
496
                        order = (
1✔
497
                            "F"
498
                            if all(
499
                                arr.flags.f_contiguous
500
                                for arr in (
501
                                    src1,
502
                                    src2,
503
                                )
504
                            )
505
                            else "C"
506
                        )
507
                    out = dpt.empty(
1✔
508
                        res_shape,
509
                        dtype=res_dt,
510
                        usm_type=res_usm_type,
511
                        sycl_queue=exec_q,
512
                        order=order,
513
                    )
514
            else:
515
                if res_dt != out.dtype:
1✔
516
                    raise TypeError(
1✔
517
                        f"Output array of type {res_dt} is needed,"
518
                        f"got {out.dtype}"
519
                    )
520

521
            src1 = dpt.broadcast_to(src1, res_shape)
1✔
522
            src2 = dpt.broadcast_to(src2, res_shape)
1✔
523
            ht_, _ = self.binary_fn_(
1✔
524
                src1=src1, src2=src2, dst=out, sycl_queue=exec_q
525
            )
526
            ht_.wait()
1✔
527
            return out
1✔
528
        elif buf1_dt is None:
1✔
529
            if order == "K":
1!
530
                buf2 = _empty_like_orderK(src2, buf2_dt)
1✔
531
            else:
532
                if order == "A":
×
533
                    order = "F" if src1.flags.f_contiguous else "C"
×
534
                buf2 = dpt.empty_like(src2, dtype=buf2_dt, order=order)
×
535
            ht_copy_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
1✔
536
                src=src2, dst=buf2, sycl_queue=exec_q
537
            )
538
            if out is None:
1✔
539
                if order == "K":
1!
540
                    out = _empty_like_pair_orderK(
1✔
541
                        src1, buf2, res_dt, res_shape, res_usm_type, exec_q
542
                    )
543
                else:
544
                    out = dpt.empty(
×
545
                        res_shape,
546
                        dtype=res_dt,
547
                        usm_type=res_usm_type,
548
                        sycl_queue=exec_q,
549
                        order=order,
550
                    )
551
            else:
552
                if res_dt != out.dtype:
1✔
553
                    raise TypeError(
1✔
554
                        f"Output array of type {res_dt} is needed,"
555
                        f"got {out.dtype}"
556
                    )
557

558
            src1 = dpt.broadcast_to(src1, res_shape)
1✔
559
            buf2 = dpt.broadcast_to(buf2, res_shape)
1✔
560
            ht_, _ = self.binary_fn_(
1✔
561
                src1=src1,
562
                src2=buf2,
563
                dst=out,
564
                sycl_queue=exec_q,
565
                depends=[copy_ev],
566
            )
567
            ht_copy_ev.wait()
1✔
568
            ht_.wait()
1✔
569
            return out
1✔
570
        elif buf2_dt is None:
1✔
571
            if order == "K":
1!
572
                buf1 = _empty_like_orderK(src1, buf1_dt)
1✔
573
            else:
574
                if order == "A":
×
575
                    order = "F" if src1.flags.f_contiguous else "C"
×
576
                buf1 = dpt.empty_like(src1, dtype=buf1_dt, order=order)
×
577
            ht_copy_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
1✔
578
                src=src1, dst=buf1, sycl_queue=exec_q
579
            )
580
            if out is None:
1✔
581
                if order == "K":
1!
582
                    out = _empty_like_pair_orderK(
1✔
583
                        buf1, src2, res_dt, res_shape, res_usm_type, exec_q
584
                    )
585
                else:
586
                    out = dpt.empty(
×
587
                        res_shape,
588
                        dtype=res_dt,
589
                        usm_type=res_usm_type,
590
                        sycl_queue=exec_q,
591
                        order=order,
592
                    )
593
            else:
594
                if res_dt != out.dtype:
1✔
595
                    raise TypeError(
1✔
596
                        f"Output array of type {res_dt} is needed,"
597
                        f"got {out.dtype}"
598
                    )
599

600
            buf1 = dpt.broadcast_to(buf1, res_shape)
1✔
601
            src2 = dpt.broadcast_to(src2, res_shape)
1✔
602
            ht_, _ = self.binary_fn_(
1✔
603
                src1=buf1,
604
                src2=src2,
605
                dst=out,
606
                sycl_queue=exec_q,
607
                depends=[copy_ev],
608
            )
609
            ht_copy_ev.wait()
1✔
610
            ht_.wait()
1✔
611
            return out
1✔
612

613
        if order in ["K", "A"]:
1✔
614
            if src1.flags.f_contiguous and src2.flags.f_contiguous:
1✔
615
                order = "F"
1✔
616
            elif src1.flags.c_contiguous and src2.flags.c_contiguous:
1✔
617
                order = "C"
1✔
618
            else:
619
                order = "C" if order == "A" else "K"
1✔
620
        if order == "K":
1✔
621
            buf1 = _empty_like_orderK(src1, buf1_dt)
1✔
622
        else:
623
            buf1 = dpt.empty_like(src1, dtype=buf1_dt, order=order)
1✔
624
        ht_copy1_ev, copy1_ev = ti._copy_usm_ndarray_into_usm_ndarray(
1✔
625
            src=src1, dst=buf1, sycl_queue=exec_q
626
        )
627
        if order == "K":
1✔
628
            buf2 = _empty_like_orderK(src2, buf2_dt)
1✔
629
        else:
630
            buf2 = dpt.empty_like(src2, dtype=buf2_dt, order=order)
1✔
631
        ht_copy2_ev, copy2_ev = ti._copy_usm_ndarray_into_usm_ndarray(
1✔
632
            src=src2, dst=buf2, sycl_queue=exec_q
633
        )
634
        if out is None:
1✔
635
            if order == "K":
1✔
636
                out = _empty_like_pair_orderK(
1✔
637
                    buf1, buf2, res_dt, res_shape, res_usm_type, exec_q
638
                )
639
            else:
640
                out = dpt.empty(
1✔
641
                    res_shape,
642
                    dtype=res_dt,
643
                    usm_type=res_usm_type,
644
                    sycl_queue=exec_q,
645
                    order=order,
646
                )
647
        else:
648
            if res_dt != out.dtype:
1✔
649
                raise TypeError(
1✔
650
                    f"Output array of type {res_dt} is needed, got {out.dtype}"
651
                )
652

653
        buf1 = dpt.broadcast_to(buf1, res_shape)
1✔
654
        buf2 = dpt.broadcast_to(buf2, res_shape)
1✔
655
        ht_, _ = self.binary_fn_(
1✔
656
            src1=buf1,
657
            src2=buf2,
658
            dst=out,
659
            sycl_queue=exec_q,
660
            depends=[copy1_ev, copy2_ev],
661
        )
662
        dpctl.SyclEvent.wait_for([ht_copy1_ev, ht_copy2_ev, ht_])
1✔
663
        return out
1✔
664

665
    def _inplace(self, lhs, val):
1✔
666
        if self.binary_inplace_fn_ is None:
1!
667
            raise ValueError(
×
668
                f"In-place operation not supported for ufunc '{self.name_}'"
669
            )
670
        if not isinstance(lhs, dpt.usm_ndarray):
1!
671
            raise TypeError(
×
672
                f"Expected dpctl.tensor.usm_ndarray, got {type(lhs)}"
673
            )
674
        q1, lhs_usm_type = _get_queue_usm_type(lhs)
1✔
675
        q2, val_usm_type = _get_queue_usm_type(val)
1✔
676
        if q2 is None:
1✔
677
            exec_q = q1
1✔
678
            usm_type = lhs_usm_type
1✔
679
        else:
680
            exec_q = dpctl.utils.get_execution_queue((q1, q2))
1✔
681
            if exec_q is None:
1!
682
                raise ExecutionPlacementError(
×
683
                    "Execution placement can not be unambiguously inferred "
684
                    "from input arguments."
685
                )
686
            usm_type = dpctl.utils.get_coerced_usm_type(
1✔
687
                (
688
                    lhs_usm_type,
689
                    val_usm_type,
690
                )
691
            )
692
        dpctl.utils.validate_usm_type(usm_type, allow_none=False)
1✔
693
        lhs_shape = _get_shape(lhs)
1✔
694
        val_shape = _get_shape(val)
1✔
695
        if not all(
1!
696
            isinstance(s, (tuple, list))
697
            for s in (
698
                lhs_shape,
699
                val_shape,
700
            )
701
        ):
702
            raise TypeError(
×
703
                "Shape of arguments can not be inferred. "
704
                "Arguments are expected to be "
705
                "lists, tuples, or both"
706
            )
707
        try:
1✔
708
            res_shape = _broadcast_shape_impl(
1✔
709
                [
710
                    lhs_shape,
711
                    val_shape,
712
                ]
713
            )
714
        except ValueError:
×
715
            raise ValueError(
×
716
                "operands could not be broadcast together with shapes "
717
                f"{lhs_shape} and {val_shape}"
718
            )
719
        if res_shape != lhs_shape:
1!
720
            raise ValueError(
×
721
                f"output shape {lhs_shape} does not match "
722
                f"broadcast shape {res_shape}"
723
            )
724
        sycl_dev = exec_q.sycl_device
1✔
725
        lhs_dtype = lhs.dtype
1✔
726
        val_dtype = _get_dtype(val, sycl_dev)
1✔
727
        if not _validate_dtype(val_dtype):
1!
728
            raise ValueError("Input operand of unsupported type")
×
729

730
        lhs_dtype, val_dtype = _resolve_weak_types(
1✔
731
            lhs_dtype, val_dtype, sycl_dev
732
        )
733

734
        buf_dt = _find_inplace_dtype(
1✔
735
            lhs_dtype, val_dtype, self.result_type_resolver_fn_, sycl_dev
736
        )
737

738
        if buf_dt is None:
1✔
739
            raise TypeError(
1✔
740
                f"In-place '{self.name_}' does not support input types "
741
                f"({lhs_dtype}, {val_dtype}), "
742
                "and the inputs could not be safely coerced to any "
743
                "supported types according to the casting rule ''safe''."
744
            )
745

746
        if isinstance(val, dpt.usm_ndarray):
1✔
747
            rhs = val
1✔
748
            overlap = ti._array_overlap(lhs, rhs)
1✔
749
        else:
750
            rhs = dpt.asarray(val, dtype=val_dtype, sycl_queue=exec_q)
1✔
751
            overlap = False
1✔
752

753
        if buf_dt == val_dtype and overlap is False:
1✔
754
            rhs = dpt.broadcast_to(rhs, res_shape)
1✔
755
            ht_, _ = self.binary_inplace_fn_(
1✔
756
                lhs=lhs, rhs=rhs, sycl_queue=exec_q
757
            )
758
            ht_.wait()
1✔
759

760
        else:
761
            buf = dpt.empty_like(rhs, dtype=buf_dt)
1✔
762
            ht_copy_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
1✔
763
                src=rhs, dst=buf, sycl_queue=exec_q
764
            )
765

766
            buf = dpt.broadcast_to(buf, res_shape)
1✔
767
            ht_, _ = self.binary_inplace_fn_(
1✔
768
                lhs=lhs,
769
                rhs=buf,
770
                sycl_queue=exec_q,
771
                depends=[copy_ev],
772
            )
773
            ht_copy_ev.wait()
1✔
774
            ht_.wait()
1✔
775

776
        return lhs
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc