• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

IntelPython / dpctl / 8807160764

23 Apr 2024 09:01PM UTC coverage: 88.218% (+0.008%) from 88.21%
8807160764

Pull #1650

github

web-flow
Merge b559f0a94 into ef5a75133
Pull Request #1650: Fixes element-wise comparisons of mixed signed-unsigned integer inputs

3575 of 4060 branches covered (88.05%)

Branch coverage included in aggregate %.

51 of 53 new or added lines in 3 files covered. (96.23%)

2 existing lines in 1 file now uncovered.

10763 of 12193 relevant lines covered (88.27%)

8735.75 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

93.35
/dpctl/tensor/_elementwise_common.py
1
#                       Data Parallel Control (dpctl)
2
#
3
#  Copyright 2020-2024 Intel Corporation
4
#
5
#  Licensed under the Apache License, Version 2.0 (the "License");
6
#  you may not use this file except in compliance with the License.
7
#  You may obtain a copy of the License at
8
#
9
#     http://www.apache.org/licenses/LICENSE-2.0
10
#
11
#  Unless required by applicable law or agreed to in writing, software
12
#  distributed under the License is distributed on an "AS IS" BASIS,
13
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
#  See the License for the specific language governing permissions and
15
#  limitations under the License.
16

17
import numbers
1✔
18

19
import numpy as np
1✔
20

21
import dpctl
1✔
22
import dpctl.memory as dpm
1✔
23
import dpctl.tensor as dpt
1✔
24
import dpctl.tensor._tensor_impl as ti
1✔
25
from dpctl.tensor._manipulation_functions import _broadcast_shape_impl
1✔
26
from dpctl.tensor._usmarray import _is_object_with_buffer_protocol as _is_buffer
1✔
27
from dpctl.utils import ExecutionPlacementError
1✔
28

29
from ._copy_utils import _empty_like_orderK, _empty_like_pair_orderK
1✔
30
from ._type_utils import (
1✔
31
    WeakBooleanType,
32
    WeakComplexType,
33
    WeakFloatingType,
34
    WeakIntegralType,
35
    _acceptance_fn_default_binary,
36
    _acceptance_fn_default_unary,
37
    _all_data_types,
38
    _find_buf_dtype,
39
    _find_buf_dtype2,
40
    _resolve_weak_types,
41
    _to_device_supported_dtype,
42
)
43

44

45
class UnaryElementwiseFunc:
1✔
46
    """
47
    Class that implements unary element-wise functions.
48

49
    Args:
50
        name (str):
51
            Name of the unary function
52
        result_type_resovler_fn (callable):
53
            Function that takes dtype of the input and
54
            returns the dtype of the result if the
55
            implementation functions supports it, or
56
            returns `None` otherwise.
57
        unary_dp_impl_fn (callable):
58
            Data-parallel implementation function with signature
59
            `impl_fn(src: usm_ndarray, dst: usm_ndarray,
60
             sycl_queue: SyclQueue, depends: Optional[List[SyclEvent]])`
61
            where the `src` is the argument array, `dst` is the
62
            array to be populated with function values, effectively
63
            evaluating `dst = func(src)`.
64
            The `impl_fn` is expected to return a 2-tuple of `SyclEvent`s.
65
            The first event corresponds to data-management host tasks,
66
            including lifetime management of argument Python objects to ensure
67
            that their associated USM allocation is not freed before offloaded
68
            computational tasks complete execution, while the second event
69
            corresponds to computational tasks associated with function
70
            evaluation.
71
        acceptance_fn (callable, optional):
72
            Function to influence type promotion behavior of this unary
73
            function. The function takes 4 arguments:
74
                arg_dtype - Data type of the first argument
75
                buf_dtype - Data type the argument would be cast to
76
                res_dtype - Data type of the output array with function values
77
                sycl_dev - The :class:`dpctl.SyclDevice` where the function
78
                    evaluation is carried out.
79
            The function is invoked when the argument of the unary function
80
            requires casting, e.g. the argument of `dpctl.tensor.log` is an
81
            array with integral data type.
82
        docs (str):
83
            Documentation string for the unary function.
84
    """
85

86
    def __init__(
1✔
87
        self,
88
        name,
89
        result_type_resolver_fn,
90
        unary_dp_impl_fn,
91
        docs,
92
        acceptance_fn=None,
93
    ):
94
        self.__name__ = "UnaryElementwiseFunc"
1✔
95
        self.name_ = name
1✔
96
        self.result_type_resolver_fn_ = result_type_resolver_fn
1✔
97
        self.types_ = None
1✔
98
        self.unary_fn_ = unary_dp_impl_fn
1✔
99
        self.__doc__ = docs
1✔
100
        if callable(acceptance_fn):
1✔
101
            self.acceptance_fn_ = acceptance_fn
1✔
102
        else:
103
            self.acceptance_fn_ = _acceptance_fn_default_unary
1✔
104

105
    def __str__(self):
1✔
106
        return f"<{self.__name__} '{self.name_}'>"
1✔
107

108
    def __repr__(self):
1✔
109
        return f"<{self.__name__} '{self.name_}'>"
1✔
110

111
    def get_implementation_function(self):
1✔
112
        """Returns the implementation function for
113
        this elementwise unary function.
114

115
        """
116
        return self.unary_fn_
1✔
117

118
    def get_type_result_resolver_function(self):
1✔
119
        """Returns the type resolver function for this
120
        elementwise unary function.
121
        """
122
        return self.result_type_resolver_fn_
1✔
123

124
    def get_type_promotion_path_acceptance_function(self):
1✔
125
        """Returns the acceptance function for this
126
        elementwise binary function.
127

128
        Acceptance function influences the type promotion
129
        behavior of this unary function.
130
        The function takes 4 arguments:
131
            arg_dtype - Data type of the first argument
132
            buf_dtype - Data type the argument would be cast to
133
            res_dtype - Data type of the output array with function values
134
            sycl_dev - The :class:`dpctl.SyclDevice` where the function
135
                evaluation is carried out.
136
        The function is invoked when the argument of the unary function
137
        requires casting, e.g. the argument of `dpctl.tensor.log` is an
138
        array with integral data type.
139
        """
140
        return self.acceptance_fn_
×
141

142
    @property
1✔
143
    def nin(self):
1✔
144
        """
145
        Returns the number of arguments treated as inputs.
146
        """
147
        return 1
1✔
148

149
    @property
1✔
150
    def nout(self):
1✔
151
        """
152
        Returns the number of arguments treated as outputs.
153
        """
154
        return 1
1✔
155

156
    @property
1✔
157
    def types(self):
1✔
158
        """Returns information about types supported by
159
        implementation function, using NumPy's character
160
        encoding for data types, e.g.
161

162
        :Example:
163
            .. code-block:: python
164

165
                dpctl.tensor.sin.types
166
                # Outputs: ['e->e', 'f->f', 'd->d', 'F->F', 'D->D']
167
        """
168
        types = self.types_
1✔
169
        if not types:
1!
170
            types = []
1✔
171
            for dt1 in _all_data_types(True, True):
1✔
172
                dt2 = self.result_type_resolver_fn_(dt1)
1✔
173
                if dt2:
1✔
174
                    types.append(f"{dt1.char}->{dt2.char}")
1✔
175
            self.types_ = types
1✔
176
        return types
1✔
177

178
    def __call__(self, x, /, *, out=None, order="K"):
1✔
179
        if not isinstance(x, dpt.usm_ndarray):
1✔
180
            raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x)}")
1✔
181

182
        if order not in ["C", "F", "K", "A"]:
1✔
183
            order = "K"
1✔
184
        buf_dt, res_dt = _find_buf_dtype(
1✔
185
            x.dtype,
186
            self.result_type_resolver_fn_,
187
            x.sycl_device,
188
            acceptance_fn=self.acceptance_fn_,
189
        )
190
        if res_dt is None:
1!
191
            raise ValueError(
×
192
                f"function '{self.name_}' does not support input type "
193
                f"({x.dtype}), "
194
                "and the input could not be safely coerced to any "
195
                "supported types according to the casting rule ''safe''."
196
            )
197

198
        orig_out = out
1✔
199
        if out is not None:
1✔
200
            if not isinstance(out, dpt.usm_ndarray):
1!
201
                raise TypeError(
×
202
                    f"output array must be of usm_ndarray type, got {type(out)}"
203
                )
204

205
            if not out.flags.writable:
1✔
206
                raise ValueError("provided `out` array is read-only")
1✔
207

208
            if out.shape != x.shape:
1!
209
                raise ValueError(
×
210
                    "The shape of input and output arrays are inconsistent. "
211
                    f"Expected output shape is {x.shape}, got {out.shape}"
212
                )
213

214
            if res_dt != out.dtype:
1✔
215
                raise ValueError(
1✔
216
                    f"Output array of type {res_dt} is needed,"
217
                    f" got {out.dtype}"
218
                )
219

220
            if (
1✔
221
                buf_dt is None
222
                and ti._array_overlap(x, out)
223
                and not ti._same_logical_tensors(x, out)
224
            ):
225
                # Allocate a temporary buffer to avoid memory overlapping.
226
                # Note if `buf_dt` is not None, a temporary copy of `x` will be
227
                # created, so the array overlap check isn't needed.
228
                out = dpt.empty_like(out)
1✔
229

230
            if (
1!
231
                dpctl.utils.get_execution_queue((x.sycl_queue, out.sycl_queue))
232
                is None
233
            ):
234
                raise ExecutionPlacementError(
×
235
                    "Input and output allocation queues are not compatible"
236
                )
237

238
        exec_q = x.sycl_queue
1✔
239
        if buf_dt is None:
1✔
240
            if out is None:
1✔
241
                if order == "K":
1✔
242
                    out = _empty_like_orderK(x, res_dt)
1✔
243
                else:
244
                    if order == "A":
1✔
245
                        order = "F" if x.flags.f_contiguous else "C"
1✔
246
                    out = dpt.empty_like(x, dtype=res_dt, order=order)
1✔
247

248
            ht_unary_ev, unary_ev = self.unary_fn_(x, out, sycl_queue=exec_q)
1✔
249

250
            if not (orig_out is None or orig_out is out):
1✔
251
                # Copy the out data from temporary buffer to original memory
252
                ht_copy_ev, _ = ti._copy_usm_ndarray_into_usm_ndarray(
1✔
253
                    src=out, dst=orig_out, sycl_queue=exec_q, depends=[unary_ev]
254
                )
255
                ht_copy_ev.wait()
1✔
256
                out = orig_out
1✔
257

258
            ht_unary_ev.wait()
1✔
259
            return out
1✔
260

261
        if order == "K":
1✔
262
            buf = _empty_like_orderK(x, buf_dt)
1✔
263
        else:
264
            if order == "A":
1✔
265
                order = "F" if x.flags.f_contiguous else "C"
1✔
266
            buf = dpt.empty_like(x, dtype=buf_dt, order=order)
1✔
267

268
        ht_copy_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
1✔
269
            src=x, dst=buf, sycl_queue=exec_q
270
        )
271
        if out is None:
1✔
272
            if order == "K":
1✔
273
                out = _empty_like_orderK(buf, res_dt)
1✔
274
            else:
275
                out = dpt.empty_like(buf, dtype=res_dt, order=order)
1✔
276

277
        ht, _ = self.unary_fn_(buf, out, sycl_queue=exec_q, depends=[copy_ev])
1✔
278
        ht_copy_ev.wait()
1✔
279
        ht.wait()
1✔
280

281
        return out
1✔
282

283

284
def _get_queue_usm_type(o):
1✔
285
    """Return SYCL device where object `o` allocated memory, or None."""
286
    if isinstance(o, dpt.usm_ndarray):
1✔
287
        return o.sycl_queue, o.usm_type
1✔
288
    elif hasattr(o, "__sycl_usm_array_interface__"):
1✔
289
        try:
1✔
290
            m = dpm.as_usm_memory(o)
1✔
291
            return m.sycl_queue, m.get_usm_type()
1✔
292
        except Exception:
1✔
293
            return None, None
1✔
294
    return None, None
1✔
295

296

297
def _get_dtype(o, dev):
1✔
298
    if isinstance(o, dpt.usm_ndarray):
1✔
299
        return o.dtype
1✔
300
    if hasattr(o, "__sycl_usm_array_interface__"):
1✔
301
        return dpt.asarray(o).dtype
1✔
302
    if _is_buffer(o):
1✔
303
        host_dt = np.array(o).dtype
1✔
304
        dev_dt = _to_device_supported_dtype(host_dt, dev)
1✔
305
        return dev_dt
1✔
306
    if hasattr(o, "dtype"):
1!
307
        dev_dt = _to_device_supported_dtype(o.dtype, dev)
×
308
        return dev_dt
×
309
    if isinstance(o, bool):
1✔
310
        return WeakBooleanType(o)
1✔
311
    if isinstance(o, int):
1✔
312
        return WeakIntegralType(o)
1✔
313
    if isinstance(o, float):
1✔
314
        return WeakFloatingType(o)
1✔
315
    if isinstance(o, complex):
1✔
316
        return WeakComplexType(o)
1✔
317
    return np.object_
1✔
318

319

320
def _validate_dtype(dt) -> bool:
1✔
321
    return isinstance(
1✔
322
        dt,
323
        (WeakBooleanType, WeakIntegralType, WeakFloatingType, WeakComplexType),
324
    ) or (
325
        isinstance(dt, dpt.dtype)
326
        and dt
327
        in [
328
            dpt.bool,
329
            dpt.int8,
330
            dpt.uint8,
331
            dpt.int16,
332
            dpt.uint16,
333
            dpt.int32,
334
            dpt.uint32,
335
            dpt.int64,
336
            dpt.uint64,
337
            dpt.float16,
338
            dpt.float32,
339
            dpt.float64,
340
            dpt.complex64,
341
            dpt.complex128,
342
        ]
343
    )
344

345

346
def _get_shape(o):
1✔
347
    if isinstance(o, dpt.usm_ndarray):
1✔
348
        return o.shape
1✔
349
    if _is_buffer(o):
1✔
350
        return memoryview(o).shape
1✔
351
    if isinstance(o, numbers.Number):
1✔
352
        return tuple()
1✔
353
    return getattr(o, "shape", tuple())
1✔
354

355

356
class BinaryElementwiseFunc:
1✔
357
    """
358
    Class that implements binary element-wise functions.
359

360
    Args:
361
        name (str):
362
            Name of the unary function
363
        result_type_resovle_fn (callable):
364
            Function that takes dtypes of the input and
365
            returns the dtype of the result if the
366
            implementation functions supports it, or
367
            returns `None` otherwise.
368
        binary_dp_impl_fn (callable):
369
            Data-parallel implementation function with signature
370
            `impl_fn(src1: usm_ndarray, src2: usm_ndarray, dst: usm_ndarray,
371
             sycl_queue: SyclQueue, depends: Optional[List[SyclEvent]])`
372
            where the `src1` and `src2` are the argument arrays, `dst` is the
373
            array to be populated with function values,
374
            i.e. `dst=func(src1, src2)`.
375
            The `impl_fn` is expected to return a 2-tuple of `SyclEvent`s.
376
            The first event corresponds to data-management host tasks,
377
            including lifetime management of argument Python objects to ensure
378
            that their associated USM allocation is not freed before offloaded
379
            computational tasks complete execution, while the second event
380
            corresponds to computational tasks associated with function
381
            evaluation.
382
        docs (str):
383
            Documentation string for the unary function.
384
        binary_inplace_fn (callable, optional):
385
            Data-parallel implementation function with signature
386
            `impl_fn(src: usm_ndarray, dst: usm_ndarray,
387
             sycl_queue: SyclQueue, depends: Optional[List[SyclEvent]])`
388
            where the `src` is the argument array, `dst` is the
389
            array to be populated with function values,
390
            i.e. `dst=func(dst, src)`.
391
            The `impl_fn` is expected to return a 2-tuple of `SyclEvent`s.
392
            The first event corresponds to data-management host tasks,
393
            including async lifetime management of Python arguments,
394
            while the second event corresponds to computational tasks
395
            associated with function evaluation.
396
        acceptance_fn (callable, optional):
397
            Function to influence type promotion behavior of this binary
398
            function. The function takes 6 arguments:
399
                arg1_dtype - Data type of the first argument
400
                arg2_dtype - Data type of the second argument
401
                ret_buf1_dtype - Data type the first argument would be cast to
402
                ret_buf2_dtype - Data type the second argument would be cast to
403
                res_dtype - Data type of the output array with function values
404
                sycl_dev - The :class:`dpctl.SyclDevice` where the function
405
                    evaluation is carried out.
406
            The function is only called when both arguments of the binary
407
            function require casting, e.g. both arguments of
408
            `dpctl.tensor.logaddexp` are arrays with integral data type.
409
    """
410

411
    def __init__(
1✔
412
        self,
413
        name,
414
        result_type_resolver_fn,
415
        binary_dp_impl_fn,
416
        docs,
417
        binary_inplace_fn=None,
418
        acceptance_fn=None,
419
        weak_type_resolver=None,
420
    ):
421
        self.__name__ = "BinaryElementwiseFunc"
1✔
422
        self.name_ = name
1✔
423
        self.result_type_resolver_fn_ = result_type_resolver_fn
1✔
424
        self.types_ = None
1✔
425
        self.binary_fn_ = binary_dp_impl_fn
1✔
426
        self.binary_inplace_fn_ = binary_inplace_fn
1✔
427
        self.__doc__ = docs
1✔
428
        if callable(acceptance_fn):
1✔
429
            self.acceptance_fn_ = acceptance_fn
1✔
430
        else:
431
            self.acceptance_fn_ = _acceptance_fn_default_binary
1✔
432
        if callable(weak_type_resolver):
1✔
433
            self.weak_type_resolver_ = weak_type_resolver
1✔
434
        else:
435
            self.weak_type_resolver_ = _resolve_weak_types
1✔
436

437
    def __str__(self):
1✔
438
        return f"<{self.__name__} '{self.name_}'>"
1✔
439

440
    def __repr__(self):
1✔
441
        return f"<{self.__name__} '{self.name_}'>"
1✔
442

443
    def get_implementation_function(self):
1✔
444
        """Returns the out-of-place implementation
445
        function for this elementwise binary function.
446

447
        """
448
        return self.binary_fn_
1✔
449

450
    def get_implementation_inplace_function(self):
1✔
451
        """Returns the in-place implementation
452
        function for this elementwise binary function.
453

454
        """
455
        return self.binary_inplace_fn_
1✔
456

457
    def get_type_result_resolver_function(self):
1✔
458
        """Returns the type resolver function for this
459
        elementwise binary function.
460
        """
461
        return self.result_type_resolver_fn_
1✔
462

463
    def get_type_promotion_path_acceptance_function(self):
1✔
464
        """Returns the acceptance function for this
465
        elementwise binary function.
466

467
        Acceptance function influences the type promotion
468
        behavior of this binary function.
469
        The function takes 6 arguments:
470
            arg1_dtype - Data type of the first argument
471
            arg2_dtype - Data type of the second argument
472
            ret_buf1_dtype - Data type the first argument would be cast to
473
            ret_buf2_dtype - Data type the second argument would be cast to
474
            res_dtype - Data type of the output array with function values
475
            sycl_dev - :class:`dpctl.SyclDevice` on which function evaluation
476
                is carried out.
477

478
        The acceptance function is only invoked if both input arrays must be
479
        cast to intermediary data types, as would happen during call of
480
        `dpctl.tensor.hypot` with both arrays being of integral data type.
481
        """
482
        return self.acceptance_fn_
1✔
483

484
    def get_array_dtype_scalar_type_resolver_function(self):
1✔
485
        """Returns the function which determines how to treat
486
        Python scalar types for this elementwise binary function.
487

488
        Resolver influences what type the scalar will be
489
        treated as prior to type promotion behavior.
490
        The function takes 3 arguments:
491

492
        Args:
493
            o1_dtype (object, dtype):
494
                A class representing a Python scalar type or a ``dtype``
495
            o2_dtype (object, dtype):
496
                A class representing a Python scalar type or a ``dtype``
497
            sycl_dev (:class:`dpctl.SyclDevice`):
498
                Device on which function evaluation is carried out.
499

500
        One of ``o1_dtype`` and ``o2_dtype`` must be a ``dtype`` instance.
501
        """
NEW
502
        return self.weak_type_resolver_
×
503

504
    @property
1✔
505
    def nin(self):
1✔
506
        """
507
        Returns the number of arguments treated as inputs.
508
        """
509
        return 2
1✔
510

511
    @property
1✔
512
    def nout(self):
1✔
513
        """
514
        Returns the number of arguments treated as outputs.
515
        """
516
        return 1
1✔
517

518
    @property
1✔
519
    def types(self):
1✔
520
        """Returns information about types supported by
521
        implementation function, using NumPy's character
522
        encoding for data types, e.g.
523

524
        :Example:
525
            .. code-block:: python
526

527
                dpctl.tensor.divide.types
528
                # Outputs: ['ee->e', 'ff->f', 'fF->F', 'dd->d', 'dD->D',
529
                #    'Ff->F', 'FF->F', 'Dd->D', 'DD->D']
530
        """
531
        types = self.types_
1✔
532
        if not types:
1!
533
            types = []
1✔
534
            _all_dtypes = _all_data_types(True, True)
1✔
535
            for dt1 in _all_dtypes:
1✔
536
                for dt2 in _all_dtypes:
1✔
537
                    dt3 = self.result_type_resolver_fn_(dt1, dt2)
1✔
538
                    if dt3:
1✔
539
                        types.append(f"{dt1.char}{dt2.char}->{dt3.char}")
1✔
540
            self.types_ = types
1✔
541
        return types
1✔
542

543
    def __call__(self, o1, o2, /, *, out=None, order="K"):
1✔
544
        if order not in ["K", "C", "F", "A"]:
1✔
545
            order = "K"
1✔
546
        q1, o1_usm_type = _get_queue_usm_type(o1)
1✔
547
        q2, o2_usm_type = _get_queue_usm_type(o2)
1✔
548
        if q1 is None and q2 is None:
1✔
549
            raise ExecutionPlacementError(
1✔
550
                "Execution placement can not be unambiguously inferred "
551
                "from input arguments. "
552
                "One of the arguments must represent USM allocation and "
553
                "expose `__sycl_usm_array_interface__` property"
554
            )
555
        if q1 is None:
1✔
556
            exec_q = q2
1✔
557
            res_usm_type = o2_usm_type
1✔
558
        elif q2 is None:
1✔
559
            exec_q = q1
1✔
560
            res_usm_type = o1_usm_type
1✔
561
        else:
562
            exec_q = dpctl.utils.get_execution_queue((q1, q2))
1✔
563
            if exec_q is None:
1✔
564
                raise ExecutionPlacementError(
1✔
565
                    "Execution placement can not be unambiguously inferred "
566
                    "from input arguments."
567
                )
568
            res_usm_type = dpctl.utils.get_coerced_usm_type(
1✔
569
                (
570
                    o1_usm_type,
571
                    o2_usm_type,
572
                )
573
            )
574
        dpctl.utils.validate_usm_type(res_usm_type, allow_none=False)
1✔
575
        o1_shape = _get_shape(o1)
1✔
576
        o2_shape = _get_shape(o2)
1✔
577
        if not all(
1!
578
            isinstance(s, (tuple, list))
579
            for s in (
580
                o1_shape,
581
                o2_shape,
582
            )
583
        ):
584
            raise TypeError(
×
585
                "Shape of arguments can not be inferred. "
586
                "Arguments are expected to be "
587
                "lists, tuples, or both"
588
            )
589
        try:
1✔
590
            res_shape = _broadcast_shape_impl(
1✔
591
                [
592
                    o1_shape,
593
                    o2_shape,
594
                ]
595
            )
596
        except ValueError:
1✔
597
            raise ValueError(
1✔
598
                "operands could not be broadcast together with shapes "
599
                f"{o1_shape} and {o2_shape}"
600
            )
601
        sycl_dev = exec_q.sycl_device
1✔
602
        o1_dtype = _get_dtype(o1, sycl_dev)
1✔
603
        o2_dtype = _get_dtype(o2, sycl_dev)
1✔
604
        if not all(_validate_dtype(o) for o in (o1_dtype, o2_dtype)):
1✔
605
            raise ValueError("Operands have unsupported data types")
1✔
606

607
        o1_dtype, o2_dtype = self.weak_type_resolver_(
1✔
608
            o1_dtype, o2_dtype, sycl_dev
609
        )
610

611
        buf1_dt, buf2_dt, res_dt = _find_buf_dtype2(
1✔
612
            o1_dtype,
613
            o2_dtype,
614
            self.result_type_resolver_fn_,
615
            sycl_dev,
616
            acceptance_fn=self.acceptance_fn_,
617
        )
618

619
        if res_dt is None:
1✔
620
            raise ValueError(
1✔
621
                f"function '{self.name_}' does not support input types "
622
                f"({o1_dtype}, {o2_dtype}), "
623
                "and the inputs could not be safely coerced to any "
624
                "supported types according to the casting rule ''safe''."
625
            )
626

627
        orig_out = out
1✔
628
        if out is not None:
1✔
629
            if not isinstance(out, dpt.usm_ndarray):
1✔
630
                raise TypeError(
1✔
631
                    f"output array must be of usm_ndarray type, got {type(out)}"
632
                )
633

634
            if not out.flags.writable:
1✔
635
                raise ValueError("provided `out` array is read-only")
1✔
636

637
            if out.shape != res_shape:
1✔
638
                raise ValueError(
1✔
639
                    "The shape of input and output arrays are inconsistent. "
640
                    f"Expected output shape is {res_shape}, got {out.shape}"
641
                )
642

643
            if res_dt != out.dtype:
1✔
644
                raise ValueError(
1✔
645
                    f"Output array of type {res_dt} is needed,"
646
                    f"got {out.dtype}"
647
                )
648

649
            if (
1✔
650
                dpctl.utils.get_execution_queue((exec_q, out.sycl_queue))
651
                is None
652
            ):
653
                raise ExecutionPlacementError(
1✔
654
                    "Input and output allocation queues are not compatible"
655
                )
656

657
            if isinstance(o1, dpt.usm_ndarray):
1!
658
                if ti._array_overlap(o1, out) and buf1_dt is None:
1✔
659
                    if not ti._same_logical_tensors(o1, out):
1✔
660
                        out = dpt.empty_like(out)
1✔
661
                    elif self.binary_inplace_fn_ is not None:
1!
662
                        # if there is a dedicated in-place kernel
663
                        # it can be called here, otherwise continues
664
                        if isinstance(o2, dpt.usm_ndarray):
1✔
665
                            src2 = o2
1✔
666
                            if (
1✔
667
                                ti._array_overlap(o2, out)
668
                                and not ti._same_logical_tensors(o2, out)
669
                                and buf2_dt is None
670
                            ):
671
                                buf2_dt = o2_dtype
1✔
672
                        else:
673
                            src2 = dpt.asarray(
1✔
674
                                o2, dtype=o2_dtype, sycl_queue=exec_q
675
                            )
676
                        if buf2_dt is None:
1✔
677
                            if src2.shape != res_shape:
1✔
678
                                src2 = dpt.broadcast_to(src2, res_shape)
1✔
679
                            ht_, _ = self.binary_inplace_fn_(
1✔
680
                                lhs=o1, rhs=src2, sycl_queue=exec_q
681
                            )
682
                            ht_.wait()
1✔
683
                        else:
684
                            buf2 = dpt.empty_like(src2, dtype=buf2_dt)
1✔
685
                            (
1✔
686
                                ht_copy_ev,
687
                                copy_ev,
688
                            ) = ti._copy_usm_ndarray_into_usm_ndarray(
689
                                src=src2, dst=buf2, sycl_queue=exec_q
690
                            )
691

692
                            buf2 = dpt.broadcast_to(buf2, res_shape)
1✔
693
                            ht_, _ = self.binary_inplace_fn_(
1✔
694
                                lhs=o1,
695
                                rhs=buf2,
696
                                sycl_queue=exec_q,
697
                                depends=[copy_ev],
698
                            )
699
                            ht_copy_ev.wait()
1✔
700
                            ht_.wait()
1✔
701

702
                        return out
1✔
703

704
            if isinstance(o2, dpt.usm_ndarray):
1✔
705
                if (
1!
706
                    ti._array_overlap(o2, out)
707
                    and not ti._same_logical_tensors(o2, out)
708
                    and buf2_dt is None
709
                ):
710
                    # should not reach if out is reallocated
711
                    # after being checked against o1
712
                    out = dpt.empty_like(out)
×
713

714
        if isinstance(o1, dpt.usm_ndarray):
1✔
715
            src1 = o1
1✔
716
        else:
717
            src1 = dpt.asarray(o1, dtype=o1_dtype, sycl_queue=exec_q)
1✔
718
        if isinstance(o2, dpt.usm_ndarray):
1✔
719
            src2 = o2
1✔
720
        else:
721
            src2 = dpt.asarray(o2, dtype=o2_dtype, sycl_queue=exec_q)
1✔
722

723
        if order == "A":
1✔
724
            order = (
1✔
725
                "F"
726
                if all(
727
                    arr.flags.f_contiguous
728
                    for arr in (
729
                        src1,
730
                        src2,
731
                    )
732
                )
733
                else "C"
734
            )
735

736
        if buf1_dt is None and buf2_dt is None:
1✔
737
            if out is None:
1✔
738
                if order == "K":
1✔
739
                    out = _empty_like_pair_orderK(
1✔
740
                        src1, src2, res_dt, res_shape, res_usm_type, exec_q
741
                    )
742
                else:
743
                    out = dpt.empty(
1✔
744
                        res_shape,
745
                        dtype=res_dt,
746
                        usm_type=res_usm_type,
747
                        sycl_queue=exec_q,
748
                        order=order,
749
                    )
750
            if src1.shape != res_shape:
1✔
751
                src1 = dpt.broadcast_to(src1, res_shape)
1✔
752
            if src2.shape != res_shape:
1✔
753
                src2 = dpt.broadcast_to(src2, res_shape)
1✔
754
            ht_binary_ev, binary_ev = self.binary_fn_(
1✔
755
                src1=src1, src2=src2, dst=out, sycl_queue=exec_q
756
            )
757
            if not (orig_out is None or orig_out is out):
1✔
758
                # Copy the out data from temporary buffer to original memory
759
                ht_copy_out_ev, _ = ti._copy_usm_ndarray_into_usm_ndarray(
1✔
760
                    src=out,
761
                    dst=orig_out,
762
                    sycl_queue=exec_q,
763
                    depends=[binary_ev],
764
                )
765
                ht_copy_out_ev.wait()
1✔
766
                out = orig_out
1✔
767
            ht_binary_ev.wait()
1✔
768
            return out
1✔
769
        elif buf1_dt is None:
1✔
770
            if order == "K":
1!
771
                buf2 = _empty_like_orderK(src2, buf2_dt)
1✔
772
            else:
773
                buf2 = dpt.empty_like(src2, dtype=buf2_dt, order=order)
×
774
            ht_copy_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
1✔
775
                src=src2, dst=buf2, sycl_queue=exec_q
776
            )
777
            if out is None:
1✔
778
                if order == "K":
1!
779
                    out = _empty_like_pair_orderK(
1✔
780
                        src1, buf2, res_dt, res_shape, res_usm_type, exec_q
781
                    )
782
                else:
783
                    out = dpt.empty(
×
784
                        res_shape,
785
                        dtype=res_dt,
786
                        usm_type=res_usm_type,
787
                        sycl_queue=exec_q,
788
                        order=order,
789
                    )
790

791
            if src1.shape != res_shape:
1✔
792
                src1 = dpt.broadcast_to(src1, res_shape)
1✔
793
            buf2 = dpt.broadcast_to(buf2, res_shape)
1✔
794
            ht_binary_ev, binary_ev = self.binary_fn_(
1✔
795
                src1=src1,
796
                src2=buf2,
797
                dst=out,
798
                sycl_queue=exec_q,
799
                depends=[copy_ev],
800
            )
801
            if not (orig_out is None or orig_out is out):
1!
802
                # Copy the out data from temporary buffer to original memory
803
                ht_copy_out_ev, _ = ti._copy_usm_ndarray_into_usm_ndarray(
×
804
                    src=out,
805
                    dst=orig_out,
806
                    sycl_queue=exec_q,
807
                    depends=[binary_ev],
808
                )
809
                ht_copy_out_ev.wait()
×
810
                out = orig_out
×
811
            ht_copy_ev.wait()
1✔
812
            ht_binary_ev.wait()
1✔
813
            return out
1✔
814
        elif buf2_dt is None:
1✔
815
            if order == "K":
1!
816
                buf1 = _empty_like_orderK(src1, buf1_dt)
1✔
817
            else:
818
                buf1 = dpt.empty_like(src1, dtype=buf1_dt, order=order)
×
819
            ht_copy_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
1✔
820
                src=src1, dst=buf1, sycl_queue=exec_q
821
            )
822
            if out is None:
1✔
823
                if order == "K":
1!
824
                    out = _empty_like_pair_orderK(
1✔
825
                        buf1, src2, res_dt, res_shape, res_usm_type, exec_q
826
                    )
827
                else:
828
                    out = dpt.empty(
×
829
                        res_shape,
830
                        dtype=res_dt,
831
                        usm_type=res_usm_type,
832
                        sycl_queue=exec_q,
833
                        order=order,
834
                    )
835

836
            buf1 = dpt.broadcast_to(buf1, res_shape)
1✔
837
            if src2.shape != res_shape:
1✔
838
                src2 = dpt.broadcast_to(src2, res_shape)
1✔
839
            ht_binary_ev, binary_ev = self.binary_fn_(
1✔
840
                src1=buf1,
841
                src2=src2,
842
                dst=out,
843
                sycl_queue=exec_q,
844
                depends=[copy_ev],
845
            )
846
            if not (orig_out is None or orig_out is out):
1!
847
                # Copy the out data from temporary buffer to original memory
848
                ht_copy_out_ev, _ = ti._copy_usm_ndarray_into_usm_ndarray(
×
849
                    src=out,
850
                    dst=orig_out,
851
                    sycl_queue=exec_q,
852
                    depends=[binary_ev],
853
                )
854
                ht_copy_out_ev.wait()
×
855
                out = orig_out
×
856
            ht_copy_ev.wait()
1✔
857
            ht_binary_ev.wait()
1✔
858
            return out
1✔
859

860
        if order == "K":
1✔
861
            if src1.flags.f_contiguous and src2.flags.f_contiguous:
1✔
862
                order = "F"
1✔
863
            elif src1.flags.c_contiguous and src2.flags.c_contiguous:
1✔
864
                order = "C"
1✔
865
        if order == "K":
1✔
866
            buf1 = _empty_like_orderK(src1, buf1_dt)
1✔
867
        else:
868
            buf1 = dpt.empty_like(src1, dtype=buf1_dt, order=order)
1✔
869
        ht_copy1_ev, copy1_ev = ti._copy_usm_ndarray_into_usm_ndarray(
1✔
870
            src=src1, dst=buf1, sycl_queue=exec_q
871
        )
872
        if order == "K":
1✔
873
            buf2 = _empty_like_orderK(src2, buf2_dt)
1✔
874
        else:
875
            buf2 = dpt.empty_like(src2, dtype=buf2_dt, order=order)
1✔
876
        ht_copy2_ev, copy2_ev = ti._copy_usm_ndarray_into_usm_ndarray(
1✔
877
            src=src2, dst=buf2, sycl_queue=exec_q
878
        )
879
        if out is None:
1✔
880
            if order == "K":
1✔
881
                out = _empty_like_pair_orderK(
1✔
882
                    buf1, buf2, res_dt, res_shape, res_usm_type, exec_q
883
                )
884
            else:
885
                out = dpt.empty(
1✔
886
                    res_shape,
887
                    dtype=res_dt,
888
                    usm_type=res_usm_type,
889
                    sycl_queue=exec_q,
890
                    order=order,
891
                )
892

893
        buf1 = dpt.broadcast_to(buf1, res_shape)
1✔
894
        buf2 = dpt.broadcast_to(buf2, res_shape)
1✔
895
        ht_, _ = self.binary_fn_(
1✔
896
            src1=buf1,
897
            src2=buf2,
898
            dst=out,
899
            sycl_queue=exec_q,
900
            depends=[copy1_ev, copy2_ev],
901
        )
902
        dpctl.SyclEvent.wait_for([ht_copy1_ev, ht_copy2_ev, ht_])
1✔
903
        return out
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc