• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

IntelPython / dpctl / 11580062618

29 Oct 2024 06:19PM UTC coverage: 87.705% (+0.04%) from 87.669%
11580062618

Pull #1867

github

web-flow
Merge d63dd70d2 into 9b83beff0
Pull Request #1867: Feature/radix sort

3103 of 3618 branches covered (85.77%)

Branch coverage included in aggregate %.

37 of 39 new or added lines in 1 file covered. (94.87%)

49 existing lines in 3 files now uncovered.

11756 of 13324 relevant lines covered (88.23%)

7098.89 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

95.67
/dpctl/tensor/_manipulation_functions.py
1
#                       Data Parallel Control (dpctl)
2
#
3
#  Copyright 2020-2024 Intel Corporation
4
#
5
#  Licensed under the Apache License, Version 2.0 (the "License");
6
#  you may not use this file except in compliance with the License.
7
#  You may obtain a copy of the License at
8
#
9
#     http://www.apache.org/licenses/LICENSE-2.0
10
#
11
#  Unless required by applicable law or agreed to in writing, software
12
#  distributed under the License is distributed on an "AS IS" BASIS,
13
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
#  See the License for the specific language governing permissions and
15
#  limitations under the License.
16

17

18
import itertools
1✔
19
import operator
1✔
20

21
import numpy as np
1✔
22

23
import dpctl
1✔
24
import dpctl.tensor as dpt
1✔
25
import dpctl.tensor._tensor_impl as ti
1✔
26
import dpctl.utils as dputils
1✔
27

28
from ._copy_utils import _broadcast_strides
1✔
29
from ._numpy_helper import normalize_axis_index, normalize_axis_tuple
1✔
30
from ._type_utils import _supported_dtype, _to_device_supported_dtype
1✔
31

32
__doc__ = (
1✔
33
    "Implementation module for array manipulation "
34
    "functions in :module:`dpctl.tensor`"
35
)
36

37

38
def _broadcast_shape_impl(shapes):
1✔
39
    if len(set(shapes)) == 1:
1✔
40
        return shapes[0]
1✔
41
    mutable_shapes = False
1✔
42
    nds = [len(s) for s in shapes]
1✔
43
    biggest = max(nds)
1✔
44
    sh_len = len(shapes)
1✔
45
    for i in range(sh_len):
1✔
46
        diff = biggest - nds[i]
1✔
47
        if diff > 0:
1✔
48
            ty = type(shapes[i])
1✔
49
            shapes[i] = ty(
1✔
50
                itertools.chain(itertools.repeat(1, diff), shapes[i])
51
            )
52
    common_shape = []
1✔
53
    for axis in range(biggest):
1✔
54
        lengths = [s[axis] for s in shapes]
1✔
55
        unique = set(lengths + [1])
1✔
56
        if len(unique) > 2:
1✔
57
            raise ValueError(
1✔
58
                "Shape mismatch: two or more arrays have "
59
                f"incompatible dimensions on axis ({axis},)"
60
            )
61
        elif len(unique) == 2:
1✔
62
            unique.remove(1)
1✔
63
            new_length = unique.pop()
1✔
64
            common_shape.append(new_length)
1✔
65
            for i in range(sh_len):
1✔
66
                if shapes[i][axis] == 1:
1✔
67
                    if not mutable_shapes:
1✔
68
                        shapes = [list(s) for s in shapes]
1✔
69
                        mutable_shapes = True
1✔
70
                    shapes[i][axis] = new_length
1✔
71
        else:
72
            common_shape.append(1)
1✔
73

74
    return tuple(common_shape)
1✔
75

76

77
def _broadcast_shapes(*args):
1✔
78
    """
79
    Broadcast the input shapes into a single shape;
80
    returns tuple broadcasted shape.
81
    """
82
    array_shapes = [array.shape for array in args]
1✔
83
    return _broadcast_shape_impl(array_shapes)
1✔
84

85

86
def permute_dims(X, /, axes):
1✔
87
    """permute_dims(x, axes)
88

89
    Permute the axes (dimensions) of an array; returns the permuted
90
    array as a view.
91

92
    Args:
93
        x (usm_ndarray): input array.
94
        axes (Tuple[int, ...]): tuple containing permutation of
95
           `(0,1,...,N-1)` where `N` is the number of axes (dimensions)
96
           of `x`.
97
    Returns:
98
        usm_ndarray:
99
            An array with permuted axes.
100
            The returned array must has the same data type as `x`,
101
            is created on the same device as `x` and has the same USM allocation
102
            type as `x`.
103
    """
104
    if not isinstance(X, dpt.usm_ndarray):
1✔
105
        raise TypeError(f"Expected usm_ndarray type, got {type(X)}.")
1✔
106
    axes = normalize_axis_tuple(axes, X.ndim, "axes")
1✔
107
    if not X.ndim == len(axes):
1✔
108
        raise ValueError(
1✔
109
            "The length of the passed axes does not match "
110
            "to the number of usm_ndarray dimensions."
111
        )
112
    newstrides = tuple(X.strides[i] for i in axes)
1✔
113
    newshape = tuple(X.shape[i] for i in axes)
1✔
114
    return dpt.usm_ndarray(
1✔
115
        shape=newshape,
116
        dtype=X.dtype,
117
        buffer=X,
118
        strides=newstrides,
119
        offset=X._element_offset,
120
    )
121

122

123
def expand_dims(X, /, *, axis=0):
1✔
124
    """expand_dims(x, axis)
125

126
    Expands the shape of an array by inserting a new axis (dimension)
127
    of size one at the position specified by axis.
128

129
    Args:
130
        x (usm_ndarray):
131
            input array
132
        axis (Union[int, Tuple[int]]):
133
            axis position in the expanded axes (zero-based). If `x` has rank
134
            (i.e, number of dimensions) `N`, a valid `axis` must reside
135
            in the closed-interval `[-N-1, N]`. If provided a negative
136
            `axis`, the `axis` position at which to insert a singleton
137
            dimension is computed as `N + axis + 1`. Hence, if
138
            provided `-1`, the resolved axis position is `N` (i.e.,
139
            a singleton dimension must be appended to the input array `x`).
140
            If provided `-N-1`, the resolved axis position is `0` (i.e., a
141
            singleton dimension is prepended to the input array `x`).
142

143
    Returns:
144
        usm_ndarray:
145
            Returns a view, if possible, and a copy otherwise with the number
146
            of dimensions increased.
147
            The expanded array has the same data type as the input array `x`.
148
            The expanded array is located on the same device as the input
149
            array, and has the same USM allocation type.
150

151
    Raises:
152
        IndexError: if `axis` value is invalid.
153
    """
154
    if not isinstance(X, dpt.usm_ndarray):
1!
155
        raise TypeError(f"Expected usm_ndarray type, got {type(X)}.")
×
156

157
    if type(axis) not in (tuple, list):
1✔
158
        axis = (axis,)
1✔
159

160
    out_ndim = len(axis) + X.ndim
1✔
161
    axis = normalize_axis_tuple(axis, out_ndim)
1✔
162

163
    shape_it = iter(X.shape)
1✔
164
    shape = tuple(1 if ax in axis else next(shape_it) for ax in range(out_ndim))
1✔
165

166
    return dpt.reshape(X, shape)
1✔
167

168

169
def squeeze(X, /, axis=None):
1✔
170
    """squeeze(x, axis)
171

172
    Removes singleton dimensions (axes) from array `x`.
173

174
    Args:
175
        x (usm_ndarray): input array
176
        axis (Union[int, Tuple[int,...]]): axis (or axes) to squeeze.
177

178
    Returns:
179
        usm_ndarray:
180
            Output array is a view, if possible,
181
            and a copy otherwise, but with all or a subset of the
182
            dimensions of length 1 removed. Output has the same data
183
            type as the input, is allocated on the same device as the
184
            input and has the same USM allocation type as the input
185
            array `x`.
186

187
    Raises:
188
        ValueError: if the specified axis has a size greater than one.
189
    """
190
    if not isinstance(X, dpt.usm_ndarray):
1!
191
        raise TypeError(f"Expected usm_ndarray type, got {type(X)}.")
×
192
    X_shape = X.shape
1✔
193
    if axis is not None:
1✔
194
        axis = normalize_axis_tuple(axis, X.ndim if X.ndim != 0 else X.ndim + 1)
1✔
195
        new_shape = []
1✔
196
        for i, x in enumerate(X_shape):
1✔
197
            if i not in axis:
1✔
198
                new_shape.append(x)
1✔
199
            else:
200
                if x != 1:
1✔
201
                    raise ValueError(
1✔
202
                        "Cannot select an axis to squeeze out "
203
                        "which has size not equal to one."
204
                    )
205
        new_shape = tuple(new_shape)
1✔
206
    else:
207
        new_shape = tuple(axis for axis in X_shape if axis != 1)
1✔
208
    if new_shape == X.shape:
1✔
209
        return X
1✔
210
    else:
211
        return dpt.reshape(X, new_shape)
1✔
212

213

214
def broadcast_to(X, /, shape):
1✔
215
    """broadcast_to(x, shape)
216

217
    Broadcast an array to a new `shape`; returns the broadcasted
218
    :class:`dpctl.tensor.usm_ndarray` as a view.
219

220
    Args:
221
        x (usm_ndarray): input array
222
        shape (Tuple[int,...]): array shape. The `shape` must be
223
            compatible with `x` according to broadcasting rules.
224

225
    Returns:
226
        usm_ndarray:
227
            An array with the specified `shape`.
228
            The output array is a view of the input array, and
229
            hence has the same data type, USM allocation type and
230
            device attributes.
231
    """
232
    if not isinstance(X, dpt.usm_ndarray):
1!
233
        raise TypeError(f"Expected usm_ndarray type, got {type(X)}.")
×
234

235
    # Use numpy.broadcast_to to check the validity of the input
236
    # parameter 'shape'. Raise ValueError if 'X' is not compatible
237
    # with 'shape' according to NumPy's broadcasting rules.
238
    new_array = np.broadcast_to(
1✔
239
        np.broadcast_to(np.empty(tuple(), dtype="u1"), X.shape), shape
240
    )
241
    new_sts = _broadcast_strides(X.shape, X.strides, new_array.ndim)
1✔
242
    return dpt.usm_ndarray(
1✔
243
        shape=new_array.shape,
244
        dtype=X.dtype,
245
        buffer=X,
246
        strides=new_sts,
247
        offset=X._element_offset,
248
    )
249

250

251
def broadcast_arrays(*args):
1✔
252
    """broadcast_arrays(*arrays)
253

254
    Broadcasts one or more :class:`dpctl.tensor.usm_ndarrays` against
255
    one another.
256

257
    Args:
258
        arrays (usm_ndarray): an arbitrary number of arrays to be
259
            broadcasted.
260

261
    Returns:
262
        List[usm_ndarray]:
263
            A list of broadcasted arrays. Each array
264
            must have the same shape. Each array must have the same `dtype`,
265
            `device` and `usm_type` attributes as its corresponding input
266
            array.
267
    """
268
    if len(args) == 0:
1✔
269
        raise ValueError("`broadcast_arrays` requires at least one argument")
1✔
270
    for X in args:
1✔
271
        if not isinstance(X, dpt.usm_ndarray):
1!
272
            raise TypeError(f"Expected usm_ndarray type, got {type(X)}.")
×
273

274
    shape = _broadcast_shapes(*args)
1✔
275

276
    if all(X.shape == shape for X in args):
1✔
277
        return args
1✔
278

279
    return [broadcast_to(X, shape) for X in args]
1✔
280

281

282
def flip(X, /, *, axis=None):
1✔
283
    """flip(x, axis)
284

285
    Reverses the order of elements in an array `x` along the given `axis`.
286
    The shape of the array is preserved, but the elements are reordered.
287

288
    Args:
289
        x (usm_ndarray): input array.
290
        axis (Optional[Union[int, Tuple[int,...]]]): axis (or axes) along
291
            which to flip.
292
            If `axis` is `None`, all input array axes are flipped.
293
            If `axis` is negative, the flipped axis is counted from the
294
            last dimension. If provided more than one axis, only the specified
295
            axes are flipped. Default: `None`.
296

297
    Returns:
298
        usm_ndarray:
299
            A view of `x` with the entries of `axis` reversed.
300
    """
301
    if not isinstance(X, dpt.usm_ndarray):
1!
302
        raise TypeError(f"Expected usm_ndarray type, got {type(X)}.")
×
303
    X_ndim = X.ndim
1✔
304
    if axis is None:
1✔
305
        indexer = (np.s_[::-1],) * X_ndim
1✔
306
    else:
307
        axis = normalize_axis_tuple(axis, X_ndim)
1✔
308
        indexer = tuple(
1✔
309
            np.s_[::-1] if i in axis else np.s_[:] for i in range(X.ndim)
310
        )
311
    return X[indexer]
1✔
312

313

314
def roll(x, /, shift, *, axis=None):
1✔
315
    """
316
    roll(x, shift, axis)
317

318
    Rolls array elements along a specified axis.
319
    Array elements that roll beyond the last position are re-introduced
320
    at the first position. Array elements that roll beyond the first position
321
    are re-introduced at the last position.
322

323
    Args:
324
        x (usm_ndarray): input array
325
        shift (Union[int, Tuple[int,...]]): number of places by which the
326
            elements are shifted. If `shift` is a tuple, then `axis` must be a
327
            tuple of the same size, and each of the given axes must be shifted
328
            by the corresponding element in `shift`. If `shift` is an `int`
329
            and `axis` a tuple, then the same `shift` must be used for all
330
            specified axes. If a `shift` is positive, then array elements is
331
            shifted positively (toward larger indices) along the dimension of
332
            `axis`.
333
            If a `shift` is negative, then array elements must be shifted
334
            negatively (toward smaller indices) along the dimension of `axis`.
335
        axis (Optional[Union[int, Tuple[int,...]]]): axis (or axes) along which
336
            elements to shift. If `axis` is `None`, the array is
337
            flattened, shifted, and then restored to its original shape.
338
            Default: `None`.
339

340
    Returns:
341
        usm_ndarray:
342
            An array having the same `dtype`, `usm_type` and
343
            `device` attributes as `x` and whose elements are shifted relative
344
            to `x`.
345
    """
346
    if not isinstance(x, dpt.usm_ndarray):
1!
347
        raise TypeError(f"Expected usm_ndarray type, got {type(x)}.")
×
348
    exec_q = x.sycl_queue
1✔
349
    _manager = dputils.SequentialOrderManager[exec_q]
1✔
350
    if axis is None:
1✔
351
        shift = operator.index(shift)
1✔
352
        res = dpt.empty(
1✔
353
            x.shape, dtype=x.dtype, usm_type=x.usm_type, sycl_queue=exec_q
354
        )
355
        sz = operator.index(x.size)
1✔
356
        shift = (shift % sz) if sz > 0 else 0
1✔
357
        dep_evs = _manager.submitted_events
1✔
358
        hev, roll_ev = ti._copy_usm_ndarray_for_roll_1d(
1✔
359
            src=x,
360
            dst=res,
361
            shift=shift,
362
            sycl_queue=exec_q,
363
            depends=dep_evs,
364
        )
365
        _manager.add_event_pair(hev, roll_ev)
1✔
366
        return res
1✔
367
    axis = normalize_axis_tuple(axis, x.ndim, allow_duplicate=True)
1✔
368
    broadcasted = np.broadcast(shift, axis)
1✔
369
    if broadcasted.ndim > 1:
1✔
370
        raise ValueError("'shift' and 'axis' should be scalars or 1D sequences")
1✔
371
    shifts = [
1✔
372
        0,
373
    ] * x.ndim
374
    shape = x.shape
1✔
375
    for sh, ax in broadcasted:
1✔
376
        n_i = operator.index(shape[ax])
1✔
377
        shifted = shifts[ax] + operator.index(sh)
1✔
378
        shifts[ax] = (shifted % n_i) if n_i > 0 else 0
1✔
379
    res = dpt.empty(
1✔
380
        x.shape, dtype=x.dtype, usm_type=x.usm_type, sycl_queue=exec_q
381
    )
382
    dep_evs = _manager.submitted_events
1✔
383
    ht_e, roll_ev = ti._copy_usm_ndarray_for_roll_nd(
1✔
384
        src=x, dst=res, shifts=shifts, sycl_queue=exec_q, depends=dep_evs
385
    )
386
    _manager.add_event_pair(ht_e, roll_ev)
1✔
387
    return res
1✔
388

389

390
def _arrays_validation(arrays, check_ndim=True):
1✔
391
    n = len(arrays)
1✔
392
    if n == 0:
1✔
393
        raise TypeError("Missing 1 required positional argument: 'arrays'.")
1✔
394

395
    if not isinstance(arrays, (list, tuple)):
1✔
396
        raise TypeError(f"Expected tuple or list type, got {type(arrays)}.")
1✔
397

398
    for X in arrays:
1✔
399
        if not isinstance(X, dpt.usm_ndarray):
1✔
400
            raise TypeError(f"Expected usm_ndarray type, got {type(X)}.")
1✔
401

402
    exec_q = dputils.get_execution_queue([X.sycl_queue for X in arrays])
1✔
403
    if exec_q is None:
1✔
404
        raise ValueError("All the input arrays must have same sycl queue.")
1✔
405

406
    res_usm_type = dputils.get_coerced_usm_type([X.usm_type for X in arrays])
1✔
407
    if res_usm_type is None:
1!
UNCOV
408
        raise ValueError("All the input arrays must have usm_type.")
×
409

410
    X0 = arrays[0]
1✔
411
    _supported_dtype(Xi.dtype for Xi in arrays)
1✔
412

413
    res_dtype = X0.dtype
1✔
414
    dev = exec_q.sycl_device
1✔
415
    for i in range(1, n):
1✔
416
        res_dtype = np.promote_types(res_dtype, arrays[i])
1✔
417
        res_dtype = _to_device_supported_dtype(res_dtype, dev)
1✔
418

419
    if check_ndim:
1✔
420
        for i in range(1, n):
1✔
421
            if X0.ndim != arrays[i].ndim:
1✔
422
                raise ValueError(
1✔
423
                    "All the input arrays must have same number of dimensions, "
424
                    f"but the array at index 0 has {X0.ndim} dimension(s) and "
425
                    f"the array at index {i} has {arrays[i].ndim} dimension(s)."
426
                )
427
    return res_dtype, res_usm_type, exec_q
1✔
428

429

430
def _check_same_shapes(X0_shape, axis, n, arrays):
1✔
431
    for i in range(1, n):
1✔
432
        Xi_shape = arrays[i].shape
1✔
433
        for j, X0j in enumerate(X0_shape):
1✔
434
            if X0j != Xi_shape[j] and j != axis:
1✔
435
                raise ValueError(
1✔
436
                    "All the input array dimensions for the concatenation "
437
                    f"axis must match exactly, but along dimension {j}, the "
438
                    f"array at index 0 has size {X0j} and the array "
439
                    f"at index {i} has size {Xi_shape[j]}."
440
                )
441

442

443
def _concat_axis_None(arrays):
1✔
444
    "Implementation of concat(arrays, axis=None)."
445
    res_dtype, res_usm_type, exec_q = _arrays_validation(
1✔
446
        arrays, check_ndim=False
447
    )
448
    res_shape = 0
1✔
449
    for array in arrays:
1✔
450
        res_shape += array.size
1✔
451
    res = dpt.empty(
1✔
452
        res_shape, dtype=res_dtype, usm_type=res_usm_type, sycl_queue=exec_q
453
    )
454

455
    fill_start = 0
1✔
456
    _manager = dputils.SequentialOrderManager[exec_q]
1✔
457
    deps = _manager.submitted_events
1✔
458
    for array in arrays:
1✔
459
        fill_end = fill_start + array.size
1✔
460
        if array.flags.c_contiguous:
1✔
461
            hev, cpy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
1✔
462
                src=dpt.reshape(array, -1),
463
                dst=res[fill_start:fill_end],
464
                sycl_queue=exec_q,
465
                depends=deps,
466
            )
467
            _manager.add_event_pair(hev, cpy_ev)
1✔
468
        else:
469
            src_ = array
1✔
470
            # _copy_usm_ndarray_for_reshape requires src and dst to have
471
            # the same data type
472
            if not array.dtype == res_dtype:
1✔
473
                src2_ = dpt.empty_like(src_, dtype=res_dtype)
1✔
474
                ht_copy_ev, cpy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
1✔
475
                    src=src_, dst=src2_, sycl_queue=exec_q, depends=deps
476
                )
477
                _manager.add_event_pair(ht_copy_ev, cpy_ev)
1✔
478
                hev, reshape_copy_ev = ti._copy_usm_ndarray_for_reshape(
1✔
479
                    src=src2_,
480
                    dst=res[fill_start:fill_end],
481
                    sycl_queue=exec_q,
482
                    depends=[cpy_ev],
483
                )
484
                _manager.add_event_pair(hev, reshape_copy_ev)
1✔
485
            else:
486
                hev, cpy_ev = ti._copy_usm_ndarray_for_reshape(
1✔
487
                    src=src_,
488
                    dst=res[fill_start:fill_end],
489
                    sycl_queue=exec_q,
490
                    depends=deps,
491
                )
492
                _manager.add_event_pair(hev, cpy_ev)
1✔
493
        fill_start = fill_end
1✔
494

495
    return res
1✔
496

497

498
def concat(arrays, /, *, axis=0):
1✔
499
    """concat(arrays, axis)
500

501
    Joins a sequence of arrays along an existing axis.
502

503
    Args:
504
        arrays (Union[List[usm_ndarray, Tuple[usm_ndarray,...]]]):
505
            input arrays to join. The arrays must have the same shape,
506
            except in the dimension specified by `axis`.
507
        axis (Optional[int]): axis along which the arrays will be joined.
508
            If `axis` is `None`, arrays must be flattened before
509
            concatenation. If `axis` is negative, it is understood as
510
            being counted from the last dimension. Default: `0`.
511

512
    Returns:
513
        usm_ndarray:
514
            An output array containing the concatenated
515
            values. The output array data type is determined by Type
516
            Promotion Rules of array API.
517

518
    All input arrays must have the same device attribute. The output array
519
    is allocated on that same device, and data movement operations are
520
    scheduled on a queue underlying the device. The USM allocation type
521
    of the output array is determined by USM allocation type promotion
522
    rules.
523
    """
524
    if axis is None:
1✔
525
        return _concat_axis_None(arrays)
1✔
526

527
    res_dtype, res_usm_type, exec_q = _arrays_validation(arrays)
1✔
528
    n = len(arrays)
1✔
529
    X0 = arrays[0]
1✔
530

531
    axis = normalize_axis_index(axis, X0.ndim)
1✔
532
    X0_shape = X0.shape
1✔
533
    _check_same_shapes(X0_shape, axis, n, arrays)
1✔
534

535
    res_shape_axis = 0
1✔
536
    for X in arrays:
1✔
537
        res_shape_axis = res_shape_axis + X.shape[axis]
1✔
538

539
    res_shape = tuple(
1✔
540
        X0_shape[i] if i != axis else res_shape_axis for i in range(X0.ndim)
541
    )
542

543
    res = dpt.empty(
1✔
544
        res_shape, dtype=res_dtype, usm_type=res_usm_type, sycl_queue=exec_q
545
    )
546

547
    _manager = dputils.SequentialOrderManager[exec_q]
1✔
548
    deps = _manager.submitted_events
1✔
549
    fill_start = 0
1✔
550
    for i in range(n):
1✔
551
        fill_end = fill_start + arrays[i].shape[axis]
1✔
552
        c_shapes_copy = tuple(
1✔
553
            np.s_[fill_start:fill_end] if j == axis else np.s_[:]
554
            for j in range(X0.ndim)
555
        )
556
        hev, cpy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
1✔
557
            src=arrays[i],
558
            dst=res[c_shapes_copy],
559
            sycl_queue=exec_q,
560
            depends=deps,
561
        )
562
        _manager.add_event_pair(hev, cpy_ev)
1✔
563
        fill_start = fill_end
1✔
564

565
    return res
1✔
566

567

568
def stack(arrays, /, *, axis=0):
1✔
569
    """
570
    stack(arrays, axis)
571

572
    Joins a sequence of arrays along a new axis.
573

574
    Args:
575
        arrays (Union[List[usm_ndarray], Tuple[usm_ndarray,...]]):
576
            input arrays to join. Each array must have the same shape.
577
        axis (int): axis along which the arrays will be joined. Providing
578
            an `axis` specified the index of the new axis in the dimensions
579
            of the output array. A valid axis must be on the interval
580
            `[-N, N)`, where `N` is the rank (number of dimensions) of `x`.
581
            Default: `0`.
582

583
    Returns:
584
        usm_ndarray:
585
            An output array having rank `N+1`, where `N` is
586
            the rank (number of dimensions) of `x`. If the input arrays have
587
            different data types, array API Type Promotion Rules apply.
588

589
    Raises:
590
        ValueError: if not all input arrays have the same shape
591
        IndexError: if provided an `axis` outside of the required interval.
592
    """
593
    res_dtype, res_usm_type, exec_q = _arrays_validation(arrays)
1✔
594

595
    n = len(arrays)
1✔
596
    X0 = arrays[0]
1✔
597
    res_ndim = X0.ndim + 1
1✔
598
    axis = normalize_axis_index(axis, res_ndim)
1✔
599
    X0_shape = X0.shape
1✔
600

601
    for i in range(1, n):
1✔
602
        if X0_shape != arrays[i].shape:
1✔
603
            raise ValueError("All input arrays must have the same shape")
1✔
604

605
    res_shape = tuple(
1✔
606
        X0_shape[i - 1 * (i >= axis)] if i != axis else n
607
        for i in range(res_ndim)
608
    )
609

610
    res = dpt.empty(
1✔
611
        res_shape, dtype=res_dtype, usm_type=res_usm_type, sycl_queue=exec_q
612
    )
613

614
    _manager = dputils.SequentialOrderManager[exec_q]
1✔
615
    dep_evs = _manager.submitted_events
1✔
616
    for i in range(n):
1✔
617
        c_shapes_copy = tuple(
1✔
618
            i if j == axis else np.s_[:] for j in range(res_ndim)
619
        )
620
        _dst = res[c_shapes_copy]
1✔
621
        hev, cpy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
1✔
622
            src=arrays[i], dst=_dst, sycl_queue=exec_q, depends=dep_evs
623
        )
624
        _manager.add_event_pair(hev, cpy_ev)
1✔
625

626
    return res
1✔
627

628

629
def unstack(X, /, *, axis=0):
1✔
630
    """unstack(x, axis=0)
631

632
    Splits an array in a sequence of arrays along the given axis.
633

634
    Args:
635
        x (usm_ndarray): input array
636

637
        axis (int, optional): axis along which `x` is unstacked.
638
            If `x` has rank (i.e, number of dimensions) `N`,
639
            a valid `axis` must reside in the half-open interval `[-N, N)`.
640
            Default: `0`.
641

642
    Returns:
643
        Tuple[usm_ndarray,...]:
644
            Output sequence of arrays which are views into the input array.
645

646
    Raises:
647
        AxisError: if the `axis` value is invalid.
648
    """
649
    if not isinstance(X, dpt.usm_ndarray):
1!
UNCOV
650
        raise TypeError(f"Expected usm_ndarray type, got {type(X)}.")
×
651

652
    axis = normalize_axis_index(axis, X.ndim)
1✔
653
    Y = dpt.moveaxis(X, axis, 0)
1✔
654

655
    return tuple(Y[i] for i in range(Y.shape[0]))
1✔
656

657

658
def moveaxis(X, source, destination, /):
1✔
659
    """moveaxis(x, source, destination)
660

661
    Moves axes of an array to new positions.
662

663
    Args:
664
        x (usm_ndarray): input array
665

666
        source (int or a sequence of int):
667
            Original positions of the axes to move.
668
            These must be unique. If `x` has rank (i.e., number of
669
            dimensions) `N`, a valid `axis` must be in the
670
            half-open interval `[-N, N)`.
671

672
        destination (int or a sequence of int):
673
            Destination positions for each of the original axes.
674
            These must also be unique. If `x` has rank
675
            (i.e., number of dimensions) `N`, a valid `axis` must be
676
            in the half-open interval `[-N, N)`.
677

678
    Returns:
679
        usm_ndarray:
680
            Array with moved axes.
681
            The returned array must has the same data type as `x`,
682
            is created on the same device as `x` and has the same
683
            USM allocation type as `x`.
684

685
    Raises:
686
        AxisError: if `axis` value is invalid.
687
        ValueError: if `src` and `dst` have not equal number of elements.
688
    """
689
    if not isinstance(X, dpt.usm_ndarray):
1!
UNCOV
690
        raise TypeError(f"Expected usm_ndarray type, got {type(X)}.")
×
691

692
    source = normalize_axis_tuple(source, X.ndim, "source")
1✔
693
    destination = normalize_axis_tuple(destination, X.ndim, "destination")
1✔
694

695
    if len(source) != len(destination):
1✔
696
        raise ValueError(
1✔
697
            "`source` and `destination` arguments must have "
698
            "the same number of elements"
699
        )
700

701
    ind = [n for n in range(X.ndim) if n not in source]
1✔
702

703
    for src, dst in sorted(zip(destination, source)):
1✔
704
        ind.insert(src, dst)
1✔
705

706
    return dpt.permute_dims(X, tuple(ind))
1✔
707

708

709
def swapaxes(X, axis1, axis2):
1✔
710
    """swapaxes(x, axis1, axis2)
711

712
    Interchanges two axes of an array.
713

714
    Args:
715
        x (usm_ndarray): input array
716

717
        axis1 (int): First axis.
718
            If `x` has rank (i.e., number of dimensions) `N`,
719
            a valid `axis` must be in the half-open interval `[-N, N)`.
720

721
        axis2 (int): Second axis.
722
            If `x` has rank (i.e., number of dimensions) `N`,
723
            a valid `axis` must be in the half-open interval `[-N, N)`.
724

725
    Returns:
726
        usm_ndarray:
727
            Array with swapped axes.
728
            The returned array must has the same data type as `x`,
729
            is created on the same device as `x` and has the same USM
730
            allocation type as `x`.
731

732
    Raises:
733
        AxisError: if `axis` value is invalid.
734
    """
735
    if not isinstance(X, dpt.usm_ndarray):
1!
UNCOV
736
        raise TypeError(f"Expected usm_ndarray type, got {type(X)}.")
×
737

738
    axis1 = normalize_axis_index(axis1, X.ndim, "axis1")
1✔
739
    axis2 = normalize_axis_index(axis2, X.ndim, "axis2")
1✔
740

741
    ind = list(range(0, X.ndim))
1✔
742
    ind[axis1] = axis2
1✔
743
    ind[axis2] = axis1
1✔
744
    return dpt.permute_dims(X, tuple(ind))
1✔
745

746

747
def repeat(x, repeats, /, *, axis=None):
1✔
748
    """repeat(x, repeats, axis=None)
749

750
    Repeat elements of an array on a per-element basis.
751

752
    Args:
753
        x (usm_ndarray): input array
754

755
        repeats (Union[int, Sequence[int, ...], usm_ndarray]):
756
            The number of repetitions for each element.
757

758
            `repeats` must be broadcast-compatible with `N` where `N` is
759
            `prod(x.shape)` if `axis` is `None` and `x.shape[axis]`
760
            otherwise.
761

762
            If `repeats` is an array, it must have an integer data type.
763
            Otherwise, `repeats` must be a Python integer or sequence of
764
            Python integers (i.e., a tuple, list, or range).
765

766
        axis (Optional[int]):
767
            The axis along which to repeat values. If `axis` is `None`, the
768
            function repeats elements of the flattened array. Default: `None`.
769

770
    Returns:
771
        usm_ndarray:
772
            output array with repeated elements.
773

774
            If `axis` is `None`, the returned array is one-dimensional,
775
            otherwise, it has the same shape as `x`, except for the axis along
776
            which elements were repeated.
777

778
            The returned array will have the same data type as `x`.
779
            The returned array will be located on the same device as `x` and
780
            have the same USM allocation type as `x`.
781

782
    Raises:
783
        AxisError: if `axis` value is invalid.
784
    """
785
    if not isinstance(x, dpt.usm_ndarray):
1✔
786
        raise TypeError(f"Expected usm_ndarray type, got {type(x)}.")
1✔
787

788
    x_ndim = x.ndim
1✔
789
    x_shape = x.shape
1✔
790
    if axis is not None:
1✔
791
        axis = normalize_axis_index(operator.index(axis), x_ndim)
1✔
792
        axis_size = x_shape[axis]
1✔
793
    else:
794
        axis_size = x.size
1✔
795

796
    scalar = False
1✔
797
    if isinstance(repeats, int):
1✔
798
        if repeats < 0:
1✔
799
            raise ValueError("`repeats` must be a positive integer")
1✔
800
        usm_type = x.usm_type
1✔
801
        exec_q = x.sycl_queue
1✔
802
        scalar = True
1✔
803
    elif isinstance(repeats, dpt.usm_ndarray):
1✔
804
        if repeats.ndim > 1:
1✔
805
            raise ValueError(
1✔
806
                "`repeats` array must be 0- or 1-dimensional, got "
807
                f"{repeats.ndim}"
808
            )
809
        exec_q = dpctl.utils.get_execution_queue(
1✔
810
            (x.sycl_queue, repeats.sycl_queue)
811
        )
812
        if exec_q is None:
1✔
813
            raise dputils.ExecutionPlacementError(
1✔
814
                "Execution placement can not be unambiguously inferred "
815
                "from input arguments."
816
            )
817
        usm_type = dpctl.utils.get_coerced_usm_type(
1✔
818
            (
819
                x.usm_type,
820
                repeats.usm_type,
821
            )
822
        )
823
        dpctl.utils.validate_usm_type(usm_type, allow_none=False)
1✔
824
        if not dpt.can_cast(repeats.dtype, dpt.int64, casting="same_kind"):
1✔
825
            raise TypeError(
1✔
826
                f"'repeats' data type {repeats.dtype} cannot be cast to "
827
                "'int64' according to the casting rule ''safe.''"
828
            )
829
        if repeats.size == 1:
1✔
830
            scalar = True
1✔
831
            # bring the single element to the host
832
            repeats = int(repeats)
1✔
833
            if repeats < 0:
1✔
834
                raise ValueError("`repeats` elements must be positive")
1✔
835
        else:
836
            if repeats.size != axis_size:
1✔
837
                raise ValueError(
1✔
838
                    "'repeats' array must be broadcastable to the size of "
839
                    "the repeated axis"
840
                )
841
            if not dpt.all(repeats >= 0):
1✔
842
                raise ValueError("'repeats' elements must be positive")
1✔
843

844
    elif isinstance(repeats, (tuple, list, range)):
1✔
845
        usm_type = x.usm_type
1✔
846
        exec_q = x.sycl_queue
1✔
847

848
        len_reps = len(repeats)
1✔
849
        if len_reps == 1:
1✔
850
            repeats = repeats[0]
1✔
851
            if repeats < 0:
1!
852
                raise ValueError("`repeats` elements must be positive")
1✔
UNCOV
853
            scalar = True
×
854
        else:
855
            if len_reps != axis_size:
1✔
856
                raise ValueError(
1✔
857
                    "`repeats` sequence must have the same length as the "
858
                    "repeated axis"
859
                )
860
            repeats = dpt.asarray(
1✔
861
                repeats, dtype=dpt.int64, usm_type=usm_type, sycl_queue=exec_q
862
            )
863
            if not dpt.all(repeats >= 0):
1!
UNCOV
864
                raise ValueError("`repeats` elements must be positive")
×
865
    else:
866
        raise TypeError(
1✔
867
            "Expected int, sequence, or `usm_ndarray` for second argument,"
868
            f"got {type(repeats)}"
869
        )
870

871
    _manager = dputils.SequentialOrderManager[exec_q]
1✔
872
    dep_evs = _manager.submitted_events
1✔
873
    if scalar:
1✔
874
        res_axis_size = repeats * axis_size
1✔
875
        if axis is not None:
1✔
876
            res_shape = x_shape[:axis] + (res_axis_size,) + x_shape[axis + 1 :]
1✔
877
        else:
878
            res_shape = (res_axis_size,)
1✔
879
        res = dpt.empty(
1✔
880
            res_shape, dtype=x.dtype, usm_type=usm_type, sycl_queue=exec_q
881
        )
882
        if res_axis_size > 0:
1✔
883
            ht_rep_ev, rep_ev = ti._repeat_by_scalar(
1✔
884
                src=x,
885
                dst=res,
886
                reps=repeats,
887
                axis=axis,
888
                sycl_queue=exec_q,
889
                depends=dep_evs,
890
            )
891
            _manager.add_event_pair(ht_rep_ev, rep_ev)
1✔
892
    else:
893
        if repeats.dtype != dpt.int64:
1✔
894
            rep_buf = dpt.empty(
1✔
895
                repeats.shape,
896
                dtype=dpt.int64,
897
                usm_type=usm_type,
898
                sycl_queue=exec_q,
899
            )
900
            ht_copy_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
1✔
901
                src=repeats, dst=rep_buf, sycl_queue=exec_q, depends=dep_evs
902
            )
903
            _manager.add_event_pair(ht_copy_ev, copy_ev)
1✔
904
            cumsum = dpt.empty(
1✔
905
                (axis_size,),
906
                dtype=dpt.int64,
907
                usm_type=usm_type,
908
                sycl_queue=exec_q,
909
            )
910
            # _cumsum_1d synchronizes so `depends` ends here safely
911
            res_axis_size = ti._cumsum_1d(
1✔
912
                rep_buf, cumsum, sycl_queue=exec_q, depends=[copy_ev]
913
            )
914
            if axis is not None:
1✔
915
                res_shape = (
1✔
916
                    x_shape[:axis] + (res_axis_size,) + x_shape[axis + 1 :]
917
                )
918
            else:
919
                res_shape = (res_axis_size,)
1✔
920
            res = dpt.empty(
1✔
921
                res_shape,
922
                dtype=x.dtype,
923
                usm_type=usm_type,
924
                sycl_queue=exec_q,
925
            )
926
            if res_axis_size > 0:
1!
927
                ht_rep_ev, rep_ev = ti._repeat_by_sequence(
1✔
928
                    src=x,
929
                    dst=res,
930
                    reps=rep_buf,
931
                    cumsum=cumsum,
932
                    axis=axis,
933
                    sycl_queue=exec_q,
934
                )
935
                _manager.add_event_pair(ht_rep_ev, rep_ev)
1✔
936
        else:
937
            cumsum = dpt.empty(
1✔
938
                (axis_size,),
939
                dtype=dpt.int64,
940
                usm_type=usm_type,
941
                sycl_queue=exec_q,
942
            )
943
            res_axis_size = ti._cumsum_1d(
1✔
944
                repeats, cumsum, sycl_queue=exec_q, depends=dep_evs
945
            )
946
            if axis is not None:
1✔
947
                res_shape = (
1✔
948
                    x_shape[:axis] + (res_axis_size,) + x_shape[axis + 1 :]
949
                )
950
            else:
951
                res_shape = (res_axis_size,)
1✔
952
            res = dpt.empty(
1✔
953
                res_shape,
954
                dtype=x.dtype,
955
                usm_type=usm_type,
956
                sycl_queue=exec_q,
957
            )
958
            if res_axis_size > 0:
1✔
959
                ht_rep_ev, rep_ev = ti._repeat_by_sequence(
1✔
960
                    src=x,
961
                    dst=res,
962
                    reps=repeats,
963
                    cumsum=cumsum,
964
                    axis=axis,
965
                    sycl_queue=exec_q,
966
                )
967
                _manager.add_event_pair(ht_rep_ev, rep_ev)
1✔
968
    return res
1✔
969

970

971
def tile(x, repetitions, /):
1✔
972
    """tile(x, repetitions)
973

974
    Repeat an input array `x` along each axis a number of times given by
975
    `repetitions`.
976

977
    For `N` = len(`repetitions`) and `M` = len(`x.shape`):
978

979
        * If `M < N`, `x` will have `N - M` new axes prepended to its shape
980
        * If `M > N`, `repetitions` will have `M - N` ones prepended to it
981

982
    Args:
983
        x (usm_ndarray): input array
984

985
        repetitions (Union[int, Tuple[int, ...]]):
986
            The number of repetitions along each dimension of `x`.
987

988
    Returns:
989
        usm_ndarray:
990
            tiled output array.
991

992
            The returned array will have rank `max(M, N)`. If `S` is the
993
            shape of `x` after prepending dimensions and `R` is
994
            `repetitions` after prepending ones, then the shape of the
995
            result will be `S[i] * R[i]` for each dimension `i`.
996

997
            The returned array will have the same data type as `x`.
998
            The returned array will be located on the same device as `x` and
999
            have the same USM allocation type as `x`.
1000
    """
1001
    if not isinstance(x, dpt.usm_ndarray):
1✔
1002
        raise TypeError(f"Expected usm_ndarray type, got {type(x)}.")
1✔
1003

1004
    if not isinstance(repetitions, tuple):
1✔
1005
        if isinstance(repetitions, int):
1✔
1006
            repetitions = (repetitions,)
1✔
1007
        else:
1008
            raise TypeError(
1✔
1009
                f"Expected tuple or integer type, got {type(repetitions)}."
1010
            )
1011

1012
    rep_dims = len(repetitions)
1✔
1013
    x_dims = x.ndim
1✔
1014
    if rep_dims < x_dims:
1✔
1015
        repetitions = (x_dims - rep_dims) * (1,) + repetitions
1✔
1016
    elif x_dims < rep_dims:
1✔
1017
        x = dpt.reshape(x, (rep_dims - x_dims) * (1,) + x.shape)
1✔
1018
    res_shape = tuple(map(lambda sh, rep: sh * rep, x.shape, repetitions))
1✔
1019
    # case of empty input
1020
    if x.size == 0:
1✔
1021
        return dpt.empty(
1✔
1022
            res_shape,
1023
            dtype=x.dtype,
1024
            usm_type=x.usm_type,
1025
            sycl_queue=x.sycl_queue,
1026
        )
1027
    in_sh = x.shape
1✔
1028
    if res_shape == in_sh:
1✔
1029
        return dpt.copy(x)
1✔
1030
    expanded_sh = []
1✔
1031
    broadcast_sh = []
1✔
1032
    out_sz = 1
1✔
1033
    for i in range(len(res_shape)):
1✔
1034
        out_sz *= res_shape[i]
1✔
1035
        reps, sh = repetitions[i], in_sh[i]
1✔
1036
        if reps == 1:
1✔
1037
            # dimension will be unchanged
1038
            broadcast_sh.append(sh)
1✔
1039
            expanded_sh.append(sh)
1✔
1040
        elif sh == 1:
1✔
1041
            # dimension will be broadcast
1042
            broadcast_sh.append(reps)
1✔
1043
            expanded_sh.append(sh)
1✔
1044
        else:
1045
            broadcast_sh.extend([reps, sh])
1✔
1046
            expanded_sh.extend([1, sh])
1✔
1047
    exec_q = x.sycl_queue
1✔
1048
    xdt = x.dtype
1✔
1049
    xut = x.usm_type
1✔
1050
    res = dpt.empty((out_sz,), dtype=xdt, usm_type=xut, sycl_queue=exec_q)
1✔
1051
    # no need to copy data for empty output
1052
    if out_sz > 0:
1✔
1053
        x = dpt.broadcast_to(
1✔
1054
            # this reshape should never copy
1055
            dpt.reshape(x, expanded_sh),
1056
            broadcast_sh,
1057
        )
1058
        # copy broadcast input into flat array
1059
        _manager = dputils.SequentialOrderManager[exec_q]
1✔
1060
        dep_evs = _manager.submitted_events
1✔
1061
        hev, cp_ev = ti._copy_usm_ndarray_for_reshape(
1✔
1062
            src=x, dst=res, sycl_queue=exec_q, depends=dep_evs
1063
        )
1064
        _manager.add_event_pair(hev, cp_ev)
1✔
1065
    return dpt.reshape(res, res_shape)
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc