• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

IntelPython / dpctl / 10820219709

11 Sep 2024 09:51PM UTC coverage: 87.907% (+0.01%) from 87.893%
10820219709

push

github

web-flow
Merge pull request #1827 from IntelPython/in-place-element-wise-func-casting

Permit `"same_kind"` casting for `usm_ndarray` element-wise in-place operators

3409 of 3922 branches covered (86.92%)

Branch coverage included in aggregate %.

78 of 82 new or added lines in 3 files covered. (95.12%)

2 existing lines in 1 file now uncovered.

11711 of 13278 relevant lines covered (88.2%)

7086.32 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

92.81
/dpctl/tensor/_elementwise_common.py
1
#                       Data Parallel Control (dpctl)
2
#
3
#  Copyright 2020-2024 Intel Corporation
4
#
5
#  Licensed under the Apache License, Version 2.0 (the "License");
6
#  you may not use this file except in compliance with the License.
7
#  You may obtain a copy of the License at
8
#
9
#     http://www.apache.org/licenses/LICENSE-2.0
10
#
11
#  Unless required by applicable law or agreed to in writing, software
12
#  distributed under the License is distributed on an "AS IS" BASIS,
13
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
#  See the License for the specific language governing permissions and
15
#  limitations under the License.
16

17
import numbers
1✔
18

19
import numpy as np
1✔
20

21
import dpctl
1✔
22
import dpctl.memory as dpm
1✔
23
import dpctl.tensor as dpt
1✔
24
import dpctl.tensor._tensor_impl as ti
1✔
25
from dpctl.tensor._manipulation_functions import _broadcast_shape_impl
1✔
26
from dpctl.tensor._usmarray import _is_object_with_buffer_protocol as _is_buffer
1✔
27
from dpctl.utils import ExecutionPlacementError, SequentialOrderManager
1✔
28

29
from ._copy_utils import _empty_like_orderK, _empty_like_pair_orderK
1✔
30
from ._type_utils import (
1✔
31
    WeakBooleanType,
32
    WeakComplexType,
33
    WeakFloatingType,
34
    WeakIntegralType,
35
    _acceptance_fn_default_binary,
36
    _acceptance_fn_default_unary,
37
    _all_data_types,
38
    _find_buf_dtype,
39
    _find_buf_dtype2,
40
    _find_buf_dtype_in_place_op,
41
    _resolve_weak_types,
42
    _to_device_supported_dtype,
43
)
44

45

46
class UnaryElementwiseFunc:
1✔
47
    """
48
    Class that implements unary element-wise functions.
49

50
    Args:
51
        name (str):
52
            Name of the unary function
53
        result_type_resovler_fn (callable):
54
            Function that takes dtype of the input and
55
            returns the dtype of the result if the
56
            implementation functions supports it, or
57
            returns `None` otherwise.
58
        unary_dp_impl_fn (callable):
59
            Data-parallel implementation function with signature
60
            `impl_fn(src: usm_ndarray, dst: usm_ndarray,
61
             sycl_queue: SyclQueue, depends: Optional[List[SyclEvent]])`
62
            where the `src` is the argument array, `dst` is the
63
            array to be populated with function values, effectively
64
            evaluating `dst = func(src)`.
65
            The `impl_fn` is expected to return a 2-tuple of `SyclEvent`s.
66
            The first event corresponds to data-management host tasks,
67
            including lifetime management of argument Python objects to ensure
68
            that their associated USM allocation is not freed before offloaded
69
            computational tasks complete execution, while the second event
70
            corresponds to computational tasks associated with function
71
            evaluation.
72
        acceptance_fn (callable, optional):
73
            Function to influence type promotion behavior of this unary
74
            function. The function takes 4 arguments:
75
                arg_dtype - Data type of the first argument
76
                buf_dtype - Data type the argument would be cast to
77
                res_dtype - Data type of the output array with function values
78
                sycl_dev - The :class:`dpctl.SyclDevice` where the function
79
                    evaluation is carried out.
80
            The function is invoked when the argument of the unary function
81
            requires casting, e.g. the argument of `dpctl.tensor.log` is an
82
            array with integral data type.
83
        docs (str):
84
            Documentation string for the unary function.
85
    """
86

87
    def __init__(
1✔
88
        self,
89
        name,
90
        result_type_resolver_fn,
91
        unary_dp_impl_fn,
92
        docs,
93
        acceptance_fn=None,
94
    ):
95
        self.__name__ = "UnaryElementwiseFunc"
1✔
96
        self.name_ = name
1✔
97
        self.result_type_resolver_fn_ = result_type_resolver_fn
1✔
98
        self.types_ = None
1✔
99
        self.unary_fn_ = unary_dp_impl_fn
1✔
100
        self.__doc__ = docs
1✔
101
        if callable(acceptance_fn):
1✔
102
            self.acceptance_fn_ = acceptance_fn
1✔
103
        else:
104
            self.acceptance_fn_ = _acceptance_fn_default_unary
1✔
105

106
    def __str__(self):
1✔
107
        return f"<{self.__name__} '{self.name_}'>"
1✔
108

109
    def __repr__(self):
1✔
110
        return f"<{self.__name__} '{self.name_}'>"
1✔
111

112
    def get_implementation_function(self):
1✔
113
        """Returns the implementation function for
114
        this elementwise unary function.
115

116
        """
117
        return self.unary_fn_
1✔
118

119
    def get_type_result_resolver_function(self):
1✔
120
        """Returns the type resolver function for this
121
        elementwise unary function.
122
        """
123
        return self.result_type_resolver_fn_
1✔
124

125
    def get_type_promotion_path_acceptance_function(self):
1✔
126
        """Returns the acceptance function for this
127
        elementwise binary function.
128

129
        Acceptance function influences the type promotion
130
        behavior of this unary function.
131
        The function takes 4 arguments:
132
            arg_dtype - Data type of the first argument
133
            buf_dtype - Data type the argument would be cast to
134
            res_dtype - Data type of the output array with function values
135
            sycl_dev - The :class:`dpctl.SyclDevice` where the function
136
                evaluation is carried out.
137
        The function is invoked when the argument of the unary function
138
        requires casting, e.g. the argument of `dpctl.tensor.log` is an
139
        array with integral data type.
140
        """
141
        return self.acceptance_fn_
×
142

143
    @property
1✔
144
    def nin(self):
1✔
145
        """
146
        Returns the number of arguments treated as inputs.
147
        """
148
        return 1
1✔
149

150
    @property
1✔
151
    def nout(self):
1✔
152
        """
153
        Returns the number of arguments treated as outputs.
154
        """
155
        return 1
1✔
156

157
    @property
1✔
158
    def types(self):
1✔
159
        """Returns information about types supported by
160
        implementation function, using NumPy's character
161
        encoding for data types, e.g.
162

163
        :Example:
164
            .. code-block:: python
165

166
                dpctl.tensor.sin.types
167
                # Outputs: ['e->e', 'f->f', 'd->d', 'F->F', 'D->D']
168
        """
169
        types = self.types_
1✔
170
        if not types:
1!
171
            types = []
1✔
172
            for dt1 in _all_data_types(True, True):
1✔
173
                dt2 = self.result_type_resolver_fn_(dt1)
1✔
174
                if dt2:
1✔
175
                    types.append(f"{dt1.char}->{dt2.char}")
1✔
176
            self.types_ = types
1✔
177
        return types
1✔
178

179
    def __call__(self, x, /, *, out=None, order="K"):
1✔
180
        if not isinstance(x, dpt.usm_ndarray):
1✔
181
            raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x)}")
1✔
182

183
        if order not in ["C", "F", "K", "A"]:
1✔
184
            order = "K"
1✔
185
        buf_dt, res_dt = _find_buf_dtype(
1✔
186
            x.dtype,
187
            self.result_type_resolver_fn_,
188
            x.sycl_device,
189
            acceptance_fn=self.acceptance_fn_,
190
        )
191
        if res_dt is None:
1!
192
            raise ValueError(
×
193
                f"function '{self.name_}' does not support input type "
194
                f"({x.dtype}), "
195
                "and the input could not be safely coerced to any "
196
                "supported types according to the casting rule ''safe''."
197
            )
198

199
        orig_out = out
1✔
200
        if out is not None:
1✔
201
            if not isinstance(out, dpt.usm_ndarray):
1!
202
                raise TypeError(
×
203
                    f"output array must be of usm_ndarray type, got {type(out)}"
204
                )
205

206
            if not out.flags.writable:
1✔
207
                raise ValueError("provided `out` array is read-only")
1✔
208

209
            if out.shape != x.shape:
1!
210
                raise ValueError(
×
211
                    "The shape of input and output arrays are inconsistent. "
212
                    f"Expected output shape is {x.shape}, got {out.shape}"
213
                )
214

215
            if res_dt != out.dtype:
1✔
216
                raise ValueError(
1✔
217
                    f"Output array of type {res_dt} is needed, "
218
                    f"got {out.dtype}"
219
                )
220

221
            if (
1✔
222
                buf_dt is None
223
                and ti._array_overlap(x, out)
224
                and not ti._same_logical_tensors(x, out)
225
            ):
226
                # Allocate a temporary buffer to avoid memory overlapping.
227
                # Note if `buf_dt` is not None, a temporary copy of `x` will be
228
                # created, so the array overlap check isn't needed.
229
                out = dpt.empty_like(out)
1✔
230

231
            if (
1!
232
                dpctl.utils.get_execution_queue((x.sycl_queue, out.sycl_queue))
233
                is None
234
            ):
235
                raise ExecutionPlacementError(
×
236
                    "Input and output allocation queues are not compatible"
237
                )
238

239
        exec_q = x.sycl_queue
1✔
240
        _manager = SequentialOrderManager[exec_q]
1✔
241
        if buf_dt is None:
1✔
242
            if out is None:
1✔
243
                if order == "K":
1✔
244
                    out = _empty_like_orderK(x, res_dt)
1✔
245
                else:
246
                    if order == "A":
1✔
247
                        order = "F" if x.flags.f_contiguous else "C"
1✔
248
                    out = dpt.empty_like(x, dtype=res_dt, order=order)
1✔
249

250
            dep_evs = _manager.submitted_events
1✔
251
            ht_unary_ev, unary_ev = self.unary_fn_(
1✔
252
                x, out, sycl_queue=exec_q, depends=dep_evs
253
            )
254
            _manager.add_event_pair(ht_unary_ev, unary_ev)
1✔
255

256
            if not (orig_out is None or orig_out is out):
1✔
257
                # Copy the out data from temporary buffer to original memory
258
                ht_copy_ev, cpy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
1✔
259
                    src=out, dst=orig_out, sycl_queue=exec_q, depends=[unary_ev]
260
                )
261
                _manager.add_event_pair(ht_copy_ev, cpy_ev)
1✔
262
                out = orig_out
1✔
263

264
            return out
1✔
265

266
        if order == "K":
1✔
267
            buf = _empty_like_orderK(x, buf_dt)
1✔
268
        else:
269
            if order == "A":
1✔
270
                order = "F" if x.flags.f_contiguous else "C"
1✔
271
            buf = dpt.empty_like(x, dtype=buf_dt, order=order)
1✔
272

273
        dep_evs = _manager.submitted_events
1✔
274
        ht_copy_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
1✔
275
            src=x, dst=buf, sycl_queue=exec_q, depends=dep_evs
276
        )
277
        _manager.add_event_pair(ht_copy_ev, copy_ev)
1✔
278
        if out is None:
1✔
279
            if order == "K":
1✔
280
                out = _empty_like_orderK(buf, res_dt)
1✔
281
            else:
282
                out = dpt.empty_like(buf, dtype=res_dt, order=order)
1✔
283

284
        ht, uf_ev = self.unary_fn_(
1✔
285
            buf, out, sycl_queue=exec_q, depends=[copy_ev]
286
        )
287
        _manager.add_event_pair(ht, uf_ev)
1✔
288

289
        return out
1✔
290

291

292
def _get_queue_usm_type(o):
1✔
293
    """Return SYCL device where object `o` allocated memory, or None."""
294
    if isinstance(o, dpt.usm_ndarray):
1✔
295
        return o.sycl_queue, o.usm_type
1✔
296
    elif hasattr(o, "__sycl_usm_array_interface__"):
1✔
297
        try:
1✔
298
            m = dpm.as_usm_memory(o)
1✔
299
            return m.sycl_queue, m.get_usm_type()
1✔
300
        except Exception:
1✔
301
            return None, None
1✔
302
    return None, None
1✔
303

304

305
def _get_dtype(o, dev):
1✔
306
    if isinstance(o, dpt.usm_ndarray):
1✔
307
        return o.dtype
1✔
308
    if hasattr(o, "__sycl_usm_array_interface__"):
1✔
309
        return dpt.asarray(o).dtype
1✔
310
    if _is_buffer(o):
1✔
311
        host_dt = np.array(o).dtype
1✔
312
        dev_dt = _to_device_supported_dtype(host_dt, dev)
1✔
313
        return dev_dt
1✔
314
    if hasattr(o, "dtype"):
1!
315
        dev_dt = _to_device_supported_dtype(o.dtype, dev)
×
316
        return dev_dt
×
317
    if isinstance(o, bool):
1✔
318
        return WeakBooleanType(o)
1✔
319
    if isinstance(o, int):
1✔
320
        return WeakIntegralType(o)
1✔
321
    if isinstance(o, float):
1✔
322
        return WeakFloatingType(o)
1✔
323
    if isinstance(o, complex):
1✔
324
        return WeakComplexType(o)
1✔
325
    return np.object_
1✔
326

327

328
def _validate_dtype(dt) -> bool:
1✔
329
    return isinstance(
1✔
330
        dt,
331
        (WeakBooleanType, WeakIntegralType, WeakFloatingType, WeakComplexType),
332
    ) or (
333
        isinstance(dt, dpt.dtype)
334
        and dt
335
        in [
336
            dpt.bool,
337
            dpt.int8,
338
            dpt.uint8,
339
            dpt.int16,
340
            dpt.uint16,
341
            dpt.int32,
342
            dpt.uint32,
343
            dpt.int64,
344
            dpt.uint64,
345
            dpt.float16,
346
            dpt.float32,
347
            dpt.float64,
348
            dpt.complex64,
349
            dpt.complex128,
350
        ]
351
    )
352

353

354
def _get_shape(o):
1✔
355
    if isinstance(o, dpt.usm_ndarray):
1✔
356
        return o.shape
1✔
357
    if _is_buffer(o):
1✔
358
        return memoryview(o).shape
1✔
359
    if isinstance(o, numbers.Number):
1✔
360
        return tuple()
1✔
361
    return getattr(o, "shape", tuple())
1✔
362

363

364
class BinaryElementwiseFunc:
1✔
365
    """
366
    Class that implements binary element-wise functions.
367

368
    Args:
369
        name (str):
370
            Name of the unary function
371
        result_type_resovle_fn (callable):
372
            Function that takes dtypes of the input and
373
            returns the dtype of the result if the
374
            implementation functions supports it, or
375
            returns `None` otherwise.
376
        binary_dp_impl_fn (callable):
377
            Data-parallel implementation function with signature
378
            `impl_fn(src1: usm_ndarray, src2: usm_ndarray, dst: usm_ndarray,
379
             sycl_queue: SyclQueue, depends: Optional[List[SyclEvent]])`
380
            where the `src1` and `src2` are the argument arrays, `dst` is the
381
            array to be populated with function values,
382
            i.e. `dst=func(src1, src2)`.
383
            The `impl_fn` is expected to return a 2-tuple of `SyclEvent`s.
384
            The first event corresponds to data-management host tasks,
385
            including lifetime management of argument Python objects to ensure
386
            that their associated USM allocation is not freed before offloaded
387
            computational tasks complete execution, while the second event
388
            corresponds to computational tasks associated with function
389
            evaluation.
390
        docs (str):
391
            Documentation string for the unary function.
392
        binary_inplace_fn (callable, optional):
393
            Data-parallel implementation function with signature
394
            `impl_fn(src: usm_ndarray, dst: usm_ndarray,
395
             sycl_queue: SyclQueue, depends: Optional[List[SyclEvent]])`
396
            where the `src` is the argument array, `dst` is the
397
            array to be populated with function values,
398
            i.e. `dst=func(dst, src)`.
399
            The `impl_fn` is expected to return a 2-tuple of `SyclEvent`s.
400
            The first event corresponds to data-management host tasks,
401
            including async lifetime management of Python arguments,
402
            while the second event corresponds to computational tasks
403
            associated with function evaluation.
404
        acceptance_fn (callable, optional):
405
            Function to influence type promotion behavior of this binary
406
            function. The function takes 6 arguments:
407
                arg1_dtype - Data type of the first argument
408
                arg2_dtype - Data type of the second argument
409
                ret_buf1_dtype - Data type the first argument would be cast to
410
                ret_buf2_dtype - Data type the second argument would be cast to
411
                res_dtype - Data type of the output array with function values
412
                sycl_dev - The :class:`dpctl.SyclDevice` where the function
413
                    evaluation is carried out.
414
            The function is only called when both arguments of the binary
415
            function require casting, e.g. both arguments of
416
            `dpctl.tensor.logaddexp` are arrays with integral data type.
417
    """
418

419
    def __init__(
1✔
420
        self,
421
        name,
422
        result_type_resolver_fn,
423
        binary_dp_impl_fn,
424
        docs,
425
        binary_inplace_fn=None,
426
        acceptance_fn=None,
427
        weak_type_resolver=None,
428
    ):
429
        self.__name__ = "BinaryElementwiseFunc"
1✔
430
        self.name_ = name
1✔
431
        self.result_type_resolver_fn_ = result_type_resolver_fn
1✔
432
        self.types_ = None
1✔
433
        self.binary_fn_ = binary_dp_impl_fn
1✔
434
        self.binary_inplace_fn_ = binary_inplace_fn
1✔
435
        self.__doc__ = docs
1✔
436
        if callable(acceptance_fn):
1✔
437
            self.acceptance_fn_ = acceptance_fn
1✔
438
        else:
439
            self.acceptance_fn_ = _acceptance_fn_default_binary
1✔
440
        if callable(weak_type_resolver):
1✔
441
            self.weak_type_resolver_ = weak_type_resolver
1✔
442
        else:
443
            self.weak_type_resolver_ = _resolve_weak_types
1✔
444

445
    def __str__(self):
1✔
446
        return f"<{self.__name__} '{self.name_}'>"
1✔
447

448
    def __repr__(self):
1✔
449
        return f"<{self.__name__} '{self.name_}'>"
1✔
450

451
    def get_implementation_function(self):
1✔
452
        """Returns the out-of-place implementation
453
        function for this elementwise binary function.
454

455
        """
456
        return self.binary_fn_
1✔
457

458
    def get_implementation_inplace_function(self):
1✔
459
        """Returns the in-place implementation
460
        function for this elementwise binary function.
461

462
        """
463
        return self.binary_inplace_fn_
1✔
464

465
    def get_type_result_resolver_function(self):
1✔
466
        """Returns the type resolver function for this
467
        elementwise binary function.
468
        """
469
        return self.result_type_resolver_fn_
1✔
470

471
    def get_type_promotion_path_acceptance_function(self):
1✔
472
        """Returns the acceptance function for this
473
        elementwise binary function.
474

475
        Acceptance function influences the type promotion
476
        behavior of this binary function.
477
        The function takes 6 arguments:
478
            arg1_dtype - Data type of the first argument
479
            arg2_dtype - Data type of the second argument
480
            ret_buf1_dtype - Data type the first argument would be cast to
481
            ret_buf2_dtype - Data type the second argument would be cast to
482
            res_dtype - Data type of the output array with function values
483
            sycl_dev - :class:`dpctl.SyclDevice` on which function evaluation
484
                is carried out.
485

486
        The acceptance function is only invoked if both input arrays must be
487
        cast to intermediary data types, as would happen during call of
488
        `dpctl.tensor.hypot` with both arrays being of integral data type.
489
        """
490
        return self.acceptance_fn_
1✔
491

492
    def get_array_dtype_scalar_type_resolver_function(self):
1✔
493
        """Returns the function which determines how to treat
494
        Python scalar types for this elementwise binary function.
495

496
        Resolver influences what type the scalar will be
497
        treated as prior to type promotion behavior.
498
        The function takes 3 arguments:
499

500
        Args:
501
            o1_dtype (object, dtype):
502
                A class representing a Python scalar type or a ``dtype``
503
            o2_dtype (object, dtype):
504
                A class representing a Python scalar type or a ``dtype``
505
            sycl_dev (:class:`dpctl.SyclDevice`):
506
                Device on which function evaluation is carried out.
507

508
        One of ``o1_dtype`` and ``o2_dtype`` must be a ``dtype`` instance.
509
        """
510
        return self.weak_type_resolver_
×
511

512
    @property
1✔
513
    def nin(self):
1✔
514
        """
515
        Returns the number of arguments treated as inputs.
516
        """
517
        return 2
1✔
518

519
    @property
1✔
520
    def nout(self):
1✔
521
        """
522
        Returns the number of arguments treated as outputs.
523
        """
524
        return 1
1✔
525

526
    @property
1✔
527
    def types(self):
1✔
528
        """Returns information about types supported by
529
        implementation function, using NumPy's character
530
        encoding for data types, e.g.
531

532
        :Example:
533
            .. code-block:: python
534

535
                dpctl.tensor.divide.types
536
                # Outputs: ['ee->e', 'ff->f', 'fF->F', 'dd->d', 'dD->D',
537
                #    'Ff->F', 'FF->F', 'Dd->D', 'DD->D']
538
        """
539
        types = self.types_
1✔
540
        if not types:
1!
541
            types = []
1✔
542
            _all_dtypes = _all_data_types(True, True)
1✔
543
            for dt1 in _all_dtypes:
1✔
544
                for dt2 in _all_dtypes:
1✔
545
                    dt3 = self.result_type_resolver_fn_(dt1, dt2)
1✔
546
                    if dt3:
1✔
547
                        types.append(f"{dt1.char}{dt2.char}->{dt3.char}")
1✔
548
            self.types_ = types
1✔
549
        return types
1✔
550

551
    def __call__(self, o1, o2, /, *, out=None, order="K"):
1✔
552
        if order not in ["K", "C", "F", "A"]:
1✔
553
            order = "K"
1✔
554
        q1, o1_usm_type = _get_queue_usm_type(o1)
1✔
555
        q2, o2_usm_type = _get_queue_usm_type(o2)
1✔
556
        if q1 is None and q2 is None:
1✔
557
            raise ExecutionPlacementError(
1✔
558
                "Execution placement can not be unambiguously inferred "
559
                "from input arguments. "
560
                "One of the arguments must represent USM allocation and "
561
                "expose `__sycl_usm_array_interface__` property"
562
            )
563
        if q1 is None:
1✔
564
            exec_q = q2
1✔
565
            res_usm_type = o2_usm_type
1✔
566
        elif q2 is None:
1✔
567
            exec_q = q1
1✔
568
            res_usm_type = o1_usm_type
1✔
569
        else:
570
            exec_q = dpctl.utils.get_execution_queue((q1, q2))
1✔
571
            if exec_q is None:
1✔
572
                raise ExecutionPlacementError(
1✔
573
                    "Execution placement can not be unambiguously inferred "
574
                    "from input arguments."
575
                )
576
            res_usm_type = dpctl.utils.get_coerced_usm_type(
1✔
577
                (
578
                    o1_usm_type,
579
                    o2_usm_type,
580
                )
581
            )
582
        dpctl.utils.validate_usm_type(res_usm_type, allow_none=False)
1✔
583
        o1_shape = _get_shape(o1)
1✔
584
        o2_shape = _get_shape(o2)
1✔
585
        if not all(
1!
586
            isinstance(s, (tuple, list))
587
            for s in (
588
                o1_shape,
589
                o2_shape,
590
            )
591
        ):
592
            raise TypeError(
×
593
                "Shape of arguments can not be inferred. "
594
                "Arguments are expected to be "
595
                "lists, tuples, or both"
596
            )
597
        try:
1✔
598
            res_shape = _broadcast_shape_impl(
1✔
599
                [
600
                    o1_shape,
601
                    o2_shape,
602
                ]
603
            )
604
        except ValueError:
1✔
605
            raise ValueError(
1✔
606
                "operands could not be broadcast together with shapes "
607
                f"{o1_shape} and {o2_shape}"
608
            )
609
        sycl_dev = exec_q.sycl_device
1✔
610
        o1_dtype = _get_dtype(o1, sycl_dev)
1✔
611
        o2_dtype = _get_dtype(o2, sycl_dev)
1✔
612
        if not all(_validate_dtype(o) for o in (o1_dtype, o2_dtype)):
1✔
613
            raise ValueError("Operands have unsupported data types")
1✔
614

615
        o1_dtype, o2_dtype = self.weak_type_resolver_(
1✔
616
            o1_dtype, o2_dtype, sycl_dev
617
        )
618

619
        buf1_dt, buf2_dt, res_dt = _find_buf_dtype2(
1✔
620
            o1_dtype,
621
            o2_dtype,
622
            self.result_type_resolver_fn_,
623
            sycl_dev,
624
            acceptance_fn=self.acceptance_fn_,
625
        )
626

627
        if res_dt is None:
1✔
628
            raise ValueError(
1✔
629
                f"function '{self.name_}' does not support input types "
630
                f"({o1_dtype}, {o2_dtype}), "
631
                "and the inputs could not be safely coerced to any "
632
                "supported types according to the casting rule ''safe''."
633
            )
634

635
        orig_out = out
1✔
636
        _manager = SequentialOrderManager[exec_q]
1✔
637
        if out is not None:
1✔
638
            if not isinstance(out, dpt.usm_ndarray):
1✔
639
                raise TypeError(
1✔
640
                    f"output array must be of usm_ndarray type, got {type(out)}"
641
                )
642

643
            if not out.flags.writable:
1✔
644
                raise ValueError("provided `out` array is read-only")
1✔
645

646
            if out.shape != res_shape:
1✔
647
                raise ValueError(
1✔
648
                    "The shape of input and output arrays are inconsistent. "
649
                    f"Expected output shape is {res_shape}, got {out.shape}"
650
                )
651

652
            if res_dt != out.dtype:
1✔
653
                raise ValueError(
1✔
654
                    f"Output array of type {res_dt} is needed, "
655
                    f"got {out.dtype}"
656
                )
657

658
            if (
1✔
659
                dpctl.utils.get_execution_queue((exec_q, out.sycl_queue))
660
                is None
661
            ):
662
                raise ExecutionPlacementError(
1✔
663
                    "Input and output allocation queues are not compatible"
664
                )
665

666
            if isinstance(o1, dpt.usm_ndarray):
1!
667
                if ti._array_overlap(o1, out) and buf1_dt is None:
1✔
668
                    if not ti._same_logical_tensors(o1, out):
1✔
669
                        out = dpt.empty_like(out)
1✔
670
                    elif self.binary_inplace_fn_ is not None:
1!
671
                        # if there is a dedicated in-place kernel
672
                        # it can be called here, otherwise continues
673
                        if isinstance(o2, dpt.usm_ndarray):
1!
674
                            src2 = o2
1✔
675
                            if (
1!
676
                                ti._array_overlap(o2, out)
677
                                and not ti._same_logical_tensors(o2, out)
678
                                and buf2_dt is None
679
                            ):
UNCOV
680
                                buf2_dt = o2_dtype
×
681
                        else:
UNCOV
682
                            src2 = dpt.asarray(
×
683
                                o2, dtype=o2_dtype, sycl_queue=exec_q
684
                            )
685
                        if buf2_dt is None:
1✔
686
                            if src2.shape != res_shape:
1✔
687
                                src2 = dpt.broadcast_to(src2, res_shape)
1✔
688
                            dep_evs = _manager.submitted_events
1✔
689
                            ht_, comp_ev = self.binary_inplace_fn_(
1✔
690
                                lhs=o1,
691
                                rhs=src2,
692
                                sycl_queue=exec_q,
693
                                depends=dep_evs,
694
                            )
695
                            _manager.add_event_pair(ht_, comp_ev)
1✔
696
                        else:
697
                            buf2 = dpt.empty_like(src2, dtype=buf2_dt)
1✔
698
                            dep_evs = _manager.submitted_events
1✔
699
                            (
1✔
700
                                ht_copy_ev,
701
                                copy_ev,
702
                            ) = ti._copy_usm_ndarray_into_usm_ndarray(
703
                                src=src2,
704
                                dst=buf2,
705
                                sycl_queue=exec_q,
706
                                depends=dep_evs,
707
                            )
708
                            _manager.add_event_pair(ht_copy_ev, copy_ev)
1✔
709

710
                            buf2 = dpt.broadcast_to(buf2, res_shape)
1✔
711
                            ht_, bf_ev = self.binary_inplace_fn_(
1✔
712
                                lhs=o1,
713
                                rhs=buf2,
714
                                sycl_queue=exec_q,
715
                                depends=[copy_ev],
716
                            )
717
                            _manager.add_event_pair(ht_, bf_ev)
1✔
718

719
                        return out
1✔
720

721
            if isinstance(o2, dpt.usm_ndarray):
1✔
722
                if (
1!
723
                    ti._array_overlap(o2, out)
724
                    and not ti._same_logical_tensors(o2, out)
725
                    and buf2_dt is None
726
                ):
727
                    # should not reach if out is reallocated
728
                    # after being checked against o1
729
                    out = dpt.empty_like(out)
×
730

731
        if isinstance(o1, dpt.usm_ndarray):
1✔
732
            src1 = o1
1✔
733
        else:
734
            src1 = dpt.asarray(o1, dtype=o1_dtype, sycl_queue=exec_q)
1✔
735
        if isinstance(o2, dpt.usm_ndarray):
1✔
736
            src2 = o2
1✔
737
        else:
738
            src2 = dpt.asarray(o2, dtype=o2_dtype, sycl_queue=exec_q)
1✔
739

740
        if order == "A":
1✔
741
            order = (
1✔
742
                "F"
743
                if all(
744
                    arr.flags.f_contiguous
745
                    for arr in (
746
                        src1,
747
                        src2,
748
                    )
749
                )
750
                else "C"
751
            )
752

753
        if buf1_dt is None and buf2_dt is None:
1✔
754
            if out is None:
1✔
755
                if order == "K":
1✔
756
                    out = _empty_like_pair_orderK(
1✔
757
                        src1, src2, res_dt, res_shape, res_usm_type, exec_q
758
                    )
759
                else:
760
                    out = dpt.empty(
1✔
761
                        res_shape,
762
                        dtype=res_dt,
763
                        usm_type=res_usm_type,
764
                        sycl_queue=exec_q,
765
                        order=order,
766
                    )
767
            if src1.shape != res_shape:
1✔
768
                src1 = dpt.broadcast_to(src1, res_shape)
1✔
769
            if src2.shape != res_shape:
1✔
770
                src2 = dpt.broadcast_to(src2, res_shape)
1✔
771
            deps_ev = _manager.submitted_events
1✔
772
            ht_binary_ev, binary_ev = self.binary_fn_(
1✔
773
                src1=src1,
774
                src2=src2,
775
                dst=out,
776
                sycl_queue=exec_q,
777
                depends=deps_ev,
778
            )
779
            _manager.add_event_pair(ht_binary_ev, binary_ev)
1✔
780
            if not (orig_out is None or orig_out is out):
1✔
781
                # Copy the out data from temporary buffer to original memory
782
                ht_copy_out_ev, cpy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
1✔
783
                    src=out,
784
                    dst=orig_out,
785
                    sycl_queue=exec_q,
786
                    depends=[binary_ev],
787
                )
788
                _manager.add_event_pair(ht_copy_out_ev, cpy_ev)
1✔
789
                out = orig_out
1✔
790
            return out
1✔
791
        elif buf1_dt is None:
1✔
792
            if order == "K":
1!
793
                buf2 = _empty_like_orderK(src2, buf2_dt)
1✔
794
            else:
795
                buf2 = dpt.empty_like(src2, dtype=buf2_dt, order=order)
×
796
            dep_evs = _manager.submitted_events
1✔
797
            ht_copy_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
1✔
798
                src=src2, dst=buf2, sycl_queue=exec_q, depends=dep_evs
799
            )
800
            _manager.add_event_pair(ht_copy_ev, copy_ev)
1✔
801
            if out is None:
1✔
802
                if order == "K":
1!
803
                    out = _empty_like_pair_orderK(
1✔
804
                        src1, buf2, res_dt, res_shape, res_usm_type, exec_q
805
                    )
806
                else:
807
                    out = dpt.empty(
×
808
                        res_shape,
809
                        dtype=res_dt,
810
                        usm_type=res_usm_type,
811
                        sycl_queue=exec_q,
812
                        order=order,
813
                    )
814

815
            if src1.shape != res_shape:
1✔
816
                src1 = dpt.broadcast_to(src1, res_shape)
1✔
817
            buf2 = dpt.broadcast_to(buf2, res_shape)
1✔
818
            ht_binary_ev, binary_ev = self.binary_fn_(
1✔
819
                src1=src1,
820
                src2=buf2,
821
                dst=out,
822
                sycl_queue=exec_q,
823
                depends=[copy_ev],
824
            )
825
            _manager.add_event_pair(ht_binary_ev, binary_ev)
1✔
826
            if not (orig_out is None or orig_out is out):
1!
827
                # Copy the out data from temporary buffer to original memory
828
                ht_copy_out_ev, cpy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
×
829
                    src=out,
830
                    dst=orig_out,
831
                    sycl_queue=exec_q,
832
                    depends=[binary_ev],
833
                )
834
                _manager.add_event_pair(ht_copy_out_ev, cpy_ev)
×
835
                out = orig_out
×
836
            return out
1✔
837
        elif buf2_dt is None:
1✔
838
            if order == "K":
1!
839
                buf1 = _empty_like_orderK(src1, buf1_dt)
1✔
840
            else:
841
                buf1 = dpt.empty_like(src1, dtype=buf1_dt, order=order)
×
842
            dep_evs = _manager.submitted_events
1✔
843
            ht_copy_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
1✔
844
                src=src1, dst=buf1, sycl_queue=exec_q, depends=dep_evs
845
            )
846
            _manager.add_event_pair(ht_copy_ev, copy_ev)
1✔
847
            if out is None:
1✔
848
                if order == "K":
1!
849
                    out = _empty_like_pair_orderK(
1✔
850
                        buf1, src2, res_dt, res_shape, res_usm_type, exec_q
851
                    )
852
                else:
853
                    out = dpt.empty(
×
854
                        res_shape,
855
                        dtype=res_dt,
856
                        usm_type=res_usm_type,
857
                        sycl_queue=exec_q,
858
                        order=order,
859
                    )
860

861
            buf1 = dpt.broadcast_to(buf1, res_shape)
1✔
862
            if src2.shape != res_shape:
1✔
863
                src2 = dpt.broadcast_to(src2, res_shape)
1✔
864
            ht_binary_ev, binary_ev = self.binary_fn_(
1✔
865
                src1=buf1,
866
                src2=src2,
867
                dst=out,
868
                sycl_queue=exec_q,
869
                depends=[copy_ev],
870
            )
871
            _manager.add_event_pair(ht_binary_ev, binary_ev)
1✔
872
            if not (orig_out is None or orig_out is out):
1!
873
                # Copy the out data from temporary buffer to original memory
874
                ht_copy_out_ev, cpy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
×
875
                    src=out,
876
                    dst=orig_out,
877
                    sycl_queue=exec_q,
878
                    depends=[binary_ev],
879
                )
880
                _manager.add_event_pair(ht_copy_out_ev, cpy_ev)
×
881
                out = orig_out
×
882
            return out
1✔
883

884
        if order == "K":
1✔
885
            if src1.flags.c_contiguous and src2.flags.c_contiguous:
1✔
886
                order = "C"
1✔
887
            elif src1.flags.f_contiguous and src2.flags.f_contiguous:
1✔
888
                order = "F"
1✔
889
        if order == "K":
1✔
890
            buf1 = _empty_like_orderK(src1, buf1_dt)
1✔
891
        else:
892
            buf1 = dpt.empty_like(src1, dtype=buf1_dt, order=order)
1✔
893
        dep_evs = _manager.submitted_events
1✔
894
        ht_copy1_ev, copy1_ev = ti._copy_usm_ndarray_into_usm_ndarray(
1✔
895
            src=src1, dst=buf1, sycl_queue=exec_q, depends=dep_evs
896
        )
897
        _manager.add_event_pair(ht_copy1_ev, copy1_ev)
1✔
898
        if order == "K":
1✔
899
            buf2 = _empty_like_orderK(src2, buf2_dt)
1✔
900
        else:
901
            buf2 = dpt.empty_like(src2, dtype=buf2_dt, order=order)
1✔
902
        ht_copy2_ev, copy2_ev = ti._copy_usm_ndarray_into_usm_ndarray(
1✔
903
            src=src2, dst=buf2, sycl_queue=exec_q, depends=dep_evs
904
        )
905
        _manager.add_event_pair(ht_copy2_ev, copy2_ev)
1✔
906
        if out is None:
1✔
907
            if order == "K":
1✔
908
                out = _empty_like_pair_orderK(
1✔
909
                    buf1, buf2, res_dt, res_shape, res_usm_type, exec_q
910
                )
911
            else:
912
                out = dpt.empty(
1✔
913
                    res_shape,
914
                    dtype=res_dt,
915
                    usm_type=res_usm_type,
916
                    sycl_queue=exec_q,
917
                    order=order,
918
                )
919

920
        buf1 = dpt.broadcast_to(buf1, res_shape)
1✔
921
        buf2 = dpt.broadcast_to(buf2, res_shape)
1✔
922
        ht_, bf_ev = self.binary_fn_(
1✔
923
            src1=buf1,
924
            src2=buf2,
925
            dst=out,
926
            sycl_queue=exec_q,
927
            depends=[copy1_ev, copy2_ev],
928
        )
929
        _manager.add_event_pair(ht_, bf_ev)
1✔
930
        return out
1✔
931

932
    def _inplace_op(self, o1, o2):
1✔
933
        if self.binary_inplace_fn_ is None:
1✔
934
            raise ValueError(
1✔
935
                "binary function does not have a dedicated in-place "
936
                "implementation"
937
            )
938
        if not isinstance(o1, dpt.usm_ndarray):
1✔
939
            raise TypeError(
1✔
940
                "Expected first argument to be "
941
                f"dpctl.tensor.usm_ndarray, got {type(o1)}"
942
            )
943
        if not o1.flags.writable:
1✔
944
            raise ValueError("provided left-hand side array is read-only")
1✔
945
        q1, o1_usm_type = o1.sycl_queue, o1.usm_type
1✔
946
        q2, o2_usm_type = _get_queue_usm_type(o2)
1✔
947
        if q2 is None:
1✔
948
            exec_q = q1
1✔
949
            res_usm_type = o1_usm_type
1✔
950
        else:
951
            exec_q = dpctl.utils.get_execution_queue((q1, q2))
1✔
952
            if exec_q is None:
1✔
953
                raise ExecutionPlacementError(
1✔
954
                    "Execution placement can not be unambiguously inferred "
955
                    "from input arguments."
956
                )
957
            res_usm_type = dpctl.utils.get_coerced_usm_type(
1✔
958
                (
959
                    o1_usm_type,
960
                    o2_usm_type,
961
                )
962
            )
963
        dpctl.utils.validate_usm_type(res_usm_type, allow_none=False)
1✔
964
        o1_shape = o1.shape
1✔
965
        o2_shape = _get_shape(o2)
1✔
966
        if not isinstance(o2_shape, (tuple, list)):
1!
NEW
967
            raise TypeError(
×
968
                "Shape of second argument can not be inferred. "
969
                "Expected list or tuple."
970
            )
971
        try:
1✔
972
            res_shape = _broadcast_shape_impl(
1✔
973
                [
974
                    o1_shape,
975
                    o2_shape,
976
                ]
977
            )
NEW
978
        except ValueError:
×
NEW
979
            raise ValueError(
×
980
                "operands could not be broadcast together with shapes "
981
                f"{o1_shape} and {o2_shape}"
982
            )
983

984
        if res_shape != o1_shape:
1✔
985
            raise ValueError(
1✔
986
                "The shape of the non-broadcastable left-hand "
987
                f"side {o1_shape} is inconsistent with the "
988
                f"broadcast shape {res_shape}."
989
            )
990

991
        sycl_dev = exec_q.sycl_device
1✔
992
        o1_dtype = o1.dtype
1✔
993
        o2_dtype = _get_dtype(o2, sycl_dev)
1✔
994
        if not _validate_dtype(o2_dtype):
1!
NEW
995
            raise ValueError("Operand has an unsupported data type")
×
996

997
        o1_dtype, o2_dtype = self.weak_type_resolver_(
1✔
998
            o1_dtype, o2_dtype, sycl_dev
999
        )
1000

1001
        buf_dt, res_dt = _find_buf_dtype_in_place_op(
1✔
1002
            o1_dtype,
1003
            o2_dtype,
1004
            self.result_type_resolver_fn_,
1005
            sycl_dev,
1006
        )
1007

1008
        if res_dt is None:
1✔
1009
            raise ValueError(
1✔
1010
                f"function '{self.name_}' does not support input types "
1011
                f"({o1_dtype}, {o2_dtype}), "
1012
                "and the inputs could not be safely coerced to any "
1013
                "supported types according to the casting rule "
1014
                "''same_kind''."
1015
            )
1016

1017
        if res_dt != o1_dtype:
1✔
1018
            raise ValueError(
1✔
1019
                f"Output array of type {res_dt} is needed, " f"got {o1_dtype}"
1020
            )
1021

1022
        _manager = SequentialOrderManager[exec_q]
1✔
1023
        if isinstance(o2, dpt.usm_ndarray):
1✔
1024
            src2 = o2
1✔
1025
            if (
1✔
1026
                ti._array_overlap(o2, o1)
1027
                and not ti._same_logical_tensors(o2, o1)
1028
                and buf_dt is None
1029
            ):
1030
                buf_dt = o2_dtype
1✔
1031
        else:
1032
            src2 = dpt.asarray(o2, dtype=o2_dtype, sycl_queue=exec_q)
1✔
1033
        if buf_dt is None:
1✔
1034
            if src2.shape != res_shape:
1✔
1035
                src2 = dpt.broadcast_to(src2, res_shape)
1✔
1036
            dep_evs = _manager.submitted_events
1✔
1037
            ht_, comp_ev = self.binary_inplace_fn_(
1✔
1038
                lhs=o1,
1039
                rhs=src2,
1040
                sycl_queue=exec_q,
1041
                depends=dep_evs,
1042
            )
1043
            _manager.add_event_pair(ht_, comp_ev)
1✔
1044
        else:
1045
            buf = dpt.empty_like(src2, dtype=buf_dt)
1✔
1046
            dep_evs = _manager.submitted_events
1✔
1047
            (
1✔
1048
                ht_copy_ev,
1049
                copy_ev,
1050
            ) = ti._copy_usm_ndarray_into_usm_ndarray(
1051
                src=src2,
1052
                dst=buf,
1053
                sycl_queue=exec_q,
1054
                depends=dep_evs,
1055
            )
1056
            _manager.add_event_pair(ht_copy_ev, copy_ev)
1✔
1057

1058
            buf = dpt.broadcast_to(buf, res_shape)
1✔
1059
            ht_, bf_ev = self.binary_inplace_fn_(
1✔
1060
                lhs=o1,
1061
                rhs=buf,
1062
                sycl_queue=exec_q,
1063
                depends=[copy_ev],
1064
            )
1065
            _manager.add_event_pair(ht_, bf_ev)
1✔
1066

1067
        return o1
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc