• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

IntelPython / dpctl / 9391233749

05 Jun 2024 09:07PM UTC coverage: 88.057% (+0.1%) from 87.911%
9391233749

Pull #1705

github

web-flow
Merge f4e3b6f1b into c6f5f790b
Pull Request #1705: Change memory object USM allocation ownership, and make execution asynchronous

3276 of 3765 branches covered (87.01%)

Branch coverage included in aggregate %.

572 of 634 new or added lines in 23 files covered. (90.22%)

2 existing lines in 2 files now uncovered.

11212 of 12688 relevant lines covered (88.37%)

7552.79 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

93.49
/dpctl/tensor/_elementwise_common.py
1
#                       Data Parallel Control (dpctl)
2
#
3
#  Copyright 2020-2024 Intel Corporation
4
#
5
#  Licensed under the Apache License, Version 2.0 (the "License");
6
#  you may not use this file except in compliance with the License.
7
#  You may obtain a copy of the License at
8
#
9
#     http://www.apache.org/licenses/LICENSE-2.0
10
#
11
#  Unless required by applicable law or agreed to in writing, software
12
#  distributed under the License is distributed on an "AS IS" BASIS,
13
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
#  See the License for the specific language governing permissions and
15
#  limitations under the License.
16

17
import numbers
1✔
18

19
import numpy as np
1✔
20

21
import dpctl
1✔
22
import dpctl.memory as dpm
1✔
23
import dpctl.tensor as dpt
1✔
24
import dpctl.tensor._tensor_impl as ti
1✔
25
from dpctl.tensor._manipulation_functions import _broadcast_shape_impl
1✔
26
from dpctl.tensor._usmarray import _is_object_with_buffer_protocol as _is_buffer
1✔
27
from dpctl.utils import ExecutionPlacementError, SequentialOrderManager
1✔
28

29
from ._copy_utils import _empty_like_orderK, _empty_like_pair_orderK
1✔
30
from ._type_utils import (
1✔
31
    WeakBooleanType,
32
    WeakComplexType,
33
    WeakFloatingType,
34
    WeakIntegralType,
35
    _acceptance_fn_default_binary,
36
    _acceptance_fn_default_unary,
37
    _all_data_types,
38
    _find_buf_dtype,
39
    _find_buf_dtype2,
40
    _resolve_weak_types,
41
    _to_device_supported_dtype,
42
)
43

44

45
class UnaryElementwiseFunc:
1✔
46
    """
47
    Class that implements unary element-wise functions.
48

49
    Args:
50
        name (str):
51
            Name of the unary function
52
        result_type_resovler_fn (callable):
53
            Function that takes dtype of the input and
54
            returns the dtype of the result if the
55
            implementation functions supports it, or
56
            returns `None` otherwise.
57
        unary_dp_impl_fn (callable):
58
            Data-parallel implementation function with signature
59
            `impl_fn(src: usm_ndarray, dst: usm_ndarray,
60
             sycl_queue: SyclQueue, depends: Optional[List[SyclEvent]])`
61
            where the `src` is the argument array, `dst` is the
62
            array to be populated with function values, effectively
63
            evaluating `dst = func(src)`.
64
            The `impl_fn` is expected to return a 2-tuple of `SyclEvent`s.
65
            The first event corresponds to data-management host tasks,
66
            including lifetime management of argument Python objects to ensure
67
            that their associated USM allocation is not freed before offloaded
68
            computational tasks complete execution, while the second event
69
            corresponds to computational tasks associated with function
70
            evaluation.
71
        acceptance_fn (callable, optional):
72
            Function to influence type promotion behavior of this unary
73
            function. The function takes 4 arguments:
74
                arg_dtype - Data type of the first argument
75
                buf_dtype - Data type the argument would be cast to
76
                res_dtype - Data type of the output array with function values
77
                sycl_dev - The :class:`dpctl.SyclDevice` where the function
78
                    evaluation is carried out.
79
            The function is invoked when the argument of the unary function
80
            requires casting, e.g. the argument of `dpctl.tensor.log` is an
81
            array with integral data type.
82
        docs (str):
83
            Documentation string for the unary function.
84
    """
85

86
    def __init__(
1✔
87
        self,
88
        name,
89
        result_type_resolver_fn,
90
        unary_dp_impl_fn,
91
        docs,
92
        acceptance_fn=None,
93
    ):
94
        self.__name__ = "UnaryElementwiseFunc"
1✔
95
        self.name_ = name
1✔
96
        self.result_type_resolver_fn_ = result_type_resolver_fn
1✔
97
        self.types_ = None
1✔
98
        self.unary_fn_ = unary_dp_impl_fn
1✔
99
        self.__doc__ = docs
1✔
100
        if callable(acceptance_fn):
1✔
101
            self.acceptance_fn_ = acceptance_fn
1✔
102
        else:
103
            self.acceptance_fn_ = _acceptance_fn_default_unary
1✔
104

105
    def __str__(self):
1✔
106
        return f"<{self.__name__} '{self.name_}'>"
1✔
107

108
    def __repr__(self):
1✔
109
        return f"<{self.__name__} '{self.name_}'>"
1✔
110

111
    def get_implementation_function(self):
1✔
112
        """Returns the implementation function for
113
        this elementwise unary function.
114

115
        """
116
        return self.unary_fn_
1✔
117

118
    def get_type_result_resolver_function(self):
1✔
119
        """Returns the type resolver function for this
120
        elementwise unary function.
121
        """
122
        return self.result_type_resolver_fn_
1✔
123

124
    def get_type_promotion_path_acceptance_function(self):
1✔
125
        """Returns the acceptance function for this
126
        elementwise binary function.
127

128
        Acceptance function influences the type promotion
129
        behavior of this unary function.
130
        The function takes 4 arguments:
131
            arg_dtype - Data type of the first argument
132
            buf_dtype - Data type the argument would be cast to
133
            res_dtype - Data type of the output array with function values
134
            sycl_dev - The :class:`dpctl.SyclDevice` where the function
135
                evaluation is carried out.
136
        The function is invoked when the argument of the unary function
137
        requires casting, e.g. the argument of `dpctl.tensor.log` is an
138
        array with integral data type.
139
        """
140
        return self.acceptance_fn_
×
141

142
    @property
1✔
143
    def nin(self):
1✔
144
        """
145
        Returns the number of arguments treated as inputs.
146
        """
147
        return 1
1✔
148

149
    @property
1✔
150
    def nout(self):
1✔
151
        """
152
        Returns the number of arguments treated as outputs.
153
        """
154
        return 1
1✔
155

156
    @property
1✔
157
    def types(self):
1✔
158
        """Returns information about types supported by
159
        implementation function, using NumPy's character
160
        encoding for data types, e.g.
161

162
        :Example:
163
            .. code-block:: python
164

165
                dpctl.tensor.sin.types
166
                # Outputs: ['e->e', 'f->f', 'd->d', 'F->F', 'D->D']
167
        """
168
        types = self.types_
1✔
169
        if not types:
1!
170
            types = []
1✔
171
            for dt1 in _all_data_types(True, True):
1✔
172
                dt2 = self.result_type_resolver_fn_(dt1)
1✔
173
                if dt2:
1✔
174
                    types.append(f"{dt1.char}->{dt2.char}")
1✔
175
            self.types_ = types
1✔
176
        return types
1✔
177

178
    def __call__(self, x, /, *, out=None, order="K"):
1✔
179
        if not isinstance(x, dpt.usm_ndarray):
1✔
180
            raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x)}")
1✔
181

182
        if order not in ["C", "F", "K", "A"]:
1✔
183
            order = "K"
1✔
184
        buf_dt, res_dt = _find_buf_dtype(
1✔
185
            x.dtype,
186
            self.result_type_resolver_fn_,
187
            x.sycl_device,
188
            acceptance_fn=self.acceptance_fn_,
189
        )
190
        if res_dt is None:
1!
191
            raise ValueError(
×
192
                f"function '{self.name_}' does not support input type "
193
                f"({x.dtype}), "
194
                "and the input could not be safely coerced to any "
195
                "supported types according to the casting rule ''safe''."
196
            )
197

198
        orig_out = out
1✔
199
        if out is not None:
1✔
200
            if not isinstance(out, dpt.usm_ndarray):
1!
201
                raise TypeError(
×
202
                    f"output array must be of usm_ndarray type, got {type(out)}"
203
                )
204

205
            if not out.flags.writable:
1✔
206
                raise ValueError("provided `out` array is read-only")
1✔
207

208
            if out.shape != x.shape:
1!
209
                raise ValueError(
×
210
                    "The shape of input and output arrays are inconsistent. "
211
                    f"Expected output shape is {x.shape}, got {out.shape}"
212
                )
213

214
            if res_dt != out.dtype:
1✔
215
                raise ValueError(
1✔
216
                    f"Output array of type {res_dt} is needed,"
217
                    f" got {out.dtype}"
218
                )
219

220
            if (
1✔
221
                buf_dt is None
222
                and ti._array_overlap(x, out)
223
                and not ti._same_logical_tensors(x, out)
224
            ):
225
                # Allocate a temporary buffer to avoid memory overlapping.
226
                # Note if `buf_dt` is not None, a temporary copy of `x` will be
227
                # created, so the array overlap check isn't needed.
228
                out = dpt.empty_like(out)
1✔
229

230
            if (
1!
231
                dpctl.utils.get_execution_queue((x.sycl_queue, out.sycl_queue))
232
                is None
233
            ):
234
                raise ExecutionPlacementError(
×
235
                    "Input and output allocation queues are not compatible"
236
                )
237

238
        exec_q = x.sycl_queue
1✔
239
        _manager = SequentialOrderManager[exec_q]
1✔
240
        if buf_dt is None:
1✔
241
            if out is None:
1✔
242
                if order == "K":
1✔
243
                    out = _empty_like_orderK(x, res_dt)
1✔
244
                else:
245
                    if order == "A":
1✔
246
                        order = "F" if x.flags.f_contiguous else "C"
1✔
247
                    out = dpt.empty_like(x, dtype=res_dt, order=order)
1✔
248

249
            dep_evs = _manager.submitted_events
1✔
250
            ht_unary_ev, unary_ev = self.unary_fn_(
1✔
251
                x, out, sycl_queue=exec_q, depends=dep_evs
252
            )
253
            _manager.add_event_pair(ht_unary_ev, unary_ev)
1✔
254

255
            if not (orig_out is None or orig_out is out):
1✔
256
                # Copy the out data from temporary buffer to original memory
257
                ht_copy_ev, cpy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
1✔
258
                    src=out, dst=orig_out, sycl_queue=exec_q, depends=[unary_ev]
259
                )
260
                _manager.add_event_pair(ht_copy_ev, cpy_ev)
1✔
261
                out = orig_out
1✔
262

263
            return out
1✔
264

265
        if order == "K":
1✔
266
            buf = _empty_like_orderK(x, buf_dt)
1✔
267
        else:
268
            if order == "A":
1✔
269
                order = "F" if x.flags.f_contiguous else "C"
1✔
270
            buf = dpt.empty_like(x, dtype=buf_dt, order=order)
1✔
271

272
        dep_evs = _manager.submitted_events
1✔
273
        ht_copy_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
1✔
274
            src=x, dst=buf, sycl_queue=exec_q, depends=dep_evs
275
        )
276
        _manager.add_event_pair(ht_copy_ev, copy_ev)
1✔
277
        if out is None:
1✔
278
            if order == "K":
1✔
279
                out = _empty_like_orderK(buf, res_dt)
1✔
280
            else:
281
                out = dpt.empty_like(buf, dtype=res_dt, order=order)
1✔
282

283
        ht, uf_ev = self.unary_fn_(
1✔
284
            buf, out, sycl_queue=exec_q, depends=[copy_ev]
285
        )
286
        _manager.add_event_pair(ht, uf_ev)
1✔
287

288
        return out
1✔
289

290

291
def _get_queue_usm_type(o):
1✔
292
    """Return SYCL device where object `o` allocated memory, or None."""
293
    if isinstance(o, dpt.usm_ndarray):
1✔
294
        return o.sycl_queue, o.usm_type
1✔
295
    elif hasattr(o, "__sycl_usm_array_interface__"):
1✔
296
        try:
1✔
297
            m = dpm.as_usm_memory(o)
1✔
298
            return m.sycl_queue, m.get_usm_type()
1✔
299
        except Exception:
1✔
300
            return None, None
1✔
301
    return None, None
1✔
302

303

304
def _get_dtype(o, dev):
1✔
305
    if isinstance(o, dpt.usm_ndarray):
1✔
306
        return o.dtype
1✔
307
    if hasattr(o, "__sycl_usm_array_interface__"):
1✔
308
        return dpt.asarray(o).dtype
1✔
309
    if _is_buffer(o):
1✔
310
        host_dt = np.array(o).dtype
1✔
311
        dev_dt = _to_device_supported_dtype(host_dt, dev)
1✔
312
        return dev_dt
1✔
313
    if hasattr(o, "dtype"):
1!
314
        dev_dt = _to_device_supported_dtype(o.dtype, dev)
×
315
        return dev_dt
×
316
    if isinstance(o, bool):
1✔
317
        return WeakBooleanType(o)
1✔
318
    if isinstance(o, int):
1✔
319
        return WeakIntegralType(o)
1✔
320
    if isinstance(o, float):
1✔
321
        return WeakFloatingType(o)
1✔
322
    if isinstance(o, complex):
1✔
323
        return WeakComplexType(o)
1✔
324
    return np.object_
1✔
325

326

327
def _validate_dtype(dt) -> bool:
1✔
328
    return isinstance(
1✔
329
        dt,
330
        (WeakBooleanType, WeakIntegralType, WeakFloatingType, WeakComplexType),
331
    ) or (
332
        isinstance(dt, dpt.dtype)
333
        and dt
334
        in [
335
            dpt.bool,
336
            dpt.int8,
337
            dpt.uint8,
338
            dpt.int16,
339
            dpt.uint16,
340
            dpt.int32,
341
            dpt.uint32,
342
            dpt.int64,
343
            dpt.uint64,
344
            dpt.float16,
345
            dpt.float32,
346
            dpt.float64,
347
            dpt.complex64,
348
            dpt.complex128,
349
        ]
350
    )
351

352

353
def _get_shape(o):
1✔
354
    if isinstance(o, dpt.usm_ndarray):
1✔
355
        return o.shape
1✔
356
    if _is_buffer(o):
1✔
357
        return memoryview(o).shape
1✔
358
    if isinstance(o, numbers.Number):
1✔
359
        return tuple()
1✔
360
    return getattr(o, "shape", tuple())
1✔
361

362

363
class BinaryElementwiseFunc:
1✔
364
    """
365
    Class that implements binary element-wise functions.
366

367
    Args:
368
        name (str):
369
            Name of the unary function
370
        result_type_resovle_fn (callable):
371
            Function that takes dtypes of the input and
372
            returns the dtype of the result if the
373
            implementation functions supports it, or
374
            returns `None` otherwise.
375
        binary_dp_impl_fn (callable):
376
            Data-parallel implementation function with signature
377
            `impl_fn(src1: usm_ndarray, src2: usm_ndarray, dst: usm_ndarray,
378
             sycl_queue: SyclQueue, depends: Optional[List[SyclEvent]])`
379
            where the `src1` and `src2` are the argument arrays, `dst` is the
380
            array to be populated with function values,
381
            i.e. `dst=func(src1, src2)`.
382
            The `impl_fn` is expected to return a 2-tuple of `SyclEvent`s.
383
            The first event corresponds to data-management host tasks,
384
            including lifetime management of argument Python objects to ensure
385
            that their associated USM allocation is not freed before offloaded
386
            computational tasks complete execution, while the second event
387
            corresponds to computational tasks associated with function
388
            evaluation.
389
        docs (str):
390
            Documentation string for the unary function.
391
        binary_inplace_fn (callable, optional):
392
            Data-parallel implementation function with signature
393
            `impl_fn(src: usm_ndarray, dst: usm_ndarray,
394
             sycl_queue: SyclQueue, depends: Optional[List[SyclEvent]])`
395
            where the `src` is the argument array, `dst` is the
396
            array to be populated with function values,
397
            i.e. `dst=func(dst, src)`.
398
            The `impl_fn` is expected to return a 2-tuple of `SyclEvent`s.
399
            The first event corresponds to data-management host tasks,
400
            including async lifetime management of Python arguments,
401
            while the second event corresponds to computational tasks
402
            associated with function evaluation.
403
        acceptance_fn (callable, optional):
404
            Function to influence type promotion behavior of this binary
405
            function. The function takes 6 arguments:
406
                arg1_dtype - Data type of the first argument
407
                arg2_dtype - Data type of the second argument
408
                ret_buf1_dtype - Data type the first argument would be cast to
409
                ret_buf2_dtype - Data type the second argument would be cast to
410
                res_dtype - Data type of the output array with function values
411
                sycl_dev - The :class:`dpctl.SyclDevice` where the function
412
                    evaluation is carried out.
413
            The function is only called when both arguments of the binary
414
            function require casting, e.g. both arguments of
415
            `dpctl.tensor.logaddexp` are arrays with integral data type.
416
    """
417

418
    def __init__(
1✔
419
        self,
420
        name,
421
        result_type_resolver_fn,
422
        binary_dp_impl_fn,
423
        docs,
424
        binary_inplace_fn=None,
425
        acceptance_fn=None,
426
        weak_type_resolver=None,
427
    ):
428
        self.__name__ = "BinaryElementwiseFunc"
1✔
429
        self.name_ = name
1✔
430
        self.result_type_resolver_fn_ = result_type_resolver_fn
1✔
431
        self.types_ = None
1✔
432
        self.binary_fn_ = binary_dp_impl_fn
1✔
433
        self.binary_inplace_fn_ = binary_inplace_fn
1✔
434
        self.__doc__ = docs
1✔
435
        if callable(acceptance_fn):
1✔
436
            self.acceptance_fn_ = acceptance_fn
1✔
437
        else:
438
            self.acceptance_fn_ = _acceptance_fn_default_binary
1✔
439
        if callable(weak_type_resolver):
1✔
440
            self.weak_type_resolver_ = weak_type_resolver
1✔
441
        else:
442
            self.weak_type_resolver_ = _resolve_weak_types
1✔
443

444
    def __str__(self):
1✔
445
        return f"<{self.__name__} '{self.name_}'>"
1✔
446

447
    def __repr__(self):
1✔
448
        return f"<{self.__name__} '{self.name_}'>"
1✔
449

450
    def get_implementation_function(self):
1✔
451
        """Returns the out-of-place implementation
452
        function for this elementwise binary function.
453

454
        """
455
        return self.binary_fn_
1✔
456

457
    def get_implementation_inplace_function(self):
1✔
458
        """Returns the in-place implementation
459
        function for this elementwise binary function.
460

461
        """
462
        return self.binary_inplace_fn_
1✔
463

464
    def get_type_result_resolver_function(self):
1✔
465
        """Returns the type resolver function for this
466
        elementwise binary function.
467
        """
468
        return self.result_type_resolver_fn_
1✔
469

470
    def get_type_promotion_path_acceptance_function(self):
1✔
471
        """Returns the acceptance function for this
472
        elementwise binary function.
473

474
        Acceptance function influences the type promotion
475
        behavior of this binary function.
476
        The function takes 6 arguments:
477
            arg1_dtype - Data type of the first argument
478
            arg2_dtype - Data type of the second argument
479
            ret_buf1_dtype - Data type the first argument would be cast to
480
            ret_buf2_dtype - Data type the second argument would be cast to
481
            res_dtype - Data type of the output array with function values
482
            sycl_dev - :class:`dpctl.SyclDevice` on which function evaluation
483
                is carried out.
484

485
        The acceptance function is only invoked if both input arrays must be
486
        cast to intermediary data types, as would happen during call of
487
        `dpctl.tensor.hypot` with both arrays being of integral data type.
488
        """
489
        return self.acceptance_fn_
1✔
490

491
    def get_array_dtype_scalar_type_resolver_function(self):
1✔
492
        """Returns the function which determines how to treat
493
        Python scalar types for this elementwise binary function.
494

495
        Resolver influences what type the scalar will be
496
        treated as prior to type promotion behavior.
497
        The function takes 3 arguments:
498

499
        Args:
500
            o1_dtype (object, dtype):
501
                A class representing a Python scalar type or a ``dtype``
502
            o2_dtype (object, dtype):
503
                A class representing a Python scalar type or a ``dtype``
504
            sycl_dev (:class:`dpctl.SyclDevice`):
505
                Device on which function evaluation is carried out.
506

507
        One of ``o1_dtype`` and ``o2_dtype`` must be a ``dtype`` instance.
508
        """
509
        return self.weak_type_resolver_
×
510

511
    @property
1✔
512
    def nin(self):
1✔
513
        """
514
        Returns the number of arguments treated as inputs.
515
        """
516
        return 2
1✔
517

518
    @property
1✔
519
    def nout(self):
1✔
520
        """
521
        Returns the number of arguments treated as outputs.
522
        """
523
        return 1
1✔
524

525
    @property
1✔
526
    def types(self):
1✔
527
        """Returns information about types supported by
528
        implementation function, using NumPy's character
529
        encoding for data types, e.g.
530

531
        :Example:
532
            .. code-block:: python
533

534
                dpctl.tensor.divide.types
535
                # Outputs: ['ee->e', 'ff->f', 'fF->F', 'dd->d', 'dD->D',
536
                #    'Ff->F', 'FF->F', 'Dd->D', 'DD->D']
537
        """
538
        types = self.types_
1✔
539
        if not types:
1!
540
            types = []
1✔
541
            _all_dtypes = _all_data_types(True, True)
1✔
542
            for dt1 in _all_dtypes:
1✔
543
                for dt2 in _all_dtypes:
1✔
544
                    dt3 = self.result_type_resolver_fn_(dt1, dt2)
1✔
545
                    if dt3:
1✔
546
                        types.append(f"{dt1.char}{dt2.char}->{dt3.char}")
1✔
547
            self.types_ = types
1✔
548
        return types
1✔
549

550
    def __call__(self, o1, o2, /, *, out=None, order="K"):
1✔
551
        if order not in ["K", "C", "F", "A"]:
1✔
552
            order = "K"
1✔
553
        q1, o1_usm_type = _get_queue_usm_type(o1)
1✔
554
        q2, o2_usm_type = _get_queue_usm_type(o2)
1✔
555
        if q1 is None and q2 is None:
1✔
556
            raise ExecutionPlacementError(
1✔
557
                "Execution placement can not be unambiguously inferred "
558
                "from input arguments. "
559
                "One of the arguments must represent USM allocation and "
560
                "expose `__sycl_usm_array_interface__` property"
561
            )
562
        if q1 is None:
1✔
563
            exec_q = q2
1✔
564
            res_usm_type = o2_usm_type
1✔
565
        elif q2 is None:
1✔
566
            exec_q = q1
1✔
567
            res_usm_type = o1_usm_type
1✔
568
        else:
569
            exec_q = dpctl.utils.get_execution_queue((q1, q2))
1✔
570
            if exec_q is None:
1✔
571
                raise ExecutionPlacementError(
1✔
572
                    "Execution placement can not be unambiguously inferred "
573
                    "from input arguments."
574
                )
575
            res_usm_type = dpctl.utils.get_coerced_usm_type(
1✔
576
                (
577
                    o1_usm_type,
578
                    o2_usm_type,
579
                )
580
            )
581
        dpctl.utils.validate_usm_type(res_usm_type, allow_none=False)
1✔
582
        o1_shape = _get_shape(o1)
1✔
583
        o2_shape = _get_shape(o2)
1✔
584
        if not all(
1!
585
            isinstance(s, (tuple, list))
586
            for s in (
587
                o1_shape,
588
                o2_shape,
589
            )
590
        ):
591
            raise TypeError(
×
592
                "Shape of arguments can not be inferred. "
593
                "Arguments are expected to be "
594
                "lists, tuples, or both"
595
            )
596
        try:
1✔
597
            res_shape = _broadcast_shape_impl(
1✔
598
                [
599
                    o1_shape,
600
                    o2_shape,
601
                ]
602
            )
603
        except ValueError:
1✔
604
            raise ValueError(
1✔
605
                "operands could not be broadcast together with shapes "
606
                f"{o1_shape} and {o2_shape}"
607
            )
608
        sycl_dev = exec_q.sycl_device
1✔
609
        o1_dtype = _get_dtype(o1, sycl_dev)
1✔
610
        o2_dtype = _get_dtype(o2, sycl_dev)
1✔
611
        if not all(_validate_dtype(o) for o in (o1_dtype, o2_dtype)):
1✔
612
            raise ValueError("Operands have unsupported data types")
1✔
613

614
        o1_dtype, o2_dtype = self.weak_type_resolver_(
1✔
615
            o1_dtype, o2_dtype, sycl_dev
616
        )
617

618
        buf1_dt, buf2_dt, res_dt = _find_buf_dtype2(
1✔
619
            o1_dtype,
620
            o2_dtype,
621
            self.result_type_resolver_fn_,
622
            sycl_dev,
623
            acceptance_fn=self.acceptance_fn_,
624
        )
625

626
        if res_dt is None:
1✔
627
            raise ValueError(
1✔
628
                f"function '{self.name_}' does not support input types "
629
                f"({o1_dtype}, {o2_dtype}), "
630
                "and the inputs could not be safely coerced to any "
631
                "supported types according to the casting rule ''safe''."
632
            )
633

634
        orig_out = out
1✔
635
        _manager = SequentialOrderManager[exec_q]
1✔
636
        if out is not None:
1✔
637
            if not isinstance(out, dpt.usm_ndarray):
1✔
638
                raise TypeError(
1✔
639
                    f"output array must be of usm_ndarray type, got {type(out)}"
640
                )
641

642
            if not out.flags.writable:
1✔
643
                raise ValueError("provided `out` array is read-only")
1✔
644

645
            if out.shape != res_shape:
1✔
646
                raise ValueError(
1✔
647
                    "The shape of input and output arrays are inconsistent. "
648
                    f"Expected output shape is {res_shape}, got {out.shape}"
649
                )
650

651
            if res_dt != out.dtype:
1✔
652
                raise ValueError(
1✔
653
                    f"Output array of type {res_dt} is needed,"
654
                    f"got {out.dtype}"
655
                )
656

657
            if (
1✔
658
                dpctl.utils.get_execution_queue((exec_q, out.sycl_queue))
659
                is None
660
            ):
661
                raise ExecutionPlacementError(
1✔
662
                    "Input and output allocation queues are not compatible"
663
                )
664

665
            if isinstance(o1, dpt.usm_ndarray):
1!
666
                if ti._array_overlap(o1, out) and buf1_dt is None:
1✔
667
                    if not ti._same_logical_tensors(o1, out):
1✔
668
                        out = dpt.empty_like(out)
1✔
669
                    elif self.binary_inplace_fn_ is not None:
1!
670
                        # if there is a dedicated in-place kernel
671
                        # it can be called here, otherwise continues
672
                        if isinstance(o2, dpt.usm_ndarray):
1✔
673
                            src2 = o2
1✔
674
                            if (
1✔
675
                                ti._array_overlap(o2, out)
676
                                and not ti._same_logical_tensors(o2, out)
677
                                and buf2_dt is None
678
                            ):
679
                                buf2_dt = o2_dtype
1✔
680
                        else:
681
                            src2 = dpt.asarray(
1✔
682
                                o2, dtype=o2_dtype, sycl_queue=exec_q
683
                            )
684
                        if buf2_dt is None:
1✔
685
                            if src2.shape != res_shape:
1✔
686
                                src2 = dpt.broadcast_to(src2, res_shape)
1✔
687
                            dep_evs = _manager.submitted_events
1✔
688
                            ht_, comp_ev = self.binary_inplace_fn_(
1✔
689
                                lhs=o1,
690
                                rhs=src2,
691
                                sycl_queue=exec_q,
692
                                depends=dep_evs,
693
                            )
694
                            _manager.add_event_pair(ht_, comp_ev)
1✔
695
                        else:
696
                            buf2 = dpt.empty_like(src2, dtype=buf2_dt)
1✔
697
                            dep_evs = _manager.submitted_events
1✔
698
                            (
1✔
699
                                ht_copy_ev,
700
                                copy_ev,
701
                            ) = ti._copy_usm_ndarray_into_usm_ndarray(
702
                                src=src2,
703
                                dst=buf2,
704
                                sycl_queue=exec_q,
705
                                depends=dep_evs,
706
                            )
707
                            _manager.add_event_pair(ht_copy_ev, copy_ev)
1✔
708

709
                            buf2 = dpt.broadcast_to(buf2, res_shape)
1✔
710
                            ht_, bf_ev = self.binary_inplace_fn_(
1✔
711
                                lhs=o1,
712
                                rhs=buf2,
713
                                sycl_queue=exec_q,
714
                                depends=[copy_ev],
715
                            )
716
                            _manager.add_event_pair(ht_, bf_ev)
1✔
717

718
                        return out
1✔
719

720
            if isinstance(o2, dpt.usm_ndarray):
1✔
721
                if (
1!
722
                    ti._array_overlap(o2, out)
723
                    and not ti._same_logical_tensors(o2, out)
724
                    and buf2_dt is None
725
                ):
726
                    # should not reach if out is reallocated
727
                    # after being checked against o1
728
                    out = dpt.empty_like(out)
×
729

730
        if isinstance(o1, dpt.usm_ndarray):
1✔
731
            src1 = o1
1✔
732
        else:
733
            src1 = dpt.asarray(o1, dtype=o1_dtype, sycl_queue=exec_q)
1✔
734
        if isinstance(o2, dpt.usm_ndarray):
1✔
735
            src2 = o2
1✔
736
        else:
737
            src2 = dpt.asarray(o2, dtype=o2_dtype, sycl_queue=exec_q)
1✔
738

739
        if order == "A":
1✔
740
            order = (
1✔
741
                "F"
742
                if all(
743
                    arr.flags.f_contiguous
744
                    for arr in (
745
                        src1,
746
                        src2,
747
                    )
748
                )
749
                else "C"
750
            )
751

752
        if buf1_dt is None and buf2_dt is None:
1✔
753
            if out is None:
1✔
754
                if order == "K":
1✔
755
                    out = _empty_like_pair_orderK(
1✔
756
                        src1, src2, res_dt, res_shape, res_usm_type, exec_q
757
                    )
758
                else:
759
                    out = dpt.empty(
1✔
760
                        res_shape,
761
                        dtype=res_dt,
762
                        usm_type=res_usm_type,
763
                        sycl_queue=exec_q,
764
                        order=order,
765
                    )
766
            if src1.shape != res_shape:
1✔
767
                src1 = dpt.broadcast_to(src1, res_shape)
1✔
768
            if src2.shape != res_shape:
1✔
769
                src2 = dpt.broadcast_to(src2, res_shape)
1✔
770
            deps_ev = _manager.submitted_events
1✔
771
            ht_binary_ev, binary_ev = self.binary_fn_(
1✔
772
                src1=src1,
773
                src2=src2,
774
                dst=out,
775
                sycl_queue=exec_q,
776
                depends=deps_ev,
777
            )
778
            _manager.add_event_pair(ht_binary_ev, binary_ev)
1✔
779
            if not (orig_out is None or orig_out is out):
1✔
780
                # Copy the out data from temporary buffer to original memory
781
                ht_copy_out_ev, cpy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
1✔
782
                    src=out,
783
                    dst=orig_out,
784
                    sycl_queue=exec_q,
785
                    depends=[binary_ev],
786
                )
787
                _manager.add_event_pair(ht_copy_out_ev, cpy_ev)
1✔
788
                out = orig_out
1✔
789
            return out
1✔
790
        elif buf1_dt is None:
1✔
791
            if order == "K":
1!
792
                buf2 = _empty_like_orderK(src2, buf2_dt)
1✔
793
            else:
794
                buf2 = dpt.empty_like(src2, dtype=buf2_dt, order=order)
×
795
            dep_evs = _manager.submitted_events
1✔
796
            ht_copy_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
1✔
797
                src=src2, dst=buf2, sycl_queue=exec_q, depends=dep_evs
798
            )
799
            _manager.add_event_pair(ht_copy_ev, copy_ev)
1✔
800
            if out is None:
1✔
801
                if order == "K":
1!
802
                    out = _empty_like_pair_orderK(
1✔
803
                        src1, buf2, res_dt, res_shape, res_usm_type, exec_q
804
                    )
805
                else:
806
                    out = dpt.empty(
×
807
                        res_shape,
808
                        dtype=res_dt,
809
                        usm_type=res_usm_type,
810
                        sycl_queue=exec_q,
811
                        order=order,
812
                    )
813

814
            if src1.shape != res_shape:
1✔
815
                src1 = dpt.broadcast_to(src1, res_shape)
1✔
816
            buf2 = dpt.broadcast_to(buf2, res_shape)
1✔
817
            ht_binary_ev, binary_ev = self.binary_fn_(
1✔
818
                src1=src1,
819
                src2=buf2,
820
                dst=out,
821
                sycl_queue=exec_q,
822
                depends=[copy_ev],
823
            )
824
            _manager.add_event_pair(ht_binary_ev, binary_ev)
1✔
825
            if not (orig_out is None or orig_out is out):
1!
826
                # Copy the out data from temporary buffer to original memory
NEW
827
                ht_copy_out_ev, cpy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
×
828
                    src=out,
829
                    dst=orig_out,
830
                    sycl_queue=exec_q,
831
                    depends=[binary_ev],
832
                )
NEW
833
                _manager.add_event_pair(ht_copy_out_ev, cpy_ev)
×
834
                out = orig_out
×
835
            return out
1✔
836
        elif buf2_dt is None:
1✔
837
            if order == "K":
1!
838
                buf1 = _empty_like_orderK(src1, buf1_dt)
1✔
839
            else:
840
                buf1 = dpt.empty_like(src1, dtype=buf1_dt, order=order)
×
841
            dep_evs = _manager.submitted_events
1✔
842
            ht_copy_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
1✔
843
                src=src1, dst=buf1, sycl_queue=exec_q, depends=dep_evs
844
            )
845
            _manager.add_event_pair(ht_copy_ev, copy_ev)
1✔
846
            if out is None:
1✔
847
                if order == "K":
1!
848
                    out = _empty_like_pair_orderK(
1✔
849
                        buf1, src2, res_dt, res_shape, res_usm_type, exec_q
850
                    )
851
                else:
852
                    out = dpt.empty(
×
853
                        res_shape,
854
                        dtype=res_dt,
855
                        usm_type=res_usm_type,
856
                        sycl_queue=exec_q,
857
                        order=order,
858
                    )
859

860
            buf1 = dpt.broadcast_to(buf1, res_shape)
1✔
861
            if src2.shape != res_shape:
1✔
862
                src2 = dpt.broadcast_to(src2, res_shape)
1✔
863
            ht_binary_ev, binary_ev = self.binary_fn_(
1✔
864
                src1=buf1,
865
                src2=src2,
866
                dst=out,
867
                sycl_queue=exec_q,
868
                depends=[copy_ev],
869
            )
870
            _manager.add_event_pair(ht_binary_ev, binary_ev)
1✔
871
            if not (orig_out is None or orig_out is out):
1!
872
                # Copy the out data from temporary buffer to original memory
NEW
873
                ht_copy_out_ev, cpy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
×
874
                    src=out,
875
                    dst=orig_out,
876
                    sycl_queue=exec_q,
877
                    depends=[binary_ev],
878
                )
NEW
879
                _manager.add_event_pair(ht_copy_out_ev, cpy_ev)
×
880
                out = orig_out
×
881
            return out
1✔
882

883
        if order == "K":
1✔
884
            if src1.flags.c_contiguous and src2.flags.c_contiguous:
1✔
885
                order = "C"
1✔
886
            elif src1.flags.f_contiguous and src2.flags.f_contiguous:
1✔
887
                order = "F"
1✔
888
        if order == "K":
1✔
889
            buf1 = _empty_like_orderK(src1, buf1_dt)
1✔
890
        else:
891
            buf1 = dpt.empty_like(src1, dtype=buf1_dt, order=order)
1✔
892
        dep_evs = _manager.submitted_events
1✔
893
        ht_copy1_ev, copy1_ev = ti._copy_usm_ndarray_into_usm_ndarray(
1✔
894
            src=src1, dst=buf1, sycl_queue=exec_q, depends=dep_evs
895
        )
896
        _manager.add_event_pair(ht_copy1_ev, copy1_ev)
1✔
897
        if order == "K":
1✔
898
            buf2 = _empty_like_orderK(src2, buf2_dt)
1✔
899
        else:
900
            buf2 = dpt.empty_like(src2, dtype=buf2_dt, order=order)
1✔
901
        ht_copy2_ev, copy2_ev = ti._copy_usm_ndarray_into_usm_ndarray(
1✔
902
            src=src2, dst=buf2, sycl_queue=exec_q, depends=dep_evs
903
        )
904
        _manager.add_event_pair(ht_copy2_ev, copy2_ev)
1✔
905
        if out is None:
1✔
906
            if order == "K":
1✔
907
                out = _empty_like_pair_orderK(
1✔
908
                    buf1, buf2, res_dt, res_shape, res_usm_type, exec_q
909
                )
910
            else:
911
                out = dpt.empty(
1✔
912
                    res_shape,
913
                    dtype=res_dt,
914
                    usm_type=res_usm_type,
915
                    sycl_queue=exec_q,
916
                    order=order,
917
                )
918

919
        buf1 = dpt.broadcast_to(buf1, res_shape)
1✔
920
        buf2 = dpt.broadcast_to(buf2, res_shape)
1✔
921
        ht_, bf_ev = self.binary_fn_(
1✔
922
            src1=buf1,
923
            src2=buf2,
924
            dst=out,
925
            sycl_queue=exec_q,
926
            depends=[copy1_ev, copy2_ev],
927
        )
928
        _manager.add_event_pair(ht_, bf_ev)
1✔
929
        return out
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc