• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

IntelPython / dpctl / 9359853505

04 Jun 2024 01:42AM UTC coverage: 88.02% (+0.1%) from 87.911%
9359853505

Pull #1705

github

web-flow
Merge dd28026b0 into c6f5f790b
Pull Request #1705: Change memory object USM allocation ownership, and make execution asynchronous

3275 of 3765 branches covered (86.99%)

Branch coverage included in aggregate %.

567 of 634 new or added lines in 23 files covered. (89.43%)

2 existing lines in 2 files now uncovered.

11207 of 12688 relevant lines covered (88.33%)

7552.52 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

94.92
/dpctl/tensor/_manipulation_functions.py
1
#                       Data Parallel Control (dpctl)
2
#
3
#  Copyright 2020-2024 Intel Corporation
4
#
5
#  Licensed under the Apache License, Version 2.0 (the "License");
6
#  you may not use this file except in compliance with the License.
7
#  You may obtain a copy of the License at
8
#
9
#     http://www.apache.org/licenses/LICENSE-2.0
10
#
11
#  Unless required by applicable law or agreed to in writing, software
12
#  distributed under the License is distributed on an "AS IS" BASIS,
13
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
#  See the License for the specific language governing permissions and
15
#  limitations under the License.
16

17

18
import itertools
1✔
19
import operator
1✔
20

21
import numpy as np
1✔
22
from numpy.core.numeric import normalize_axis_index, normalize_axis_tuple
1✔
23

24
import dpctl
1✔
25
import dpctl.tensor as dpt
1✔
26
import dpctl.tensor._tensor_impl as ti
1✔
27
import dpctl.utils as dputils
1✔
28

29
from ._copy_utils import _broadcast_strides
1✔
30
from ._type_utils import _supported_dtype, _to_device_supported_dtype
1✔
31

32
__doc__ = (
1✔
33
    "Implementation module for array manipulation "
34
    "functions in :module:`dpctl.tensor`"
35
)
36

37

38
def _broadcast_shape_impl(shapes):
1✔
39
    if len(set(shapes)) == 1:
1✔
40
        return shapes[0]
1✔
41
    mutable_shapes = False
1✔
42
    nds = [len(s) for s in shapes]
1✔
43
    biggest = max(nds)
1✔
44
    sh_len = len(shapes)
1✔
45
    for i in range(sh_len):
1✔
46
        diff = biggest - nds[i]
1✔
47
        if diff > 0:
1✔
48
            ty = type(shapes[i])
1✔
49
            shapes[i] = ty(
1✔
50
                itertools.chain(itertools.repeat(1, diff), shapes[i])
51
            )
52
    common_shape = []
1✔
53
    for axis in range(biggest):
1✔
54
        lengths = [s[axis] for s in shapes]
1✔
55
        unique = set(lengths + [1])
1✔
56
        if len(unique) > 2:
1✔
57
            raise ValueError(
1✔
58
                "Shape mismatch: two or more arrays have "
59
                f"incompatible dimensions on axis ({axis},)"
60
            )
61
        elif len(unique) == 2:
1✔
62
            unique.remove(1)
1✔
63
            new_length = unique.pop()
1✔
64
            common_shape.append(new_length)
1✔
65
            for i in range(sh_len):
1✔
66
                if shapes[i][axis] == 1:
1✔
67
                    if not mutable_shapes:
1✔
68
                        shapes = [list(s) for s in shapes]
1✔
69
                        mutable_shapes = True
1✔
70
                    shapes[i][axis] = new_length
1✔
71
        else:
72
            common_shape.append(1)
1✔
73

74
    return tuple(common_shape)
1✔
75

76

77
def _broadcast_shapes(*args):
1✔
78
    """
79
    Broadcast the input shapes into a single shape;
80
    returns tuple broadcasted shape.
81
    """
82
    array_shapes = [array.shape for array in args]
1✔
83
    return _broadcast_shape_impl(array_shapes)
1✔
84

85

86
def permute_dims(X, /, axes):
1✔
87
    """permute_dims(x, axes)
88

89
    Permute the axes (dimensions) of an array; returns the permuted
90
    array as a view.
91

92
    Args:
93
        x (usm_ndarray): input array.
94
        axes (Tuple[int, ...]): tuple containing permutation of
95
           `(0,1,...,N-1)` where `N` is the number of axes (dimensions)
96
           of `x`.
97
    Returns:
98
        usm_ndarray:
99
            An array with permuted axes.
100
            The returned array must has the same data type as `x`,
101
            is created on the same device as `x` and has the same USM allocation
102
            type as `x`.
103
    """
104
    if not isinstance(X, dpt.usm_ndarray):
1✔
105
        raise TypeError(f"Expected usm_ndarray type, got {type(X)}.")
1✔
106
    axes = normalize_axis_tuple(axes, X.ndim, "axes")
1✔
107
    if not X.ndim == len(axes):
1✔
108
        raise ValueError(
1✔
109
            "The length of the passed axes does not match "
110
            "to the number of usm_ndarray dimensions."
111
        )
112
    newstrides = tuple(X.strides[i] for i in axes)
1✔
113
    newshape = tuple(X.shape[i] for i in axes)
1✔
114
    return dpt.usm_ndarray(
1✔
115
        shape=newshape,
116
        dtype=X.dtype,
117
        buffer=X,
118
        strides=newstrides,
119
        offset=X._element_offset,
120
    )
121

122

123
def expand_dims(X, /, *, axis=0):
1✔
124
    """expand_dims(x, axis)
125

126
    Expands the shape of an array by inserting a new axis (dimension)
127
    of size one at the position specified by axis.
128

129
    Args:
130
        x (usm_ndarray):
131
            input array
132
        axis (Union[int, Tuple[int]]):
133
            axis position in the expanded axes (zero-based). If `x` has rank
134
            (i.e, number of dimensions) `N`, a valid `axis` must reside
135
            in the closed-interval `[-N-1, N]`. If provided a negative
136
            `axis`, the `axis` position at which to insert a singleton
137
            dimension is computed as `N + axis + 1`. Hence, if
138
            provided `-1`, the resolved axis position is `N` (i.e.,
139
            a singleton dimension must be appended to the input array `x`).
140
            If provided `-N-1`, the resolved axis position is `0` (i.e., a
141
            singleton dimension is prepended to the input array `x`).
142

143
    Returns:
144
        usm_ndarray:
145
            Returns a view, if possible, and a copy otherwise with the number
146
            of dimensions increased.
147
            The expanded array has the same data type as the input array `x`.
148
            The expanded array is located on the same device as the input
149
            array, and has the same USM allocation type.
150

151
    Raises:
152
        IndexError: if `axis` value is invalid.
153
    """
154
    if not isinstance(X, dpt.usm_ndarray):
1!
155
        raise TypeError(f"Expected usm_ndarray type, got {type(X)}.")
×
156

157
    if type(axis) not in (tuple, list):
1✔
158
        axis = (axis,)
1✔
159

160
    out_ndim = len(axis) + X.ndim
1✔
161
    axis = normalize_axis_tuple(axis, out_ndim)
1✔
162

163
    shape_it = iter(X.shape)
1✔
164
    shape = tuple(1 if ax in axis else next(shape_it) for ax in range(out_ndim))
1✔
165

166
    return dpt.reshape(X, shape)
1✔
167

168

169
def squeeze(X, /, axis=None):
1✔
170
    """squeeze(x, axis)
171

172
    Removes singleton dimensions (axes) from array `x`.
173

174
    Args:
175
        x (usm_ndarray): input array
176
        axis (Union[int, Tuple[int,...]]): axis (or axes) to squeeze.
177

178
    Returns:
179
        usm_ndarray:
180
            Output array is a view, if possible,
181
            and a copy otherwise, but with all or a subset of the
182
            dimensions of length 1 removed. Output has the same data
183
            type as the input, is allocated on the same device as the
184
            input and has the same USM allocation type as the input
185
            array `x`.
186

187
    Raises:
188
        ValueError: if the specified axis has a size greater than one.
189
    """
190
    if not isinstance(X, dpt.usm_ndarray):
1!
191
        raise TypeError(f"Expected usm_ndarray type, got {type(X)}.")
×
192
    X_shape = X.shape
1✔
193
    if axis is not None:
1✔
194
        axis = normalize_axis_tuple(axis, X.ndim if X.ndim != 0 else X.ndim + 1)
1✔
195
        new_shape = []
1✔
196
        for i, x in enumerate(X_shape):
1✔
197
            if i not in axis:
1✔
198
                new_shape.append(x)
1✔
199
            else:
200
                if x != 1:
1✔
201
                    raise ValueError(
1✔
202
                        "Cannot select an axis to squeeze out "
203
                        "which has size not equal to one."
204
                    )
205
        new_shape = tuple(new_shape)
1✔
206
    else:
207
        new_shape = tuple(axis for axis in X_shape if axis != 1)
1✔
208
    if new_shape == X.shape:
1✔
209
        return X
1✔
210
    else:
211
        return dpt.reshape(X, new_shape)
1✔
212

213

214
def broadcast_to(X, /, shape):
1✔
215
    """broadcast_to(x, shape)
216

217
    Broadcast an array to a new `shape`; returns the broadcasted
218
    :class:`dpctl.tensor.usm_ndarray` as a view.
219

220
    Args:
221
        x (usm_ndarray): input array
222
        shape (Tuple[int,...]): array shape. The `shape` must be
223
            compatible with `x` according to broadcasting rules.
224

225
    Returns:
226
        usm_ndarray:
227
            An array with the specified `shape`.
228
            The output array is a view of the input array, and
229
            hence has the same data type, USM allocation type and
230
            device attributes.
231
    """
232
    if not isinstance(X, dpt.usm_ndarray):
1!
233
        raise TypeError(f"Expected usm_ndarray type, got {type(X)}.")
×
234

235
    # Use numpy.broadcast_to to check the validity of the input
236
    # parameter 'shape'. Raise ValueError if 'X' is not compatible
237
    # with 'shape' according to NumPy's broadcasting rules.
238
    new_array = np.broadcast_to(
1✔
239
        np.broadcast_to(np.empty(tuple(), dtype="u1"), X.shape), shape
240
    )
241
    new_sts = _broadcast_strides(X.shape, X.strides, new_array.ndim)
1✔
242
    return dpt.usm_ndarray(
1✔
243
        shape=new_array.shape,
244
        dtype=X.dtype,
245
        buffer=X,
246
        strides=new_sts,
247
        offset=X._element_offset,
248
    )
249

250

251
def broadcast_arrays(*args):
1✔
252
    """broadcast_arrays(*arrays)
253

254
    Broadcasts one or more :class:`dpctl.tensor.usm_ndarrays` against
255
    one another.
256

257
    Args:
258
        arrays (usm_ndarray): an arbitrary number of arrays to be
259
            broadcasted.
260

261
    Returns:
262
        List[usm_ndarray]:
263
            A list of broadcasted arrays. Each array
264
            must have the same shape. Each array must have the same `dtype`,
265
            `device` and `usm_type` attributes as its corresponding input
266
            array.
267
    """
268
    for X in args:
1✔
269
        if not isinstance(X, dpt.usm_ndarray):
1!
270
            raise TypeError(f"Expected usm_ndarray type, got {type(X)}.")
×
271

272
    shape = _broadcast_shapes(*args)
1✔
273

274
    if all(X.shape == shape for X in args):
1✔
275
        return args
1✔
276

277
    return [broadcast_to(X, shape) for X in args]
1✔
278

279

280
def flip(X, /, *, axis=None):
1✔
281
    """flip(x, axis)
282

283
    Reverses the order of elements in an array `x` along the given `axis`.
284
    The shape of the array is preserved, but the elements are reordered.
285

286
    Args:
287
        x (usm_ndarray): input array.
288
        axis (Optional[Union[int, Tuple[int,...]]]): axis (or axes) along
289
            which to flip.
290
            If `axis` is `None`, all input array axes are flipped.
291
            If `axis` is negative, the flipped axis is counted from the
292
            last dimension. If provided more than one axis, only the specified
293
            axes are flipped. Default: `None`.
294

295
    Returns:
296
        usm_ndarray:
297
            A view of `x` with the entries of `axis` reversed.
298
    """
299
    if not isinstance(X, dpt.usm_ndarray):
1!
300
        raise TypeError(f"Expected usm_ndarray type, got {type(X)}.")
×
301
    X_ndim = X.ndim
1✔
302
    if axis is None:
1✔
303
        indexer = (np.s_[::-1],) * X_ndim
1✔
304
    else:
305
        axis = normalize_axis_tuple(axis, X_ndim)
1✔
306
        indexer = tuple(
1✔
307
            np.s_[::-1] if i in axis else np.s_[:] for i in range(X.ndim)
308
        )
309
    return X[indexer]
1✔
310

311

312
def roll(X, /, shift, *, axis=None):
1✔
313
    """
314
    roll(x, shift, axis)
315

316
    Rolls array elements along a specified axis.
317
    Array elements that roll beyond the last position are re-introduced
318
    at the first position. Array elements that roll beyond the first position
319
    are re-introduced at the last position.
320

321
    Args:
322
        x (usm_ndarray): input array
323
        shift (Union[int, Tuple[int,...]]): number of places by which the
324
            elements are shifted. If `shift` is a tuple, then `axis` must be a
325
            tuple of the same size, and each of the given axes must be shifted
326
            by the corresponding element in `shift`. If `shift` is an `int`
327
            and `axis` a tuple, then the same `shift` must be used for all
328
            specified axes. If a `shift` is positive, then array elements is
329
            shifted positively (toward larger indices) along the dimension of
330
            `axis`.
331
            If a `shift` is negative, then array elements must be shifted
332
            negatively (toward smaller indices) along the dimension of `axis`.
333
        axis (Optional[Union[int, Tuple[int,...]]]): axis (or axes) along which
334
            elements to shift. If `axis` is `None`, the array is
335
            flattened, shifted, and then restored to its original shape.
336
            Default: `None`.
337

338
    Returns:
339
        usm_ndarray:
340
            An array having the same `dtype`, `usm_type` and
341
            `device` attributes as `x` and whose elements are shifted relative
342
            to `x`.
343
    """
344
    if not isinstance(X, dpt.usm_ndarray):
1!
345
        raise TypeError(f"Expected usm_ndarray type, got {type(X)}.")
×
346
    exec_q = X.sycl_queue
1✔
347
    _manager = dputils.SequentialOrderManager[exec_q]
1✔
348
    if axis is None:
1✔
349
        shift = operator.index(shift)
1✔
350
        dep_evs = _manager.submitted_events
1✔
351
        res = dpt.empty(
1✔
352
            X.shape, dtype=X.dtype, usm_type=X.usm_type, sycl_queue=exec_q
353
        )
354
        hev, roll_ev = ti._copy_usm_ndarray_for_roll_1d(
1✔
355
            src=X,
356
            dst=res,
357
            shift=shift,
358
            sycl_queue=exec_q,
359
            depends=dep_evs,
360
        )
361
        _manager.add_event_pair(hev, roll_ev)
1✔
362
        return res
1✔
363
    axis = normalize_axis_tuple(axis, X.ndim, allow_duplicate=True)
1✔
364
    broadcasted = np.broadcast(shift, axis)
1✔
365
    if broadcasted.ndim > 1:
1✔
366
        raise ValueError("'shift' and 'axis' should be scalars or 1D sequences")
1✔
367
    shifts = [
1✔
368
        0,
369
    ] * X.ndim
370
    for sh, ax in broadcasted:
1✔
371
        shifts[ax] += sh
1✔
372

373
    res = dpt.empty(
1✔
374
        X.shape, dtype=X.dtype, usm_type=X.usm_type, sycl_queue=exec_q
375
    )
376
    dep_evs = _manager.submitted_events
1✔
377
    ht_e, roll_ev = ti._copy_usm_ndarray_for_roll_nd(
1✔
378
        src=X, dst=res, shifts=shifts, sycl_queue=exec_q, depends=dep_evs
379
    )
380
    _manager.add_event_pair(ht_e, roll_ev)
1✔
381
    return res
1✔
382

383

384
def _arrays_validation(arrays, check_ndim=True):
1✔
385
    n = len(arrays)
1✔
386
    if n == 0:
1✔
387
        raise TypeError("Missing 1 required positional argument: 'arrays'.")
1✔
388

389
    if not isinstance(arrays, (list, tuple)):
1✔
390
        raise TypeError(f"Expected tuple or list type, got {type(arrays)}.")
1✔
391

392
    for X in arrays:
1✔
393
        if not isinstance(X, dpt.usm_ndarray):
1✔
394
            raise TypeError(f"Expected usm_ndarray type, got {type(X)}.")
1✔
395

396
    exec_q = dputils.get_execution_queue([X.sycl_queue for X in arrays])
1✔
397
    if exec_q is None:
1✔
398
        raise ValueError("All the input arrays must have same sycl queue.")
1✔
399

400
    res_usm_type = dputils.get_coerced_usm_type([X.usm_type for X in arrays])
1✔
401
    if res_usm_type is None:
1!
402
        raise ValueError("All the input arrays must have usm_type.")
×
403

404
    X0 = arrays[0]
1✔
405
    _supported_dtype(Xi.dtype for Xi in arrays)
1✔
406

407
    res_dtype = X0.dtype
1✔
408
    dev = exec_q.sycl_device
1✔
409
    for i in range(1, n):
1✔
410
        res_dtype = np.promote_types(res_dtype, arrays[i])
1✔
411
        res_dtype = _to_device_supported_dtype(res_dtype, dev)
1✔
412

413
    if check_ndim:
1✔
414
        for i in range(1, n):
1✔
415
            if X0.ndim != arrays[i].ndim:
1✔
416
                raise ValueError(
1✔
417
                    "All the input arrays must have same number of dimensions, "
418
                    f"but the array at index 0 has {X0.ndim} dimension(s) and "
419
                    f"the array at index {i} has {arrays[i].ndim} dimension(s)."
420
                )
421
    return res_dtype, res_usm_type, exec_q
1✔
422

423

424
def _check_same_shapes(X0_shape, axis, n, arrays):
1✔
425
    for i in range(1, n):
1✔
426
        Xi_shape = arrays[i].shape
1✔
427
        for j, X0j in enumerate(X0_shape):
1✔
428
            if X0j != Xi_shape[j] and j != axis:
1✔
429
                raise ValueError(
1✔
430
                    "All the input array dimensions for the concatenation "
431
                    f"axis must match exactly, but along dimension {j}, the "
432
                    f"array at index 0 has size {X0j} and the array "
433
                    f"at index {i} has size {Xi_shape[j]}."
434
                )
435

436

437
def _concat_axis_None(arrays):
1✔
438
    "Implementation of concat(arrays, axis=None)."
439
    res_dtype, res_usm_type, exec_q = _arrays_validation(
1✔
440
        arrays, check_ndim=False
441
    )
442
    res_shape = 0
1✔
443
    for array in arrays:
1✔
444
        res_shape += array.size
1✔
445
    res = dpt.empty(
1✔
446
        res_shape, dtype=res_dtype, usm_type=res_usm_type, sycl_queue=exec_q
447
    )
448

449
    fill_start = 0
1✔
450
    _manager = dputils.SequentialOrderManager[exec_q]
1✔
451
    deps = _manager.submitted_events
1✔
452
    for array in arrays:
1✔
453
        fill_end = fill_start + array.size
1✔
454
        if array.flags.c_contiguous:
1✔
455
            hev, cpy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
1✔
456
                src=dpt.reshape(array, -1),
457
                dst=res[fill_start:fill_end],
458
                sycl_queue=exec_q,
459
                depends=deps,
460
            )
461
            _manager.add_event_pair(hev, cpy_ev)
1✔
462
        else:
463
            src_ = array
1✔
464
            # _copy_usm_ndarray_for_reshape requires src and dst to have
465
            # the same data type
466
            if not array.dtype == res_dtype:
1!
NEW
467
                src2_ = dpt.empty_like(src_, dtype=res_dtype)
×
NEW
468
                ht_copy_ev, cpy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
×
469
                    src=src_, dst=src2_, sycl_queue=exec_q, depends=deps
470
                )
NEW
471
                _manager.add_event_pair(ht_copy_ev, cpy_ev)
×
NEW
472
                hev, reshape_copy_ev = ti._copy_usm_ndarray_for_reshape(
×
473
                    src=src_,
474
                    dst=res[fill_start:fill_end],
475
                    sycl_queue=exec_q,
476
                    depends=[cpy_ev],
477
                )
NEW
478
                _manager.add_event_pair(hev, reshape_copy_ev)
×
479
            else:
480
                hev, cpy_ev = ti._copy_usm_ndarray_for_reshape(
1✔
481
                    src=src_,
482
                    dst=res[fill_start:fill_end],
483
                    sycl_queue=exec_q,
484
                    depends=deps,
485
                )
486
                _manager.add_event_pair(hev, cpy_ev)
1✔
487
        fill_start = fill_end
1✔
488

489
    return res
1✔
490

491

492
def concat(arrays, /, *, axis=0):
1✔
493
    """concat(arrays, axis)
494

495
    Joins a sequence of arrays along an existing axis.
496

497
    Args:
498
        arrays (Union[List[usm_ndarray, Tuple[usm_ndarray,...]]]):
499
            input arrays to join. The arrays must have the same shape,
500
            except in the dimension specified by `axis`.
501
        axis (Optional[int]): axis along which the arrays will be joined.
502
            If `axis` is `None`, arrays must be flattened before
503
            concatenation. If `axis` is negative, it is understood as
504
            being counted from the last dimension. Default: `0`.
505

506
    Returns:
507
        usm_ndarray:
508
            An output array containing the concatenated
509
            values. The output array data type is determined by Type
510
            Promotion Rules of array API.
511

512
    All input arrays must have the same device attribute. The output array
513
    is allocated on that same device, and data movement operations are
514
    scheduled on a queue underlying the device. The USM allocation type
515
    of the output array is determined by USM allocation type promotion
516
    rules.
517
    """
518
    if axis is None:
1✔
519
        return _concat_axis_None(arrays)
1✔
520

521
    res_dtype, res_usm_type, exec_q = _arrays_validation(arrays)
1✔
522
    n = len(arrays)
1✔
523
    X0 = arrays[0]
1✔
524

525
    axis = normalize_axis_index(axis, X0.ndim)
1✔
526
    X0_shape = X0.shape
1✔
527
    _check_same_shapes(X0_shape, axis, n, arrays)
1✔
528

529
    res_shape_axis = 0
1✔
530
    for X in arrays:
1✔
531
        res_shape_axis = res_shape_axis + X.shape[axis]
1✔
532

533
    res_shape = tuple(
1✔
534
        X0_shape[i] if i != axis else res_shape_axis for i in range(X0.ndim)
535
    )
536

537
    res = dpt.empty(
1✔
538
        res_shape, dtype=res_dtype, usm_type=res_usm_type, sycl_queue=exec_q
539
    )
540

541
    _manager = dputils.SequentialOrderManager[exec_q]
1✔
542
    deps = _manager.submitted_events
1✔
543
    fill_start = 0
1✔
544
    for i in range(n):
1✔
545
        fill_end = fill_start + arrays[i].shape[axis]
1✔
546
        c_shapes_copy = tuple(
1✔
547
            np.s_[fill_start:fill_end] if j == axis else np.s_[:]
548
            for j in range(X0.ndim)
549
        )
550
        hev, cpy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
1✔
551
            src=arrays[i],
552
            dst=res[c_shapes_copy],
553
            sycl_queue=exec_q,
554
            depends=deps,
555
        )
556
        _manager.add_event_pair(hev, cpy_ev)
1✔
557
        fill_start = fill_end
1✔
558

559
    return res
1✔
560

561

562
def stack(arrays, /, *, axis=0):
1✔
563
    """
564
    stack(arrays, axis)
565

566
    Joins a sequence of arrays along a new axis.
567

568
    Args:
569
        arrays (Union[List[usm_ndarray], Tuple[usm_ndarray,...]]):
570
            input arrays to join. Each array must have the same shape.
571
        axis (int): axis along which the arrays will be joined. Providing
572
            an `axis` specified the index of the new axis in the dimensions
573
            of the output array. A valid axis must be on the interval
574
            `[-N, N)`, where `N` is the rank (number of dimensions) of `x`.
575
            Default: `0`.
576

577
    Returns:
578
        usm_ndarray:
579
            An output array having rank `N+1`, where `N` is
580
            the rank (number of dimensions) of `x`. If the input arrays have
581
            different data types, array API Type Promotion Rules apply.
582

583
    Raises:
584
        ValueError: if not all input arrays have the same shape
585
        IndexError: if provided an `axis` outside of the required interval.
586
    """
587
    res_dtype, res_usm_type, exec_q = _arrays_validation(arrays)
1✔
588

589
    n = len(arrays)
1✔
590
    X0 = arrays[0]
1✔
591
    res_ndim = X0.ndim + 1
1✔
592
    axis = normalize_axis_index(axis, res_ndim)
1✔
593
    X0_shape = X0.shape
1✔
594

595
    for i in range(1, n):
1✔
596
        if X0_shape != arrays[i].shape:
1✔
597
            raise ValueError("All input arrays must have the same shape")
1✔
598

599
    res_shape = tuple(
1✔
600
        X0_shape[i - 1 * (i >= axis)] if i != axis else n
601
        for i in range(res_ndim)
602
    )
603

604
    res = dpt.empty(
1✔
605
        res_shape, dtype=res_dtype, usm_type=res_usm_type, sycl_queue=exec_q
606
    )
607

608
    _manager = dputils.SequentialOrderManager[exec_q]
1✔
609
    dep_evs = _manager.submitted_events
1✔
610
    for i in range(n):
1✔
611
        c_shapes_copy = tuple(
1✔
612
            i if j == axis else np.s_[:] for j in range(res_ndim)
613
        )
614
        _dst = res[c_shapes_copy]
1✔
615
        hev, cpy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
1✔
616
            src=arrays[i], dst=_dst, sycl_queue=exec_q, depends=dep_evs
617
        )
618
        _manager.add_event_pair(hev, cpy_ev)
1✔
619

620
    return res
1✔
621

622

623
def unstack(X, /, *, axis=0):
1✔
624
    """unstack(x, axis=0)
625

626
    Splits an array in a sequence of arrays along the given axis.
627

628
    Args:
629
        x (usm_ndarray): input array
630

631
        axis (int, optional): axis along which `x` is unstacked.
632
            If `x` has rank (i.e, number of dimensions) `N`,
633
            a valid `axis` must reside in the half-open interval `[-N, N)`.
634
            Default: `0`.
635

636
    Returns:
637
        Tuple[usm_ndarray,...]:
638
            Output sequence of arrays which are views into the input array.
639

640
    Raises:
641
        AxisError: if the `axis` value is invalid.
642
    """
643
    if not isinstance(X, dpt.usm_ndarray):
1!
644
        raise TypeError(f"Expected usm_ndarray type, got {type(X)}.")
×
645

646
    axis = normalize_axis_index(axis, X.ndim)
1✔
647
    Y = dpt.moveaxis(X, axis, 0)
1✔
648

649
    return tuple(Y[i] for i in range(Y.shape[0]))
1✔
650

651

652
def moveaxis(X, source, destination, /):
1✔
653
    """moveaxis(x, source, destination)
654

655
    Moves axes of an array to new positions.
656

657
    Args:
658
        x (usm_ndarray): input array
659

660
        source (int or a sequence of int):
661
            Original positions of the axes to move.
662
            These must be unique. If `x` has rank (i.e., number of
663
            dimensions) `N`, a valid `axis` must be in the
664
            half-open interval `[-N, N)`.
665

666
        destination (int or a sequence of int):
667
            Destination positions for each of the original axes.
668
            These must also be unique. If `x` has rank
669
            (i.e., number of dimensions) `N`, a valid `axis` must be
670
            in the half-open interval `[-N, N)`.
671

672
    Returns:
673
        usm_ndarray:
674
            Array with moved axes.
675
            The returned array must has the same data type as `x`,
676
            is created on the same device as `x` and has the same
677
            USM allocation type as `x`.
678

679
    Raises:
680
        AxisError: if `axis` value is invalid.
681
        ValueError: if `src` and `dst` have not equal number of elements.
682
    """
683
    if not isinstance(X, dpt.usm_ndarray):
1!
684
        raise TypeError(f"Expected usm_ndarray type, got {type(X)}.")
×
685

686
    source = normalize_axis_tuple(source, X.ndim, "source")
1✔
687
    destination = normalize_axis_tuple(destination, X.ndim, "destination")
1✔
688

689
    if len(source) != len(destination):
1✔
690
        raise ValueError(
1✔
691
            "`source` and `destination` arguments must have "
692
            "the same number of elements"
693
        )
694

695
    ind = [n for n in range(X.ndim) if n not in source]
1✔
696

697
    for src, dst in sorted(zip(destination, source)):
1✔
698
        ind.insert(src, dst)
1✔
699

700
    return dpt.permute_dims(X, tuple(ind))
1✔
701

702

703
def swapaxes(X, axis1, axis2):
1✔
704
    """swapaxes(x, axis1, axis2)
705

706
    Interchanges two axes of an array.
707

708
    Args:
709
        x (usm_ndarray): input array
710

711
        axis1 (int): First axis.
712
            If `x` has rank (i.e., number of dimensions) `N`,
713
            a valid `axis` must be in the half-open interval `[-N, N)`.
714

715
        axis2 (int): Second axis.
716
            If `x` has rank (i.e., number of dimensions) `N`,
717
            a valid `axis` must be in the half-open interval `[-N, N)`.
718

719
    Returns:
720
        usm_ndarray:
721
            Array with swapped axes.
722
            The returned array must has the same data type as `x`,
723
            is created on the same device as `x` and has the same USM
724
            allocation type as `x`.
725

726
    Raises:
727
        AxisError: if `axis` value is invalid.
728
    """
729
    if not isinstance(X, dpt.usm_ndarray):
1!
730
        raise TypeError(f"Expected usm_ndarray type, got {type(X)}.")
×
731

732
    axis1 = normalize_axis_index(axis1, X.ndim, "axis1")
1✔
733
    axis2 = normalize_axis_index(axis2, X.ndim, "axis2")
1✔
734

735
    ind = list(range(0, X.ndim))
1✔
736
    ind[axis1] = axis2
1✔
737
    ind[axis2] = axis1
1✔
738
    return dpt.permute_dims(X, tuple(ind))
1✔
739

740

741
def repeat(x, repeats, /, *, axis=None):
1✔
742
    """repeat(x, repeats, axis=None)
743

744
    Repeat elements of an array on a per-element basis.
745

746
    Args:
747
        x (usm_ndarray): input array
748

749
        repeats (Union[int, Sequence[int, ...], usm_ndarray]):
750
            The number of repetitions for each element.
751

752
            `repeats` must be broadcast-compatible with `N` where `N` is
753
            `prod(x.shape)` if `axis` is `None` and `x.shape[axis]`
754
            otherwise.
755

756
            If `repeats` is an array, it must have an integer data type.
757
            Otherwise, `repeats` must be a Python integer or sequence of
758
            Python integers (i.e., a tuple, list, or range).
759

760
        axis (Optional[int]):
761
            The axis along which to repeat values. If `axis` is `None`, the
762
            function repeats elements of the flattened array. Default: `None`.
763

764
    Returns:
765
        usm_ndarray:
766
            output array with repeated elements.
767

768
            If `axis` is `None`, the returned array is one-dimensional,
769
            otherwise, it has the same shape as `x`, except for the axis along
770
            which elements were repeated.
771

772
            The returned array will have the same data type as `x`.
773
            The returned array will be located on the same device as `x` and
774
            have the same USM allocation type as `x`.
775

776
    Raises:
777
        AxisError: if `axis` value is invalid.
778
    """
779
    if not isinstance(x, dpt.usm_ndarray):
1✔
780
        raise TypeError(f"Expected usm_ndarray type, got {type(x)}.")
1✔
781

782
    x_ndim = x.ndim
1✔
783
    x_shape = x.shape
1✔
784
    if axis is not None:
1✔
785
        axis = normalize_axis_index(operator.index(axis), x_ndim)
1✔
786
        axis_size = x_shape[axis]
1✔
787
    else:
788
        axis_size = x.size
1✔
789

790
    scalar = False
1✔
791
    if isinstance(repeats, int):
1✔
792
        if repeats < 0:
1✔
793
            raise ValueError("`repeats` must be a positive integer")
1✔
794
        usm_type = x.usm_type
1✔
795
        exec_q = x.sycl_queue
1✔
796
        scalar = True
1✔
797
    elif isinstance(repeats, dpt.usm_ndarray):
1✔
798
        if repeats.ndim > 1:
1✔
799
            raise ValueError(
1✔
800
                "`repeats` array must be 0- or 1-dimensional, got "
801
                f"{repeats.ndim}"
802
            )
803
        exec_q = dpctl.utils.get_execution_queue(
1✔
804
            (x.sycl_queue, repeats.sycl_queue)
805
        )
806
        if exec_q is None:
1✔
807
            raise dputils.ExecutionPlacementError(
1✔
808
                "Execution placement can not be unambiguously inferred "
809
                "from input arguments."
810
            )
811
        usm_type = dpctl.utils.get_coerced_usm_type(
1✔
812
            (
813
                x.usm_type,
814
                repeats.usm_type,
815
            )
816
        )
817
        dpctl.utils.validate_usm_type(usm_type, allow_none=False)
1✔
818
        if not dpt.can_cast(repeats.dtype, dpt.int64, casting="same_kind"):
1✔
819
            raise TypeError(
1✔
820
                f"'repeats' data type {repeats.dtype} cannot be cast to "
821
                "'int64' according to the casting rule ''safe.''"
822
            )
823
        if repeats.size == 1:
1✔
824
            scalar = True
1✔
825
            # bring the single element to the host
826
            repeats = int(repeats)
1✔
827
            if repeats < 0:
1✔
828
                raise ValueError("`repeats` elements must be positive")
1✔
829
        else:
830
            if repeats.size != axis_size:
1✔
831
                raise ValueError(
1✔
832
                    "'repeats' array must be broadcastable to the size of "
833
                    "the repeated axis"
834
                )
835
            if not dpt.all(repeats >= 0):
1✔
836
                raise ValueError("'repeats' elements must be positive")
1✔
837

838
    elif isinstance(repeats, (tuple, list, range)):
1✔
839
        usm_type = x.usm_type
1✔
840
        exec_q = x.sycl_queue
1✔
841

842
        len_reps = len(repeats)
1✔
843
        if len_reps == 1:
1✔
844
            repeats = repeats[0]
1✔
845
            if repeats < 0:
1!
846
                raise ValueError("`repeats` elements must be positive")
1✔
847
            scalar = True
×
848
        else:
849
            if len_reps != axis_size:
1✔
850
                raise ValueError(
1✔
851
                    "`repeats` sequence must have the same length as the "
852
                    "repeated axis"
853
                )
854
            repeats = dpt.asarray(
1✔
855
                repeats, dtype=dpt.int64, usm_type=usm_type, sycl_queue=exec_q
856
            )
857
            if not dpt.all(repeats >= 0):
1!
858
                raise ValueError("`repeats` elements must be positive")
×
859
    else:
860
        raise TypeError(
1✔
861
            "Expected int, sequence, or `usm_ndarray` for second argument,"
862
            f"got {type(repeats)}"
863
        )
864

865
    _manager = dputils.SequentialOrderManager[exec_q]
1✔
866
    dep_evs = _manager.submitted_events
1✔
867
    if scalar:
1✔
868
        res_axis_size = repeats * axis_size
1✔
869
        if axis is not None:
1✔
870
            res_shape = x_shape[:axis] + (res_axis_size,) + x_shape[axis + 1 :]
1✔
871
        else:
872
            res_shape = (res_axis_size,)
1✔
873
        res = dpt.empty(
1✔
874
            res_shape, dtype=x.dtype, usm_type=usm_type, sycl_queue=exec_q
875
        )
876
        if res_axis_size > 0:
1✔
877
            ht_rep_ev, rep_ev = ti._repeat_by_scalar(
1✔
878
                src=x,
879
                dst=res,
880
                reps=repeats,
881
                axis=axis,
882
                sycl_queue=exec_q,
883
                depends=dep_evs,
884
            )
885
            _manager.add_event_pair(ht_rep_ev, rep_ev)
1✔
886
    else:
887
        if repeats.dtype != dpt.int64:
1✔
888
            rep_buf = dpt.empty(
1✔
889
                repeats.shape,
890
                dtype=dpt.int64,
891
                usm_type=usm_type,
892
                sycl_queue=exec_q,
893
            )
894
            ht_copy_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
1✔
895
                src=repeats, dst=rep_buf, sycl_queue=exec_q, depends=dep_evs
896
            )
897
            _manager.add_event_pair(ht_copy_ev, copy_ev)
1✔
898
            cumsum = dpt.empty(
1✔
899
                (axis_size,),
900
                dtype=dpt.int64,
901
                usm_type=usm_type,
902
                sycl_queue=exec_q,
903
            )
904
            # _cumsum_1d synchronizes so `depends` ends here safely
905
            res_axis_size = ti._cumsum_1d(
1✔
906
                rep_buf, cumsum, sycl_queue=exec_q, depends=[copy_ev]
907
            )
908
            if axis is not None:
1✔
909
                res_shape = (
1✔
910
                    x_shape[:axis] + (res_axis_size,) + x_shape[axis + 1 :]
911
                )
912
            else:
913
                res_shape = (res_axis_size,)
1✔
914
            res = dpt.empty(
1✔
915
                res_shape,
916
                dtype=x.dtype,
917
                usm_type=usm_type,
918
                sycl_queue=exec_q,
919
            )
920
            if res_axis_size > 0:
1!
921
                ht_rep_ev, rep_ev = ti._repeat_by_sequence(
1✔
922
                    src=x,
923
                    dst=res,
924
                    reps=rep_buf,
925
                    cumsum=cumsum,
926
                    axis=axis,
927
                    sycl_queue=exec_q,
928
                )
929
                _manager.add_event_pair(ht_rep_ev, rep_ev)
1✔
930
        else:
931
            cumsum = dpt.empty(
1✔
932
                (axis_size,),
933
                dtype=dpt.int64,
934
                usm_type=usm_type,
935
                sycl_queue=exec_q,
936
            )
937
            res_axis_size = ti._cumsum_1d(
1✔
938
                repeats, cumsum, sycl_queue=exec_q, depends=dep_evs
939
            )
940
            if axis is not None:
1✔
941
                res_shape = (
1✔
942
                    x_shape[:axis] + (res_axis_size,) + x_shape[axis + 1 :]
943
                )
944
            else:
945
                res_shape = (res_axis_size,)
1✔
946
            res = dpt.empty(
1✔
947
                res_shape,
948
                dtype=x.dtype,
949
                usm_type=usm_type,
950
                sycl_queue=exec_q,
951
            )
952
            if res_axis_size > 0:
1✔
953
                ht_rep_ev, rep_ev = ti._repeat_by_sequence(
1✔
954
                    src=x,
955
                    dst=res,
956
                    reps=repeats,
957
                    cumsum=cumsum,
958
                    axis=axis,
959
                    sycl_queue=exec_q,
960
                )
961
                _manager.add_event_pair(ht_rep_ev, rep_ev)
1✔
962
    return res
1✔
963

964

965
def tile(x, repetitions, /):
1✔
966
    """tile(x, repetitions)
967

968
    Repeat an input array `x` along each axis a number of times given by
969
    `repetitions`.
970

971
    For `N` = len(`repetitions`) and `M` = len(`x.shape`):
972

973
        * If `M < N`, `x` will have `N - M` new axes prepended to its shape
974
        * If `M > N`, `repetitions` will have `M - N` ones prepended to it
975

976
    Args:
977
        x (usm_ndarray): input array
978

979
        repetitions (Union[int, Tuple[int, ...]]):
980
            The number of repetitions along each dimension of `x`.
981

982
    Returns:
983
        usm_ndarray:
984
            tiled output array.
985

986
            The returned array will have rank `max(M, N)`. If `S` is the
987
            shape of `x` after prepending dimensions and `R` is
988
            `repetitions` after prepending ones, then the shape of the
989
            result will be `S[i] * R[i]` for each dimension `i`.
990

991
            The returned array will have the same data type as `x`.
992
            The returned array will be located on the same device as `x` and
993
            have the same USM allocation type as `x`.
994
    """
995
    if not isinstance(x, dpt.usm_ndarray):
1✔
996
        raise TypeError(f"Expected usm_ndarray type, got {type(x)}.")
1✔
997

998
    if not isinstance(repetitions, tuple):
1✔
999
        if isinstance(repetitions, int):
1✔
1000
            repetitions = (repetitions,)
1✔
1001
        else:
1002
            raise TypeError(
1✔
1003
                f"Expected tuple or integer type, got {type(repetitions)}."
1004
            )
1005

1006
    rep_dims = len(repetitions)
1✔
1007
    x_dims = x.ndim
1✔
1008
    if rep_dims < x_dims:
1✔
1009
        repetitions = (x_dims - rep_dims) * (1,) + repetitions
1✔
1010
    elif x_dims < rep_dims:
1✔
1011
        x = dpt.reshape(x, (rep_dims - x_dims) * (1,) + x.shape)
1✔
1012
    res_shape = tuple(map(lambda sh, rep: sh * rep, x.shape, repetitions))
1✔
1013
    # case of empty input
1014
    if x.size == 0:
1✔
1015
        return dpt.empty(
1✔
1016
            res_shape,
1017
            dtype=x.dtype,
1018
            usm_type=x.usm_type,
1019
            sycl_queue=x.sycl_queue,
1020
        )
1021
    in_sh = x.shape
1✔
1022
    if res_shape == in_sh:
1✔
1023
        return dpt.copy(x)
1✔
1024
    expanded_sh = []
1✔
1025
    broadcast_sh = []
1✔
1026
    out_sz = 1
1✔
1027
    for i in range(len(res_shape)):
1✔
1028
        out_sz *= res_shape[i]
1✔
1029
        reps, sh = repetitions[i], in_sh[i]
1✔
1030
        if reps == 1:
1✔
1031
            # dimension will be unchanged
1032
            broadcast_sh.append(sh)
1✔
1033
            expanded_sh.append(sh)
1✔
1034
        elif sh == 1:
1✔
1035
            # dimension will be broadcast
1036
            broadcast_sh.append(reps)
1✔
1037
            expanded_sh.append(sh)
1✔
1038
        else:
1039
            broadcast_sh.extend([reps, sh])
1✔
1040
            expanded_sh.extend([1, sh])
1✔
1041
    exec_q = x.sycl_queue
1✔
1042
    xdt = x.dtype
1✔
1043
    xut = x.usm_type
1✔
1044
    res = dpt.empty((out_sz,), dtype=xdt, usm_type=xut, sycl_queue=exec_q)
1✔
1045
    # no need to copy data for empty output
1046
    if out_sz > 0:
1✔
1047
        x = dpt.broadcast_to(
1✔
1048
            # this reshape should never copy
1049
            dpt.reshape(x, expanded_sh),
1050
            broadcast_sh,
1051
        )
1052
        # copy broadcast input into flat array
1053
        _manager = dputils.SequentialOrderManager[exec_q]
1✔
1054
        dep_evs = _manager.submitted_events
1✔
1055
        hev, cp_ev = ti._copy_usm_ndarray_for_reshape(
1✔
1056
            src=x, dst=res, sycl_queue=exec_q, depends=dep_evs
1057
        )
1058
        _manager.add_event_pair(hev, cp_ev)
1✔
1059
    return dpt.reshape(res, res_shape)
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc