• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

IntelPython / dpctl / 5906987980

18 Aug 2023 09:36PM UTC coverage: 85.584% (+0.09%) from 85.497%
5906987980

push

github

web-flow
Fully enable ``usm_ndarray`` in-place arithmetic operators (#1352)

* Binary elementwise functions can now act on any input in-place
- A temporary will be allocated as necessary (i.e., when arrays overlap, are not going to be cast, and are not the same logical arrays)
- Uses dedicated in-place kernels where they are implemented
- Now called directly by Python operators
- Removes _inplace method of BinaryElementwiseFunc class
- Removes _find_inplace_dtype function

* Tests for new out parameter behavior for add

* Broadcasting made conditional in binary functions where memory overlap is possible
- Broadcasting can change the values of strides without changing array shape

* Changed exception types raised

Use ExecutionPlacementError for CFD violations.
Use ValueError is types of input are as expected, but values are
not as expected.

* Adding tests to improve coverage

Removed tests expecting error raised in case of overlapping inputs.
Added tests guided by coverage report.

* Removed provably unreachable branches in _resolve_weak_types

Since o1_dtype_kind_num > o2_dtype_kind_num, o1 can be not be
weak boolean type, since it has the lowest kind number in the
hierarchy.

* All in-place operators now use call operator of BinaryElementwiseFunc

* Removed some redundant and obsolete tests
- Removed from test_floor_ceil_trunc, test_hyperbolic, test_trigonometric, and test_logaddexp
- These tests would fail on GPU but never run on CPU, and therefore were not impacting the coverage
- These tests focused on aspects of the BinaryElementwiseFunc class rather than the behavior of the operator

---------

Co-authored-by: Oleksandr Pavlyk <oleksandr.pavlyk@intel.com>

2323 of 2748 branches covered (84.53%)

Branch coverage included in aggregate %.

60 of 70 new or added lines in 2 files covered. (85.71%)

1 existing line in 1 file now uncovered.

8458 of 9849 relevant lines covered (85.88%)

7969.03 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

89.57
/dpctl/tensor/_elementwise_common.py
1
#                       Data Parallel Control (dpctl)
2
#
3
#  Copyright 2020-2023 Intel Corporation
4
#
5
#  Licensed under the Apache License, Version 2.0 (the "License");
6
#  you may not use this file except in compliance with the License.
7
#  You may obtain a copy of the License at
8
#
9
#     http://www.apache.org/licenses/LICENSE-2.0
10
#
11
#  Unless required by applicable law or agreed to in writing, software
12
#  distributed under the License is distributed on an "AS IS" BASIS,
13
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
#  See the License for the specific language governing permissions and
15
#  limitations under the License.
16

17
import numbers
1✔
18

19
import numpy as np
1✔
20

21
import dpctl
1✔
22
import dpctl.memory as dpm
1✔
23
import dpctl.tensor as dpt
1✔
24
import dpctl.tensor._tensor_impl as ti
1✔
25
from dpctl.tensor._manipulation_functions import _broadcast_shape_impl
1✔
26
from dpctl.tensor._usmarray import _is_object_with_buffer_protocol as _is_buffer
1✔
27
from dpctl.utils import ExecutionPlacementError
1✔
28

29
from ._copy_utils import _empty_like_orderK, _empty_like_pair_orderK
1✔
30
from ._type_utils import (
1✔
31
    _acceptance_fn_default,
32
    _find_buf_dtype,
33
    _find_buf_dtype2,
34
    _to_device_supported_dtype,
35
)
36

37

38
class UnaryElementwiseFunc:
1✔
39
    """
40
    Class that implements unary element-wise functions.
41
    """
42

43
    def __init__(self, name, result_type_resolver_fn, unary_dp_impl_fn, docs):
1✔
44
        self.__name__ = "UnaryElementwiseFunc"
1✔
45
        self.name_ = name
1✔
46
        self.result_type_resolver_fn_ = result_type_resolver_fn
1✔
47
        self.unary_fn_ = unary_dp_impl_fn
1✔
48
        self.__doc__ = docs
1✔
49

50
    def __str__(self):
1✔
51
        return f"<{self.__name__} '{self.name_}'>"
×
52

53
    def __repr__(self):
1✔
54
        return f"<{self.__name__} '{self.name_}'>"
×
55

56
    def __call__(self, x, out=None, order="K"):
1✔
57
        if not isinstance(x, dpt.usm_ndarray):
1✔
58
            raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x)}")
1✔
59

60
        if order not in ["C", "F", "K", "A"]:
1✔
61
            order = "K"
1✔
62
        buf_dt, res_dt = _find_buf_dtype(
1✔
63
            x.dtype, self.result_type_resolver_fn_, x.sycl_device
64
        )
65
        if res_dt is None:
1!
66
            raise TypeError(
×
67
                f"function '{self.name_}' does not support input type "
68
                f"({x.dtype}), "
69
                "and the input could not be safely coerced to any "
70
                "supported types according to the casting rule ''safe''."
71
            )
72

73
        orig_out = out
1✔
74
        if out is not None:
1✔
75
            if not isinstance(out, dpt.usm_ndarray):
1!
76
                raise TypeError(
×
77
                    f"output array must be of usm_ndarray type, got {type(out)}"
78
                )
79

80
            if out.shape != x.shape:
1!
NEW
81
                raise ValueError(
×
82
                    "The shape of input and output arrays are inconsistent. "
83
                    f"Expected output shape is {x.shape}, got {out.shape}"
84
                )
85

86
            if res_dt != out.dtype:
1✔
87
                raise TypeError(
1✔
88
                    f"Output array of type {res_dt} is needed,"
89
                    f" got {out.dtype}"
90
                )
91

92
            if (
1✔
93
                buf_dt is None
94
                and ti._array_overlap(x, out)
95
                and not ti._same_logical_tensors(x, out)
96
            ):
97
                # Allocate a temporary buffer to avoid memory overlapping.
98
                # Note if `buf_dt` is not None, a temporary copy of `x` will be
99
                # created, so the array overlap check isn't needed.
100
                out = dpt.empty_like(out)
1✔
101

102
            if (
1!
103
                dpctl.utils.get_execution_queue((x.sycl_queue, out.sycl_queue))
104
                is None
105
            ):
NEW
106
                raise ExecutionPlacementError(
×
107
                    "Input and output allocation queues are not compatible"
108
                )
109

110
        exec_q = x.sycl_queue
1✔
111
        if buf_dt is None:
1✔
112
            if out is None:
1✔
113
                if order == "K":
1✔
114
                    out = _empty_like_orderK(x, res_dt)
1✔
115
                else:
116
                    if order == "A":
1✔
117
                        order = "F" if x.flags.f_contiguous else "C"
1✔
118
                    out = dpt.empty_like(x, dtype=res_dt, order=order)
1✔
119

120
            ht_unary_ev, unary_ev = self.unary_fn_(x, out, sycl_queue=exec_q)
1✔
121

122
            if not (orig_out is None or orig_out is out):
1✔
123
                # Copy the out data from temporary buffer to original memory
124
                ht_copy_ev, _ = ti._copy_usm_ndarray_into_usm_ndarray(
1✔
125
                    src=out, dst=orig_out, sycl_queue=exec_q, depends=[unary_ev]
126
                )
127
                ht_copy_ev.wait()
1✔
128
                out = orig_out
1✔
129

130
            ht_unary_ev.wait()
1✔
131
            return out
1✔
132

133
        if order == "K":
1✔
134
            buf = _empty_like_orderK(x, buf_dt)
1✔
135
        else:
136
            if order == "A":
1✔
137
                order = "F" if x.flags.f_contiguous else "C"
1✔
138
            buf = dpt.empty_like(x, dtype=buf_dt, order=order)
1✔
139

140
        ht_copy_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
1✔
141
            src=x, dst=buf, sycl_queue=exec_q
142
        )
143
        if out is None:
1✔
144
            if order == "K":
1✔
145
                out = _empty_like_orderK(buf, res_dt)
1✔
146
            else:
147
                out = dpt.empty_like(buf, dtype=res_dt, order=order)
1✔
148

149
        ht, _ = self.unary_fn_(buf, out, sycl_queue=exec_q, depends=[copy_ev])
1✔
150
        ht_copy_ev.wait()
1✔
151
        ht.wait()
1✔
152

153
        return out
1✔
154

155

156
def _get_queue_usm_type(o):
1✔
157
    """Return SYCL device where object `o` allocated memory, or None."""
158
    if isinstance(o, dpt.usm_ndarray):
1✔
159
        return o.sycl_queue, o.usm_type
1✔
160
    elif hasattr(o, "__sycl_usm_array_interface__"):
1✔
161
        try:
1✔
162
            m = dpm.as_usm_memory(o)
1✔
163
            return m.sycl_queue, m.get_usm_type()
1✔
164
        except Exception:
1✔
165
            return None, None
1✔
166
    return None, None
1✔
167

168

169
class WeakBooleanType:
1✔
170
    "Python type representing type of Python boolean objects"
171

172
    def __init__(self, o):
1✔
173
        self.o_ = o
1✔
174

175
    def get(self):
1✔
176
        return self.o_
×
177

178

179
class WeakIntegralType:
1✔
180
    "Python type representing type of Python integral objects"
181

182
    def __init__(self, o):
1✔
183
        self.o_ = o
1✔
184

185
    def get(self):
1✔
186
        return self.o_
×
187

188

189
class WeakFloatingType:
1✔
190
    """Python type representing type of Python floating point objects"""
191

192
    def __init__(self, o):
1✔
193
        self.o_ = o
1✔
194

195
    def get(self):
1✔
196
        return self.o_
×
197

198

199
class WeakComplexType:
1✔
200
    """Python type representing type of Python complex floating point objects"""
201

202
    def __init__(self, o):
1✔
203
        self.o_ = o
1✔
204

205
    def get(self):
1✔
206
        return self.o_
×
207

208

209
def _get_dtype(o, dev):
1✔
210
    if isinstance(o, dpt.usm_ndarray):
1✔
211
        return o.dtype
1✔
212
    if hasattr(o, "__sycl_usm_array_interface__"):
1✔
213
        return dpt.asarray(o).dtype
1✔
214
    if _is_buffer(o):
1✔
215
        host_dt = np.array(o).dtype
1✔
216
        dev_dt = _to_device_supported_dtype(host_dt, dev)
1✔
217
        return dev_dt
1✔
218
    if hasattr(o, "dtype"):
1!
219
        dev_dt = _to_device_supported_dtype(o.dtype, dev)
×
220
        return dev_dt
×
221
    if isinstance(o, bool):
1✔
222
        return WeakBooleanType(o)
1✔
223
    if isinstance(o, int):
1✔
224
        return WeakIntegralType(o)
1✔
225
    if isinstance(o, float):
1✔
226
        return WeakFloatingType(o)
1✔
227
    if isinstance(o, complex):
1✔
228
        return WeakComplexType(o)
1✔
229
    return np.object_
1✔
230

231

232
def _validate_dtype(dt) -> bool:
1✔
233
    return isinstance(
1✔
234
        dt,
235
        (WeakBooleanType, WeakIntegralType, WeakFloatingType, WeakComplexType),
236
    ) or (
237
        isinstance(dt, dpt.dtype)
238
        and dt
239
        in [
240
            dpt.bool,
241
            dpt.int8,
242
            dpt.uint8,
243
            dpt.int16,
244
            dpt.uint16,
245
            dpt.int32,
246
            dpt.uint32,
247
            dpt.int64,
248
            dpt.uint64,
249
            dpt.float16,
250
            dpt.float32,
251
            dpt.float64,
252
            dpt.complex64,
253
            dpt.complex128,
254
        ]
255
    )
256

257

258
def _weak_type_num_kind(o):
1✔
259
    _map = {"?": 0, "i": 1, "f": 2, "c": 3}
1✔
260
    if isinstance(o, WeakBooleanType):
1✔
261
        return _map["?"]
1✔
262
    if isinstance(o, WeakIntegralType):
1✔
263
        return _map["i"]
1✔
264
    if isinstance(o, WeakFloatingType):
1✔
265
        return _map["f"]
1✔
266
    if isinstance(o, WeakComplexType):
1!
267
        return _map["c"]
1✔
268
    raise TypeError(
×
269
        f"Unexpected type {o} while expecting "
270
        "`WeakBooleanType`, `WeakIntegralType`,"
271
        "`WeakFloatingType`, or `WeakComplexType`."
272
    )
273

274

275
def _strong_dtype_num_kind(o):
1✔
276
    _map = {"b": 0, "i": 1, "u": 1, "f": 2, "c": 3}
1✔
277
    if not isinstance(o, dpt.dtype):
1!
278
        raise TypeError
×
279
    k = o.kind
1✔
280
    if k in _map:
1!
281
        return _map[k]
1✔
282
    raise ValueError(f"Unrecognized kind {k} for dtype {o}")
×
283

284

285
def _resolve_weak_types(o1_dtype, o2_dtype, dev):
1✔
286
    "Resolves weak data type per NEP-0050"
287
    if isinstance(
1✔
288
        o1_dtype,
289
        (WeakBooleanType, WeakIntegralType, WeakFloatingType, WeakComplexType),
290
    ):
291
        if isinstance(
1!
292
            o2_dtype,
293
            (
294
                WeakBooleanType,
295
                WeakIntegralType,
296
                WeakFloatingType,
297
                WeakComplexType,
298
            ),
299
        ):
300
            raise ValueError
×
301
        o1_kind_num = _weak_type_num_kind(o1_dtype)
1✔
302
        o2_kind_num = _strong_dtype_num_kind(o2_dtype)
1✔
303
        if o1_kind_num > o2_kind_num:
1✔
304
            if isinstance(o1_dtype, WeakIntegralType):
1✔
305
                return dpt.int64, o2_dtype
1✔
306
            if isinstance(o1_dtype, WeakComplexType):
1✔
307
                if o2_dtype is dpt.float16 or o2_dtype is dpt.float32:
1✔
308
                    return dpt.complex64, o2_dtype
1✔
309
                return (
1✔
310
                    _to_device_supported_dtype(dpt.complex128, dev),
311
                    o2_dtype,
312
                )
313
            return _to_device_supported_dtype(dpt.float64, dev), o2_dtype
1✔
314
        else:
315
            return o2_dtype, o2_dtype
1✔
316
    elif isinstance(
1✔
317
        o2_dtype,
318
        (WeakBooleanType, WeakIntegralType, WeakFloatingType, WeakComplexType),
319
    ):
320
        o1_kind_num = _strong_dtype_num_kind(o1_dtype)
1✔
321
        o2_kind_num = _weak_type_num_kind(o2_dtype)
1✔
322
        if o2_kind_num > o1_kind_num:
1✔
323
            if isinstance(o2_dtype, WeakIntegralType):
1✔
324
                return o1_dtype, dpt.int64
1✔
325
            if isinstance(o2_dtype, WeakComplexType):
1✔
326
                if o1_dtype is dpt.float16 or o1_dtype is dpt.float32:
1✔
327
                    return o1_dtype, dpt.complex64
1✔
328
                return o1_dtype, _to_device_supported_dtype(dpt.complex128, dev)
1✔
329
            return (
1✔
330
                o1_dtype,
331
                _to_device_supported_dtype(dpt.float64, dev),
332
            )
333
        else:
334
            return o1_dtype, o1_dtype
1✔
335
    else:
336
        return o1_dtype, o2_dtype
1✔
337

338

339
def _get_shape(o):
1✔
340
    if isinstance(o, dpt.usm_ndarray):
1✔
341
        return o.shape
1✔
342
    if _is_buffer(o):
1✔
343
        return memoryview(o).shape
1✔
344
    if isinstance(o, numbers.Number):
1✔
345
        return tuple()
1✔
346
    return getattr(o, "shape", tuple())
1✔
347

348

349
class BinaryElementwiseFunc:
1✔
350
    """
351
    Class that implements binary element-wise functions.
352
    """
353

354
    def __init__(
1✔
355
        self,
356
        name,
357
        result_type_resolver_fn,
358
        binary_dp_impl_fn,
359
        docs,
360
        binary_inplace_fn=None,
361
        acceptance_fn=None,
362
    ):
363
        self.__name__ = "BinaryElementwiseFunc"
1✔
364
        self.name_ = name
1✔
365
        self.result_type_resolver_fn_ = result_type_resolver_fn
1✔
366
        self.binary_fn_ = binary_dp_impl_fn
1✔
367
        self.binary_inplace_fn_ = binary_inplace_fn
1✔
368
        self.__doc__ = docs
1✔
369
        if callable(acceptance_fn):
1✔
370
            self.acceptance_fn_ = acceptance_fn
1✔
371
        else:
372
            self.acceptance_fn_ = _acceptance_fn_default
1✔
373

374
    def __str__(self):
1✔
375
        return f"<{self.__name__} '{self.name_}'>"
1✔
376

377
    def __repr__(self):
1✔
378
        return f"<{self.__name__} '{self.name_}'>"
1✔
379

380
    def __call__(self, o1, o2, out=None, order="K"):
1✔
381
        if order not in ["K", "C", "F", "A"]:
1✔
382
            order = "K"
1✔
383
        q1, o1_usm_type = _get_queue_usm_type(o1)
1✔
384
        q2, o2_usm_type = _get_queue_usm_type(o2)
1✔
385
        if q1 is None and q2 is None:
1✔
386
            raise ExecutionPlacementError(
1✔
387
                "Execution placement can not be unambiguously inferred "
388
                "from input arguments. "
389
                "One of the arguments must represent USM allocation and "
390
                "expose `__sycl_usm_array_interface__` property"
391
            )
392
        if q1 is None:
1✔
393
            exec_q = q2
1✔
394
            res_usm_type = o2_usm_type
1✔
395
        elif q2 is None:
1✔
396
            exec_q = q1
1✔
397
            res_usm_type = o1_usm_type
1✔
398
        else:
399
            exec_q = dpctl.utils.get_execution_queue((q1, q2))
1✔
400
            if exec_q is None:
1✔
401
                raise ExecutionPlacementError(
1✔
402
                    "Execution placement can not be unambiguously inferred "
403
                    "from input arguments."
404
                )
405
            res_usm_type = dpctl.utils.get_coerced_usm_type(
1✔
406
                (
407
                    o1_usm_type,
408
                    o2_usm_type,
409
                )
410
            )
411
        dpctl.utils.validate_usm_type(res_usm_type, allow_none=False)
1✔
412
        o1_shape = _get_shape(o1)
1✔
413
        o2_shape = _get_shape(o2)
1✔
414
        if not all(
1!
415
            isinstance(s, (tuple, list))
416
            for s in (
417
                o1_shape,
418
                o2_shape,
419
            )
420
        ):
421
            raise TypeError(
×
422
                "Shape of arguments can not be inferred. "
423
                "Arguments are expected to be "
424
                "lists, tuples, or both"
425
            )
426
        try:
1✔
427
            res_shape = _broadcast_shape_impl(
1✔
428
                [
429
                    o1_shape,
430
                    o2_shape,
431
                ]
432
            )
433
        except ValueError:
1✔
434
            raise ValueError(
1✔
435
                "operands could not be broadcast together with shapes "
436
                f"{o1_shape} and {o2_shape}"
437
            )
438
        sycl_dev = exec_q.sycl_device
1✔
439
        o1_dtype = _get_dtype(o1, sycl_dev)
1✔
440
        o2_dtype = _get_dtype(o2, sycl_dev)
1✔
441
        if not all(_validate_dtype(o) for o in (o1_dtype, o2_dtype)):
1✔
442
            raise ValueError("Operands of unsupported types")
1✔
443

444
        o1_dtype, o2_dtype = _resolve_weak_types(o1_dtype, o2_dtype, sycl_dev)
1✔
445

446
        buf1_dt, buf2_dt, res_dt = _find_buf_dtype2(
1✔
447
            o1_dtype,
448
            o2_dtype,
449
            self.result_type_resolver_fn_,
450
            sycl_dev,
451
            acceptance_fn=self.acceptance_fn_,
452
        )
453

454
        if res_dt is None:
1!
455
            raise TypeError(
×
456
                f"function '{self.name_}' does not support input types "
457
                f"({o1_dtype}, {o2_dtype}), "
458
                "and the inputs could not be safely coerced to any "
459
                "supported types according to the casting rule ''safe''."
460
            )
461

462
        orig_out = out
1✔
463
        if out is not None:
1✔
464
            if not isinstance(out, dpt.usm_ndarray):
1✔
465
                raise TypeError(
1✔
466
                    f"output array must be of usm_ndarray type, got {type(out)}"
467
                )
468

469
            if out.shape != res_shape:
1✔
470
                raise ValueError(
1✔
471
                    "The shape of input and output arrays are inconsistent. "
472
                    f"Expected output shape is {o1_shape}, got {out.shape}"
473
                )
474

475
            if res_dt != out.dtype:
1✔
476
                raise TypeError(
1✔
477
                    f"Output array of type {res_dt} is needed,"
478
                    f"got {out.dtype}"
479
                )
480

481
            if (
1✔
482
                dpctl.utils.get_execution_queue((exec_q, out.sycl_queue))
483
                is None
484
            ):
485
                raise ExecutionPlacementError(
1✔
486
                    "Input and output allocation queues are not compatible"
487
                )
488

489
            if isinstance(o1, dpt.usm_ndarray):
1!
490
                if ti._array_overlap(o1, out) and buf1_dt is None:
1✔
491
                    if not ti._same_logical_tensors(o1, out):
1✔
492
                        out = dpt.empty_like(out)
1✔
493
                    elif self.binary_inplace_fn_ is not None:
1✔
494
                        # if there is a dedicated in-place kernel
495
                        # it can be called here, otherwise continues
496
                        if isinstance(o2, dpt.usm_ndarray):
1✔
497
                            src2 = o2
1✔
498
                            if (
1!
499
                                ti._array_overlap(o2, out)
500
                                and not ti._same_logical_tensors(o2, out)
501
                                and buf2_dt is None
502
                            ):
NEW
503
                                buf2_dt = o2_dtype
×
504
                        else:
505
                            src2 = dpt.asarray(
1✔
506
                                o2, dtype=o2_dtype, sycl_queue=exec_q
507
                            )
508
                        if buf2_dt is None:
1✔
509
                            if src2.shape != res_shape:
1✔
510
                                src2 = dpt.broadcast_to(src2, res_shape)
1✔
511
                            ht_, _ = self.binary_inplace_fn_(
1✔
512
                                lhs=o1, rhs=src2, sycl_queue=exec_q
513
                            )
514
                            ht_.wait()
1✔
515
                        else:
516
                            buf2 = dpt.empty_like(src2, dtype=buf2_dt)
1✔
517
                            (
1✔
518
                                ht_copy_ev,
519
                                copy_ev,
520
                            ) = ti._copy_usm_ndarray_into_usm_ndarray(
521
                                src=src2, dst=buf2, sycl_queue=exec_q
522
                            )
523

524
                            buf2 = dpt.broadcast_to(buf2, res_shape)
1✔
525
                            ht_, _ = self.binary_inplace_fn_(
1✔
526
                                lhs=o1,
527
                                rhs=buf2,
528
                                sycl_queue=exec_q,
529
                                depends=[copy_ev],
530
                            )
531
                            ht_copy_ev.wait()
1✔
532
                            ht_.wait()
1✔
533

534
                        return out
1✔
535

536
            if isinstance(o2, dpt.usm_ndarray):
1✔
537
                if (
1!
538
                    ti._array_overlap(o2, out)
539
                    and not ti._same_logical_tensors(o2, out)
540
                    and buf2_dt is None
541
                ):
542
                    # should not reach if out is reallocated
543
                    # after being checked against o1
NEW
544
                    out = dpt.empty_like(out)
×
545

546
        if isinstance(o1, dpt.usm_ndarray):
1✔
547
            src1 = o1
1✔
548
        else:
549
            src1 = dpt.asarray(o1, dtype=o1_dtype, sycl_queue=exec_q)
1✔
550
        if isinstance(o2, dpt.usm_ndarray):
1✔
551
            src2 = o2
1✔
552
        else:
553
            src2 = dpt.asarray(o2, dtype=o2_dtype, sycl_queue=exec_q)
1✔
554

555
        if buf1_dt is None and buf2_dt is None:
1✔
556
            if out is None:
1✔
557
                if order == "K":
1✔
558
                    out = _empty_like_pair_orderK(
1✔
559
                        src1, src2, res_dt, res_shape, res_usm_type, exec_q
560
                    )
561
                else:
562
                    if order == "A":
1✔
563
                        order = (
1✔
564
                            "F"
565
                            if all(
566
                                arr.flags.f_contiguous
567
                                for arr in (
568
                                    src1,
569
                                    src2,
570
                                )
571
                            )
572
                            else "C"
573
                        )
574
                    out = dpt.empty(
1✔
575
                        res_shape,
576
                        dtype=res_dt,
577
                        usm_type=res_usm_type,
578
                        sycl_queue=exec_q,
579
                        order=order,
580
                    )
581
            if src1.shape != res_shape:
1✔
582
                src1 = dpt.broadcast_to(src1, res_shape)
1✔
583
            if src2.shape != res_shape:
1✔
584
                src2 = dpt.broadcast_to(src2, res_shape)
1✔
585
            ht_binary_ev, binary_ev = self.binary_fn_(
1✔
586
                src1=src1, src2=src2, dst=out, sycl_queue=exec_q
587
            )
588
            if not (orig_out is None or orig_out is out):
1✔
589
                # Copy the out data from temporary buffer to original memory
590
                ht_copy_out_ev, _ = ti._copy_usm_ndarray_into_usm_ndarray(
1✔
591
                    src=out,
592
                    dst=orig_out,
593
                    sycl_queue=exec_q,
594
                    depends=[binary_ev],
595
                )
596
                ht_copy_out_ev.wait()
1✔
597
                out = orig_out
1✔
598
            ht_binary_ev.wait()
1✔
599
            return out
1✔
600
        elif buf1_dt is None:
1✔
601
            if order == "K":
1!
602
                buf2 = _empty_like_orderK(src2, buf2_dt)
1✔
603
            else:
604
                if order == "A":
×
605
                    order = "F" if src1.flags.f_contiguous else "C"
×
606
                buf2 = dpt.empty_like(src2, dtype=buf2_dt, order=order)
×
607
            ht_copy_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
1✔
608
                src=src2, dst=buf2, sycl_queue=exec_q
609
            )
610
            if out is None:
1✔
611
                if order == "K":
1!
612
                    out = _empty_like_pair_orderK(
1✔
613
                        src1, buf2, res_dt, res_shape, res_usm_type, exec_q
614
                    )
615
                else:
616
                    out = dpt.empty(
×
617
                        res_shape,
618
                        dtype=res_dt,
619
                        usm_type=res_usm_type,
620
                        sycl_queue=exec_q,
621
                        order=order,
622
                    )
623
            else:
624
                if res_dt != out.dtype:
1!
UNCOV
625
                    raise TypeError(
×
626
                        f"Output array of type {res_dt} is needed,"
627
                        f"got {out.dtype}"
628
                    )
629
            if src1.shape != res_shape:
1✔
630
                src1 = dpt.broadcast_to(src1, res_shape)
1✔
631
            buf2 = dpt.broadcast_to(buf2, res_shape)
1✔
632
            ht_binary_ev, binary_ev = self.binary_fn_(
1✔
633
                src1=src1,
634
                src2=buf2,
635
                dst=out,
636
                sycl_queue=exec_q,
637
                depends=[copy_ev],
638
            )
639
            if not (orig_out is None or orig_out is out):
1!
640
                # Copy the out data from temporary buffer to original memory
NEW
641
                ht_copy_out_ev, _ = ti._copy_usm_ndarray_into_usm_ndarray(
×
642
                    src=out,
643
                    dst=orig_out,
644
                    sycl_queue=exec_q,
645
                    depends=[binary_ev],
646
                )
NEW
647
                ht_copy_out_ev.wait()
×
NEW
648
                out = orig_out
×
649
            ht_copy_ev.wait()
1✔
650
            ht_binary_ev.wait()
1✔
651
            return out
1✔
652
        elif buf2_dt is None:
1✔
653
            if order == "K":
1!
654
                buf1 = _empty_like_orderK(src1, buf1_dt)
1✔
655
            else:
656
                if order == "A":
×
657
                    order = "F" if src1.flags.f_contiguous else "C"
×
658
                buf1 = dpt.empty_like(src1, dtype=buf1_dt, order=order)
×
659
            ht_copy_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
1✔
660
                src=src1, dst=buf1, sycl_queue=exec_q
661
            )
662
            if out is None:
1✔
663
                if order == "K":
1!
664
                    out = _empty_like_pair_orderK(
1✔
665
                        buf1, src2, res_dt, res_shape, res_usm_type, exec_q
666
                    )
667
                else:
668
                    out = dpt.empty(
×
669
                        res_shape,
670
                        dtype=res_dt,
671
                        usm_type=res_usm_type,
672
                        sycl_queue=exec_q,
673
                        order=order,
674
                    )
675

676
            buf1 = dpt.broadcast_to(buf1, res_shape)
1✔
677
            if src2.shape != res_shape:
1✔
678
                src2 = dpt.broadcast_to(src2, res_shape)
1✔
679
            ht_binary_ev, binary_ev = self.binary_fn_(
1✔
680
                src1=buf1,
681
                src2=src2,
682
                dst=out,
683
                sycl_queue=exec_q,
684
                depends=[copy_ev],
685
            )
686
            if not (orig_out is None or orig_out is out):
1!
687
                # Copy the out data from temporary buffer to original memory
NEW
688
                ht_copy_out_ev, _ = ti._copy_usm_ndarray_into_usm_ndarray(
×
689
                    src=out,
690
                    dst=orig_out,
691
                    sycl_queue=exec_q,
692
                    depends=[binary_ev],
693
                )
NEW
694
                ht_copy_out_ev.wait()
×
NEW
695
                out = orig_out
×
696
            ht_copy_ev.wait()
1✔
697
            ht_binary_ev.wait()
1✔
698
            return out
1✔
699

700
        if order in ["K", "A"]:
1✔
701
            if src1.flags.f_contiguous and src2.flags.f_contiguous:
1✔
702
                order = "F"
1✔
703
            elif src1.flags.c_contiguous and src2.flags.c_contiguous:
1✔
704
                order = "C"
1✔
705
            else:
706
                order = "C" if order == "A" else "K"
1✔
707
        if order == "K":
1✔
708
            buf1 = _empty_like_orderK(src1, buf1_dt)
1✔
709
        else:
710
            buf1 = dpt.empty_like(src1, dtype=buf1_dt, order=order)
1✔
711
        ht_copy1_ev, copy1_ev = ti._copy_usm_ndarray_into_usm_ndarray(
1✔
712
            src=src1, dst=buf1, sycl_queue=exec_q
713
        )
714
        if order == "K":
1✔
715
            buf2 = _empty_like_orderK(src2, buf2_dt)
1✔
716
        else:
717
            buf2 = dpt.empty_like(src2, dtype=buf2_dt, order=order)
1✔
718
        ht_copy2_ev, copy2_ev = ti._copy_usm_ndarray_into_usm_ndarray(
1✔
719
            src=src2, dst=buf2, sycl_queue=exec_q
720
        )
721
        if out is None:
1✔
722
            if order == "K":
1✔
723
                out = _empty_like_pair_orderK(
1✔
724
                    buf1, buf2, res_dt, res_shape, res_usm_type, exec_q
725
                )
726
            else:
727
                out = dpt.empty(
1✔
728
                    res_shape,
729
                    dtype=res_dt,
730
                    usm_type=res_usm_type,
731
                    sycl_queue=exec_q,
732
                    order=order,
733
                )
734

735
        buf1 = dpt.broadcast_to(buf1, res_shape)
1✔
736
        buf2 = dpt.broadcast_to(buf2, res_shape)
1✔
737
        ht_, _ = self.binary_fn_(
1✔
738
            src1=buf1,
739
            src2=buf2,
740
            dst=out,
741
            sycl_queue=exec_q,
742
            depends=[copy1_ev, copy2_ev],
743
        )
744
        dpctl.SyclEvent.wait_for([ht_copy1_ev, ht_copy2_ev, ht_])
1✔
745
        return out
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc